├── .env ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── robot_by_qkr_v1.iml └── workspace.xml ├── README.md ├── VITS ├── VC_inference.py ├── __pycache__ │ ├── VC_inference.cpython-310.pyc │ ├── attentions.cpython-310.pyc │ ├── commons.cpython-310.pyc │ ├── mel_processing.cpython-310.pyc │ ├── models.cpython-310.pyc │ ├── modules.cpython-310.pyc │ ├── transforms.cpython-310.pyc │ └── utils.cpython-310.pyc ├── attentions.py ├── commons.py ├── configs │ ├── modified_finetune_speaker.json │ └── uma_trilingual.json ├── mel_processing.py ├── models.py ├── models_infer.py ├── modules.py ├── monotonic_align │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ ├── build │ │ ├── lib.win-amd64-3.10 │ │ │ └── monotonic_align │ │ │ │ └── core.cp310-win_amd64.pyd │ │ └── temp.win-amd64-3.10 │ │ │ └── Release │ │ │ ├── core.cp310-win_amd64.exp │ │ │ ├── core.cp310-win_amd64.lib │ │ │ └── core.obj │ ├── core.c │ ├── core.pyx │ ├── monotonic_align │ │ └── core.cp310-win_amd64.pyd │ └── setup.py ├── text │ ├── LICENSE │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── cleaners.cpython-310.pyc │ │ ├── english.cpython-310.pyc │ │ ├── japanese.cpython-310.pyc │ │ ├── korean.cpython-310.pyc │ │ ├── mandarin.cpython-310.pyc │ │ ├── sanskrit.cpython-310.pyc │ │ ├── symbols.cpython-310.pyc │ │ └── thai.cpython-310.pyc │ ├── cantonese.py │ ├── cleaners.py │ ├── english.py │ ├── japanese.py │ ├── korean.py │ ├── mandarin.py │ ├── ngu_dialect.py │ ├── sanskrit.py │ ├── shanghainese.py │ ├── symbols.py │ └── thai.py ├── transforms.py └── utils.py ├── baiduApi.py ├── gptApi.py ├── gptApi_duolun.py ├── input └── input.wav ├── output └── output.wav ├── requirements.txt └── robot.py /.env: -------------------------------------------------------------------------------- 1 | api_id=**** 2 | api_key=***** 3 | api_secert=***** 4 | openai_api=***** -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/robot_by_qkr_v1.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 16 | 17 | 18 | 20 | 21 | 22 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 64 | 65 | 66 | 86 | 87 | 88 | 108 | 109 | 110 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 1679894164177 143 | 168 | 169 | 170 | 171 | 173 | 174 | 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 关于项目 2 | 这是一个python制作的语音对话机器人,可以与语音输入与语音输出,来与chatgpt对话 3 | 4 | ### 使用方法 5 | * 安装依赖库 `pip intsall -r requirement.txt` 6 | * 下载vits训练的模型及其配置文件并放置到VITS目录下,这里我提供我我自己训练的阿米娅模型,仅供学习交流。 7 | * https://pan.baidu.com/s/1OIqAHzItdxgYkAma942VvQ?pwd=fuou 8 | * 申请百度云api和openai-api,并填写在.env目录下 9 | * 入口是robot,愉快的对话吧 10 | 11 | ps:反代服务器到期了QAQ,要用的话还是直接挂梯子吧,修改api文件的url 12 | -------------------------------------------------------------------------------- /VITS/VC_inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import no_grad, LongTensor 4 | import argparse 5 | from VITS import commons#自建文件 6 | #from mel_processing import spectrogram_torch#自建文件 7 | from VITS import utils#自建文件 8 | from VITS.models import SynthesizerTrn#自建文件 9 | import librosa 10 | from playsound import playsound 11 | import logging 12 | 13 | 14 | from VITS.text import text_to_sequence, _clean_text 15 | 16 | from scipy.io.wavfile import write 17 | 18 | def to_wav(audio,path): 19 | #将ndarray转换为int16数据类型(必须为16位有符号整数) 20 | # data = audio * 32767 21 | # data = data.astype(np.int16) 22 | data=audio 23 | # 设置采样率 24 | sample_rate = 24100 25 | 26 | # 将数据写入.wav文件 27 | write(path, sample_rate, data) 28 | 29 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 30 | language_marks = { 31 | "Japanese": "", 32 | "日本語": "[JA]", 33 | "简体中文": "[ZH]", 34 | "English": "[EN]", 35 | "Mix": "", 36 | } 37 | lang = ['日本語', '简体中文', 'English', 'Mix'] 38 | def get_text(text, hps, is_symbol): 39 | text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) 40 | if hps.data.add_blank: 41 | text_norm = commons.intersperse(text_norm, 0) 42 | text_norm = LongTensor(text_norm) 43 | return text_norm 44 | 45 | def create_tts_fn(model, hps, speaker_ids): 46 | def tts_fn(text, speaker, language, speed,path): 47 | if language is not None: 48 | text = language_marks[language] + text + language_marks[language]#选择推理语言 49 | speaker_id = speaker_ids[speaker]#选择说话的人的id 50 | stn_tst = get_text(text, hps, False)#获取文本并选择模型文件(hps) 51 | with no_grad(): 52 | x_tst = stn_tst.unsqueeze(0).to(device) 53 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device) 54 | sid = LongTensor([speaker_id]).to(device) 55 | #src="http://127.0.0.1:7860/file=C:\Users\15093289086\AppData\Local\Temp\tmpjz4wxooq.wav" 56 | audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, 57 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()# 58 | ##print(type(audio)) 59 | to_wav(audio,path) 60 | 61 | del stn_tst, x_tst, x_tst_lengths, sid 62 | ##print("Success", (hps.data.sampling_rate, audio)) 63 | return "Success", (hps.data.sampling_rate, audio) 64 | 65 | return tts_fn 66 | 67 | # def create_vc_fn(model, hps, speaker_ids): 68 | # def vc_fn(original_speaker, target_speaker, record_audio, upload_audio): 69 | # input_audio = record_audio if record_audio is not None else upload_audio 70 | # if input_audio is None: 71 | # return "You need to record or upload an audio", None 72 | # sampling_rate, audio = input_audio 73 | # original_speaker_id = speaker_ids[original_speaker] 74 | # target_speaker_id = speaker_ids[target_speaker] 75 | # 76 | # audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32) 77 | # if len(audio.shape) > 1: 78 | # audio = librosa.to_mono(audio.transpose(1, 0)) 79 | # if sampling_rate != hps.data.sampling_rate: 80 | # audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate) 81 | # with no_grad(): 82 | # y = torch.FloatTensor(audio) 83 | # y = y / max(-y.min(), y.max()) / 0.99 84 | # y = y.to(device) 85 | # y = y.unsqueeze(0) 86 | # spec = spectrogram_torch(y, hps.data.filter_length, 87 | # hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 88 | # center=False).to(device) 89 | # spec_lengths = LongTensor([spec.size(-1)]).to(device) 90 | # sid_src = LongTensor([original_speaker_id]).to(device) 91 | # sid_tgt = LongTensor([target_speaker_id]).to(device) 92 | # audio = model.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][ 93 | # 0, 0].data.cpu().float().numpy() 94 | # del y, spec, spec_lengths, sid_src, sid_tgt 95 | # return "Success", (hps.data.sampling_rate, audio) 96 | # 97 | # return vc_fn 98 | 99 | 100 | 101 | def getVoice(text,path): 102 | logging.getLogger('jieba').disabled = True#禁用日志输出 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("--model_dir", default="VITS/G_latest.pth", help="directory to your fine-tuned model") 105 | parser.add_argument("--config_dir", default="VITS/finetune_speaker.json", help="directory to your model config file") 106 | parser.add_argument("--share", default=False, help="make link public (used in colab)") 107 | 108 | args = parser.parse_args() 109 | hps = utils.get_hparams_from_file(args.config_dir) 110 | 111 | net_g = SynthesizerTrn( 112 | len(hps.symbols), 113 | hps.data.filter_length // 2 + 1, 114 | hps.train.segment_size // hps.data.hop_length, 115 | n_speakers=hps.data.n_speakers, 116 | **hps.model).to(device) 117 | _ = net_g.eval() 118 | 119 | _ = utils.load_checkpoint(args.model_dir, net_g, None) 120 | speaker_ids = hps.speakers 121 | speakers = list(hps.speakers.keys()) 122 | tts_fn = create_tts_fn(net_g, hps, speaker_ids) 123 | # vc_fn = create_vc_fn(net_g, hps, speaker_ids) 124 | 125 | # 本地运行,不挂web 126 | textbox = text 127 | char_dropdown = speakers[0] 128 | language_dropdown = lang[1] 129 | duration_slider = 0.8 130 | 131 | tts_fn(textbox, char_dropdown, language_dropdown, duration_slider,path) 132 | 133 | #print('okk...') 134 | # playsound('../output/output.wav') 135 | 136 | if __name__ == "__main__": 137 | 138 | #logging.getLogger().setLevel(logging.ERROR) 139 | logging.getLogger('jieba').disabled = True 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("--model_dir", default="./G_latest.pth", help="directory to your fine-tuned model") 142 | parser.add_argument("--config_dir", default="./finetune_speaker.json", help="directory to your model config file") 143 | parser.add_argument("--share", default=False, help="make link public (used in colab)") 144 | 145 | args = parser.parse_args() 146 | hps = utils.get_hparams_from_file(args.config_dir) 147 | 148 | 149 | net_g = SynthesizerTrn( 150 | len(hps.symbols), 151 | hps.data.filter_length // 2 + 1, 152 | hps.train.segment_size // hps.data.hop_length, 153 | n_speakers=hps.data.n_speakers, 154 | **hps.model).to(device) 155 | _ = net_g.eval() 156 | 157 | _ = utils.load_checkpoint(args.model_dir, net_g, None) 158 | speaker_ids = hps.speakers 159 | speakers = list(hps.speakers.keys()) 160 | tts_fn = create_tts_fn(net_g, hps, speaker_ids) 161 | #vc_fn = create_vc_fn(net_g, hps, speaker_ids) 162 | 163 | #本地运行,不挂web 164 | textbox = "你好,博士。" 165 | char_dropdown = speakers[0] 166 | language_dropdown = lang[1] 167 | duration_slider = 0.8 168 | 169 | tts_fn(textbox, char_dropdown, language_dropdown, duration_slider, './output.wav') 170 | 171 | print('okk...') 172 | playsound('../output/output.wav') 173 | 174 | 175 | # app = gr.Blocks() 176 | # with app: 177 | # with gr.Tab("Text-to-Speech"): 178 | # with gr.Row(): 179 | # with gr.Column(): 180 | # textbox = gr.TextArea(label="Text", 181 | # placeholder="Type your sentence here", 182 | # value="博士,很高兴见到你,欢迎回来。", elem_id=f"tts-input") 183 | # # select character 184 | # char_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label='character') 185 | # #print(char_dropdown) 186 | # language_dropdown = gr.Dropdown(choices=lang, value=lang[1], label='language') 187 | # #print(language_dropdown) 188 | # duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1, 189 | # label='速度 Speed') 190 | # with gr.Column(): 191 | # text_output = gr.Textbox(label="Message") 192 | # audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio") 193 | # btn = gr.Button("Generate!") 194 | # #res=tts_fn(textbox, char_dropdown, language_dropdown, duration_slider) 195 | # ##print(res) 196 | # btn.click(tts_fn,#发包函数 197 | # inputs=[textbox, char_dropdown, language_dropdown, duration_slider,], 198 | # outputs=[text_output, audio_output])#这是绑定了函数 199 | # 200 | # #可以写成tts_fn(textbox, char_dropdown, language_dropdown, duration_slider) 201 | # 202 | # with gr.Tab("Voice Conversion"):#这个是声音转声音的 203 | # gr.Markdown(""" 204 | # 录制或上传声音,并选择要转换的音色。User代表的音色是你自己。 205 | # """) 206 | # with gr.Column(): 207 | # record_audio = gr.Audio(label="record your voice", source="microphone") 208 | # upload_audio = gr.Audio(label="or upload audio here", source="upload") 209 | # source_speaker = gr.Dropdown(choices=speakers, value="User", label="source speaker") 210 | # target_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="target speaker") 211 | # with gr.Column(): 212 | # message_box = gr.Textbox(label="Message") 213 | # converted_audio = gr.Audio(label='converted audio') 214 | # btn = gr.Button("Convert!") 215 | # btn.click(vc_fn, inputs=[source_speaker, target_speaker, record_audio, upload_audio], 216 | # outputs=[message_box, converted_audio]) 217 | # webbrowser.open("http://127.0.0.1:7860") 218 | # app.launch(share=args.share) 219 | # 220 | -------------------------------------------------------------------------------- /VITS/__pycache__/VC_inference.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/VC_inference.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/__pycache__/attentions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/attentions.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/__pycache__/commons.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/commons.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/__pycache__/mel_processing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/mel_processing.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from VITS import commons 9 | #from VITS import modules 10 | from VITS.modules import LayerNorm 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): 15 | super().__init__() 16 | self.hidden_channels = hidden_channels 17 | self.filter_channels = filter_channels 18 | self.n_heads = n_heads 19 | self.n_layers = n_layers 20 | self.kernel_size = kernel_size 21 | self.p_dropout = p_dropout 22 | self.window_size = window_size 23 | 24 | self.drop = nn.Dropout(p_dropout) 25 | self.attn_layers = nn.ModuleList() 26 | self.norm_layers_1 = nn.ModuleList() 27 | self.ffn_layers = nn.ModuleList() 28 | self.norm_layers_2 = nn.ModuleList() 29 | for i in range(self.n_layers): 30 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) 31 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 32 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 33 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 34 | 35 | def forward(self, x, x_mask): 36 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 37 | x = x * x_mask 38 | for i in range(self.n_layers): 39 | y = self.attn_layers[i](x, x, attn_mask) 40 | y = self.drop(y) 41 | x = self.norm_layers_1[i](x + y) 42 | 43 | y = self.ffn_layers[i](x, x_mask) 44 | y = self.drop(y) 45 | x = self.norm_layers_2[i](x + y) 46 | x = x * x_mask 47 | return x 48 | 49 | 50 | class Decoder(nn.Module): 51 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): 52 | super().__init__() 53 | self.hidden_channels = hidden_channels 54 | self.filter_channels = filter_channels 55 | self.n_heads = n_heads 56 | self.n_layers = n_layers 57 | self.kernel_size = kernel_size 58 | self.p_dropout = p_dropout 59 | self.proximal_bias = proximal_bias 60 | self.proximal_init = proximal_init 61 | 62 | self.drop = nn.Dropout(p_dropout) 63 | self.self_attn_layers = nn.ModuleList() 64 | self.norm_layers_0 = nn.ModuleList() 65 | self.encdec_attn_layers = nn.ModuleList() 66 | self.norm_layers_1 = nn.ModuleList() 67 | self.ffn_layers = nn.ModuleList() 68 | self.norm_layers_2 = nn.ModuleList() 69 | for i in range(self.n_layers): 70 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) 71 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 72 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 73 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 74 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 75 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 76 | 77 | def forward(self, x, x_mask, h, h_mask): 78 | """ 79 | x: decoder input 80 | h: encoder output 81 | """ 82 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 83 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 84 | x = x * x_mask 85 | for i in range(self.n_layers): 86 | y = self.self_attn_layers[i](x, x, self_attn_mask) 87 | y = self.drop(y) 88 | x = self.norm_layers_0[i](x + y) 89 | 90 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 91 | y = self.drop(y) 92 | x = self.norm_layers_1[i](x + y) 93 | 94 | y = self.ffn_layers[i](x, x_mask) 95 | y = self.drop(y) 96 | x = self.norm_layers_2[i](x + y) 97 | x = x * x_mask 98 | return x 99 | 100 | 101 | class MultiHeadAttention(nn.Module): 102 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 103 | super().__init__() 104 | assert channels % n_heads == 0 105 | 106 | self.channels = channels 107 | self.out_channels = out_channels 108 | self.n_heads = n_heads 109 | self.p_dropout = p_dropout 110 | self.window_size = window_size 111 | self.heads_share = heads_share 112 | self.block_length = block_length 113 | self.proximal_bias = proximal_bias 114 | self.proximal_init = proximal_init 115 | self.attn = None 116 | 117 | self.k_channels = channels // n_heads 118 | self.conv_q = nn.Conv1d(channels, channels, 1) 119 | self.conv_k = nn.Conv1d(channels, channels, 1) 120 | self.conv_v = nn.Conv1d(channels, channels, 1) 121 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 122 | self.drop = nn.Dropout(p_dropout) 123 | 124 | if window_size is not None: 125 | n_heads_rel = 1 if heads_share else n_heads 126 | rel_stddev = self.k_channels**-0.5 127 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 128 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 129 | 130 | nn.init.xavier_uniform_(self.conv_q.weight) 131 | nn.init.xavier_uniform_(self.conv_k.weight) 132 | nn.init.xavier_uniform_(self.conv_v.weight) 133 | if proximal_init: 134 | with torch.no_grad(): 135 | self.conv_k.weight.copy_(self.conv_q.weight) 136 | self.conv_k.bias.copy_(self.conv_q.bias) 137 | 138 | def forward(self, x, c, attn_mask=None): 139 | q = self.conv_q(x) 140 | k = self.conv_k(c) 141 | v = self.conv_v(c) 142 | 143 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 144 | 145 | x = self.conv_o(x) 146 | return x 147 | 148 | def attention(self, query, key, value, mask=None): 149 | # reshape [b, d, t] -> [b, n_h, t, d_k] 150 | b, d, t_s, t_t = (*key.size(), query.size(2)) 151 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 152 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 153 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 154 | 155 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 156 | if self.window_size is not None: 157 | assert t_s == t_t, "Relative attention is only available for self-attention." 158 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 159 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) 160 | scores_local = self._relative_position_to_absolute_position(rel_logits) 161 | scores = scores + scores_local 162 | if self.proximal_bias: 163 | assert t_s == t_t, "Proximal bias is only available for self-attention." 164 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 165 | if mask is not None: 166 | scores = scores.masked_fill(mask == 0, -1e4) 167 | if self.block_length is not None: 168 | assert t_s == t_t, "Local attention is only available for self-attention." 169 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 170 | scores = scores.masked_fill(block_mask == 0, -1e4) 171 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 172 | p_attn = self.drop(p_attn) 173 | output = torch.matmul(p_attn, value) 174 | if self.window_size is not None: 175 | relative_weights = self._absolute_position_to_relative_position(p_attn) 176 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 177 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 178 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 179 | return output, p_attn 180 | 181 | def _matmul_with_relative_values(self, x, y): 182 | """ 183 | x: [b, h, l, m] 184 | y: [h or 1, m, d] 185 | ret: [b, h, l, d] 186 | """ 187 | ret = torch.matmul(x, y.unsqueeze(0)) 188 | return ret 189 | 190 | def _matmul_with_relative_keys(self, x, y): 191 | """ 192 | x: [b, h, l, d] 193 | y: [h or 1, m, d] 194 | ret: [b, h, l, m] 195 | """ 196 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 197 | return ret 198 | 199 | def _get_relative_embeddings(self, relative_embeddings, length): 200 | max_relative_position = 2 * self.window_size + 1 201 | # Pad first before slice to avoid using cond ops. 202 | pad_length = max(length - (self.window_size + 1), 0) 203 | slice_start_position = max((self.window_size + 1) - length, 0) 204 | slice_end_position = slice_start_position + 2 * length - 1 205 | if pad_length > 0: 206 | padded_relative_embeddings = F.pad( 207 | relative_embeddings, 208 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 209 | else: 210 | padded_relative_embeddings = relative_embeddings 211 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] 212 | return used_relative_embeddings 213 | 214 | def _relative_position_to_absolute_position(self, x): 215 | """ 216 | x: [b, h, l, 2*l-1] 217 | ret: [b, h, l, l] 218 | """ 219 | batch, heads, length, _ = x.size() 220 | # Concat columns of pad to shift from relative to absolute indexing. 221 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 222 | 223 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 224 | x_flat = x.view([batch, heads, length * 2 * length]) 225 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) 226 | 227 | # Reshape and slice out the padded elements. 228 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 229 | return x_final 230 | 231 | def _absolute_position_to_relative_position(self, x): 232 | """ 233 | x: [b, h, l, l] 234 | ret: [b, h, l, 2*l-1] 235 | """ 236 | batch, heads, length, _ = x.size() 237 | # padd along column 238 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 239 | x_flat = x.view([batch, heads, length**2 + length*(length -1)]) 240 | # add 0's in the beginning that will skew the elements after reshape 241 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 242 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 243 | return x_final 244 | 245 | def _attention_bias_proximal(self, length): 246 | """Bias for self-attention to encourage attention to close positions. 247 | Args: 248 | length: an integer scalar. 249 | Returns: 250 | a Tensor with shape [1, 1, length, length] 251 | """ 252 | r = torch.arange(length, dtype=torch.float32) 253 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 254 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 255 | 256 | 257 | class FFN(nn.Module): 258 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): 259 | super().__init__() 260 | self.in_channels = in_channels 261 | self.out_channels = out_channels 262 | self.filter_channels = filter_channels 263 | self.kernel_size = kernel_size 264 | self.p_dropout = p_dropout 265 | self.activation = activation 266 | self.causal = causal 267 | 268 | if causal: 269 | self.padding = self._causal_padding 270 | else: 271 | self.padding = self._same_padding 272 | 273 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 274 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 275 | self.drop = nn.Dropout(p_dropout) 276 | 277 | def forward(self, x, x_mask): 278 | x = self.conv_1(self.padding(x * x_mask)) 279 | if self.activation == "gelu": 280 | x = x * torch.sigmoid(1.702 * x) 281 | else: 282 | x = torch.relu(x) 283 | x = self.drop(x) 284 | x = self.conv_2(self.padding(x * x_mask)) 285 | return x * x_mask 286 | 287 | def _causal_padding(self, x): 288 | if self.kernel_size == 1: 289 | return x 290 | pad_l = self.kernel_size - 1 291 | pad_r = 0 292 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 293 | x = F.pad(x, commons.convert_pad_shape(padding)) 294 | return x 295 | 296 | def _same_padding(self, x): 297 | if self.kernel_size == 1: 298 | return x 299 | pad_l = (self.kernel_size - 1) // 2 300 | pad_r = self.kernel_size // 2 301 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 302 | x = F.pad(x, commons.convert_pad_shape(padding)) 303 | return x 304 | -------------------------------------------------------------------------------- /VITS/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | try: 54 | ret[i] = x[i, :, idx_str:idx_end] 55 | except RuntimeError: 56 | print("?") 57 | return ret 58 | 59 | 60 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 61 | b, d, t = x.size() 62 | if x_lengths is None: 63 | x_lengths = t 64 | ids_str_max = x_lengths - segment_size + 1 65 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 66 | ret = slice_segments(x, ids_str, segment_size) 67 | return ret, ids_str 68 | 69 | 70 | def get_timing_signal_1d( 71 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 72 | position = torch.arange(length, dtype=torch.float) 73 | num_timescales = channels // 2 74 | log_timescale_increment = ( 75 | math.log(float(max_timescale) / float(min_timescale)) / 76 | (num_timescales - 1)) 77 | inv_timescales = min_timescale * torch.exp( 78 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 79 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 80 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 81 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 82 | signal = signal.view(1, channels, length) 83 | return signal 84 | 85 | 86 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 87 | b, channels, length = x.size() 88 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 89 | return x + signal.to(dtype=x.dtype, device=x.device) 90 | 91 | 92 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 93 | b, channels, length = x.size() 94 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 95 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 96 | 97 | 98 | def subsequent_mask(length): 99 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 100 | return mask 101 | 102 | 103 | @torch.jit.script 104 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 105 | n_channels_int = n_channels[0] 106 | in_act = input_a + input_b 107 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 108 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 109 | acts = t_act * s_act 110 | return acts 111 | 112 | 113 | def convert_pad_shape(pad_shape): 114 | l = pad_shape[::-1] 115 | pad_shape = [item for sublist in l for item in sublist] 116 | return pad_shape 117 | 118 | 119 | def shift_1d(x): 120 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 121 | return x 122 | 123 | 124 | def sequence_mask(length, max_length=None): 125 | if max_length is None: 126 | max_length = length.max() 127 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 128 | return x.unsqueeze(0) < length.unsqueeze(1) 129 | 130 | 131 | def generate_path(duration, mask): 132 | """ 133 | duration: [b, 1, t_x] 134 | mask: [b, 1, t_y, t_x] 135 | """ 136 | device = duration.device 137 | 138 | b, _, t_y, t_x = mask.shape 139 | cum_duration = torch.cumsum(duration, -1) 140 | 141 | cum_duration_flat = cum_duration.view(b * t_x) 142 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 143 | path = path.view(b, t_x, t_y) 144 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 145 | path = path.unsqueeze(1).transpose(2,3) * mask 146 | return path 147 | 148 | 149 | def clip_grad_value_(parameters, clip_value, norm_type=2): 150 | if isinstance(parameters, torch.Tensor): 151 | parameters = [parameters] 152 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 153 | norm_type = float(norm_type) 154 | if clip_value is not None: 155 | clip_value = float(clip_value) 156 | 157 | total_norm = 0 158 | for p in parameters: 159 | param_norm = p.grad.data.norm(norm_type) 160 | total_norm += param_norm.item() ** norm_type 161 | if clip_value is not None: 162 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 163 | total_norm = total_norm ** (1. / norm_type) 164 | return total_norm 165 | -------------------------------------------------------------------------------- /VITS/configs/modified_finetune_speaker.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 10, 4 | "eval_interval": 100, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 0.0002, 8 | "betas": [ 9 | 0.8, 10 | 0.99 11 | ], 12 | "eps": 1e-09, 13 | "batch_size": 16, 14 | "fp16_run": true, 15 | "lr_decay": 0.999875, 16 | "segment_size": 8192, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 45, 20 | "c_kl": 1.0 21 | }, 22 | "data": { 23 | "training_files": "final_annotation_train.txt", 24 | "validation_files": "final_annotation_val.txt", 25 | "text_cleaners": [ 26 | "chinese_cleaners" 27 | ], 28 | "max_wav_value": 32768.0, 29 | "sampling_rate": 22050, 30 | "filter_length": 1024, 31 | "hop_length": 256, 32 | "win_length": 1024, 33 | "n_mel_channels": 80, 34 | "mel_fmin": 0.0, 35 | "mel_fmax": null, 36 | "add_blank": true, 37 | "n_speakers": 2, 38 | "cleaned_text": true 39 | }, 40 | "model": { 41 | "inter_channels": 192, 42 | "hidden_channels": 192, 43 | "filter_channels": 768, 44 | "n_heads": 2, 45 | "n_layers": 6, 46 | "kernel_size": 3, 47 | "p_dropout": 0.1, 48 | "resblock": "1", 49 | "resblock_kernel_sizes": [ 50 | 3, 51 | 7, 52 | 11 53 | ], 54 | "resblock_dilation_sizes": [ 55 | [ 56 | 1, 57 | 3, 58 | 5 59 | ], 60 | [ 61 | 1, 62 | 3, 63 | 5 64 | ], 65 | [ 66 | 1, 67 | 3, 68 | 5 69 | ] 70 | ], 71 | "upsample_rates": [ 72 | 8, 73 | 8, 74 | 2, 75 | 2 76 | ], 77 | "upsample_initial_channel": 512, 78 | "upsample_kernel_sizes": [ 79 | 16, 80 | 16, 81 | 4, 82 | 4 83 | ], 84 | "n_layers_q": 3, 85 | "use_spectral_norm": false, 86 | "gin_channels": 256 87 | }, 88 | "symbols": [ 89 | "_", 90 | "\uff1b", 91 | "\uff1a", 92 | "\uff0c", 93 | "\u3002", 94 | "\uff01", 95 | "\uff1f", 96 | "-", 97 | "\u201c", 98 | "\u201d", 99 | "\u300a", 100 | "\u300b", 101 | "\u3001", 102 | "\uff08", 103 | "\uff09", 104 | "\u2026", 105 | "\u2014", 106 | " ", 107 | "A", 108 | "B", 109 | "C", 110 | "D", 111 | "E", 112 | "F", 113 | "G", 114 | "H", 115 | "I", 116 | "J", 117 | "K", 118 | "L", 119 | "M", 120 | "N", 121 | "O", 122 | "P", 123 | "Q", 124 | "R", 125 | "S", 126 | "T", 127 | "U", 128 | "V", 129 | "W", 130 | "X", 131 | "Y", 132 | "Z", 133 | "a", 134 | "b", 135 | "c", 136 | "d", 137 | "e", 138 | "f", 139 | "g", 140 | "h", 141 | "i", 142 | "j", 143 | "k", 144 | "l", 145 | "m", 146 | "n", 147 | "o", 148 | "p", 149 | "q", 150 | "r", 151 | "s", 152 | "t", 153 | "u", 154 | "v", 155 | "w", 156 | "x", 157 | "y", 158 | "z", 159 | "1", 160 | "2", 161 | "3", 162 | "4", 163 | "5", 164 | "0", 165 | "\uff22", 166 | "\uff30" 167 | ], 168 | "speakers": { 169 | "dingzhen": 0, 170 | "taffy": 1 171 | } 172 | } -------------------------------------------------------------------------------- /VITS/configs/uma_trilingual.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 16, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "training_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.train.txt.cleaned", 21 | "validation_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.val.txt.cleaned", 22 | "text_cleaners":["cjke_cleaners2"], 23 | "max_wav_value": 32768.0, 24 | "sampling_rate": 22050, 25 | "filter_length": 1024, 26 | "hop_length": 256, 27 | "win_length": 1024, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0.0, 30 | "mel_fmax": null, 31 | "add_blank": true, 32 | "n_speakers": 999, 33 | "cleaned_text": true 34 | }, 35 | "model": { 36 | "inter_channels": 192, 37 | "hidden_channels": 192, 38 | "filter_channels": 768, 39 | "n_heads": 2, 40 | "n_layers": 6, 41 | "kernel_size": 3, 42 | "p_dropout": 0.1, 43 | "resblock": "1", 44 | "resblock_kernel_sizes": [3,7,11], 45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 46 | "upsample_rates": [8,8,2,2], 47 | "upsample_initial_channel": 512, 48 | "upsample_kernel_sizes": [16,16,4,4], 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false, 51 | "gin_channels": 256 52 | }, 53 | "symbols": ["_", ",", ".", "!", "?", "-", "~", "\u2026", "N", "Q", "a", "b", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "s", "t", "u", "v", "w", "x", "y", "z", "\u0251", "\u00e6", "\u0283", "\u0291", "\u00e7", "\u026f", "\u026a", "\u0254", "\u025b", "\u0279", "\u00f0", "\u0259", "\u026b", "\u0265", "\u0278", "\u028a", "\u027e", "\u0292", "\u03b8", "\u03b2", "\u014b", "\u0266", "\u207c", "\u02b0", "`", "^", "#", "*", "=", "\u02c8", "\u02cc", "\u2192", "\u2193", "\u2191", " "] 54 | } -------------------------------------------------------------------------------- /VITS/mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + '_' + str(y.device) 59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | 66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 67 | center=center, pad_mode='reflect', normalized=False, onesided=True) 68 | 69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 70 | return spec 71 | 72 | 73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 74 | global mel_basis 75 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 76 | fmax_dtype_device = str(fmax) + '_' + dtype_device 77 | if fmax_dtype_device not in mel_basis: 78 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 81 | spec = spectral_normalize_torch(spec) 82 | return spec 83 | 84 | 85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 86 | if torch.min(y) < -1.: 87 | print('min value is ', torch.min(y)) 88 | if torch.max(y) > 1.: 89 | print('max value is ', torch.max(y)) 90 | 91 | global mel_basis, hann_window 92 | dtype_device = str(y.dtype) + '_' + str(y.device) 93 | fmax_dtype_device = str(fmax) + '_' + dtype_device 94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 95 | if fmax_dtype_device not in mel_basis: 96 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 98 | if wnsize_dtype_device not in hann_window: 99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 100 | 101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 102 | y = y.squeeze(1) 103 | 104 | spec = torch.stft(y.float(), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 105 | center=center, pad_mode='reflect', normalized=False, onesided=True) 106 | 107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 108 | 109 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 110 | spec = spectral_normalize_torch(spec) 111 | 112 | return spec 113 | -------------------------------------------------------------------------------- /VITS/models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from VITS import commons 8 | from VITS import modules#自建文件 9 | from VITS import attentions#自建文件 10 | from VITS import monotonic_align 11 | 12 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 13 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 14 | from VITS.commons import init_weights, get_padding 15 | 16 | 17 | class StochasticDurationPredictor(nn.Module): 18 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 19 | super().__init__() 20 | filter_channels = in_channels # it needs to be removed from future version. 21 | self.in_channels = in_channels 22 | self.filter_channels = filter_channels 23 | self.kernel_size = kernel_size 24 | self.p_dropout = p_dropout 25 | self.n_flows = n_flows 26 | self.gin_channels = gin_channels 27 | 28 | self.log_flow = modules.Log() 29 | self.flows = nn.ModuleList() 30 | self.flows.append(modules.ElementwiseAffine(2)) 31 | for i in range(n_flows): 32 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 33 | self.flows.append(modules.Flip()) 34 | 35 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 36 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 37 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 38 | self.post_flows = nn.ModuleList() 39 | self.post_flows.append(modules.ElementwiseAffine(2)) 40 | for i in range(4): 41 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 42 | self.post_flows.append(modules.Flip()) 43 | 44 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 45 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 46 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 47 | if gin_channels != 0: 48 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 49 | 50 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 51 | x = torch.detach(x) 52 | x = self.pre(x) 53 | if g is not None: 54 | g = torch.detach(g) 55 | x = x + self.cond(g) 56 | x = self.convs(x, x_mask) 57 | x = self.proj(x) * x_mask 58 | 59 | if not reverse: 60 | flows = self.flows 61 | assert w is not None 62 | 63 | logdet_tot_q = 0 64 | h_w = self.post_pre(w) 65 | h_w = self.post_convs(h_w, x_mask) 66 | h_w = self.post_proj(h_w) * x_mask 67 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 68 | z_q = e_q 69 | for flow in self.post_flows: 70 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 71 | logdet_tot_q += logdet_q 72 | z_u, z1 = torch.split(z_q, [1, 1], 1) 73 | u = torch.sigmoid(z_u) * x_mask 74 | z0 = (w - u) * x_mask 75 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 76 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 77 | 78 | logdet_tot = 0 79 | z0, logdet = self.log_flow(z0, x_mask) 80 | logdet_tot += logdet 81 | z = torch.cat([z0, z1], 1) 82 | for flow in flows: 83 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 84 | logdet_tot = logdet_tot + logdet 85 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 86 | return nll + logq # [b] 87 | else: 88 | flows = list(reversed(self.flows)) 89 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 90 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 91 | for flow in flows: 92 | z = flow(z, x_mask, g=x, reverse=reverse) 93 | z0, z1 = torch.split(z, [1, 1], 1) 94 | logw = z0 95 | return logw 96 | 97 | 98 | class DurationPredictor(nn.Module): 99 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 100 | super().__init__() 101 | 102 | self.in_channels = in_channels 103 | self.filter_channels = filter_channels 104 | self.kernel_size = kernel_size 105 | self.p_dropout = p_dropout 106 | self.gin_channels = gin_channels 107 | 108 | self.drop = nn.Dropout(p_dropout) 109 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 110 | self.norm_1 = modules.LayerNorm(filter_channels) 111 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) 112 | self.norm_2 = modules.LayerNorm(filter_channels) 113 | self.proj = nn.Conv1d(filter_channels, 1, 1) 114 | 115 | if gin_channels != 0: 116 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 117 | 118 | def forward(self, x, x_mask, g=None): 119 | x = torch.detach(x) 120 | if g is not None: 121 | g = torch.detach(g) 122 | x = x + self.cond(g) 123 | x = self.conv_1(x * x_mask) 124 | x = torch.relu(x) 125 | x = self.norm_1(x) 126 | x = self.drop(x) 127 | x = self.conv_2(x * x_mask) 128 | x = torch.relu(x) 129 | x = self.norm_2(x) 130 | x = self.drop(x) 131 | x = self.proj(x * x_mask) 132 | return x * x_mask 133 | 134 | 135 | class TextEncoder(nn.Module): 136 | def __init__(self, 137 | n_vocab, 138 | out_channels, 139 | hidden_channels, 140 | filter_channels, 141 | n_heads, 142 | n_layers, 143 | kernel_size, 144 | p_dropout): 145 | super().__init__() 146 | self.n_vocab = n_vocab 147 | self.out_channels = out_channels 148 | self.hidden_channels = hidden_channels 149 | self.filter_channels = filter_channels 150 | self.n_heads = n_heads 151 | self.n_layers = n_layers 152 | self.kernel_size = kernel_size 153 | self.p_dropout = p_dropout 154 | 155 | self.emb = nn.Embedding(n_vocab, hidden_channels) 156 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 157 | 158 | self.encoder = attentions.Encoder( 159 | hidden_channels, 160 | filter_channels, 161 | n_heads, 162 | n_layers, 163 | kernel_size, 164 | p_dropout) 165 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) 166 | 167 | def forward(self, x, x_lengths): 168 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 169 | x = torch.transpose(x, 1, -1) # [b, h, t] 170 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 171 | 172 | x = self.encoder(x * x_mask, x_mask) 173 | stats = self.proj(x) * x_mask 174 | 175 | m, logs = torch.split(stats, self.out_channels, dim=1) 176 | return x, m, logs, x_mask 177 | 178 | 179 | class ResidualCouplingBlock(nn.Module): 180 | def __init__(self, 181 | channels, 182 | hidden_channels, 183 | kernel_size, 184 | dilation_rate, 185 | n_layers, 186 | n_flows=4, 187 | gin_channels=0): 188 | super().__init__() 189 | self.channels = channels 190 | self.hidden_channels = hidden_channels 191 | self.kernel_size = kernel_size 192 | self.dilation_rate = dilation_rate 193 | self.n_layers = n_layers 194 | self.n_flows = n_flows 195 | self.gin_channels = gin_channels 196 | 197 | self.flows = nn.ModuleList() 198 | for i in range(n_flows): 199 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) 200 | self.flows.append(modules.Flip()) 201 | 202 | def forward(self, x, x_mask, g=None, reverse=False): 203 | if not reverse: 204 | for flow in self.flows: 205 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 206 | else: 207 | for flow in reversed(self.flows): 208 | x = flow(x, x_mask, g=g, reverse=reverse) 209 | return x 210 | 211 | 212 | class PosteriorEncoder(nn.Module): 213 | def __init__(self, 214 | in_channels, 215 | out_channels, 216 | hidden_channels, 217 | kernel_size, 218 | dilation_rate, 219 | n_layers, 220 | gin_channels=0): 221 | super().__init__() 222 | self.in_channels = in_channels 223 | self.out_channels = out_channels 224 | self.hidden_channels = hidden_channels 225 | self.kernel_size = kernel_size 226 | self.dilation_rate = dilation_rate 227 | self.n_layers = n_layers 228 | self.gin_channels = gin_channels 229 | 230 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 231 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 232 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 233 | 234 | def forward(self, x, x_lengths, g=None): 235 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 236 | x = self.pre(x) * x_mask 237 | x = self.enc(x, x_mask, g=g) 238 | stats = self.proj(x) * x_mask 239 | m, logs = torch.split(stats, self.out_channels, dim=1) 240 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 241 | return z, m, logs, x_mask 242 | 243 | 244 | class Generator(torch.nn.Module): 245 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 246 | super(Generator, self).__init__() 247 | self.num_kernels = len(resblock_kernel_sizes) 248 | self.num_upsamples = len(upsample_rates) 249 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 250 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 251 | 252 | self.ups = nn.ModuleList() 253 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 254 | self.ups.append(weight_norm( 255 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 256 | k, u, padding=(k-u)//2))) 257 | 258 | self.resblocks = nn.ModuleList() 259 | for i in range(len(self.ups)): 260 | ch = upsample_initial_channel//(2**(i+1)) 261 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 262 | self.resblocks.append(resblock(ch, k, d)) 263 | 264 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 265 | self.ups.apply(init_weights) 266 | 267 | if gin_channels != 0: 268 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 269 | 270 | def forward(self, x, g=None): 271 | x = self.conv_pre(x) 272 | if g is not None: 273 | x = x + self.cond(g) 274 | 275 | for i in range(self.num_upsamples): 276 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 277 | x = self.ups[i](x) 278 | xs = None 279 | for j in range(self.num_kernels): 280 | if xs is None: 281 | xs = self.resblocks[i*self.num_kernels+j](x) 282 | else: 283 | xs += self.resblocks[i*self.num_kernels+j](x) 284 | x = xs / self.num_kernels 285 | x = F.leaky_relu(x) 286 | x = self.conv_post(x) 287 | x = torch.tanh(x) 288 | 289 | return x 290 | 291 | def remove_weight_norm(self): 292 | #print('Removing weight norm...') 293 | for l in self.ups: 294 | remove_weight_norm(l) 295 | for l in self.resblocks: 296 | l.remove_weight_norm() 297 | 298 | 299 | class DiscriminatorP(torch.nn.Module): 300 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 301 | super(DiscriminatorP, self).__init__() 302 | self.period = period 303 | self.use_spectral_norm = use_spectral_norm 304 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 305 | self.convs = nn.ModuleList([ 306 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 307 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 308 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 309 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 310 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 311 | ]) 312 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 313 | 314 | def forward(self, x): 315 | fmap = [] 316 | 317 | # 1d to 2d 318 | b, c, t = x.shape 319 | if t % self.period != 0: # pad first 320 | n_pad = self.period - (t % self.period) 321 | x = F.pad(x, (0, n_pad), "reflect") 322 | t = t + n_pad 323 | x = x.view(b, c, t // self.period, self.period) 324 | 325 | for l in self.convs: 326 | x = l(x) 327 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 328 | fmap.append(x) 329 | x = self.conv_post(x) 330 | fmap.append(x) 331 | x = torch.flatten(x, 1, -1) 332 | 333 | return x, fmap 334 | 335 | 336 | class DiscriminatorS(torch.nn.Module): 337 | def __init__(self, use_spectral_norm=False): 338 | super(DiscriminatorS, self).__init__() 339 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 340 | self.convs = nn.ModuleList([ 341 | norm_f(Conv1d(1, 16, 15, 1, padding=7)), 342 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), 343 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), 344 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 345 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 346 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 347 | ]) 348 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 349 | 350 | def forward(self, x): 351 | fmap = [] 352 | 353 | for l in self.convs: 354 | x = l(x) 355 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 356 | fmap.append(x) 357 | x = self.conv_post(x) 358 | fmap.append(x) 359 | x = torch.flatten(x, 1, -1) 360 | 361 | return x, fmap 362 | 363 | 364 | class MultiPeriodDiscriminator(torch.nn.Module): 365 | def __init__(self, use_spectral_norm=False): 366 | super(MultiPeriodDiscriminator, self).__init__() 367 | periods = [2,3,5,7,11] 368 | 369 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 370 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 371 | self.discriminators = nn.ModuleList(discs) 372 | 373 | def forward(self, y, y_hat): 374 | y_d_rs = [] 375 | y_d_gs = [] 376 | fmap_rs = [] 377 | fmap_gs = [] 378 | for i, d in enumerate(self.discriminators): 379 | y_d_r, fmap_r = d(y) 380 | y_d_g, fmap_g = d(y_hat) 381 | y_d_rs.append(y_d_r) 382 | y_d_gs.append(y_d_g) 383 | fmap_rs.append(fmap_r) 384 | fmap_gs.append(fmap_g) 385 | 386 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 387 | 388 | 389 | 390 | class SynthesizerTrn(nn.Module): 391 | """ 392 | Synthesizer for Training 393 | """ 394 | 395 | def __init__(self, 396 | n_vocab, 397 | spec_channels, 398 | segment_size, 399 | inter_channels, 400 | hidden_channels, 401 | filter_channels, 402 | n_heads, 403 | n_layers, 404 | kernel_size, 405 | p_dropout, 406 | resblock, 407 | resblock_kernel_sizes, 408 | resblock_dilation_sizes, 409 | upsample_rates, 410 | upsample_initial_channel, 411 | upsample_kernel_sizes, 412 | n_speakers=0, 413 | gin_channels=0, 414 | use_sdp=True, 415 | **kwargs): 416 | 417 | super().__init__() 418 | self.n_vocab = n_vocab 419 | self.spec_channels = spec_channels 420 | self.inter_channels = inter_channels 421 | self.hidden_channels = hidden_channels 422 | self.filter_channels = filter_channels 423 | self.n_heads = n_heads 424 | self.n_layers = n_layers 425 | self.kernel_size = kernel_size 426 | self.p_dropout = p_dropout 427 | self.resblock = resblock 428 | self.resblock_kernel_sizes = resblock_kernel_sizes 429 | self.resblock_dilation_sizes = resblock_dilation_sizes 430 | self.upsample_rates = upsample_rates 431 | self.upsample_initial_channel = upsample_initial_channel 432 | self.upsample_kernel_sizes = upsample_kernel_sizes 433 | self.segment_size = segment_size 434 | self.n_speakers = n_speakers 435 | self.gin_channels = gin_channels 436 | 437 | self.use_sdp = use_sdp 438 | 439 | self.enc_p = TextEncoder(n_vocab, 440 | inter_channels, 441 | hidden_channels, 442 | filter_channels, 443 | n_heads, 444 | n_layers, 445 | kernel_size, 446 | p_dropout) 447 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) 448 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) 449 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 450 | 451 | if use_sdp: 452 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 453 | else: 454 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 455 | 456 | if n_speakers >= 1: 457 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 458 | 459 | def forward(self, x, x_lengths, y, y_lengths, sid=None): 460 | 461 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 462 | if self.n_speakers > 0: 463 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 464 | else: 465 | g = None 466 | 467 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) 468 | z_p = self.flow(z, y_mask, g=g) 469 | 470 | with torch.no_grad(): 471 | # negative cross-entropy 472 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] 473 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] 474 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 475 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 476 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] 477 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 478 | 479 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 480 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() 481 | 482 | w = attn.sum(2) 483 | if self.use_sdp: 484 | l_length = self.dp(x, x_mask, w, g=g) 485 | l_length = l_length / torch.sum(x_mask) 486 | else: 487 | logw_ = torch.log(w + 1e-6) * x_mask 488 | logw = self.dp(x, x_mask, g=g) 489 | l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging 490 | 491 | # expand prior 492 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) 493 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) 494 | 495 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) 496 | o = self.dec(z_slice, g=g) 497 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) 498 | 499 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): 500 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 501 | if self.n_speakers > 0: 502 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 503 | else: 504 | g = None 505 | 506 | if self.use_sdp: 507 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) 508 | else: 509 | logw = self.dp(x, x_mask, g=g) 510 | w = torch.exp(logw) * x_mask * length_scale 511 | w_ceil = torch.ceil(w) 512 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 513 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 514 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 515 | attn = commons.generate_path(w_ceil, attn_mask) 516 | 517 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 518 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 519 | 520 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 521 | z = self.flow(z_p, y_mask, g=g, reverse=True) 522 | o = self.dec((z * y_mask)[:,:,:max_len], g=g) 523 | return o, attn, y_mask, (z, z_p, m_p, logs_p) 524 | 525 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): 526 | assert self.n_speakers > 0, "n_speakers have to be larger than 0." 527 | g_src = self.emb_g(sid_src).unsqueeze(-1) 528 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) 529 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) 530 | z_p = self.flow(z, y_mask, g=g_src) 531 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 532 | o_hat = self.dec(z_hat * y_mask, g=g_tgt) 533 | return o_hat, y_mask, (z, z_p, z_hat) 534 | -------------------------------------------------------------------------------- /VITS/models_infer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | import commons 7 | import modules 8 | import attentions 9 | 10 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 11 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 12 | from commons import init_weights, get_padding 13 | 14 | 15 | class StochasticDurationPredictor(nn.Module): 16 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 17 | super().__init__() 18 | filter_channels = in_channels # it needs to be removed from future version. 19 | self.in_channels = in_channels 20 | self.filter_channels = filter_channels 21 | self.kernel_size = kernel_size 22 | self.p_dropout = p_dropout 23 | self.n_flows = n_flows 24 | self.gin_channels = gin_channels 25 | 26 | self.log_flow = modules.Log() 27 | self.flows = nn.ModuleList() 28 | self.flows.append(modules.ElementwiseAffine(2)) 29 | for i in range(n_flows): 30 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 31 | self.flows.append(modules.Flip()) 32 | 33 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 34 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 35 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 36 | self.post_flows = nn.ModuleList() 37 | self.post_flows.append(modules.ElementwiseAffine(2)) 38 | for i in range(4): 39 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 40 | self.post_flows.append(modules.Flip()) 41 | 42 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 43 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 44 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 45 | if gin_channels != 0: 46 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 47 | 48 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 49 | x = torch.detach(x) 50 | x = self.pre(x) 51 | if g is not None: 52 | g = torch.detach(g) 53 | x = x + self.cond(g) 54 | x = self.convs(x, x_mask) 55 | x = self.proj(x) * x_mask 56 | 57 | if not reverse: 58 | flows = self.flows 59 | assert w is not None 60 | 61 | logdet_tot_q = 0 62 | h_w = self.post_pre(w) 63 | h_w = self.post_convs(h_w, x_mask) 64 | h_w = self.post_proj(h_w) * x_mask 65 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 66 | z_q = e_q 67 | for flow in self.post_flows: 68 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 69 | logdet_tot_q += logdet_q 70 | z_u, z1 = torch.split(z_q, [1, 1], 1) 71 | u = torch.sigmoid(z_u) * x_mask 72 | z0 = (w - u) * x_mask 73 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 74 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 75 | 76 | logdet_tot = 0 77 | z0, logdet = self.log_flow(z0, x_mask) 78 | logdet_tot += logdet 79 | z = torch.cat([z0, z1], 1) 80 | for flow in flows: 81 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 82 | logdet_tot = logdet_tot + logdet 83 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 84 | return nll + logq # [b] 85 | else: 86 | flows = list(reversed(self.flows)) 87 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 88 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 89 | for flow in flows: 90 | z = flow(z, x_mask, g=x, reverse=reverse) 91 | z0, z1 = torch.split(z, [1, 1], 1) 92 | logw = z0 93 | return logw 94 | 95 | 96 | class DurationPredictor(nn.Module): 97 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 98 | super().__init__() 99 | 100 | self.in_channels = in_channels 101 | self.filter_channels = filter_channels 102 | self.kernel_size = kernel_size 103 | self.p_dropout = p_dropout 104 | self.gin_channels = gin_channels 105 | 106 | self.drop = nn.Dropout(p_dropout) 107 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 108 | self.norm_1 = modules.LayerNorm(filter_channels) 109 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) 110 | self.norm_2 = modules.LayerNorm(filter_channels) 111 | self.proj = nn.Conv1d(filter_channels, 1, 1) 112 | 113 | if gin_channels != 0: 114 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 115 | 116 | def forward(self, x, x_mask, g=None): 117 | x = torch.detach(x) 118 | if g is not None: 119 | g = torch.detach(g) 120 | x = x + self.cond(g) 121 | x = self.conv_1(x * x_mask) 122 | x = torch.relu(x) 123 | x = self.norm_1(x) 124 | x = self.drop(x) 125 | x = self.conv_2(x * x_mask) 126 | x = torch.relu(x) 127 | x = self.norm_2(x) 128 | x = self.drop(x) 129 | x = self.proj(x * x_mask) 130 | return x * x_mask 131 | 132 | 133 | class TextEncoder(nn.Module): 134 | def __init__(self, 135 | n_vocab, 136 | out_channels, 137 | hidden_channels, 138 | filter_channels, 139 | n_heads, 140 | n_layers, 141 | kernel_size, 142 | p_dropout): 143 | super().__init__() 144 | self.n_vocab = n_vocab 145 | self.out_channels = out_channels 146 | self.hidden_channels = hidden_channels 147 | self.filter_channels = filter_channels 148 | self.n_heads = n_heads 149 | self.n_layers = n_layers 150 | self.kernel_size = kernel_size 151 | self.p_dropout = p_dropout 152 | 153 | self.emb = nn.Embedding(n_vocab, hidden_channels) 154 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 155 | 156 | self.encoder = attentions.Encoder( 157 | hidden_channels, 158 | filter_channels, 159 | n_heads, 160 | n_layers, 161 | kernel_size, 162 | p_dropout) 163 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) 164 | 165 | def forward(self, x, x_lengths): 166 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 167 | x = torch.transpose(x, 1, -1) # [b, h, t] 168 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 169 | 170 | x = self.encoder(x * x_mask, x_mask) 171 | stats = self.proj(x) * x_mask 172 | 173 | m, logs = torch.split(stats, self.out_channels, dim=1) 174 | return x, m, logs, x_mask 175 | 176 | 177 | class ResidualCouplingBlock(nn.Module): 178 | def __init__(self, 179 | channels, 180 | hidden_channels, 181 | kernel_size, 182 | dilation_rate, 183 | n_layers, 184 | n_flows=4, 185 | gin_channels=0): 186 | super().__init__() 187 | self.channels = channels 188 | self.hidden_channels = hidden_channels 189 | self.kernel_size = kernel_size 190 | self.dilation_rate = dilation_rate 191 | self.n_layers = n_layers 192 | self.n_flows = n_flows 193 | self.gin_channels = gin_channels 194 | 195 | self.flows = nn.ModuleList() 196 | for i in range(n_flows): 197 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) 198 | self.flows.append(modules.Flip()) 199 | 200 | def forward(self, x, x_mask, g=None, reverse=False): 201 | if not reverse: 202 | for flow in self.flows: 203 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 204 | else: 205 | for flow in reversed(self.flows): 206 | x = flow(x, x_mask, g=g, reverse=reverse) 207 | return x 208 | 209 | 210 | class PosteriorEncoder(nn.Module): 211 | def __init__(self, 212 | in_channels, 213 | out_channels, 214 | hidden_channels, 215 | kernel_size, 216 | dilation_rate, 217 | n_layers, 218 | gin_channels=0): 219 | super().__init__() 220 | self.in_channels = in_channels 221 | self.out_channels = out_channels 222 | self.hidden_channels = hidden_channels 223 | self.kernel_size = kernel_size 224 | self.dilation_rate = dilation_rate 225 | self.n_layers = n_layers 226 | self.gin_channels = gin_channels 227 | 228 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 229 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 230 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 231 | 232 | def forward(self, x, x_lengths, g=None): 233 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 234 | x = self.pre(x) * x_mask 235 | x = self.enc(x, x_mask, g=g) 236 | stats = self.proj(x) * x_mask 237 | m, logs = torch.split(stats, self.out_channels, dim=1) 238 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 239 | return z, m, logs, x_mask 240 | 241 | 242 | class Generator(torch.nn.Module): 243 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 244 | super(Generator, self).__init__() 245 | self.num_kernels = len(resblock_kernel_sizes) 246 | self.num_upsamples = len(upsample_rates) 247 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 248 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 249 | 250 | self.ups = nn.ModuleList() 251 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 252 | self.ups.append(weight_norm( 253 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 254 | k, u, padding=(k-u)//2))) 255 | 256 | self.resblocks = nn.ModuleList() 257 | for i in range(len(self.ups)): 258 | ch = upsample_initial_channel//(2**(i+1)) 259 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 260 | self.resblocks.append(resblock(ch, k, d)) 261 | 262 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 263 | self.ups.apply(init_weights) 264 | 265 | if gin_channels != 0: 266 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 267 | 268 | def forward(self, x, g=None): 269 | x = self.conv_pre(x) 270 | if g is not None: 271 | x = x + self.cond(g) 272 | 273 | for i in range(self.num_upsamples): 274 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 275 | x = self.ups[i](x) 276 | xs = None 277 | for j in range(self.num_kernels): 278 | if xs is None: 279 | xs = self.resblocks[i*self.num_kernels+j](x) 280 | else: 281 | xs += self.resblocks[i*self.num_kernels+j](x) 282 | x = xs / self.num_kernels 283 | x = F.leaky_relu(x) 284 | x = self.conv_post(x) 285 | x = torch.tanh(x) 286 | 287 | return x 288 | 289 | def remove_weight_norm(self): 290 | #print('Removing weight norm...') 291 | for l in self.ups: 292 | remove_weight_norm(l) 293 | for l in self.resblocks: 294 | l.remove_weight_norm() 295 | 296 | 297 | 298 | class SynthesizerTrn(nn.Module): 299 | """ 300 | Synthesizer for Training 301 | """ 302 | 303 | def __init__(self, 304 | n_vocab, 305 | spec_channels, 306 | segment_size, 307 | inter_channels, 308 | hidden_channels, 309 | filter_channels, 310 | n_heads, 311 | n_layers, 312 | kernel_size, 313 | p_dropout, 314 | resblock, 315 | resblock_kernel_sizes, 316 | resblock_dilation_sizes, 317 | upsample_rates, 318 | upsample_initial_channel, 319 | upsample_kernel_sizes, 320 | n_speakers=0, 321 | gin_channels=0, 322 | use_sdp=True, 323 | **kwargs): 324 | 325 | super().__init__() 326 | self.n_vocab = n_vocab 327 | self.spec_channels = spec_channels 328 | self.inter_channels = inter_channels 329 | self.hidden_channels = hidden_channels 330 | self.filter_channels = filter_channels 331 | self.n_heads = n_heads 332 | self.n_layers = n_layers 333 | self.kernel_size = kernel_size 334 | self.p_dropout = p_dropout 335 | self.resblock = resblock 336 | self.resblock_kernel_sizes = resblock_kernel_sizes 337 | self.resblock_dilation_sizes = resblock_dilation_sizes 338 | self.upsample_rates = upsample_rates 339 | self.upsample_initial_channel = upsample_initial_channel 340 | self.upsample_kernel_sizes = upsample_kernel_sizes 341 | self.segment_size = segment_size 342 | self.n_speakers = n_speakers 343 | self.gin_channels = gin_channels 344 | 345 | self.use_sdp = use_sdp 346 | 347 | self.enc_p = TextEncoder(n_vocab, 348 | inter_channels, 349 | hidden_channels, 350 | filter_channels, 351 | n_heads, 352 | n_layers, 353 | kernel_size, 354 | p_dropout) 355 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) 356 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) 357 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 358 | 359 | if use_sdp: 360 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 361 | else: 362 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 363 | 364 | if n_speakers > 1: 365 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 366 | 367 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): 368 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 369 | if self.n_speakers > 0: 370 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 371 | else: 372 | g = None 373 | 374 | if self.use_sdp: 375 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) 376 | else: 377 | logw = self.dp(x, x_mask, g=g) 378 | w = torch.exp(logw) * x_mask * length_scale 379 | w_ceil = torch.ceil(w) 380 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 381 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 382 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 383 | attn = commons.generate_path(w_ceil, attn_mask) 384 | 385 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 386 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 387 | 388 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 389 | z = self.flow(z_p, y_mask, g=g, reverse=True) 390 | o = self.dec((z * y_mask)[:,:,:max_len], g=g) 391 | return o, attn, y_mask, (z, z_p, m_p, logs_p) 392 | 393 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): 394 | assert self.n_speakers > 0, "n_speakers have to be larger than 0." 395 | g_src = self.emb_g(sid_src).unsqueeze(-1) 396 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) 397 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) 398 | z_p = self.flow(z, y_mask, g=g_src) 399 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 400 | o_hat = self.dec(z_hat * y_mask, g=g_tgt) 401 | return o_hat, y_mask, (z, z_p, z_hat) 402 | 403 | -------------------------------------------------------------------------------- /VITS/modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | from VITS import commons 13 | from VITS.commons import init_weights, get_padding 14 | from VITS.transforms import piecewise_rational_quadratic_transform 15 | 16 | 17 | LRELU_SLOPE = 0.1 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class ConvReluNorm(nn.Module): 36 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 37 | super().__init__() 38 | self.in_channels = in_channels 39 | self.hidden_channels = hidden_channels 40 | self.out_channels = out_channels 41 | self.kernel_size = kernel_size 42 | self.n_layers = n_layers 43 | self.p_dropout = p_dropout 44 | assert n_layers > 1, "Number of layers should be larger than 0." 45 | 46 | self.conv_layers = nn.ModuleList() 47 | self.norm_layers = nn.ModuleList() 48 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 49 | self.norm_layers.append(LayerNorm(hidden_channels)) 50 | self.relu_drop = nn.Sequential( 51 | nn.ReLU(), 52 | nn.Dropout(p_dropout)) 53 | for _ in range(n_layers-1): 54 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 55 | self.norm_layers.append(LayerNorm(hidden_channels)) 56 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 57 | self.proj.weight.data.zero_() 58 | self.proj.bias.data.zero_() 59 | 60 | def forward(self, x, x_mask): 61 | x_org = x 62 | for i in range(self.n_layers): 63 | x = self.conv_layers[i](x * x_mask) 64 | x = self.norm_layers[i](x) 65 | x = self.relu_drop(x) 66 | x = x_org + self.proj(x) 67 | return x * x_mask 68 | 69 | 70 | class DDSConv(nn.Module): 71 | """ 72 | Dialted and Depth-Separable Convolution 73 | """ 74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 75 | super().__init__() 76 | self.channels = channels 77 | self.kernel_size = kernel_size 78 | self.n_layers = n_layers 79 | self.p_dropout = p_dropout 80 | 81 | self.drop = nn.Dropout(p_dropout) 82 | self.convs_sep = nn.ModuleList() 83 | self.convs_1x1 = nn.ModuleList() 84 | self.norms_1 = nn.ModuleList() 85 | self.norms_2 = nn.ModuleList() 86 | for i in range(n_layers): 87 | dilation = kernel_size ** i 88 | padding = (kernel_size * dilation - dilation) // 2 89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 90 | groups=channels, dilation=dilation, padding=padding 91 | )) 92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 93 | self.norms_1.append(LayerNorm(channels)) 94 | self.norms_2.append(LayerNorm(channels)) 95 | 96 | def forward(self, x, x_mask, g=None): 97 | if g is not None: 98 | x = x + g 99 | for i in range(self.n_layers): 100 | y = self.convs_sep[i](x * x_mask) 101 | y = self.norms_1[i](y) 102 | y = F.gelu(y) 103 | y = self.convs_1x1[i](y) 104 | y = self.norms_2[i](y) 105 | y = F.gelu(y) 106 | y = self.drop(y) 107 | x = x + y 108 | return x * x_mask 109 | 110 | 111 | class WN(torch.nn.Module): 112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | self.hidden_channels =hidden_channels 116 | self.kernel_size = kernel_size, 117 | self.dilation_rate = dilation_rate 118 | self.n_layers = n_layers 119 | self.gin_channels = gin_channels 120 | self.p_dropout = p_dropout 121 | 122 | self.in_layers = torch.nn.ModuleList() 123 | self.res_skip_layers = torch.nn.ModuleList() 124 | self.drop = nn.Dropout(p_dropout) 125 | 126 | if gin_channels != 0: 127 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 129 | 130 | for i in range(n_layers): 131 | dilation = dilation_rate ** i 132 | padding = int((kernel_size * dilation - dilation) / 2) 133 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 134 | dilation=dilation, padding=padding) 135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 136 | self.in_layers.append(in_layer) 137 | 138 | # last one is not necessary 139 | if i < n_layers - 1: 140 | res_skip_channels = 2 * hidden_channels 141 | else: 142 | res_skip_channels = hidden_channels 143 | 144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 146 | self.res_skip_layers.append(res_skip_layer) 147 | 148 | def forward(self, x, x_mask, g=None, **kwargs): 149 | output = torch.zeros_like(x) 150 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 151 | 152 | if g is not None: 153 | g = self.cond_layer(g) 154 | 155 | for i in range(self.n_layers): 156 | x_in = self.in_layers[i](x) 157 | if g is not None: 158 | cond_offset = i * 2 * self.hidden_channels 159 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 160 | else: 161 | g_l = torch.zeros_like(x_in) 162 | 163 | acts = commons.fused_add_tanh_sigmoid_multiply( 164 | x_in, 165 | g_l, 166 | n_channels_tensor) 167 | acts = self.drop(acts) 168 | 169 | res_skip_acts = self.res_skip_layers[i](acts) 170 | if i < self.n_layers - 1: 171 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 172 | x = (x + res_acts) * x_mask 173 | output = output + res_skip_acts[:,self.hidden_channels:,:] 174 | else: 175 | output = output + res_skip_acts 176 | return output * x_mask 177 | 178 | def remove_weight_norm(self): 179 | if self.gin_channels != 0: 180 | torch.nn.utils.remove_weight_norm(self.cond_layer) 181 | for l in self.in_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | for l in self.res_skip_layers: 184 | torch.nn.utils.remove_weight_norm(l) 185 | 186 | 187 | class ResBlock1(torch.nn.Module): 188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 189 | super(ResBlock1, self).__init__() 190 | self.convs1 = nn.ModuleList([ 191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 192 | padding=get_padding(kernel_size, dilation[0]))), 193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 194 | padding=get_padding(kernel_size, dilation[1]))), 195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 196 | padding=get_padding(kernel_size, dilation[2]))) 197 | ]) 198 | self.convs1.apply(init_weights) 199 | 200 | self.convs2 = nn.ModuleList([ 201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 202 | padding=get_padding(kernel_size, 1))), 203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 204 | padding=get_padding(kernel_size, 1))), 205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 206 | padding=get_padding(kernel_size, 1))) 207 | ]) 208 | self.convs2.apply(init_weights) 209 | 210 | def forward(self, x, x_mask=None): 211 | for c1, c2 in zip(self.convs1, self.convs2): 212 | xt = F.leaky_relu(x, LRELU_SLOPE) 213 | if x_mask is not None: 214 | xt = xt * x_mask 215 | xt = c1(xt) 216 | xt = F.leaky_relu(xt, LRELU_SLOPE) 217 | if x_mask is not None: 218 | xt = xt * x_mask 219 | xt = c2(xt) 220 | x = xt + x 221 | if x_mask is not None: 222 | x = x * x_mask 223 | return x 224 | 225 | def remove_weight_norm(self): 226 | for l in self.convs1: 227 | remove_weight_norm(l) 228 | for l in self.convs2: 229 | remove_weight_norm(l) 230 | 231 | 232 | class ResBlock2(torch.nn.Module): 233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 234 | super(ResBlock2, self).__init__() 235 | self.convs = nn.ModuleList([ 236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 237 | padding=get_padding(kernel_size, dilation[0]))), 238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 239 | padding=get_padding(kernel_size, dilation[1]))) 240 | ]) 241 | self.convs.apply(init_weights) 242 | 243 | def forward(self, x, x_mask=None): 244 | for c in self.convs: 245 | xt = F.leaky_relu(x, LRELU_SLOPE) 246 | if x_mask is not None: 247 | xt = xt * x_mask 248 | xt = c(xt) 249 | x = xt + x 250 | if x_mask is not None: 251 | x = x * x_mask 252 | return x 253 | 254 | def remove_weight_norm(self): 255 | for l in self.convs: 256 | remove_weight_norm(l) 257 | 258 | 259 | class Log(nn.Module): 260 | def forward(self, x, x_mask, reverse=False, **kwargs): 261 | if not reverse: 262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 263 | logdet = torch.sum(-y, [1, 2]) 264 | return y, logdet 265 | else: 266 | x = torch.exp(x) * x_mask 267 | return x 268 | 269 | 270 | class Flip(nn.Module): 271 | def forward(self, x, *args, reverse=False, **kwargs): 272 | x = torch.flip(x, [1]) 273 | if not reverse: 274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 275 | return x, logdet 276 | else: 277 | return x 278 | 279 | 280 | class ElementwiseAffine(nn.Module): 281 | def __init__(self, channels): 282 | super().__init__() 283 | self.channels = channels 284 | self.m = nn.Parameter(torch.zeros(channels,1)) 285 | self.logs = nn.Parameter(torch.zeros(channels,1)) 286 | 287 | def forward(self, x, x_mask, reverse=False, **kwargs): 288 | if not reverse: 289 | y = self.m + torch.exp(self.logs) * x 290 | y = y * x_mask 291 | logdet = torch.sum(self.logs * x_mask, [1,2]) 292 | return y, logdet 293 | else: 294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 295 | return x 296 | 297 | 298 | class ResidualCouplingLayer(nn.Module): 299 | def __init__(self, 300 | channels, 301 | hidden_channels, 302 | kernel_size, 303 | dilation_rate, 304 | n_layers, 305 | p_dropout=0, 306 | gin_channels=0, 307 | mean_only=False): 308 | assert channels % 2 == 0, "channels should be divisible by 2" 309 | super().__init__() 310 | self.channels = channels 311 | self.hidden_channels = hidden_channels 312 | self.kernel_size = kernel_size 313 | self.dilation_rate = dilation_rate 314 | self.n_layers = n_layers 315 | self.half_channels = channels // 2 316 | self.mean_only = mean_only 317 | 318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 320 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 321 | self.post.weight.data.zero_() 322 | self.post.bias.data.zero_() 323 | 324 | def forward(self, x, x_mask, g=None, reverse=False): 325 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 326 | h = self.pre(x0) * x_mask 327 | h = self.enc(h, x_mask, g=g) 328 | stats = self.post(h) * x_mask 329 | if not self.mean_only: 330 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 331 | else: 332 | m = stats 333 | logs = torch.zeros_like(m) 334 | 335 | if not reverse: 336 | x1 = m + x1 * torch.exp(logs) * x_mask 337 | x = torch.cat([x0, x1], 1) 338 | logdet = torch.sum(logs, [1,2]) 339 | return x, logdet 340 | else: 341 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 342 | x = torch.cat([x0, x1], 1) 343 | return x 344 | 345 | 346 | class ConvFlow(nn.Module): 347 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 348 | super().__init__() 349 | self.in_channels = in_channels 350 | self.filter_channels = filter_channels 351 | self.kernel_size = kernel_size 352 | self.n_layers = n_layers 353 | self.num_bins = num_bins 354 | self.tail_bound = tail_bound 355 | self.half_channels = in_channels // 2 356 | 357 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 358 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 359 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 360 | self.proj.weight.data.zero_() 361 | self.proj.bias.data.zero_() 362 | 363 | def forward(self, x, x_mask, g=None, reverse=False): 364 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 365 | h = self.pre(x0) 366 | h = self.convs(h, x_mask, g=g) 367 | h = self.proj(h) * x_mask 368 | 369 | b, c, t = x0.shape 370 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 371 | 372 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 373 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 374 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 375 | 376 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 377 | unnormalized_widths, 378 | unnormalized_heights, 379 | unnormalized_derivatives, 380 | inverse=reverse, 381 | tails='linear', 382 | tail_bound=self.tail_bound 383 | ) 384 | 385 | x = torch.cat([x0, x1], 1) * x_mask 386 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 387 | if not reverse: 388 | return x, logdet 389 | else: 390 | return x 391 | -------------------------------------------------------------------------------- /VITS/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """ Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /VITS/monotonic_align/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/monotonic_align/build/lib.win-amd64-3.10/monotonic_align/core.cp310-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/lib.win-amd64-3.10/monotonic_align/core.cp310-win_amd64.pyd -------------------------------------------------------------------------------- /VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.exp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.exp -------------------------------------------------------------------------------- /VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.lib -------------------------------------------------------------------------------- /VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.obj -------------------------------------------------------------------------------- /VITS/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /VITS/monotonic_align/monotonic_align/core.cp310-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/monotonic_align/core.cp310-win_amd64.pyd -------------------------------------------------------------------------------- /VITS/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name = 'monotonic_align', 7 | ext_modules = cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /VITS/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /VITS/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from VITS.text import cleaners 3 | from VITS.text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, symbols, cleaner_names): 12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | ''' 19 | sequence = [] 20 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 21 | clean_text = _clean_text(text, cleaner_names) 22 | ##print(clean_text) 23 | ##print(f" length:{len(clean_text)}") 24 | for symbol in clean_text: 25 | if symbol not in symbol_to_id.keys(): 26 | continue 27 | symbol_id = symbol_to_id[symbol] 28 | sequence += [symbol_id] 29 | ##print(f" length:{len(sequence)}") 30 | return sequence 31 | 32 | 33 | def cleaned_text_to_sequence(cleaned_text, symbols): 34 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 35 | Args: 36 | text: string to convert to a sequence 37 | Returns: 38 | List of integers corresponding to the symbols in the text 39 | ''' 40 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 41 | sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()] 42 | return sequence 43 | 44 | 45 | def sequence_to_text(sequence): 46 | '''Converts a sequence of IDs back to a string''' 47 | result = '' 48 | for symbol_id in sequence: 49 | s = _id_to_symbol[symbol_id] 50 | result += s 51 | return result 52 | 53 | 54 | def _clean_text(text, cleaner_names): 55 | for name in cleaner_names: 56 | cleaner = getattr(cleaners, name) 57 | if not cleaner: 58 | raise Exception('Unknown cleaner: %s' % name) 59 | text = cleaner(text) 60 | return text 61 | -------------------------------------------------------------------------------- /VITS/text/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/cleaners.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/cleaners.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/english.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/english.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/japanese.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/japanese.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/korean.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/korean.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/mandarin.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/mandarin.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/sanskrit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/sanskrit.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/symbols.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/symbols.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/__pycache__/thai.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/thai.cpython-310.pyc -------------------------------------------------------------------------------- /VITS/text/cantonese.py: -------------------------------------------------------------------------------- 1 | import re 2 | import cn2an 3 | import opencc 4 | 5 | 6 | converter = opencc.OpenCC('jyutjyu') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ei˥'), 11 | ('B', 'biː˥'), 12 | ('C', 'siː˥'), 13 | ('D', 'tiː˥'), 14 | ('E', 'iː˥'), 15 | ('F', 'e˥fuː˨˩'), 16 | ('G', 'tsiː˥'), 17 | ('H', 'ɪk̚˥tsʰyː˨˩'), 18 | ('I', 'ɐi˥'), 19 | ('J', 'tsei˥'), 20 | ('K', 'kʰei˥'), 21 | ('L', 'e˥llou˨˩'), 22 | ('M', 'ɛːm˥'), 23 | ('N', 'ɛːn˥'), 24 | ('O', 'ou˥'), 25 | ('P', 'pʰiː˥'), 26 | ('Q', 'kʰiːu˥'), 27 | ('R', 'aː˥lou˨˩'), 28 | ('S', 'ɛː˥siː˨˩'), 29 | ('T', 'tʰiː˥'), 30 | ('U', 'juː˥'), 31 | ('V', 'wiː˥'), 32 | ('W', 'tʊk̚˥piː˥juː˥'), 33 | ('X', 'ɪk̚˥siː˨˩'), 34 | ('Y', 'waːi˥'), 35 | ('Z', 'iː˨sɛːt̚˥') 36 | ]] 37 | 38 | 39 | def number_to_cantonese(text): 40 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text) 41 | 42 | 43 | def latin_to_ipa(text): 44 | for regex, replacement in _latin_to_ipa: 45 | text = re.sub(regex, replacement, text) 46 | return text 47 | 48 | 49 | def cantonese_to_ipa(text): 50 | text = number_to_cantonese(text.upper()) 51 | text = converter.convert(text).replace('-','').replace('$',' ') 52 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text) 53 | text = re.sub(r'[、;:]', ',', text) 54 | text = re.sub(r'\s*,\s*', ', ', text) 55 | text = re.sub(r'\s*。\s*', '. ', text) 56 | text = re.sub(r'\s*?\s*', '? ', text) 57 | text = re.sub(r'\s*!\s*', '! ', text) 58 | text = re.sub(r'\s*$', '', text) 59 | return text 60 | -------------------------------------------------------------------------------- /VITS/text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | from VITS.text.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3 3 | from VITS.text.korean import latin_to_hangul, number_to_hangul, divide_hangul, korean_to_lazy_ipa, korean_to_ipa 4 | from VITS.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 5 | from VITS.text.sanskrit import devanagari_to_ipa 6 | from VITS.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 7 | from VITS.text.thai import num_to_thai, latin_to_thai 8 | # from text.shanghainese import shanghainese_to_ipa 9 | # from text.cantonese import cantonese_to_ipa 10 | # from text.ngu_dialect import ngu_dialect_to_ipa 11 | 12 | 13 | def japanese_cleaners(text): 14 | text = japanese_to_romaji_with_accent(text) 15 | text = re.sub(r'([A-Za-z])$', r'\1.', text) 16 | return text 17 | 18 | 19 | def japanese_cleaners2(text): 20 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') 21 | 22 | 23 | def korean_cleaners(text): 24 | '''Pipeline for Korean text''' 25 | text = latin_to_hangul(text) 26 | text = number_to_hangul(text) 27 | text = divide_hangul(text) 28 | text = re.sub(r'([\u3131-\u3163])$', r'\1.', text) 29 | return text 30 | 31 | 32 | # def chinese_cleaners(text): 33 | # '''Pipeline for Chinese text''' 34 | # text = number_to_chinese(text) 35 | # text = chinese_to_bopomofo(text) 36 | # text = latin_to_bopomofo(text) 37 | # text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text) 38 | # return text 39 | 40 | def chinese_cleaners(text): 41 | from pypinyin import Style, pinyin 42 | text = text.replace("[ZH]", "") 43 | phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)] 44 | return ' '.join(phones) 45 | 46 | 47 | def zh_ja_mixture_cleaners(text): 48 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 49 | lambda x: chinese_to_romaji(x.group(1))+' ', text) 50 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent( 51 | x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…')+' ', text) 52 | text = re.sub(r'\s+$', '', text) 53 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 54 | return text 55 | 56 | 57 | def sanskrit_cleaners(text): 58 | text = text.replace('॥', '।').replace('ॐ', 'ओम्') 59 | text = re.sub(r'([^।])$', r'\1।', text) 60 | return text 61 | 62 | 63 | def cjks_cleaners(text): 64 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 65 | lambda x: chinese_to_lazy_ipa(x.group(1))+' ', text) 66 | text = re.sub(r'\[JA\](.*?)\[JA\]', 67 | lambda x: japanese_to_ipa(x.group(1))+' ', text) 68 | text = re.sub(r'\[KO\](.*?)\[KO\]', 69 | lambda x: korean_to_lazy_ipa(x.group(1))+' ', text) 70 | text = re.sub(r'\[SA\](.*?)\[SA\]', 71 | lambda x: devanagari_to_ipa(x.group(1))+' ', text) 72 | text = re.sub(r'\[EN\](.*?)\[EN\]', 73 | lambda x: english_to_lazy_ipa(x.group(1))+' ', text) 74 | text = re.sub(r'\s+$', '', text) 75 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 76 | return text 77 | 78 | 79 | def cjke_cleaners(text): 80 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace( 81 | 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')+' ', text) 82 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace( 83 | 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')+' ', text) 84 | text = re.sub(r'\[KO\](.*?)\[KO\]', 85 | lambda x: korean_to_ipa(x.group(1))+' ', text) 86 | text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace( 87 | 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')+' ', text) 88 | text = re.sub(r'\s+$', '', text) 89 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 90 | return text 91 | 92 | 93 | def cjke_cleaners2(text): 94 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 95 | lambda x: chinese_to_ipa(x.group(1))+' ', text) 96 | text = re.sub(r'\[JA\](.*?)\[JA\]', 97 | lambda x: japanese_to_ipa2(x.group(1))+' ', text) 98 | text = re.sub(r'\[KO\](.*?)\[KO\]', 99 | lambda x: korean_to_ipa(x.group(1))+' ', text) 100 | text = re.sub(r'\[EN\](.*?)\[EN\]', 101 | lambda x: english_to_ipa2(x.group(1))+' ', text) 102 | text = re.sub(r'\s+$', '', text) 103 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 104 | return text 105 | 106 | 107 | def thai_cleaners(text): 108 | text = num_to_thai(text) 109 | text = latin_to_thai(text) 110 | return text 111 | 112 | 113 | # def shanghainese_cleaners(text): 114 | # text = shanghainese_to_ipa(text) 115 | # text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 116 | # return text 117 | 118 | 119 | # def chinese_dialect_cleaners(text): 120 | # text = re.sub(r'\[ZH\](.*?)\[ZH\]', 121 | # lambda x: chinese_to_ipa2(x.group(1))+' ', text) 122 | # text = re.sub(r'\[JA\](.*?)\[JA\]', 123 | # lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text) 124 | # text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5', 125 | # '˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text) 126 | # text = re.sub(r'\[GD\](.*?)\[GD\]', 127 | # lambda x: cantonese_to_ipa(x.group(1))+' ', text) 128 | # text = re.sub(r'\[EN\](.*?)\[EN\]', 129 | # lambda x: english_to_lazy_ipa2(x.group(1))+' ', text) 130 | # text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group( 131 | # 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text) 132 | # text = re.sub(r'\s+$', '', text) 133 | # text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 134 | # return text 135 | -------------------------------------------------------------------------------- /VITS/text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | import inflect 21 | from unidecode import unidecode 22 | import eng_to_ipa as ipa 23 | _inflect = inflect.engine() 24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 29 | _number_re = re.compile(r'[0-9]+') 30 | 31 | # List of (regular expression, replacement) pairs for abbreviations: 32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 33 | ('mrs', 'misess'), 34 | ('mr', 'mister'), 35 | ('dr', 'doctor'), 36 | ('st', 'saint'), 37 | ('co', 'company'), 38 | ('jr', 'junior'), 39 | ('maj', 'major'), 40 | ('gen', 'general'), 41 | ('drs', 'doctors'), 42 | ('rev', 'reverend'), 43 | ('lt', 'lieutenant'), 44 | ('hon', 'honorable'), 45 | ('sgt', 'sergeant'), 46 | ('capt', 'captain'), 47 | ('esq', 'esquire'), 48 | ('ltd', 'limited'), 49 | ('col', 'colonel'), 50 | ('ft', 'fort'), 51 | ]] 52 | 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | -------------------------------------------------------------------------------- /VITS/text/japanese.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unidecode import unidecode 3 | import pyopenjtalk 4 | 5 | 6 | # Regular expression matching Japanese without punctuation marks: 7 | _japanese_characters = re.compile( 8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 9 | 10 | # Regular expression matching non-Japanese characters or punctuation marks: 11 | _japanese_marks = re.compile( 12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 13 | 14 | # List of (symbol, Japanese) pairs for marks: 15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 16 | ('%', 'パーセント') 17 | ]] 18 | 19 | # List of (romaji, ipa) pairs for marks: 20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 21 | ('ts', 'ʦ'), 22 | ('u', 'ɯ'), 23 | ('j', 'ʥ'), 24 | ('y', 'j'), 25 | ('ni', 'n^i'), 26 | ('nj', 'n^'), 27 | ('hi', 'çi'), 28 | ('hj', 'ç'), 29 | ('f', 'ɸ'), 30 | ('I', 'i*'), 31 | ('U', 'ɯ*'), 32 | ('r', 'ɾ') 33 | ]] 34 | 35 | # List of (romaji, ipa2) pairs for marks: 36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 37 | ('u', 'ɯ'), 38 | ('ʧ', 'tʃ'), 39 | ('j', 'dʑ'), 40 | ('y', 'j'), 41 | ('ni', 'n^i'), 42 | ('nj', 'n^'), 43 | ('hi', 'çi'), 44 | ('hj', 'ç'), 45 | ('f', 'ɸ'), 46 | ('I', 'i*'), 47 | ('U', 'ɯ*'), 48 | ('r', 'ɾ') 49 | ]] 50 | 51 | # List of (consonant, sokuon) pairs: 52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 53 | (r'Q([↑↓]*[kg])', r'k#\1'), 54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 55 | (r'Q([↑↓]*[sʃ])', r's\1'), 56 | (r'Q([↑↓]*[pb])', r'p#\1') 57 | ]] 58 | 59 | # List of (consonant, hatsuon) pairs: 60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 61 | (r'N([↑↓]*[pbm])', r'm\1'), 62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 63 | (r'N([↑↓]*[tdn])', r'n\1'), 64 | (r'N([↑↓]*[kg])', r'ŋ\1') 65 | ]] 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def japanese_to_romaji_with_accent(text): 75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 76 | text = symbols_to_japanese(text) 77 | sentences = re.split(_japanese_marks, text) 78 | marks = re.findall(_japanese_marks, text) 79 | text = '' 80 | for i, sentence in enumerate(sentences): 81 | if re.match(_japanese_characters, sentence): 82 | if text != '': 83 | text += ' ' 84 | labels = pyopenjtalk.extract_fullcontext(sentence) 85 | for n, label in enumerate(labels): 86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 87 | if phoneme not in ['sil', 'pau']: 88 | text += phoneme.replace('ch', 'ʧ').replace('sh', 89 | 'ʃ').replace('cl', 'Q') 90 | else: 91 | continue 92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 95 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: 97 | a2_next = -1 98 | else: 99 | a2_next = int( 100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 101 | # Accent phrase boundary 102 | if a3 == 1 and a2_next == 1: 103 | text += ' ' 104 | # Falling 105 | elif a1 == 0 and a2_next == a2 + 1: 106 | text += '↓' 107 | # Rising 108 | elif a2 == 1 and a2_next == 2: 109 | text += '↑' 110 | if i < len(marks): 111 | text += unidecode(marks[i]).replace(' ', '') 112 | return text 113 | 114 | 115 | def get_real_sokuon(text): 116 | for regex, replacement in _real_sokuon: 117 | text = re.sub(regex, replacement, text) 118 | return text 119 | 120 | 121 | def get_real_hatsuon(text): 122 | for regex, replacement in _real_hatsuon: 123 | text = re.sub(regex, replacement, text) 124 | return text 125 | 126 | 127 | def japanese_to_ipa(text): 128 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 129 | text = re.sub( 130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 131 | text = get_real_sokuon(text) 132 | text = get_real_hatsuon(text) 133 | for regex, replacement in _romaji_to_ipa: 134 | text = re.sub(regex, replacement, text) 135 | return text 136 | 137 | 138 | def japanese_to_ipa2(text): 139 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 140 | text = get_real_sokuon(text) 141 | text = get_real_hatsuon(text) 142 | for regex, replacement in _romaji_to_ipa2: 143 | text = re.sub(regex, replacement, text) 144 | return text 145 | 146 | 147 | def japanese_to_ipa3(text): 148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') 150 | text = re.sub( 151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) 153 | return text 154 | -------------------------------------------------------------------------------- /VITS/text/korean.py: -------------------------------------------------------------------------------- 1 | import re 2 | from jamo import h2j, j2hcj 3 | import ko_pron 4 | 5 | 6 | # This is a list of Korean classifiers preceded by pure Korean numerals. 7 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' 8 | 9 | # List of (hangul, hangul divided) pairs: 10 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ 11 | ('ㄳ', 'ㄱㅅ'), 12 | ('ㄵ', 'ㄴㅈ'), 13 | ('ㄶ', 'ㄴㅎ'), 14 | ('ㄺ', 'ㄹㄱ'), 15 | ('ㄻ', 'ㄹㅁ'), 16 | ('ㄼ', 'ㄹㅂ'), 17 | ('ㄽ', 'ㄹㅅ'), 18 | ('ㄾ', 'ㄹㅌ'), 19 | ('ㄿ', 'ㄹㅍ'), 20 | ('ㅀ', 'ㄹㅎ'), 21 | ('ㅄ', 'ㅂㅅ'), 22 | ('ㅘ', 'ㅗㅏ'), 23 | ('ㅙ', 'ㅗㅐ'), 24 | ('ㅚ', 'ㅗㅣ'), 25 | ('ㅝ', 'ㅜㅓ'), 26 | ('ㅞ', 'ㅜㅔ'), 27 | ('ㅟ', 'ㅜㅣ'), 28 | ('ㅢ', 'ㅡㅣ'), 29 | ('ㅑ', 'ㅣㅏ'), 30 | ('ㅒ', 'ㅣㅐ'), 31 | ('ㅕ', 'ㅣㅓ'), 32 | ('ㅖ', 'ㅣㅔ'), 33 | ('ㅛ', 'ㅣㅗ'), 34 | ('ㅠ', 'ㅣㅜ') 35 | ]] 36 | 37 | # List of (Latin alphabet, hangul) pairs: 38 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 39 | ('a', '에이'), 40 | ('b', '비'), 41 | ('c', '시'), 42 | ('d', '디'), 43 | ('e', '이'), 44 | ('f', '에프'), 45 | ('g', '지'), 46 | ('h', '에이치'), 47 | ('i', '아이'), 48 | ('j', '제이'), 49 | ('k', '케이'), 50 | ('l', '엘'), 51 | ('m', '엠'), 52 | ('n', '엔'), 53 | ('o', '오'), 54 | ('p', '피'), 55 | ('q', '큐'), 56 | ('r', '아르'), 57 | ('s', '에스'), 58 | ('t', '티'), 59 | ('u', '유'), 60 | ('v', '브이'), 61 | ('w', '더블유'), 62 | ('x', '엑스'), 63 | ('y', '와이'), 64 | ('z', '제트') 65 | ]] 66 | 67 | # List of (ipa, lazy ipa) pairs: 68 | _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 69 | ('t͡ɕ','ʧ'), 70 | ('d͡ʑ','ʥ'), 71 | ('ɲ','n^'), 72 | ('ɕ','ʃ'), 73 | ('ʷ','w'), 74 | ('ɭ','l`'), 75 | ('ʎ','ɾ'), 76 | ('ɣ','ŋ'), 77 | ('ɰ','ɯ'), 78 | ('ʝ','j'), 79 | ('ʌ','ə'), 80 | ('ɡ','g'), 81 | ('\u031a','#'), 82 | ('\u0348','='), 83 | ('\u031e',''), 84 | ('\u0320',''), 85 | ('\u0339','') 86 | ]] 87 | 88 | 89 | def latin_to_hangul(text): 90 | for regex, replacement in _latin_to_hangul: 91 | text = re.sub(regex, replacement, text) 92 | return text 93 | 94 | 95 | def divide_hangul(text): 96 | text = j2hcj(h2j(text)) 97 | for regex, replacement in _hangul_divided: 98 | text = re.sub(regex, replacement, text) 99 | return text 100 | 101 | 102 | def hangul_number(num, sino=True): 103 | '''Reference https://github.com/Kyubyong/g2pK''' 104 | num = re.sub(',', '', num) 105 | 106 | if num == '0': 107 | return '영' 108 | if not sino and num == '20': 109 | return '스무' 110 | 111 | digits = '123456789' 112 | names = '일이삼사오육칠팔구' 113 | digit2name = {d: n for d, n in zip(digits, names)} 114 | 115 | modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉' 116 | decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔' 117 | digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} 118 | digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} 119 | 120 | spelledout = [] 121 | for i, digit in enumerate(num): 122 | i = len(num) - i - 1 123 | if sino: 124 | if i == 0: 125 | name = digit2name.get(digit, '') 126 | elif i == 1: 127 | name = digit2name.get(digit, '') + '십' 128 | name = name.replace('일십', '십') 129 | else: 130 | if i == 0: 131 | name = digit2mod.get(digit, '') 132 | elif i == 1: 133 | name = digit2dec.get(digit, '') 134 | if digit == '0': 135 | if i % 4 == 0: 136 | last_three = spelledout[-min(3, len(spelledout)):] 137 | if ''.join(last_three) == '': 138 | spelledout.append('') 139 | continue 140 | else: 141 | spelledout.append('') 142 | continue 143 | if i == 2: 144 | name = digit2name.get(digit, '') + '백' 145 | name = name.replace('일백', '백') 146 | elif i == 3: 147 | name = digit2name.get(digit, '') + '천' 148 | name = name.replace('일천', '천') 149 | elif i == 4: 150 | name = digit2name.get(digit, '') + '만' 151 | name = name.replace('일만', '만') 152 | elif i == 5: 153 | name = digit2name.get(digit, '') + '십' 154 | name = name.replace('일십', '십') 155 | elif i == 6: 156 | name = digit2name.get(digit, '') + '백' 157 | name = name.replace('일백', '백') 158 | elif i == 7: 159 | name = digit2name.get(digit, '') + '천' 160 | name = name.replace('일천', '천') 161 | elif i == 8: 162 | name = digit2name.get(digit, '') + '억' 163 | elif i == 9: 164 | name = digit2name.get(digit, '') + '십' 165 | elif i == 10: 166 | name = digit2name.get(digit, '') + '백' 167 | elif i == 11: 168 | name = digit2name.get(digit, '') + '천' 169 | elif i == 12: 170 | name = digit2name.get(digit, '') + '조' 171 | elif i == 13: 172 | name = digit2name.get(digit, '') + '십' 173 | elif i == 14: 174 | name = digit2name.get(digit, '') + '백' 175 | elif i == 15: 176 | name = digit2name.get(digit, '') + '천' 177 | spelledout.append(name) 178 | return ''.join(elem for elem in spelledout) 179 | 180 | 181 | def number_to_hangul(text): 182 | '''Reference https://github.com/Kyubyong/g2pK''' 183 | tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text)) 184 | for token in tokens: 185 | num, classifier = token 186 | if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: 187 | spelledout = hangul_number(num, sino=False) 188 | else: 189 | spelledout = hangul_number(num, sino=True) 190 | text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}') 191 | # digit by digit for remaining digits 192 | digits = '0123456789' 193 | names = '영일이삼사오육칠팔구' 194 | for d, n in zip(digits, names): 195 | text = text.replace(d, n) 196 | return text 197 | 198 | 199 | def korean_to_lazy_ipa(text): 200 | text = latin_to_hangul(text) 201 | text = number_to_hangul(text) 202 | text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text) 203 | for regex, replacement in _ipa_to_lazy_ipa: 204 | text = re.sub(regex, replacement, text) 205 | return text 206 | 207 | 208 | def korean_to_ipa(text): 209 | text = korean_to_lazy_ipa(text) 210 | return text.replace('ʧ','tʃ').replace('ʥ','dʑ') 211 | -------------------------------------------------------------------------------- /VITS/text/mandarin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from pypinyin import lazy_pinyin, BOPOMOFO 5 | import jieba 6 | import cn2an 7 | import logging 8 | 9 | 10 | # List of (Latin alphabet, bopomofo) pairs: 11 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 12 | ('a', 'ㄟˉ'), 13 | ('b', 'ㄅㄧˋ'), 14 | ('c', 'ㄙㄧˉ'), 15 | ('d', 'ㄉㄧˋ'), 16 | ('e', 'ㄧˋ'), 17 | ('f', 'ㄝˊㄈㄨˋ'), 18 | ('g', 'ㄐㄧˋ'), 19 | ('h', 'ㄝˇㄑㄩˋ'), 20 | ('i', 'ㄞˋ'), 21 | ('j', 'ㄐㄟˋ'), 22 | ('k', 'ㄎㄟˋ'), 23 | ('l', 'ㄝˊㄛˋ'), 24 | ('m', 'ㄝˊㄇㄨˋ'), 25 | ('n', 'ㄣˉ'), 26 | ('o', 'ㄡˉ'), 27 | ('p', 'ㄆㄧˉ'), 28 | ('q', 'ㄎㄧㄡˉ'), 29 | ('r', 'ㄚˋ'), 30 | ('s', 'ㄝˊㄙˋ'), 31 | ('t', 'ㄊㄧˋ'), 32 | ('u', 'ㄧㄡˉ'), 33 | ('v', 'ㄨㄧˉ'), 34 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 35 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 36 | ('y', 'ㄨㄞˋ'), 37 | ('z', 'ㄗㄟˋ') 38 | ]] 39 | 40 | # List of (bopomofo, romaji) pairs: 41 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 42 | ('ㄅㄛ', 'p⁼wo'), 43 | ('ㄆㄛ', 'pʰwo'), 44 | ('ㄇㄛ', 'mwo'), 45 | ('ㄈㄛ', 'fwo'), 46 | ('ㄅ', 'p⁼'), 47 | ('ㄆ', 'pʰ'), 48 | ('ㄇ', 'm'), 49 | ('ㄈ', 'f'), 50 | ('ㄉ', 't⁼'), 51 | ('ㄊ', 'tʰ'), 52 | ('ㄋ', 'n'), 53 | ('ㄌ', 'l'), 54 | ('ㄍ', 'k⁼'), 55 | ('ㄎ', 'kʰ'), 56 | ('ㄏ', 'h'), 57 | ('ㄐ', 'ʧ⁼'), 58 | ('ㄑ', 'ʧʰ'), 59 | ('ㄒ', 'ʃ'), 60 | ('ㄓ', 'ʦ`⁼'), 61 | ('ㄔ', 'ʦ`ʰ'), 62 | ('ㄕ', 's`'), 63 | ('ㄖ', 'ɹ`'), 64 | ('ㄗ', 'ʦ⁼'), 65 | ('ㄘ', 'ʦʰ'), 66 | ('ㄙ', 's'), 67 | ('ㄚ', 'a'), 68 | ('ㄛ', 'o'), 69 | ('ㄜ', 'ə'), 70 | ('ㄝ', 'e'), 71 | ('ㄞ', 'ai'), 72 | ('ㄟ', 'ei'), 73 | ('ㄠ', 'au'), 74 | ('ㄡ', 'ou'), 75 | ('ㄧㄢ', 'yeNN'), 76 | ('ㄢ', 'aNN'), 77 | ('ㄧㄣ', 'iNN'), 78 | ('ㄣ', 'əNN'), 79 | ('ㄤ', 'aNg'), 80 | ('ㄧㄥ', 'iNg'), 81 | ('ㄨㄥ', 'uNg'), 82 | ('ㄩㄥ', 'yuNg'), 83 | ('ㄥ', 'əNg'), 84 | ('ㄦ', 'əɻ'), 85 | ('ㄧ', 'i'), 86 | ('ㄨ', 'u'), 87 | ('ㄩ', 'ɥ'), 88 | ('ˉ', '→'), 89 | ('ˊ', '↑'), 90 | ('ˇ', '↓↑'), 91 | ('ˋ', '↓'), 92 | ('˙', ''), 93 | (',', ','), 94 | ('。', '.'), 95 | ('!', '!'), 96 | ('?', '?'), 97 | ('—', '-') 98 | ]] 99 | 100 | # List of (romaji, ipa) pairs: 101 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 102 | ('ʃy', 'ʃ'), 103 | ('ʧʰy', 'ʧʰ'), 104 | ('ʧ⁼y', 'ʧ⁼'), 105 | ('NN', 'n'), 106 | ('Ng', 'ŋ'), 107 | ('y', 'j'), 108 | ('h', 'x') 109 | ]] 110 | 111 | # List of (bopomofo, ipa) pairs: 112 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 113 | ('ㄅㄛ', 'p⁼wo'), 114 | ('ㄆㄛ', 'pʰwo'), 115 | ('ㄇㄛ', 'mwo'), 116 | ('ㄈㄛ', 'fwo'), 117 | ('ㄅ', 'p⁼'), 118 | ('ㄆ', 'pʰ'), 119 | ('ㄇ', 'm'), 120 | ('ㄈ', 'f'), 121 | ('ㄉ', 't⁼'), 122 | ('ㄊ', 'tʰ'), 123 | ('ㄋ', 'n'), 124 | ('ㄌ', 'l'), 125 | ('ㄍ', 'k⁼'), 126 | ('ㄎ', 'kʰ'), 127 | ('ㄏ', 'x'), 128 | ('ㄐ', 'tʃ⁼'), 129 | ('ㄑ', 'tʃʰ'), 130 | ('ㄒ', 'ʃ'), 131 | ('ㄓ', 'ts`⁼'), 132 | ('ㄔ', 'ts`ʰ'), 133 | ('ㄕ', 's`'), 134 | ('ㄖ', 'ɹ`'), 135 | ('ㄗ', 'ts⁼'), 136 | ('ㄘ', 'tsʰ'), 137 | ('ㄙ', 's'), 138 | ('ㄚ', 'a'), 139 | ('ㄛ', 'o'), 140 | ('ㄜ', 'ə'), 141 | ('ㄝ', 'ɛ'), 142 | ('ㄞ', 'aɪ'), 143 | ('ㄟ', 'eɪ'), 144 | ('ㄠ', 'ɑʊ'), 145 | ('ㄡ', 'oʊ'), 146 | ('ㄧㄢ', 'jɛn'), 147 | ('ㄩㄢ', 'ɥæn'), 148 | ('ㄢ', 'an'), 149 | ('ㄧㄣ', 'in'), 150 | ('ㄩㄣ', 'ɥn'), 151 | ('ㄣ', 'ən'), 152 | ('ㄤ', 'ɑŋ'), 153 | ('ㄧㄥ', 'iŋ'), 154 | ('ㄨㄥ', 'ʊŋ'), 155 | ('ㄩㄥ', 'jʊŋ'), 156 | ('ㄥ', 'əŋ'), 157 | ('ㄦ', 'əɻ'), 158 | ('ㄧ', 'i'), 159 | ('ㄨ', 'u'), 160 | ('ㄩ', 'ɥ'), 161 | ('ˉ', '→'), 162 | ('ˊ', '↑'), 163 | ('ˇ', '↓↑'), 164 | ('ˋ', '↓'), 165 | ('˙', ''), 166 | (',', ','), 167 | ('。', '.'), 168 | ('!', '!'), 169 | ('?', '?'), 170 | ('—', '-') 171 | ]] 172 | 173 | # List of (bopomofo, ipa2) pairs: 174 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 175 | ('ㄅㄛ', 'pwo'), 176 | ('ㄆㄛ', 'pʰwo'), 177 | ('ㄇㄛ', 'mwo'), 178 | ('ㄈㄛ', 'fwo'), 179 | ('ㄅ', 'p'), 180 | ('ㄆ', 'pʰ'), 181 | ('ㄇ', 'm'), 182 | ('ㄈ', 'f'), 183 | ('ㄉ', 't'), 184 | ('ㄊ', 'tʰ'), 185 | ('ㄋ', 'n'), 186 | ('ㄌ', 'l'), 187 | ('ㄍ', 'k'), 188 | ('ㄎ', 'kʰ'), 189 | ('ㄏ', 'h'), 190 | ('ㄐ', 'tɕ'), 191 | ('ㄑ', 'tɕʰ'), 192 | ('ㄒ', 'ɕ'), 193 | ('ㄓ', 'tʂ'), 194 | ('ㄔ', 'tʂʰ'), 195 | ('ㄕ', 'ʂ'), 196 | ('ㄖ', 'ɻ'), 197 | ('ㄗ', 'ts'), 198 | ('ㄘ', 'tsʰ'), 199 | ('ㄙ', 's'), 200 | ('ㄚ', 'a'), 201 | ('ㄛ', 'o'), 202 | ('ㄜ', 'ɤ'), 203 | ('ㄝ', 'ɛ'), 204 | ('ㄞ', 'aɪ'), 205 | ('ㄟ', 'eɪ'), 206 | ('ㄠ', 'ɑʊ'), 207 | ('ㄡ', 'oʊ'), 208 | ('ㄧㄢ', 'jɛn'), 209 | ('ㄩㄢ', 'yæn'), 210 | ('ㄢ', 'an'), 211 | ('ㄧㄣ', 'in'), 212 | ('ㄩㄣ', 'yn'), 213 | ('ㄣ', 'ən'), 214 | ('ㄤ', 'ɑŋ'), 215 | ('ㄧㄥ', 'iŋ'), 216 | ('ㄨㄥ', 'ʊŋ'), 217 | ('ㄩㄥ', 'jʊŋ'), 218 | ('ㄥ', 'ɤŋ'), 219 | ('ㄦ', 'əɻ'), 220 | ('ㄧ', 'i'), 221 | ('ㄨ', 'u'), 222 | ('ㄩ', 'y'), 223 | ('ˉ', '˥'), 224 | ('ˊ', '˧˥'), 225 | ('ˇ', '˨˩˦'), 226 | ('ˋ', '˥˩'), 227 | ('˙', ''), 228 | (',', ','), 229 | ('。', '.'), 230 | ('!', '!'), 231 | ('?', '?'), 232 | ('—', '-') 233 | ]] 234 | 235 | 236 | def number_to_chinese(text): 237 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 238 | for number in numbers: 239 | text = text.replace(number, cn2an.an2cn(number), 1) 240 | return text 241 | 242 | 243 | def chinese_to_bopomofo(text): 244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',') 245 | words = jieba.lcut(text, cut_all=False) 246 | text = '' 247 | for word in words: 248 | bopomofos = lazy_pinyin(word, BOPOMOFO) 249 | if not re.search('[\u4e00-\u9fff]', word): 250 | text += word 251 | continue 252 | for i in range(len(bopomofos)): 253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) 254 | if text != '': 255 | text += ' ' 256 | text += ''.join(bopomofos) 257 | return text 258 | 259 | 260 | def latin_to_bopomofo(text): 261 | for regex, replacement in _latin_to_bopomofo: 262 | text = re.sub(regex, replacement, text) 263 | return text 264 | 265 | 266 | def bopomofo_to_romaji(text): 267 | for regex, replacement in _bopomofo_to_romaji: 268 | text = re.sub(regex, replacement, text) 269 | return text 270 | 271 | 272 | def bopomofo_to_ipa(text): 273 | for regex, replacement in _bopomofo_to_ipa: 274 | text = re.sub(regex, replacement, text) 275 | return text 276 | 277 | 278 | def bopomofo_to_ipa2(text): 279 | for regex, replacement in _bopomofo_to_ipa2: 280 | text = re.sub(regex, replacement, text) 281 | return text 282 | 283 | 284 | def chinese_to_romaji(text): 285 | text = number_to_chinese(text) 286 | text = chinese_to_bopomofo(text) 287 | text = latin_to_bopomofo(text) 288 | text = bopomofo_to_romaji(text) 289 | text = re.sub('i([aoe])', r'y\1', text) 290 | text = re.sub('u([aoəe])', r'w\1', text) 291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 294 | return text 295 | 296 | 297 | def chinese_to_lazy_ipa(text): 298 | text = chinese_to_romaji(text) 299 | for regex, replacement in _romaji_to_ipa: 300 | text = re.sub(regex, replacement, text) 301 | return text 302 | 303 | 304 | def chinese_to_ipa(text): 305 | text = number_to_chinese(text) 306 | text = chinese_to_bopomofo(text) 307 | text = latin_to_bopomofo(text) 308 | text = bopomofo_to_ipa(text) 309 | text = re.sub('i([aoe])', r'j\1', text) 310 | text = re.sub('u([aoəe])', r'w\1', text) 311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 314 | return text 315 | 316 | 317 | def chinese_to_ipa2(text): 318 | text = number_to_chinese(text) 319 | text = chinese_to_bopomofo(text) 320 | text = latin_to_bopomofo(text) 321 | text = bopomofo_to_ipa2(text) 322 | text = re.sub(r'i([aoe])', r'j\1', text) 323 | text = re.sub(r'u([aoəe])', r'w\1', text) 324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 326 | return text 327 | -------------------------------------------------------------------------------- /VITS/text/ngu_dialect.py: -------------------------------------------------------------------------------- 1 | import re 2 | import opencc 3 | 4 | 5 | dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou', 6 | 'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing', 7 | 'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang', 8 | 'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan', 9 | 'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen', 10 | 'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'} 11 | 12 | converters = {} 13 | 14 | for dialect in dialects.values(): 15 | try: 16 | converters[dialect] = opencc.OpenCC(dialect) 17 | except: 18 | pass 19 | 20 | 21 | def ngu_dialect_to_ipa(text, dialect): 22 | dialect = dialects[dialect] 23 | text = converters[dialect].convert(text).replace('-','').replace('$',' ') 24 | text = re.sub(r'[、;:]', ',', text) 25 | text = re.sub(r'\s*,\s*', ', ', text) 26 | text = re.sub(r'\s*。\s*', '. ', text) 27 | text = re.sub(r'\s*?\s*', '? ', text) 28 | text = re.sub(r'\s*!\s*', '! ', text) 29 | text = re.sub(r'\s*$', '', text) 30 | return text 31 | -------------------------------------------------------------------------------- /VITS/text/sanskrit.py: -------------------------------------------------------------------------------- 1 | import re 2 | from indic_transliteration import sanscript 3 | 4 | 5 | # List of (iast, ipa) pairs: 6 | _iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 7 | ('a', 'ə'), 8 | ('ā', 'aː'), 9 | ('ī', 'iː'), 10 | ('ū', 'uː'), 11 | ('ṛ', 'ɹ`'), 12 | ('ṝ', 'ɹ`ː'), 13 | ('ḷ', 'l`'), 14 | ('ḹ', 'l`ː'), 15 | ('e', 'eː'), 16 | ('o', 'oː'), 17 | ('k', 'k⁼'), 18 | ('k⁼h', 'kʰ'), 19 | ('g', 'g⁼'), 20 | ('g⁼h', 'gʰ'), 21 | ('ṅ', 'ŋ'), 22 | ('c', 'ʧ⁼'), 23 | ('ʧ⁼h', 'ʧʰ'), 24 | ('j', 'ʥ⁼'), 25 | ('ʥ⁼h', 'ʥʰ'), 26 | ('ñ', 'n^'), 27 | ('ṭ', 't`⁼'), 28 | ('t`⁼h', 't`ʰ'), 29 | ('ḍ', 'd`⁼'), 30 | ('d`⁼h', 'd`ʰ'), 31 | ('ṇ', 'n`'), 32 | ('t', 't⁼'), 33 | ('t⁼h', 'tʰ'), 34 | ('d', 'd⁼'), 35 | ('d⁼h', 'dʰ'), 36 | ('p', 'p⁼'), 37 | ('p⁼h', 'pʰ'), 38 | ('b', 'b⁼'), 39 | ('b⁼h', 'bʰ'), 40 | ('y', 'j'), 41 | ('ś', 'ʃ'), 42 | ('ṣ', 's`'), 43 | ('r', 'ɾ'), 44 | ('l̤', 'l`'), 45 | ('h', 'ɦ'), 46 | ("'", ''), 47 | ('~', '^'), 48 | ('ṃ', '^') 49 | ]] 50 | 51 | 52 | def devanagari_to_ipa(text): 53 | text = text.replace('ॐ', 'ओम्') 54 | text = re.sub(r'\s*।\s*$', '.', text) 55 | text = re.sub(r'\s*।\s*', ', ', text) 56 | text = re.sub(r'\s*॥', '.', text) 57 | text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST) 58 | for regex, replacement in _iast_to_ipa: 59 | text = re.sub(regex, replacement, text) 60 | text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0) 61 | [:-1]+'h'+x.group(1)+'*', text) 62 | return text 63 | -------------------------------------------------------------------------------- /VITS/text/shanghainese.py: -------------------------------------------------------------------------------- 1 | import re 2 | import cn2an 3 | import opencc 4 | 5 | 6 | converter = opencc.OpenCC('zaonhe') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ᴇ'), 11 | ('B', 'bi'), 12 | ('C', 'si'), 13 | ('D', 'di'), 14 | ('E', 'i'), 15 | ('F', 'ᴇf'), 16 | ('G', 'dʑi'), 17 | ('H', 'ᴇtɕʰ'), 18 | ('I', 'ᴀi'), 19 | ('J', 'dʑᴇ'), 20 | ('K', 'kʰᴇ'), 21 | ('L', 'ᴇl'), 22 | ('M', 'ᴇm'), 23 | ('N', 'ᴇn'), 24 | ('O', 'o'), 25 | ('P', 'pʰi'), 26 | ('Q', 'kʰiu'), 27 | ('R', 'ᴀl'), 28 | ('S', 'ᴇs'), 29 | ('T', 'tʰi'), 30 | ('U', 'ɦiu'), 31 | ('V', 'vi'), 32 | ('W', 'dᴀbɤliu'), 33 | ('X', 'ᴇks'), 34 | ('Y', 'uᴀi'), 35 | ('Z', 'zᴇ') 36 | ]] 37 | 38 | 39 | def _number_to_shanghainese(num): 40 | num = cn2an.an2cn(num).replace('一十','十').replace('二十', '廿').replace('二', '两') 41 | return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num) 42 | 43 | 44 | def number_to_shanghainese(text): 45 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text) 46 | 47 | 48 | def latin_to_ipa(text): 49 | for regex, replacement in _latin_to_ipa: 50 | text = re.sub(regex, replacement, text) 51 | return text 52 | 53 | 54 | def shanghainese_to_ipa(text): 55 | text = number_to_shanghainese(text.upper()) 56 | text = converter.convert(text).replace('-','').replace('$',' ') 57 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text) 58 | text = re.sub(r'[、;:]', ',', text) 59 | text = re.sub(r'\s*,\s*', ', ', text) 60 | text = re.sub(r'\s*。\s*', '. ', text) 61 | text = re.sub(r'\s*?\s*', '? ', text) 62 | text = re.sub(r'\s*!\s*', '! ', text) 63 | text = re.sub(r'\s*$', '', text) 64 | return text 65 | -------------------------------------------------------------------------------- /VITS/text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | # japanese_cleaners 6 | # _pad = '_' 7 | # _punctuation = ',.!?-' 8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' 9 | 10 | 11 | '''# japanese_cleaners2 12 | _pad = '_' 13 | _punctuation = ',.!?-~…' 14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' 15 | ''' 16 | 17 | 18 | '''# korean_cleaners 19 | _pad = '_' 20 | _punctuation = ',.!?…~' 21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' 22 | ''' 23 | 24 | '''# chinese_cleaners 25 | _pad = '_' 26 | _punctuation = ',。!?—…' 27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' 28 | ''' 29 | 30 | # # zh_ja_mixture_cleaners 31 | # _pad = '_' 32 | # _punctuation = ',.!?-~…' 33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' 34 | 35 | 36 | '''# sanskrit_cleaners 37 | _pad = '_' 38 | _punctuation = '।' 39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' 40 | ''' 41 | 42 | '''# cjks_cleaners 43 | _pad = '_' 44 | _punctuation = ',.!?-~…' 45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' 46 | ''' 47 | 48 | '''# thai_cleaners 49 | _pad = '_' 50 | _punctuation = '.!? ' 51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' 52 | ''' 53 | 54 | # # cjke_cleaners2 55 | _pad = '_' 56 | _punctuation = ',.!?-~…' 57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 58 | 59 | 60 | '''# shanghainese_cleaners 61 | _pad = '_' 62 | _punctuation = ',.!?…' 63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' 64 | ''' 65 | 66 | '''# chinese_dialect_cleaners 67 | _pad = '_' 68 | _punctuation = ',.!?~…─' 69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' 70 | ''' 71 | 72 | # Export all symbols: 73 | symbols = [_pad] + list(_punctuation) + list(_letters) 74 | 75 | # Special symbol ids 76 | SPACE_ID = symbols.index(" ") 77 | -------------------------------------------------------------------------------- /VITS/text/thai.py: -------------------------------------------------------------------------------- 1 | import re 2 | from num_thai.thainumbers import NumThai 3 | 4 | 5 | num = NumThai() 6 | 7 | # List of (Latin alphabet, Thai) pairs: 8 | _latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 9 | ('a', 'เอ'), 10 | ('b','บี'), 11 | ('c','ซี'), 12 | ('d','ดี'), 13 | ('e','อี'), 14 | ('f','เอฟ'), 15 | ('g','จี'), 16 | ('h','เอช'), 17 | ('i','ไอ'), 18 | ('j','เจ'), 19 | ('k','เค'), 20 | ('l','แอล'), 21 | ('m','เอ็ม'), 22 | ('n','เอ็น'), 23 | ('o','โอ'), 24 | ('p','พี'), 25 | ('q','คิว'), 26 | ('r','แอร์'), 27 | ('s','เอส'), 28 | ('t','ที'), 29 | ('u','ยู'), 30 | ('v','วี'), 31 | ('w','ดับเบิลยู'), 32 | ('x','เอ็กซ์'), 33 | ('y','วาย'), 34 | ('z','ซี') 35 | ]] 36 | 37 | 38 | def num_to_thai(text): 39 | return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text) 40 | 41 | def latin_to_thai(text): 42 | for regex, replacement in _latin_to_thai: 43 | text = re.sub(regex, replacement, text) 44 | return text 45 | -------------------------------------------------------------------------------- /VITS/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /VITS/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | #logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None, drop_speaker_emb=False): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | if k == 'emb_g.weight': 34 | if drop_speaker_emb: 35 | new_state_dict[k] = v 36 | continue 37 | v[:saved_state_dict[k].shape[0], :] = saved_state_dict[k] 38 | new_state_dict[k] = v 39 | else: 40 | new_state_dict[k] = saved_state_dict[k] 41 | except: 42 | logger.info("%s is not in the checkpoint" % k) 43 | new_state_dict[k] = v 44 | if hasattr(model, 'module'): 45 | model.module.load_state_dict(new_state_dict) 46 | else: 47 | model.load_state_dict(new_state_dict) 48 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 49 | checkpoint_path, iteration)) 50 | return model, optimizer, learning_rate, iteration 51 | 52 | 53 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 54 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 55 | iteration, checkpoint_path)) 56 | if hasattr(model, 'module'): 57 | state_dict = model.module.state_dict() 58 | else: 59 | state_dict = model.state_dict() 60 | torch.save({'model': state_dict, 61 | 'iteration': iteration, 62 | 'optimizer': optimizer.state_dict() if optimizer is not None else None, 63 | 'learning_rate': learning_rate}, checkpoint_path) 64 | 65 | 66 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 67 | for k, v in scalars.items(): 68 | writer.add_scalar(k, v, global_step) 69 | for k, v in histograms.items(): 70 | writer.add_histogram(k, v, global_step) 71 | for k, v in images.items(): 72 | writer.add_image(k, v, global_step, dataformats='HWC') 73 | for k, v in audios.items(): 74 | writer.add_audio(k, v, global_step, audio_sampling_rate) 75 | 76 | 77 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 78 | f_list = glob.glob(os.path.join(dir_path, regex)) 79 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 80 | x = f_list[-1] 81 | #print(x) 82 | return x 83 | 84 | 85 | def plot_spectrogram_to_numpy(spectrogram): 86 | global MATPLOTLIB_FLAG 87 | if not MATPLOTLIB_FLAG: 88 | import matplotlib 89 | matplotlib.use("Agg") 90 | MATPLOTLIB_FLAG = True 91 | #mpl_logger = logging.getLogger('matplotlib') 92 | #mpl_logger.setLevel(logging.WARNING) 93 | import matplotlib.pylab as plt 94 | import numpy as np 95 | 96 | fig, ax = plt.subplots(figsize=(10, 2)) 97 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 98 | interpolation='none') 99 | plt.colorbar(im, ax=ax) 100 | plt.xlabel("Frames") 101 | plt.ylabel("Channels") 102 | plt.tight_layout() 103 | 104 | fig.canvas.draw() 105 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 106 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 107 | plt.close() 108 | return data 109 | 110 | 111 | def plot_alignment_to_numpy(alignment, info=None): 112 | global MATPLOTLIB_FLAG 113 | if not MATPLOTLIB_FLAG: 114 | import matplotlib 115 | matplotlib.use("Agg") 116 | MATPLOTLIB_FLAG = True 117 | #mpl_logger = logging.getLogger('matplotlib') 118 | #mpl_logger.setLevel(logging.WARNING) 119 | import matplotlib.pylab as plt 120 | import numpy as np 121 | 122 | fig, ax = plt.subplots(figsize=(6, 4)) 123 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 124 | interpolation='none') 125 | fig.colorbar(im, ax=ax) 126 | xlabel = 'Decoder timestep' 127 | if info is not None: 128 | xlabel += '\n\n' + info 129 | plt.xlabel(xlabel) 130 | plt.ylabel('Encoder timestep') 131 | plt.tight_layout() 132 | 133 | fig.canvas.draw() 134 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 135 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 136 | plt.close() 137 | return data 138 | 139 | 140 | def load_wav_to_torch(full_path): 141 | sampling_rate, data = read(full_path) 142 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 143 | 144 | 145 | def load_filepaths_and_text(filename, split="|"): 146 | with open(filename, encoding='utf-8') as f: 147 | filepaths_and_text = [line.strip().split(split) for line in f] 148 | return filepaths_and_text 149 | 150 | 151 | def get_hparams(init=True): 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('-c', '--config', type=str, default="./configs/modified_finetune_speaker.json", 154 | help='JSON file for configuration') 155 | parser.add_argument('-m', '--model', type=str, default="pretrained_models", 156 | help='Model name') 157 | parser.add_argument('-n', '--max_epochs', type=int, default=50, 158 | help='finetune epochs') 159 | parser.add_argument('--drop_speaker_embed', type=bool, default=False, help='whether to drop existing characters') 160 | 161 | args = parser.parse_args() 162 | model_dir = os.path.join("./", args.model) 163 | 164 | if not os.path.exists(model_dir): 165 | os.makedirs(model_dir) 166 | 167 | config_path = args.config 168 | config_save_path = os.path.join(model_dir, "config.json") 169 | if init: 170 | with open(config_path, "r") as f: 171 | data = f.read() 172 | with open(config_save_path, "w") as f: 173 | f.write(data) 174 | else: 175 | with open(config_save_path, "r") as f: 176 | data = f.read() 177 | config = json.loads(data) 178 | 179 | hparams = HParams(**config) 180 | hparams.model_dir = model_dir 181 | hparams.max_epochs = args.max_epochs 182 | hparams.drop_speaker_embed = args.drop_speaker_embed 183 | return hparams 184 | 185 | 186 | def get_hparams_from_dir(model_dir): 187 | config_save_path = os.path.join(model_dir, "config.json") 188 | with open(config_save_path, "r") as f: 189 | data = f.read() 190 | config = json.loads(data) 191 | 192 | hparams = HParams(**config) 193 | hparams.model_dir = model_dir 194 | return hparams 195 | 196 | 197 | def get_hparams_from_file(config_path): 198 | with open(config_path, "r", encoding="utf-8") as f: 199 | data = f.read() 200 | config = json.loads(data) 201 | 202 | hparams = HParams(**config) 203 | return hparams 204 | 205 | 206 | def check_git_hash(model_dir): 207 | source_dir = os.path.dirname(os.path.realpath(__file__)) 208 | if not os.path.exists(os.path.join(source_dir, ".git")): 209 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 210 | source_dir 211 | )) 212 | return 213 | 214 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 215 | 216 | path = os.path.join(model_dir, "githash") 217 | if os.path.exists(path): 218 | saved_hash = open(path).read() 219 | if saved_hash != cur_hash: 220 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 221 | saved_hash[:8], cur_hash[:8])) 222 | else: 223 | open(path, "w").write(cur_hash) 224 | 225 | 226 | def get_logger(model_dir, filename="train.log"): 227 | global logger 228 | logger = logging.getLogger(os.path.basename(model_dir)) 229 | #logger.setLevel(logging.DEBUG) 230 | 231 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 232 | if not os.path.exists(model_dir): 233 | os.makedirs(model_dir) 234 | h = logging.FileHandler(os.path.join(model_dir, filename)) 235 | #h.setLevel(logging.DEBUG) 236 | h.setFormatter(formatter) 237 | logger.addHandler(h) 238 | return logger 239 | 240 | 241 | class HParams(): 242 | def __init__(self, **kwargs): 243 | for k, v in kwargs.items(): 244 | if type(v) == dict: 245 | v = HParams(**v) 246 | self[k] = v 247 | 248 | def keys(self): 249 | return self.__dict__.keys() 250 | 251 | def items(self): 252 | return self.__dict__.items() 253 | 254 | def values(self): 255 | return self.__dict__.values() 256 | 257 | def __len__(self): 258 | return len(self.__dict__) 259 | 260 | def __getitem__(self, key): 261 | return getattr(self, key) 262 | 263 | def __setitem__(self, key, value): 264 | return setattr(self, key, value) 265 | 266 | def __contains__(self, key): 267 | return key in self.__dict__ 268 | 269 | def __repr__(self): 270 | return self.__dict__.__repr__() -------------------------------------------------------------------------------- /baiduApi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | from sys import byteorder 3 | from array import array 4 | from struct import pack 5 | import time 6 | 7 | import pyaudio 8 | import wave 9 | import subprocess 10 | import urllib.request 11 | import urllib 12 | import json 13 | import base64 14 | import requests 15 | 16 | class BaiduRest: 17 | def __init__(self, cu_id, api_key, api_secert): 18 | self.token_url = "https://openapi.baidu.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s" 19 | self.getvoice_url = "http://tsn.baidu.com/text2audio?tex=%s&lan=zh&cuid=%s&ctp=1&tok=%s" 20 | self.upvoice_url = 'http://vop.baidu.com/server_api' 21 | self.cu_id = cu_id 22 | self.getToken(api_key, api_secert) 23 | self.THRESHOLD = 500 24 | self.CHUNK_SIZE = 1024 25 | self.FORMAT = pyaudio.paInt16 26 | self.RATE = 8000 27 | 28 | return 29 | 30 | def getToken(self, api_key, api_secert): 31 | token_url = self.token_url % (api_key, api_secert) 32 | 33 | r_str = urllib.request.urlopen(token_url).read() 34 | token_data = json.loads(r_str) 35 | self.token_str = token_data['access_token'] 36 | pass 37 | 38 | def getVoice(self, text, filename): 39 | get_url = self.getvoice_url % ( 40 | urllib.parse.quote(text), self.cu_id, self.token_str) 41 | 42 | voice_data = requests.post(get_url).content 43 | voice_fp = open(filename, 'wb+') 44 | voice_fp.write(voice_data) 45 | voice_fp.close() 46 | pass 47 | 48 | def getText(self, filename): 49 | data = {} 50 | data['format'] = 'wav' 51 | data['rate'] = 8000 52 | data['channel'] = 1 53 | data['cuid'] = self.cu_id 54 | data['token'] = self.token_str 55 | wav_fp = open(filename, 'rb') 56 | voice_data = wav_fp.read() 57 | data['len'] = len(voice_data) 58 | data['speech'] = base64.b64encode(voice_data).decode('utf-8') 59 | post_data = json.dumps(data) 60 | r_data = requests.post( 61 | self.upvoice_url, data=bytes( 62 | post_data, encoding="utf-8")).text 63 | # 3.处理返回数据 64 | 65 | return json.loads(r_data)['result'][0] 66 | 67 | # def speakMac(self, audio_file): 68 | # return_code = subprocess.call(["afplay", audio_file]) 69 | # return return_code 70 | 71 | # def speak(self, audio_file): 72 | # import mp3play 73 | # clip = mp3play.load(audio_file) 74 | # clip.play() 75 | # time.sleep(10) 76 | # clip.stop() 77 | 78 | def is_silent(self, snd_data): 79 | return max(snd_data) < self.THRESHOLD 80 | 81 | def normalize(self, snd_data): 82 | MAXIMUM = 16384 83 | times = float(MAXIMUM) / max(abs(i) for i in snd_data) 84 | 85 | r = array('h') 86 | for i in snd_data: 87 | r.append(int(i * times)) 88 | return r 89 | 90 | def trim(self, snd_data): 91 | def _trim(snd_data): 92 | snd_started = False 93 | r = array('h') 94 | 95 | for i in snd_data: 96 | if not snd_started and abs(i) > self.THRESHOLD: 97 | snd_started = True 98 | r.append(i) 99 | 100 | elif snd_started: 101 | r.append(i) 102 | return r 103 | 104 | # Trim to the left 105 | snd_data = _trim(snd_data) 106 | 107 | # Trim to the right 108 | snd_data.reverse() 109 | snd_data = _trim(snd_data) 110 | snd_data.reverse() 111 | return snd_data 112 | 113 | def add_silence(self, snd_data, seconds): 114 | r = array('h', [0 for i in range(int(seconds * self.RATE))]) 115 | r.extend(snd_data) 116 | r.extend([0 for i in range(int(seconds * self.RATE))]) 117 | return r 118 | 119 | def record(self):#录音 120 | p = pyaudio.PyAudio() 121 | stream = p.open(format=self.FORMAT, channels=1, rate=self.RATE, 122 | input=True, output=True, 123 | frames_per_buffer=self.CHUNK_SIZE) 124 | 125 | num_silent = 0 126 | snd_started = False 127 | 128 | r = array('h') 129 | 130 | while True: 131 | snd_data = array('h', stream.read(self.CHUNK_SIZE)) 132 | if byteorder == 'big': 133 | snd_data.byteswap() 134 | r.extend(snd_data) 135 | 136 | silent = self.is_silent(snd_data) 137 | 138 | if silent and snd_started: 139 | num_silent += 1 140 | elif not silent and not snd_started: 141 | snd_started = True 142 | 143 | if snd_started and num_silent > 10: 144 | break 145 | 146 | sample_width = p.get_sample_size(self.FORMAT) 147 | stream.stop_stream() 148 | stream.close() 149 | p.terminate() 150 | 151 | r = self.normalize(r) 152 | r = self.trim(r) 153 | r = self.add_silence(r, 0.5) 154 | return sample_width, r 155 | 156 | def record_to_file(self, path): 157 | 158 | sample_width, data = self.record() 159 | data = pack('<' + ('h' * len(data)), *data) 160 | 161 | wf = wave.open(path, 'wb') 162 | wf.setnchannels(1) 163 | wf.setsampwidth(sample_width) 164 | wf.setframerate(self.RATE) 165 | wf.writeframes(data) 166 | wf.close() 167 | 168 | def recorder(self, filename): 169 | self.record_to_file(filename) 170 | 171 | -------------------------------------------------------------------------------- /gptApi.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import ssl 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | # 只需要在python里设置代理即可 8 | #os.environ['HTTP_PROXY'] = 'http://202.79.168.14:6666' 9 | #os.environ['HTTPS_PROXY'] = 'http://202.79.168.14:6666' 10 | #openai没办法使用全局代理,在lib\site-packages\openai\api_requestor.py文件里添加代理即可 11 | 12 | # ssl._create_default_https_context = ssl._create_unverified_context 13 | # openai.api_key = '*********' 14 | # # openai.proxy = 'http://202.79.168.46:6666' 15 | # # openai.verify_ssl_certs = False 16 | # openai.api_base='https://202.79.168.46/v1' 17 | # def test_openai(string): 18 | # completion = openai.ChatCompletion.create( 19 | # model="gpt-3.5-turbo", 20 | # messages=[{"role": "user", "content": string}] 21 | # ) 22 | # return completion['choices'][0]['message']['content'].strip() 23 | # 24 | # res=test_openai("巩义在哪里?") 25 | # 26 | # print(res) 27 | 28 | class GptApi: 29 | def __init__(self): 30 | self.api_key='***********' 31 | openai.api_key = self.api_key 32 | openai.api_base='https://202.79.168.46/v1' 33 | 34 | def test_openai(self,string): 35 | completion = openai.ChatCompletion.create( 36 | model="gpt-3.5-turbo", 37 | messages=[{"role": "user", "content": string}] 38 | ) 39 | return completion['choices'][0]['message']['content'].strip() 40 | 41 | 42 | 43 | if __name__ == '__main__': 44 | test=GptApi() 45 | res=test.test_openai('你是谁?') 46 | print(res) 47 | -------------------------------------------------------------------------------- /gptApi_duolun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | from dotenv.main import load_dotenv 4 | 5 | load_dotenv() 6 | openai.api_key = os.environ['openai_api'] 7 | openai.api_base='https://202.79.168.46/v1' 8 | 9 | # 单轮对话调用 10 | # model可选"gpt-3.5-turbo"与"gpt-3.5-turbo-0301" 11 | def generate_answer(messages): 12 | completion = openai.ChatCompletion.create( 13 | model="gpt-3.5-turbo", 14 | messages=messages, 15 | temperature=0.7 16 | ) 17 | res_msg = completion.choices[0].message 18 | return res_msg["content"].strip() 19 | 20 | 21 | if __name__ == '__main__': 22 | # 维护一个列表用于存储多轮对话的信息 23 | messages = [{"role": "system", "content": "你现在是很有用的助手!"}] 24 | while True: 25 | prompt = input("请输入你的问题:") 26 | messages.append({"role": "user", "content": prompt}) 27 | res_msg = generate_answer(messages) 28 | messages.append({"role": "assistant", "content": res_msg}) 29 | print(res_msg) -------------------------------------------------------------------------------- /input/input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/input/input.wav -------------------------------------------------------------------------------- /output/output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/output/output.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiofiles==23.1.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | altair==4.2.2 6 | antlr4-python3-runtime==4.9.3 7 | anyio==3.6.2 8 | async-timeout==4.0.2 9 | attrs==22.2.0 10 | audioread==3.0.0 11 | backports.functools-lru-cache==1.6.4 12 | cachetools==5.3.0 13 | certifi==2022.12.7 14 | cffi==1.15.1 15 | charset-normalizer==3.1.0 16 | click==8.1.3 17 | cloudpickle==2.2.1 18 | cn2an==0.5.19 19 | colorama==0.4.6 20 | contourpy==1.0.7 21 | cycler==0.11.0 22 | Cython==0.29.33 23 | decorator==5.1.1 24 | demucs==4.0.0 25 | diffq==0.2.3 26 | dora-search==0.1.11 27 | einops==0.6.0 28 | eng-to-ipa==0.0.2 29 | entrypoints==0.4 30 | fastapi==0.95.0 31 | ffmpeg-python==0.2.0 32 | ffmpy==0.3.0 33 | filelock==3.10.7 34 | fonttools==4.39.3 35 | frozenlist==1.3.3 36 | fsspec==2023.3.0 37 | future==0.18.3 38 | google-auth==2.17.0 39 | google-auth-oauthlib==0.4.6 40 | gradio==3.23.0 41 | grpcio==1.53.0 42 | h11==0.14.0 43 | httpcore==0.16.3 44 | httpx==0.23.3 45 | huggingface-hub==0.13.3 46 | idna==3.4 47 | indic-transliteration==2.3.37 48 | inflect==6.0.2 49 | jamo==0.4.1 50 | jieba==0.42.1 51 | Jinja2==3.1.2 52 | joblib==1.2.0 53 | jsonschema==4.17.3 54 | julius==0.2.7 55 | kiwisolver==1.4.4 56 | ko-pron==1.3 57 | lameenc==1.4.2 58 | librosa==0.9.1 59 | linkify-it-py==2.0.0 60 | llvmlite==0.39.1 61 | Markdown==3.4.3 62 | markdown-it-py==2.2.0 63 | MarkupSafe==2.1.2 64 | matplotlib==3.7.1 65 | mdit-py-plugins==0.3.3 66 | mdurl==0.1.2 67 | more-itertools==9.1.0 68 | mp3play==0.1.15 69 | mpmath==1.3.0 70 | multidict==6.0.4 71 | networkx==3.0 72 | num-thai==0.0.5 73 | numba==0.56.4 74 | numpy==1.23.5 75 | oauthlib==3.2.2 76 | omegaconf==2.3.0 77 | openai == 0.27.2 78 | openai-whisper==20230314 79 | OpenCC==1.1.1 80 | openunmix==1.2.1 81 | orjson==3.8.9 82 | packaging==23.0 83 | pandas==1.5.3 84 | Pillow==9.4.0 85 | platformdirs==3.2.0 86 | playsound==1.3.0 87 | pooch==1.7.0 88 | proces==0.1.4 89 | protobuf==4.22.1 90 | pyasn1==0.4.8 91 | pyasn1-modules==0.2.8 92 | PyAudio==0.2.13 93 | pycparser==2.21 94 | pydantic==1.10.7 95 | pydub==0.25.1 96 | pyopenjtalk==0.3.0 97 | pyparsing==3.0.9 98 | pypinyin==0.48.0 99 | pyrsistent==0.19.3 100 | python-dateutil==2.8.2 101 | python-dotenv==1.0.0 102 | python-multipart==0.0.6 103 | pytz==2023.3 104 | PyYAML==6.0 105 | regex==2023.3.23 106 | requests==2.28.2 107 | requests-oauthlib==1.3.1 108 | resampy==0.4.2 109 | retrying==1.3.4 110 | rfc3986==1.5.0 111 | roman==4.0 112 | rsa==4.9 113 | scikit-learn==1.2.2 114 | scipy==1.10.1 115 | semantic-version==2.10.0 116 | six==1.16.0 117 | sniffio==1.3.0 118 | soundfile==0.12.1 119 | starlette==0.26.1 120 | submitit==1.4.5 121 | sympy==1.11.1 122 | tensorboard==2.12.0 123 | tensorboard-data-server==0.7.0 124 | tensorboard-plugin-wit==1.8.1 125 | threadpoolctl==3.1.0 126 | tiktoken==0.3.1 127 | toml==0.10.2 128 | toolz==0.12.0 129 | torch==2.0.0 130 | torchaudio==2.0.1 131 | tqdm==4.65.0 132 | treetable==0.2.5 133 | typer==0.7.0 134 | typing_extensions==4.5.0 135 | uc-micro-py==1.0.1 136 | Unidecode==1.3.6 137 | urllib3==1.26.15 138 | uvicorn==0.21.1 139 | websockets==10.4 140 | Werkzeug==2.2.3 141 | yarl==1.8.2 142 | -------------------------------------------------------------------------------- /robot.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import baiduApi 3 | import gptApi 4 | from playsound import playsound 5 | import gptApi_duolun 6 | from VITS.VC_inference import getVoice 7 | from dotenv.main import load_dotenv 8 | import os 9 | 10 | if __name__ == "__main__": 11 | load_dotenv() 12 | 13 | api_id = os.environ['api_id'] 14 | api_key = os.environ['api_key'] 15 | api_secert = os.environ['api_secert'] 16 | bdr = baiduApi.BaiduRest(api_id, api_key, api_secert) 17 | # robot = gptApi.GptApi() 18 | messages = [{"role": "system", "content": "你现在是很有用的助手!"}] 19 | while True: 20 | input("按下回车开始说话,自动停止") 21 | print('开始录音') 22 | bdr.recorder("./input/input.wav") 23 | print("结束") 24 | ask = bdr.getText('./input/input.wav') 25 | print('你:', ask) 26 | 27 | # robot=gptApi.GptApi() 28 | # ans = robot.test_openai(ask) 29 | messages.append({"role": "user", "content": ask}) 30 | ans = gptApi_duolun.generate_answer(messages) 31 | messages.append({"role": "assistant", "content": ans}) 32 | print('机器人:', ans) 33 | 34 | #bdr.getVoice(ans, "./output.mp3")#调用百度api合成语音,已经弃用 35 | getVoice(ans.lstrip(),'./output/output.wav')#默认路径:./output/output.mp3,同时删除文本前面空格,以免无法生成语音 36 | #bdr.speakMac("./output.mp3") 37 | playsound('./output/output.wav') 38 | --------------------------------------------------------------------------------