├── .gitignore ├── .pre-commit-config.yaml ├── A1_pretrained_models └── .gitkeep ├── A2_prepared_audios └── .gitkeep ├── A31_singleinfer.py ├── A32_caogao.py ├── A35_inference_server.py ├── A36_post.py ├── A3_scripts ├── .gitkeep ├── A33_ASR_ScriptsGen.py ├── A34_deleteModels.py ├── Finetune_Scripts.txt ├── __pycache__ │ └── asr_model_list.cpython-310.pyc └── asr_model_list.py ├── A40_一键启动微调pipeline.sh ├── A4_model_output └── .gitkeep ├── A5_finetuned_trainingout └── .gitkeep ├── LICENSE.txt ├── README.md ├── __pycache__ ├── attentions.cpython-310.pyc ├── commons.cpython-310.pyc ├── config.cpython-310.pyc ├── data_utils.cpython-310.pyc ├── losses.cpython-310.pyc ├── mel_processing.cpython-310.pyc ├── models.cpython-310.pyc ├── modules.cpython-310.pyc ├── transforms.cpython-310.pyc └── utils.cpython-310.pyc ├── attentions.py ├── bert_gen.py ├── commons.py ├── compress_model.py ├── config.py ├── config.yml ├── configs └── config.json ├── css └── custom.css ├── data_utils.py ├── default_config.yml ├── docs ├── 011.png ├── SSB00050007.wav ├── SSB0005_50_月光.wav ├── SSB0005_50_根据.wav ├── gentel_truth0.wav ├── gentle_girl_月光.wav ├── gentle_girl_根据.wav ├── index.html ├── 合成文本.txt └── 语音学习群.png ├── emotional ├── clap-htsat-fused │ ├── .gitattributes │ ├── README.md │ ├── config.json │ ├── merges.txt │ ├── preprocessor_config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json └── wav2vec2-large-robust-12-ft-emotion-msp-dim │ ├── .gitattributes │ ├── LICENSE │ ├── README.md │ ├── config.json │ ├── preprocessor_config.json │ └── vocab.json ├── export_onnx.py ├── filelists └── sample.list ├── for_deploy ├── infer.py ├── infer_utils.py └── webui.py ├── hiyoriUI.py ├── hlf.sh ├── img ├── yuyu.png ├── 参数说明.png ├── 宵宫.png ├── 微信图片_20231010105112.png ├── 神里绫华.png └── 纳西妲.png ├── infer.py ├── losses.py ├── mel_processing.py ├── models.py ├── modules.py ├── monotonic_align ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── core.cpython-310.pyc └── core.py ├── onnx_infer.py ├── preprocess_text.py ├── re_matching.py ├── requirements.txt ├── resample.py ├── resample_legacy.py ├── spec_gen.py ├── text ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── bert_utils.cpython-310.pyc │ ├── chinese.cpython-310.pyc │ ├── chinese_bert.cpython-310.pyc │ ├── cleaner.cpython-310.pyc │ ├── english.cpython-310.pyc │ ├── english_bert_mock.cpython-310.pyc │ ├── japanese.cpython-310.pyc │ ├── symbols.cpython-310.pyc │ └── tone_sandhi.cpython-310.pyc ├── bert_utils.py ├── chinese.py ├── chinese_bert.py ├── cleaner.py ├── cmudict.rep ├── cmudict_cache.pickle ├── english.py ├── english_bert_mock.py ├── japanese.py ├── japanese_bert.py ├── opencpop-strict.txt ├── symbols.py └── tone_sandhi.py ├── tools ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── log.cpython-310.pyc ├── classify_language.py ├── log.py ├── sentence.py └── translate.py ├── train_ms.py ├── transforms.py ├── utils.py ├── 微调用语音时长计算.py └── 语音数据复制.py /.gitignore: -------------------------------------------------------------------------------- 1 | # 忽略模型文件 2 | *.pt 3 | *.pyc 4 | 5 | bert/chinese-roberta-wwm-ext-large/* 6 | !bert/chinese-roberta-wwm-ext-large/.gitkeep 7 | 8 | slm/wavlm-base-plus/* 9 | !slm/wavlm-base-plus/.gitkeep 10 | 11 | A1_pretrained_models/* 12 | !A1_pretrained_models/.gitkeep 13 | 14 | A2_prepared_audios/* 15 | !A2_prepared_audios/.gitkeep 16 | 17 | A4_model_output/* 18 | !A4_model_output/.gitkeep 19 | 20 | A5_finetuned_trainingout/* 21 | !A5_finetuned_trainingout/.gitkeep 22 | 23 | 24 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: v0.6.3 11 | hooks: 12 | - id: ruff 13 | args: [ --fix ] 14 | 15 | - repo: https://github.com/psf/black 16 | rev: 24.8.0 17 | hooks: 18 | - id: black 19 | 20 | - repo: https://github.com/codespell-project/codespell 21 | rev: v2.3.0 22 | hooks: 23 | - id: codespell 24 | files: ^.*\.(py|md|rst|yml)$ 25 | args: [-L=fro] 26 | -------------------------------------------------------------------------------- /A1_pretrained_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/A1_pretrained_models/.gitkeep -------------------------------------------------------------------------------- /A2_prepared_audios/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/A2_prepared_audios/.gitkeep -------------------------------------------------------------------------------- /A31_singleinfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import commons 3 | import soundfile 4 | from text import cleaned_text_to_sequence, get_bert 5 | 6 | # from clap_wrapper import get_clap_audio_feature, get_clap_text_feature 7 | from typing import Union 8 | from text.cleaner import clean_text 9 | import utils 10 | from models import SynthesizerTrn 11 | from text.symbols import symbols 12 | 13 | def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7): 14 | style_text = None if style_text == "" else style_text 15 | # 在此处实现当前版本的get_text 16 | norm_text, phone, tone, word2ph = clean_text(text, language_str) 17 | phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) 18 | 19 | if hps.data.add_blank: 20 | phone = commons.intersperse(phone, 0) 21 | tone = commons.intersperse(tone, 0) 22 | language = commons.intersperse(language, 0) 23 | for i in range(len(word2ph)): 24 | word2ph[i] = word2ph[i] * 2 25 | word2ph[0] += 1 26 | bert_ori = get_bert( 27 | norm_text, word2ph, language_str, device, style_text, style_weight 28 | ) 29 | del word2ph 30 | assert bert_ori.shape[-1] == len(phone), phone 31 | 32 | if language_str == "ZH": 33 | bert = bert_ori 34 | ja_bert = torch.randn(1024, len(phone)) 35 | en_bert = torch.randn(1024, len(phone)) 36 | elif language_str == "JP": 37 | bert = torch.randn(1024, len(phone)) 38 | ja_bert = bert_ori 39 | en_bert = torch.randn(1024, len(phone)) 40 | elif language_str == "EN": 41 | bert = torch.randn(1024, len(phone)) 42 | ja_bert = torch.randn(1024, len(phone)) 43 | en_bert = bert_ori 44 | else: 45 | raise ValueError("language_str should be ZH, JP or EN") 46 | 47 | assert bert.shape[-1] == len( 48 | phone 49 | ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" 50 | 51 | phone = torch.LongTensor(phone) 52 | tone = torch.LongTensor(tone) 53 | language = torch.LongTensor(language) 54 | return bert, ja_bert, en_bert, phone, tone, language 55 | 56 | 57 | def infer( 58 | text,sdp_ratio,noise_scale,noise_scale_w,length_scale,sid,language,hps,net_g,reference_audio=None,skip_start=False, 59 | skip_end=False, 60 | style_text=None, 61 | style_weight=0.7, 62 | ): 63 | bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( 64 | text, 65 | language, 66 | hps, 67 | device, 68 | style_text=style_text, 69 | style_weight=style_weight, 70 | ) 71 | if skip_start: 72 | phones = phones[3:] 73 | tones = tones[3:] 74 | lang_ids = lang_ids[3:] 75 | bert = bert[:, 3:] 76 | ja_bert = ja_bert[:, 3:] 77 | en_bert = en_bert[:, 3:] 78 | if skip_end: 79 | phones = phones[:-2] 80 | tones = tones[:-2] 81 | lang_ids = lang_ids[:-2] 82 | bert = bert[:, :-2] 83 | ja_bert = ja_bert[:, :-2] 84 | en_bert = en_bert[:, :-2] 85 | with torch.no_grad(): 86 | x_tst = phones.to(device).unsqueeze(0) 87 | tones = tones.to(device).unsqueeze(0) 88 | lang_ids = lang_ids.to(device).unsqueeze(0) 89 | bert = bert.to(device).unsqueeze(0) 90 | ja_bert = ja_bert.to(device).unsqueeze(0) 91 | en_bert = en_bert.to(device).unsqueeze(0) 92 | x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) 93 | # emo = emo.to(device).unsqueeze(0) 94 | del phones 95 | speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) 96 | audio = ( 97 | net_g.infer( 98 | x_tst, 99 | x_tst_lengths, 100 | speakers, 101 | tones, 102 | lang_ids, 103 | bert, 104 | ja_bert, 105 | en_bert, 106 | sdp_ratio=sdp_ratio, 107 | noise_scale=noise_scale, 108 | noise_scale_w=noise_scale_w, 109 | length_scale=length_scale, 110 | )[0][0, 0] 111 | .data.cpu() 112 | .float() 113 | .numpy() 114 | ) 115 | del ( 116 | x_tst, 117 | tones, 118 | lang_ids, 119 | bert, 120 | x_tst_lengths, 121 | speakers, 122 | ja_bert, 123 | en_bert, 124 | ) # , emo 125 | if torch.cuda.is_available(): 126 | torch.cuda.empty_cache() 127 | return audio 128 | 129 | # 1 导入hps加载函数 130 | from utils import get_hparams_from_file 131 | 132 | # 2 设置输出日志配置 133 | import logging 134 | logging.basicConfig( 135 | level=logging.INFO, # 设置日志等级为 INFO 及以上 136 | format='%(asctime)s - %(levelname)s - %(message)s' # 设置日志格式 137 | ) 138 | from pathlib import Path 139 | if __name__=="__main__": 140 | ## 一、 超参数加载。 141 | hps = get_hparams_from_file(config_path="A5_finetuned_trainingout/gentle_girl/config.json") 142 | device = "cuda:0" 143 | model_path = "A5_finetuned_trainingout/gentle_girl/models/G_8000.pth" 144 | ## 145 | speaker_name = Path(model_path).parts[-3] 146 | language = "ZH" 147 | length_scale = 1.2 148 | infer_text = "根据我们上面的描述,我们的目标是希望获取一组相对最优的参数来作为模型的初始化参数" 149 | infer_id = 0 150 | sdp_ratio=0.4 151 | output_path = f'A4_model_output/{speaker_name}_{infer_id}.wav' 152 | 153 | ## 二、模型类实例初始化。 (已经初始化并加载了bert模型,但是初始化了vits模型,没加载预训练参数) 154 | 155 | net_g = SynthesizerTrn( 156 | len(symbols), 157 | hps.data.filter_length // 2 + 1, 158 | hps.train.segment_size // hps.data.hop_length, 159 | n_speakers=hps.data.n_speakers, 160 | mas_noise_scale_initial=0.01, 161 | noise_scale_delta=2e-6, 162 | **hps.model, 163 | ).to(device ) 164 | print('bert vits 的net_g初始化') 165 | 166 | ## 三、加载vits模型的预训练参数 167 | _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True) 168 | net_g = net_g.to(device) 169 | 170 | ## 四、根据各种参数,进行合成。 171 | audio = infer( 172 | text=infer_text, 173 | sdp_ratio=sdp_ratio, # 重要参数~~~ 174 | noise_scale=0.667, 175 | noise_scale_w=0.8, 176 | length_scale=length_scale, 177 | sid=speaker_name, 178 | language=language, 179 | hps=hps, 180 | net_g = net_g, 181 | reference_audio=None, 182 | skip_start=False, 183 | skip_end=False, 184 | style_text=None, 185 | style_weight=0.7, 186 | ) 187 | 188 | 189 | # 五、写入语音。 190 | soundfile.write(output_path,audio,samplerate=44100) 191 | print(f"tts 合成:{output_path}") 192 | 193 | pass 194 | 195 | 196 | -------------------------------------------------------------------------------- /A32_caogao.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # 设置日志配置 4 | logging.basicConfig( 5 | level=logging.INFO, # 设置日志等级为 INFO 及以上 6 | format='%(asctime)s - %(levelname)s - %(message)s' # 设置日志格式 7 | ) 8 | 9 | # 示例日志 10 | logging.info("This is an info message.") -------------------------------------------------------------------------------- /A35_inference_server.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import commons 3 | from text import cleaned_text_to_sequence, get_bert 4 | 5 | # from clap_wrapper import get_clap_audio_feature, get_clap_text_feature 6 | from typing import Union 7 | from text.cleaner import clean_text 8 | import utils 9 | from models import SynthesizerTrn 10 | from text.symbols import symbols 11 | 12 | def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7): 13 | style_text = None if style_text == "" else style_text 14 | # 在此处实现当前版本的get_text 15 | norm_text, phone, tone, word2ph = clean_text(text, language_str) 16 | phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) 17 | 18 | if hps.data.add_blank: 19 | phone = commons.intersperse(phone, 0) 20 | tone = commons.intersperse(tone, 0) 21 | language = commons.intersperse(language, 0) 22 | for i in range(len(word2ph)): 23 | word2ph[i] = word2ph[i] * 2 24 | word2ph[0] += 1 25 | bert_ori = get_bert( 26 | norm_text, word2ph, language_str, device, style_text, style_weight 27 | ) 28 | del word2ph 29 | assert bert_ori.shape[-1] == len(phone), phone 30 | 31 | if language_str == "ZH": 32 | bert = bert_ori 33 | ja_bert = torch.randn(1024, len(phone)) 34 | en_bert = torch.randn(1024, len(phone)) 35 | elif language_str == "JP": 36 | bert = torch.randn(1024, len(phone)) 37 | ja_bert = bert_ori 38 | en_bert = torch.randn(1024, len(phone)) 39 | elif language_str == "EN": 40 | bert = torch.randn(1024, len(phone)) 41 | ja_bert = torch.randn(1024, len(phone)) 42 | en_bert = bert_ori 43 | else: 44 | raise ValueError("language_str should be ZH, JP or EN") 45 | 46 | assert bert.shape[-1] == len( 47 | phone 48 | ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" 49 | 50 | phone = torch.LongTensor(phone) 51 | tone = torch.LongTensor(tone) 52 | language = torch.LongTensor(language) 53 | return bert, ja_bert, en_bert, phone, tone, language 54 | 55 | 56 | def infer( 57 | text,sdp_ratio,noise_scale,noise_scale_w,length_scale,sid,language,hps,net_g,reference_audio=None,skip_start=False, 58 | skip_end=False, 59 | style_text=None, 60 | style_weight=0.7, 61 | ): 62 | bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( 63 | text, 64 | language, 65 | hps, 66 | device, 67 | style_text=style_text, 68 | style_weight=style_weight, 69 | ) 70 | if skip_start: 71 | phones = phones[3:] 72 | tones = tones[3:] 73 | lang_ids = lang_ids[3:] 74 | bert = bert[:, 3:] 75 | ja_bert = ja_bert[:, 3:] 76 | en_bert = en_bert[:, 3:] 77 | if skip_end: 78 | phones = phones[:-2] 79 | tones = tones[:-2] 80 | lang_ids = lang_ids[:-2] 81 | bert = bert[:, :-2] 82 | ja_bert = ja_bert[:, :-2] 83 | en_bert = en_bert[:, :-2] 84 | with torch.no_grad(): 85 | x_tst = phones.to(device).unsqueeze(0) 86 | tones = tones.to(device).unsqueeze(0) 87 | lang_ids = lang_ids.to(device).unsqueeze(0) 88 | bert = bert.to(device).unsqueeze(0) 89 | ja_bert = ja_bert.to(device).unsqueeze(0) 90 | en_bert = en_bert.to(device).unsqueeze(0) 91 | x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) 92 | # emo = emo.to(device).unsqueeze(0) 93 | del phones 94 | speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) 95 | audio = ( 96 | net_g.infer( 97 | x_tst, 98 | x_tst_lengths, 99 | speakers, 100 | tones, 101 | lang_ids, 102 | bert, 103 | ja_bert, 104 | en_bert, 105 | sdp_ratio=sdp_ratio, 106 | noise_scale=noise_scale, 107 | noise_scale_w=noise_scale_w, 108 | length_scale=length_scale, 109 | )[0][0, 0] 110 | .data.cpu() 111 | .float() 112 | .numpy() 113 | ) 114 | del ( 115 | x_tst, 116 | tones, 117 | lang_ids, 118 | bert, 119 | x_tst_lengths, 120 | speakers, 121 | ja_bert, 122 | en_bert, 123 | ) # , emo 124 | if torch.cuda.is_available(): 125 | torch.cuda.empty_cache() 126 | return audio 127 | 128 | # 1 导入hps加载函数 129 | from utils import get_hparams_from_file 130 | 131 | # 2 设置输出日志配置 132 | import logging 133 | logging.basicConfig( 134 | level=logging.INFO, # 设置日志等级为 INFO 及以上 135 | format='%(asctime)s - %(levelname)s - %(message)s' # 设置日志格式 136 | ) 137 | 138 | 139 | from fastapi import FastAPI, Request 140 | from pydantic import BaseModel 141 | import torch 142 | import soundfile as sf 143 | import time 144 | 145 | 146 | # 假设这里有你需要的模块 147 | # from synthesizer import SynthesizerTrn, utils, infer, symbols, get_hparams_from_file 148 | 149 | app = FastAPI() 150 | 151 | # 全局变量保存模型和超参数,确保只加载一次 152 | model = None 153 | hps = None 154 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 155 | 156 | # 请求体定义 157 | class InferRequest(BaseModel): 158 | speaker_name: str 159 | language: str 160 | length_scale: float 161 | infer_text: str 162 | infer_id: int 163 | sdp_ratio: float 164 | output_path: str 165 | 166 | 167 | # 加载模型 168 | def load_model(): 169 | global model, hps 170 | if model is None and hps is None: 171 | hps = get_hparams_from_file(config_path="configs/config.json") 172 | model_path = "A1_pretrained_models/Bert-VITS2_2.3/G_0.pth" 173 | 174 | # 初始化模型 175 | model = SynthesizerTrn( 176 | len(symbols), 177 | hps.data.filter_length // 2 + 1, 178 | hps.train.segment_size // hps.data.hop_length, 179 | n_speakers=hps.data.n_speakers, 180 | mas_noise_scale_initial=0.01, 181 | noise_scale_delta=2e-6, 182 | **hps.model, 183 | ).to(device) 184 | 185 | print('bert vits 的net_g初始化') 186 | 187 | # 加载预训练模型参数 188 | utils.load_checkpoint(model_path, model, None, skip_optimizer=True) 189 | model = model.to(device) 190 | print("模型已加载") 191 | 192 | # 加载模型(仅服务程序启动时加载) 193 | load_model() 194 | # 推理接口 195 | @app.post("/infer_bertvits2/") 196 | async def infer_speech(infer_req: InferRequest): 197 | 198 | st = time.perf_counter() 199 | # 提取请求体中的参数 200 | speaker_name = infer_req.speaker_name 201 | language = infer_req.language 202 | length_scale = infer_req.length_scale 203 | infer_text = infer_req.infer_text 204 | infer_id = infer_req.infer_id 205 | sdp_ratio = infer_req.sdp_ratio 206 | output_path = infer_req.output_path 207 | 208 | # 进行语音合成 209 | audio = infer( 210 | text=infer_text, 211 | sdp_ratio=sdp_ratio, 212 | noise_scale=0.667, 213 | noise_scale_w=0.8, 214 | length_scale=length_scale, 215 | sid=speaker_name, 216 | language=language, 217 | hps=hps, 218 | net_g=model, 219 | reference_audio=None, 220 | skip_start=False, 221 | skip_end=False, 222 | style_text=None, 223 | style_weight=0.7, 224 | ) 225 | 226 | # 保存生成的音频 227 | sf.write(output_path, audio, samplerate=44100) 228 | print(f"语音已保存至: {output_path}") 229 | 230 | # 计算推理速度、效率 231 | audiolen = len(audio) / 44100 # 单位秒 232 | textlen = len(infer_text) # 单位 token数 233 | 234 | et =time.perf_counter() 235 | 236 | usetime = f"{(et-st):.4f}" 237 | outs1 = f"{{audiolen:{audiolen:.4f},textlen:{textlen:.4f},usetime:{usetime}}}" 238 | 239 | 240 | return {"output_path": output_path,"out_info1":outs1} 241 | 242 | if __name__ == "__main__": 243 | import uvicorn 244 | uvicorn.run(app, host="0.0.0.0", port=8102) 245 | 246 | -------------------------------------------------------------------------------- /A36_post.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | # 定义 FastAPI 服务的 URL 4 | url = "http://127.0.0.1:8102/infer_bertvits2/" 5 | 6 | # 定义请求体内容 7 | request_data = { 8 | "speaker_name": "八重神子_ZH", 9 | "language": "ZH", 10 | "length_scale": 1.2, 11 | "infer_text": "即使引导已经破碎,也请觐见艾尔登法环", 12 | "infer_id": 4, 13 | "sdp_ratio": 0.4, 14 | 15 | } 16 | ## 这里用 infer_id 这个变量控制输出语音的路径。 id是index的意思。而不是identity。 17 | a = request_data["infer_id"] 18 | request_data["output_path"] = f"A4_model_output/SSB0005_{a}.wav" 19 | 20 | # 发送 POST 请求 21 | response = requests.post(url, json=request_data) 22 | 23 | # 检查响应状态并处理 24 | if response.status_code == 200: 25 | # 输出服务器返回的内容 26 | print("请求成功:") 27 | print(response.json()) 28 | else: 29 | print(f"请求失败,状态码: {response.status_code}") 30 | print(response.text) -------------------------------------------------------------------------------- /A3_scripts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/A3_scripts/.gitkeep -------------------------------------------------------------------------------- /A3_scripts/A33_ASR_ScriptsGen.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | from asr_model_list import get_vad_punc_model,get_model02 5 | import click 6 | from pathlib import Path 7 | import time 8 | import logging 9 | # 配置日志记录,设置级别为 INFO 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | import time 13 | import librosa 14 | from typing import List, Tuple 15 | import wave 16 | import numpy as np 17 | 18 | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: 19 | """ 20 | Args: 21 | wave_filename: 22 | Path to a wave file. It should be single channel and each sample should 23 | be 16-bit. Its sample rate does not need to be 16kHz. 24 | Returns: 25 | Return a tuple containing: 26 | - A 1-D array of dtype np.float32 containing the samples, which are 27 | normalized to the range [-1, 1]. 28 | - sample rate of the wave file 29 | """ 30 | 31 | with wave.open(wave_filename) as f: 32 | assert f.getnchannels() == 1, f.getnchannels() 33 | assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes 34 | num_samples = f.getnframes() 35 | samples = f.readframes(num_samples) 36 | samples_int16 = np.frombuffer(samples, dtype=np.int16) 37 | samples_float32 = samples_int16.astype(np.float32) 38 | 39 | samples_float32 = samples_float32 / 32768 40 | return samples_float32, f.getframerate() 41 | 42 | 43 | 44 | 45 | 46 | 47 | class BaseASRModel(object): 48 | def __init__(self) -> None: 49 | 50 | # 模型、识别、 51 | 52 | self.recognizer = None 53 | self.vad = None 54 | self.punct = None 55 | self.window_size = None 56 | self.recognizer_name = None 57 | 58 | def asr_model_init(self,model_func): 59 | self.recognizer ,self.recognizer_name = model_func() 60 | logging.info("load asr model") 61 | 62 | def get_asr_modelname(self,): 63 | return self.recognizer_name 64 | 65 | def vad_punc_model_init(self,): 66 | self.vad,self.punct,self.window_size = get_vad_punc_model() 67 | 68 | logging.info("load vad,punc model") 69 | 70 | 71 | def single_wav_recognize(self,input_wavfile): 72 | 73 | samples, sample_rate = read_wave(input_wavfile) ## 74 | duration = len(samples) / sample_rate 75 | # 采样率修正为16K 76 | if sample_rate != 16000: 77 | samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000) 78 | sample_rate = 16000 79 | 80 | ## VAD 开始 ,VAD 会去除静音部分。将语音切割为多段。 81 | 82 | speech_samples = [] 83 | time1 = time.time() 84 | while len(samples) > self.window_size: 85 | self.vad.accept_waveform(samples[:self.window_size]) 86 | samples = samples[self.window_size:] 87 | while not self.vad.empty(): 88 | speech_samples.append(self.vad.front.samples) 89 | self.vad.pop() 90 | self.vad.flush() 91 | while not self.vad.empty(): 92 | speech_samples.append(self.vad.front.samples) 93 | self.vad.pop() 94 | #print(len(speech_samples),type(speech_samples)) 95 | time2 = time.time() 96 | ### ASR 开始。 97 | results = [] 98 | for i in range(len(speech_samples)): 99 | s = self.recognizer.create_stream() 100 | s.accept_waveform(sample_rate, speech_samples[i]) 101 | self.recognizer.decode_stream(s) 102 | results.append(s.result.text) 103 | #print("asr result:",results) 104 | time3 = time.time() 105 | 106 | ## 最终判断一下是不是要做 标点符号模型。 (某些ASR模型直接输出了 标点符号。) 107 | endsens = ["。",",",".",",","!","?","!","!"] 108 | results = [ x for x in results if x != ""] 109 | try: 110 | if results[0][-1] not in endsens: 111 | texts_with_punct = [ self.punct.add_punctuation(t) for t in results] 112 | time4 = time.time() 113 | final_result = "".join(texts_with_punct) 114 | 115 | t1,t2,t3 = (time2-time1),(time3-time2),(time4-time3) 116 | logging.info(f"语音时长:{duration:.4f},识别耗时:VAD:{t1:.4f}秒,ASR:{t2:.4f}秒,PUNC:{t3:.4f}秒") 117 | logging.info(f"识别结果:{final_result}") 118 | return {"text":final_result,"consume":{"vad":t1,"asr":t2,"punc":t3}} 119 | else: 120 | final_result = "".join(results) 121 | t1,t2,t3 = (time2-time1),(time3-time2),0.0 122 | logging.info(f"语音时长:{duration:.4f},识别耗时:VAD:{t1:.4f}秒,ASR:{t2:.4f}秒,PUNC:{t3:.4f}秒") 123 | logging.info(f"识别结果:{final_result}") 124 | return {"text":final_result,"consume":{"vad":t1,"asr":t2,"punc":t3}} 125 | except Exception as e: 126 | print(f"无正常结果,跳过该语音,{e}") 127 | 128 | 129 | 130 | 131 | ## out = aa.single_wav_recognize(input_wavfile=params.input_s) 132 | 133 | @click.command() 134 | @click.option('--wavdir', type=str, help='Your Wav datadir') 135 | @click.option('--output_txt', type=str, help='Your Annotation text') 136 | @click.option('--lang', type=str, help='Yourlanguage') 137 | def biaozhu(wavdir, output_txt, lang): 138 | wavdir = Path(wavdir) 139 | 140 | # 检查 wavdir 目录是否存在 141 | if not wavdir.exists(): 142 | raise FileNotFoundError(f"目录 {wavdir} 不存在") 143 | 144 | # 检查 output_txt 文件是否存在,存在则删除 145 | if os.path.exists(output_txt): 146 | os.remove(output_txt) 147 | logging.info(f"已删除存在的文件: {output_txt}") 148 | 149 | # 创建outputtext的父目录 150 | output_txt = Path(output_txt) 151 | output_txt_parent = output_txt.parent 152 | output_txt_parent.mkdir(exist_ok=True,parents=True) 153 | 154 | 155 | wavfiles = sorted([x for x in wavdir.rglob("*.wav")], key=lambda x: x.stem) 156 | 157 | # 逐个识别语音 158 | ttnum = len(wavfiles) 159 | kk = 0 160 | with open(output_txt, 'a', encoding="utf-8") as f1: 161 | speakername = wavdir.name 162 | for wavf in wavfiles: 163 | try: 164 | Annotation = aa.single_wav_recognize(input_wavfile=str(wavf)) 165 | writestr = f"{str(wavf)}|{speakername}|{lang}|{Annotation['text']}" 166 | 167 | f1.write(f"{writestr}\n") 168 | 169 | kk += 1 170 | logging.info(f"第 {kk}/{ttnum} 个识别完毕") 171 | 172 | except Exception as e: 173 | print(f"无正常结果,跳过该标注,{e}") 174 | continue 175 | 176 | 177 | 178 | 179 | 180 | pass 181 | 182 | 183 | if __name__ =="__main__": 184 | 185 | st = time.time() 186 | ## ASR识别pipeline类 的初始化 187 | aa = BaseASRModel() 188 | aa.vad_punc_model_init() 189 | aa.asr_model_init(model_func=get_model02) 190 | 191 | biaozhu() 192 | 193 | et = time.time() 194 | print(f'总用时间:{et-st}s') 195 | 196 | """ 197 | python A3_scripts/A33_ASR_ScriptsGen.py \ 198 | --wavdir A2_prepared_audios/SSB0005_50\ 199 | --output_txt A5_finetuned_trainingout/SSB0005_50/filelists/script.txt \ 200 | --lang ZH 201 | """ 202 | pass 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /A3_scripts/A34_deleteModels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | def delete_files_less_than_N(directory, N): 5 | # 正则表达式匹配 _.pth 格式的文件 6 | pattern = re.compile(r'.*_(\d+)\.pth$') 7 | 8 | # 遍历目录中的文件 9 | for filename in os.listdir(directory): 10 | match = pattern.match(filename) 11 | if match: 12 | # 获取文件中的数字 13 | file_number = int(match.group(1)) 14 | # 如果文件中的数字小于 N,则删除文件 15 | if file_number < N or filename[0] != "G": 16 | file_path = os.path.join(directory, filename) 17 | os.remove(file_path) 18 | print(f"Deleted: {file_path}") 19 | 20 | # 示例使用 21 | directory = 'A5_finetuned_trainingout/SSB0005_50/models' # 目录A的路径 22 | N = 8000 # 你想要删除的文件序号阈值 23 | delete_files_less_than_N(directory, N) -------------------------------------------------------------------------------- /A3_scripts/Finetune_Scripts.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/A3_scripts/Finetune_Scripts.txt -------------------------------------------------------------------------------- /A3_scripts/__pycache__/asr_model_list.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/A3_scripts/__pycache__/asr_model_list.cpython-310.pyc -------------------------------------------------------------------------------- /A3_scripts/asr_model_list.py: -------------------------------------------------------------------------------- 1 | import sherpa_onnx 2 | 3 | 4 | def get_vad_punc_model(): 5 | # load PUCN model 6 | pcmodel = "A1_pretrained_models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx" 7 | 8 | config = sherpa_onnx.OfflinePunctuationConfig( 9 | model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=pcmodel)) 10 | 11 | punct = sherpa_onnx.OfflinePunctuation(config) 12 | 13 | ## load VAD模型 14 | config = sherpa_onnx.VadModelConfig() 15 | config.silero_vad.model = "A1_pretrained_models/VAD_model/silero_vad.onnx" 16 | config.sample_rate = 16000 17 | 18 | window_size = config.silero_vad.window_size 19 | vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=30) 20 | 21 | return vad,punct,window_size 22 | 23 | 24 | def get_model02(): 25 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( 26 | paraformer="A1_pretrained_models/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx", 27 | tokens="A1_pretrained_models/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt", 28 | num_threads=1, 29 | sample_rate=16000, 30 | feature_dim=80, 31 | decoding_method="greedy_search", 32 | debug=False, 33 | provider='cuda' 34 | ) 35 | name = "sherpa-onnx-paraformer-zh-2023-03-28" 36 | return recognizer , name 37 | -------------------------------------------------------------------------------- /A40_一键启动微调pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 定义数据集名字。 4 | MODEL_ID="SSB0273_50" 5 | 6 | ## 数据集位置:A2_prepared_audios/$MODEL_ID 7 | 8 | ## 1 ASR模型识别文本 9 | python A3_scripts/A33_ASR_ScriptsGen.py \ 10 | --wavdir A2_prepared_audios/$MODEL_ID \ 11 | --output_txt A5_finetuned_trainingout/$MODEL_ID/filelists/script.txt \ 12 | --lang ZH 13 | 14 | ## 2 复制配置文件 15 | cp configs/config.json A5_finetuned_trainingout/$MODEL_ID 16 | 17 | ## 3 根据文件产生音素标注、训练验证集、 18 | python preprocess_text.py \ 19 | --transcription-path A5_finetuned_trainingout/$MODEL_ID/filelists/script.txt \ 20 | --cleaned-path A5_finetuned_trainingout/$MODEL_ID/filelists/script.txt.cleaned \ 21 | --train-path A5_finetuned_trainingout/$MODEL_ID/filelists/script.txt.cleaned.train \ 22 | --val-path A5_finetuned_trainingout/$MODEL_ID/filelists/script.txt.cleaned.val \ 23 | --config-path A5_finetuned_trainingout/$MODEL_ID/config.json 24 | 25 | # 4 用bert模型产生音素的 .pt文件 26 | python bert_gen.py -c A5_finetuned_trainingout/$MODEL_ID/config.json 27 | 28 | # 5 melspec 29 | python spec_gen.py --script A5_finetuned_trainingout/$MODEL_ID/filelists/script.txt.cleaned.train 30 | 31 | # 6 启动train ms 32 | CUDA_VISIBLE_DEVICES=1,2 torchrun --nproc_per_node=2 train_ms.py -c A5_finetuned_trainingout/$MODEL_ID/config.json \ 33 | -m A5_finetuned_trainingout/$MODEL_ID \ 34 | -mb A1_pretrained_models/Bert-VITS2_2.3 -------------------------------------------------------------------------------- /A4_model_output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/A4_model_output/.gitkeep -------------------------------------------------------------------------------- /A5_finetuned_trainingout/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/A5_finetuned_trainingout/.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 前言 2 | 原始项目:[Bert-VITS2]([https://github.com](https://github.com/fishaudio/Bert-VITS2)。 3 | 本文是一个改进版本的BERT VITS2项目使用教程,尽可能去除了bug。希望各位群策群力,提出issue,尽量减少bug,能快速开始微调。 4 | 有兴趣交流语音技术的同学可以加入QQ群 742922321。 5 | 6 | ![QQ群](/docs/语音学习群.png) 7 | 8 | Bert vits2语音合成项目已经停止维护,因此这最后一版本代码有必要分享一个部署经验。 9 | Bert vits2项目的底模模型主要是bert +vits,训练数据主要是原神角色语音。微调训练的时候主要是微调vits模型,冻结bert模型。不包含任何speaker encoder和emotional encoder。 10 | bert模型负责产生文本编码向量Ht。vits模型负责合成语音 wav = vits(Ht)。 11 | 12 | 该项目能进行语音合成推理和微调。需使用50条以上的1-5秒的语音进行微调。若用高质量语音数据,微调出来声音质量、推理速度、基本满足商业要求。 13 | 14 | 相比于gptsovits、fish-speech等新式TTS模型,有几个优势:1、由于模型小,因此合成速度快。做成接口以后,速度基本满足商业对话要求。2、经过微调后,音色稳定。 15 | Fishspeech等模型,随机因素强,音色可能偏离,甚至发出没输入过的文本的声音。也有缺点:1、仅有3种语言。 2、代码存在诸多bug,需要自己修改。 16 | 17 | 本项目准备了文档:《第三版dhtz-2024年0912Bert-vits2项目部署经验.pdf》。该文档包含了项目如何修改完全BUG的过程。但仍然建议看下面的部署命令。 18 | 19 | 本人运行设备:ubuntu22.04系统,V100显卡2张。项目可以在windows运行,但是所有的路径需要改成win格式。 20 | 21 | 经过修改后的无bug版本和代码已经发布在123云盘,注意,已经包含所有模型和预训练文件,还包含一次微调过的模型文件。使用了AIshell3的SSB0005说话人。因此,你可以从这里下载所有模型,然后上传到你的服务器。 22 | ``` 23 | https://www.123pan.com/s/KLIzVv-pnMsh 24 | 提取码:wxkO 25 | ``` 26 | 27 | # 微调Demo演示 28 | 微调案例请访问:[Bert Vits Demo 演示 加载速度稍慢 请耐心等待](http://semanticexplore.com/bert_vits_demo) 29 | 30 | ![模型原理简介](docs/011.png) 31 | 32 | # 一、conda 环境安装 33 | ``` 34 | # 推荐先安装torch torchaudio 35 | conda create -n vits2 python=3.10.12 36 | conda activate vits2 37 | pip install torch torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple 38 | pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 39 | ``` 40 | ``` 41 | ## 为了在微调阶段免去标注的需求。我们额外使用了ASR(语音识别模型)来识别待微调语音。 42 | ## 因此需要安装python的sherpa onnx库。这是一个小米公司做的开源库。相当方便使用。 43 | ## 推荐的安装方式是使用.whl安装包安装。 44 | https://k2-fsa.github.io/sherpa/onnx/cuda.html 45 | # 根据自己的操作系统和python版本进行选择。例如是linux系统,python虚拟环境是3.10的python版本,则下载: 46 | sherpa_onnx-1.10.27+cuda-cp310-cp310-linux_x86_64.whl 47 | # 随后将该文件上传到项目,执行: 48 | pip install sherpa_onnx-1.10.27+cuda-cp310-cp310-linux_x86_64.whl 49 | ``` 50 | 51 | # 二、模型、数据准备 52 | 以微调一个 aishell3中文语音数据集的SSB0005说话人为案例。 53 | 需要准备bert模型、vits模型、WAVLM模型、SSB0005说话人的语音。 54 | CN境内的服务器,建议利用hlf.sh下载。hlf.sh 的使用方式是:bash hlf.sh huggingface模型目录 你的服务器放置模型的路径 55 | 56 | ## 2.1 可以从huggingface复制模型目录 57 | ``` 58 | https://huggingface.co/hfl/chinese-roberta-wwm-ext-large 59 | ``` 60 | hfl/chinese-roberta-wwm-ext-large即为模型目录。其他模型的下载方式同理。 61 | ## 2.2 下载中文的BERT模型 62 | ```bash 63 | bash hlf.sh hfl/chinese-roberta-wwm-ext-large chinese-roberta-wwm-ext-large 64 | 65 | # 移动到 bert文件夹下面 66 | mv chinese-roberta-wwm-ext-large bert 67 | ``` 68 | 其他语言的bert请参考: 69 | ``` 70 | "- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n" 71 | "- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n" 72 | "- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n" 73 | "- [WavLM](https://huggingface.co/microsoft/wavlm-base-plus)\n" 74 | ``` 75 | 注意到,bert模型均放在文件夹./bert下面。 76 | ## 2.3 下载WAVLM模型 77 | ```bash 78 | bash hlf.sh microsoft/wavlm-base-plus wavlm-base-plus 79 | # 移动到 slm文件夹下面 80 | mv wavlm-base-plus slm 81 | 82 | ``` 83 | ## 2.4 下载vits模型底模 84 | 建议下载下面网站的底模模型。然后自己上传到服务器对应目录下。 85 | ``` 86 | https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model 87 | ``` 88 | 本项目采取《Bert-VITS2_2.3底模》。 89 | 将模型文件放在: 90 | ``` 91 | ./A1_pretrained_models/Bert-VITS2_2.3 92 | # 文件目录结构如下 93 | A1_pretrained_models/Bert-VITS2_2.3 94 | ├── D_0.pth 95 | ├── DUR_0.pth 96 | ├── G_0.pth 97 | ├── README 98 | └── WD_0.pth 99 | ``` 100 | # 二、Base model 推理 101 | 各种模型都放好的情况下,执行: 102 | ``` 103 | python A31_singleinfer.py 104 | ``` 105 | 代码关键参数如下: 106 | ```python 107 | ## 一、 超参数加载。 108 | hps = get_hparams_from_file(config_path="configs/config.json") # 配置文件不能错 109 | device = "cuda:0" 110 | model_path = "A1_pretrained_models/Bert-VITS2_2.3/G_0.pth" ## 生成器的路径 111 | ## 112 | speaker_name = "八重神子_ZH" 113 | language = "ZH" 114 | length_scale = 1.2 115 | infer_text = "今夜的月光如此清亮,不做些什么真是浪费。随我一同去月下漫步吧,不许拒绝。" 116 | infer_id = 3 ## 当前合成了第infer_id个语音 117 | sdp_ratio=0.4 118 | output_path = f'A4_model_output/{speaker_name}_{infer_id}.wav' 119 | ``` 120 | 121 | # 三、使用自己准备的数据微调 122 | 123 | ## 3.1 准备高质量的语音文件 124 | 推荐准备44Khz的中文语音文件。数量建议大于50,每条的文本token数量建议大于5。(注意,若用低采样率的语音进行上采样基本无效。) 125 | 例如,将语音放入如下文件夹: 126 | ``` 127 | A2_prepared_audios/SSB0005 128 | ├── SSB00050001.wav 129 | ├── SSB00050002.wav 130 | ...... 131 | └── SSB00050490.wav 132 | ``` 133 | ## 3.2 准备sherpaonnx库所需的ASR模型,VAD模型,PUNC模型,和环境 134 | sherpa onnx库是由C语言写的底层代码,上层支持python、java等多种语言调用。 135 | 若希望使用GPU进行自动标注,在sherpa-onnx 1.10.27版本时,只能使用cudatookit 11.8版本。 136 | 但实际用CPU运行的速度也可以接受。 137 | 使用浏览器,新建三个任务,下载三个模型。分别是语音识别模型(ASR),语音活动检测模型(VAD),标点符号模型(PUNC)。 138 | ``` 139 | https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 140 | https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx 141 | https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 142 | ``` 143 | 将模型放置为: 144 | ``` 145 | A1_pretrained_models 146 | ├── Bert-VITS2_2.3 147 | │ ├── D_0.pth 148 | │ ├── DUR_0.pth 149 | │ ├── G_0.pth 150 | │ ├── README 151 | │ └── WD_0.pth 152 | ├── sherpa-onnx-paraformer-zh-2023-03-28 153 | │ ├── model.int8.onnx 154 | │ ├── model.onnx 155 | │ ├── README.md 156 | │ .... 157 | │ └── tokens.txt 158 | ├── sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 159 | │ ├── add-model-metadata.py 160 | │ ├── config.yaml 161 | │ ├── model.onnx 162 | │ ├── README.md 163 | │ ├── show-model-input-output.py 164 | │ ├── test.py 165 | │ └── tokens.json 166 | └── VAD_model 167 | └── silero_vad.onnx 168 | ``` 169 | 如果你不具备cuda 11.8 的软件,则修改代码: 170 | ``` 171 | ## A3_scripts/asr_model_list.py 172 | def get_model02(): 173 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( 174 | paraformer="A1模型文件/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx", 175 | tokens="A1模型文件/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt", 176 | num_threads=1, 177 | sample_rate=16000, 178 | feature_dim=80, 179 | decoding_method="greedy_search", 180 | debug=False, 181 | provider='cpu' ### provider改成cpu。 182 | ) 183 | 184 | ``` 185 | 186 | 187 | ## 3.3 识别准备好的语音文件的文本,形成文本、说话人、语音路径、语言的标注清单。 188 | ```bash 189 | ## 下面代码将对目录里的每条语音,进行标注。形成标注文件A5_finetuned_trainingout/SSB0005/filelists/script.txt。 190 | 191 | python A3_scripts/A33_ASR_ScriptsGen.py \ 192 | --wavdir A2_prepared_audios/SSB0005\ 193 | --output_txt A5_finetuned_trainingout/SSB0005/filelists/script.txt \ 194 | --lang ZH 195 | ``` 196 | 可以看到标注文件: 197 | ``` 198 | A2_prepared_audios/SSB0005/SSB00050001.wav|SSB0005|ZH|广州女大学生登山失联四天,警方找到疑似女尸。 199 | ...... 200 | ``` 201 | ## 3.4 G2P 202 | G2P的目的是把文本序列转音素序列(产生.cleaned),并划分训练验证集(产生.train和.val)。 203 | 因此,输入1个text路径,输出3个text的路径。顺便复制一下config文件。 204 | ```bash 205 | cp configs/config.json A5_finetuned_trainingout/SSB0005 206 | 207 | python preprocess_text.py \ 208 | --transcription-path A5_finetuned_trainingout/SSB0005/filelists/script.txt \ 209 | --cleaned-path A5_finetuned_trainingout/SSB0005/filelists/script.txt.cleaned \ 210 | --train-path A5_finetuned_trainingout/SSB0005/filelists/script.txt.cleaned.train \ 211 | --val-path A5_finetuned_trainingout/SSB0005/filelists/script.txt.cleaned.val \ 212 | --config-path A5_finetuned_trainingout/SSB0005/config.json 213 | ``` 214 | 顺便可以看到A5_finetuned_trainingout/SSB0005/config.json已经更新了: 215 | ``` 216 | "data": { 217 | "training_files": "A5_finetuned_trainingout/SSB0005/filelists/script.txt.cleaned.train", 218 | "validation_files": "A5_finetuned_trainingout/SSB0005/filelists/script.txt.cleaned.val", 219 | ``` 220 | ## 3.5 文本输入bert,生成token 221 | 生成的token存储为整数向量,存储为.pt文件 222 | ```bash 223 | python bert_gen.py -c A5_finetuned_trainingout/SSB0005/config.json 224 | ``` 225 | ## 3.6 语音生成melspec 226 | melspec用于辅助vits训练,每个语音都会产生一个,也存储为.pt文件。 227 | 先修改spec_gen.py文件 228 | ``` 229 | if __name__ == "__main__": 230 | ## 下面这个文件填入 script.txt.cleaned.train 的路径。也就是音素训练清单 231 | with open("A5_finetuned_trainingout/SSB0005/filelists/script.txt.cleaned.train", "r") as f: 232 | filepaths = [line.split("|")[0] for line in f] # 取每一行的第一部分作为audiopath 233 | ``` 234 | 再执行 235 | ```bash 236 | python spec_gen.py 237 | ``` 238 | ## 3.7 开始微调训练 239 | 建议是使用tmux窗口进行运行,可以后台运行。例如: 240 | ``` 241 | tmux new -s vits2 242 | conda activate vits2 243 | ``` 244 | 再执行 245 | ``` 246 | # 输入三个参数:新建的配置文件,微调输出目录,底模存放目录 247 | # 该代码会自动复制底模文件 到 微调输出目录。 避免加载不到底模。 248 | python train_ms.py -c A5_finetuned_trainingout/SSB0005/config.json \ 249 | -m A5_finetuned_trainingout/SSB0005 \ 250 | -mb A1_pretrained_models/Bert-VITS2_2.3 251 | ``` 252 | 如果希望控制多卡机器,使用那张卡去训练,请加入环境变量控制。 253 | 下面的命令指定用显卡1,2进行双卡训练。 254 | ```bash 255 | CUDA_VISIBLE_DEVICES=1,2 torchrun --nproc_per_node=2 train_ms.py -c A5_finetuned_trainingout/SSB0005/config.json \ 256 | -m A5_finetuned_trainingout/SSB0005 \ 257 | -mb A1_pretrained_models/Bert-VITS2_2.3 258 | ``` 259 | 260 | ## 3.8 微调推理 261 | 在文件 A31_singleinfer.py 修改 config文件、生成器文件、speaker name即可 262 | ``` 263 | ## 一、 超参数加载。 264 | hps = get_hparams_from_file(config_path="A5_finetuned_trainingout/SSB0005/config.json") 265 | device = "cuda:0" 266 | model_path = "A5_finetuned_trainingout/SSB0005/models/G_1000.pth" 267 | ## 268 | speaker_name = "SSB0005" 269 | language = "ZH" 270 | ``` 271 | 272 | # 四、fastapi部署服务 273 | 注意修改模型初始化代码,然后启动服务器程序。 274 | ``` 275 | # 加载模型 276 | def load_model(): 277 | global model, hps 278 | if model is None and hps is None: 279 | hps = get_hparams_from_file(config_path="configs/config.json") ## 注意填写这两项 280 | model_path = "A1_pretrained_models/Bert-VITS2_2.3/G_0.pth" ## 注意填写这两项 281 | ``` 282 | ```bash 283 | python A35_inference_server.py 284 | ``` 285 | 请求程序: 286 | ```bash 287 | python A36_post.py 288 | ## 注意post的时候,写对下面的各项内容。 要和自己训练的模型对应。 289 | request_data = { 290 | "speaker_name": "八重神子_ZH", 291 | "language": "ZH", 292 | "length_scale": 1.2, 293 | "infer_text": "即使引导已经破碎,也请觐见艾尔登法环", 294 | "infer_id": 4, 295 | "sdp_ratio": 0.4, 296 | 297 | } 298 | ## 这里用 infer_id 这个变量控制输出语音的路径。 id是index的意思。而不是identity。 299 | a = request_data["infer_id"] 300 | request_data["output_path"] = f"A4_model_output/SSB0005_{a}.wav" ## 注意这里的输出目录。得存在。 301 | ``` 302 | 303 | ## 4.1 接口的推理速度: 304 | 推理产生3.4秒的音频,文本长度为18,平均用了0.15秒 305 | ``` 306 | {'output_path': 'A4_model_output/SSB0005_4.wav', 'out_info1': '{audiolen:3.4946,textlen:18.0000,usetime:0.1406}'} 307 | ``` 308 | 309 | # 五、额外工具 310 | ## 5.1 可以删除大于一定步数的模型: 311 | ```bash 312 | python A3_scripts/A34_deleteModels.py 313 | ``` 314 | ## 5.2 利用VAD模型自动切割语音 315 | 请参考下面2个文件。使用方法不再赘述 316 | ``` 317 | A37-ffmpeg.py 318 | A38-VAD_batch.py 319 | ``` 320 | 321 | ## 5.3 可以一键启动全部流程.只需放好 A2_prepared_audios/gentle_girl 数据 322 | ```bash 323 | A40_一键启动微调pipeline.sh 324 | ``` 325 | 326 | ## 感谢所有贡献者作出的努力 327 | GitYesm 328 | -------------------------------------------------------------------------------- /__pycache__/attentions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/attentions.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/commons.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/commons.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/data_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/data_utils.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/mel_processing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/mel_processing.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /bert_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from multiprocessing import Pool 3 | import commons 4 | import utils 5 | from tqdm import tqdm 6 | #from text import check_bert_models, cleaned_text_to_sequence, get_bert 7 | from text import cleaned_text_to_sequence, get_bert 8 | 9 | import argparse 10 | import torch.multiprocessing as mp 11 | from config import config 12 | 13 | 14 | def process_line(x): 15 | line, add_blank = x 16 | device = config.bert_gen_config.device 17 | if config.bert_gen_config.use_multi_device: 18 | rank = mp.current_process()._identity 19 | rank = rank[0] if len(rank) > 0 else 0 20 | if torch.cuda.is_available(): 21 | gpu_id = rank % torch.cuda.device_count() 22 | device = torch.device(f"cuda:{gpu_id}") 23 | else: 24 | device = torch.device("cpu") 25 | wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|") 26 | phone = phones.split(" ") 27 | tone = [int(i) for i in tone.split(" ")] 28 | word2ph = [int(i) for i in word2ph.split(" ")] 29 | word2ph = [i for i in word2ph] 30 | phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) 31 | 32 | if add_blank: 33 | phone = commons.intersperse(phone, 0) 34 | tone = commons.intersperse(tone, 0) 35 | language = commons.intersperse(language, 0) 36 | for i in range(len(word2ph)): 37 | word2ph[i] = word2ph[i] * 2 38 | word2ph[0] += 1 39 | 40 | bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt") 41 | 42 | try: 43 | bert = torch.load(bert_path) 44 | assert bert.shape[0] == 2048 45 | except Exception: 46 | bert = get_bert(text, word2ph, language_str, device) 47 | assert bert.shape[-1] == len(phone) 48 | torch.save(bert, bert_path) 49 | 50 | 51 | preprocess_text_config = config.preprocess_text_config 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument( 56 | "-c", "--config", type=str, default=config.bert_gen_config.config_path 57 | ) 58 | parser.add_argument( 59 | "--num_processes", type=int, default=config.bert_gen_config.num_processes 60 | ) 61 | args, _ = parser.parse_known_args() 62 | config_path = args.config 63 | hps = utils.get_hparams_from_file(config_path) 64 | # check_bert_models() 65 | lines = [] 66 | with open(hps.data.training_files, encoding="utf-8") as f: 67 | lines.extend(f.readlines()) 68 | 69 | with open(hps.data.validation_files, encoding="utf-8") as f: 70 | lines.extend(f.readlines()) 71 | add_blank = [hps.data.add_blank] * len(lines) 72 | 73 | if len(lines) != 0: 74 | num_processes = args.num_processes 75 | with Pool(processes=num_processes) as pool: 76 | for _ in tqdm( 77 | pool.imap_unordered(process_line, zip(lines, add_blank)), 78 | total=len(lines), 79 | ): 80 | # 这里是缩进的代码块,表示循环体 81 | pass # 使用pass语句作为占位符 82 | 83 | print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!") 84 | 85 | ### python bert_gen.py -c A5_finetuned_trainingout/SSB0005_50/config.json 86 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def init_weights(m, mean=0.0, std=0.01): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | m.weight.data.normal_(mean, std) 10 | 11 | 12 | def get_padding(kernel_size, dilation=1): 13 | return int((kernel_size * dilation - dilation) / 2) 14 | 15 | 16 | def convert_pad_shape(pad_shape): 17 | layer = pad_shape[::-1] 18 | pad_shape = [item for sublist in layer for item in sublist] 19 | return pad_shape 20 | 21 | 22 | def intersperse(lst, item): 23 | result = [item] * (len(lst) * 2 + 1) 24 | result[1::2] = lst 25 | return result 26 | 27 | 28 | def kl_divergence(m_p, logs_p, m_q, logs_q): 29 | """KL(P||Q)""" 30 | kl = (logs_q - logs_p) - 0.5 31 | kl += ( 32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 33 | ) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | gather_indices = ids_str.view(x.size(0), 1, 1).repeat( 50 | 1, x.size(1), 1 51 | ) + torch.arange(segment_size, device=x.device) 52 | return torch.gather(x, 2, gather_indices) 53 | 54 | 55 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 56 | b, d, t = x.size() 57 | if x_lengths is None: 58 | x_lengths = t 59 | ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) 60 | ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) 61 | ret = slice_segments(x, ids_str, segment_size) 62 | return ret, ids_str 63 | 64 | 65 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 66 | position = torch.arange(length, dtype=torch.float) 67 | num_timescales = channels // 2 68 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 69 | num_timescales - 1 70 | ) 71 | inv_timescales = min_timescale * torch.exp( 72 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 73 | ) 74 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 75 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 76 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 77 | signal = signal.view(1, channels, length) 78 | return signal 79 | 80 | 81 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 82 | b, channels, length = x.size() 83 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 84 | return x + signal.to(dtype=x.dtype, device=x.device) 85 | 86 | 87 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 88 | b, channels, length = x.size() 89 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 90 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 91 | 92 | 93 | def subsequent_mask(length): 94 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 95 | return mask 96 | 97 | 98 | @torch.jit.script 99 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 100 | n_channels_int = n_channels[0] 101 | in_act = input_a + input_b 102 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 103 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 104 | acts = t_act * s_act 105 | return acts 106 | 107 | 108 | def convert_pad_shape(pad_shape): 109 | layer = pad_shape[::-1] 110 | pad_shape = [item for sublist in layer for item in sublist] 111 | return pad_shape 112 | 113 | 114 | def shift_1d(x): 115 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 116 | return x 117 | 118 | 119 | def sequence_mask(length, max_length=None): 120 | if max_length is None: 121 | max_length = length.max() 122 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 123 | return x.unsqueeze(0) < length.unsqueeze(1) 124 | 125 | 126 | def generate_path(duration, mask): 127 | """ 128 | duration: [b, 1, t_x] 129 | mask: [b, 1, t_y, t_x] 130 | """ 131 | 132 | b, _, t_y, t_x = mask.shape 133 | cum_duration = torch.cumsum(duration, -1) 134 | 135 | cum_duration_flat = cum_duration.view(b * t_x) 136 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 137 | path = path.view(b, t_x, t_y) 138 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 139 | path = path.unsqueeze(1).transpose(2, 3) * mask 140 | return path 141 | 142 | 143 | def clip_grad_value_(parameters, clip_value, norm_type=2): 144 | if isinstance(parameters, torch.Tensor): 145 | parameters = [parameters] 146 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 147 | norm_type = float(norm_type) 148 | if clip_value is not None: 149 | clip_value = float(clip_value) 150 | 151 | total_norm = 0 152 | for p in parameters: 153 | param_norm = p.grad.data.norm(norm_type) 154 | total_norm += param_norm.item() ** norm_type 155 | if clip_value is not None: 156 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 157 | total_norm = total_norm ** (1.0 / norm_type) 158 | return total_norm 159 | -------------------------------------------------------------------------------- /compress_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from text.symbols import symbols 3 | import torch 4 | 5 | from tools.log import logger 6 | import utils 7 | from models import SynthesizerTrn 8 | import os 9 | 10 | 11 | def copyStateDict(state_dict): 12 | if list(state_dict.keys())[0].startswith("module"): 13 | start_idx = 1 14 | else: 15 | start_idx = 0 16 | new_state_dict = OrderedDict() 17 | for k, v in state_dict.items(): 18 | name = ",".join(k.split(".")[start_idx:]) 19 | new_state_dict[name] = v 20 | return new_state_dict 21 | 22 | 23 | def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str): 24 | hps = utils.get_hparams_from_file(config) 25 | 26 | net_g = SynthesizerTrn( 27 | len(symbols), 28 | hps.data.filter_length // 2 + 1, 29 | hps.train.segment_size // hps.data.hop_length, 30 | n_speakers=hps.data.n_speakers, 31 | **hps.model, 32 | ) 33 | 34 | optim_g = torch.optim.AdamW( 35 | net_g.parameters(), 36 | hps.train.learning_rate, 37 | betas=hps.train.betas, 38 | eps=hps.train.eps, 39 | ) 40 | 41 | state_dict_g = torch.load(input_model, map_location="cpu") 42 | new_dict_g = copyStateDict(state_dict_g) 43 | keys = [] 44 | for k, v in new_dict_g["model"].items(): 45 | if "enc_q" in k: 46 | continue # noqa: E701 47 | keys.append(k) 48 | 49 | new_dict_g = ( 50 | {k: new_dict_g["model"][k].half() for k in keys} 51 | if ishalf 52 | else {k: new_dict_g["model"][k] for k in keys} 53 | ) 54 | 55 | torch.save( 56 | { 57 | "model": new_dict_g, 58 | "iteration": 0, 59 | "optimizer": optim_g.state_dict(), 60 | "learning_rate": 0.0001, 61 | }, 62 | output_model, 63 | ) 64 | 65 | 66 | if __name__ == "__main__": 67 | import argparse 68 | 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("-c", "--config", type=str, default="configs/config.json") 71 | parser.add_argument("-i", "--input", type=str) 72 | parser.add_argument("-o", "--output", type=str, default=None) 73 | parser.add_argument( 74 | "-hf", "--half", action="store_true", default=False, help="Save as FP16" 75 | ) 76 | 77 | args = parser.parse_args() 78 | 79 | output = args.output 80 | 81 | if output is None: 82 | import os.path 83 | 84 | filename, ext = os.path.splitext(args.input) 85 | half = "_half" if args.half else "" 86 | output = filename + "_release" + half + ext 87 | 88 | removeOptimizer(args.config, args.input, args.half, output) 89 | logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}") 90 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 全局配置文件读取 3 | """ 4 | 5 | import argparse 6 | import yaml 7 | from typing import Dict, List 8 | import os 9 | import shutil 10 | import sys 11 | 12 | 13 | class Resample_config: 14 | """重采样配置""" 15 | 16 | def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100): 17 | self.sampling_rate: int = sampling_rate # 目标采样率 18 | self.in_dir: str = in_dir # 待处理音频目录路径 19 | self.out_dir: str = out_dir # 重采样输出路径 20 | 21 | @classmethod 22 | def from_dict(cls, dataset_path: str, data: Dict[str, any]): 23 | """从字典中生成实例""" 24 | 25 | # 不检查路径是否有效,此逻辑在resample.py中处理 26 | data["in_dir"] = os.path.join(dataset_path, data["in_dir"]) 27 | data["out_dir"] = os.path.join(dataset_path, data["out_dir"]) 28 | 29 | return cls(**data) 30 | 31 | 32 | class Preprocess_text_config: 33 | """数据预处理配置""" 34 | 35 | def __init__( 36 | self, 37 | transcription_path: str, 38 | cleaned_path: str, 39 | train_path: str, 40 | val_path: str, 41 | config_path: str, 42 | val_per_lang: int = 5, 43 | max_val_total: int = 10000, 44 | clean: bool = True, 45 | ): 46 | self.transcription_path: str = ( 47 | transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。 48 | ) 49 | self.cleaned_path: str = ( 50 | cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成 51 | ) 52 | self.train_path: str = ( 53 | train_path # 训练集路径,可以不填。不填则将在原始文本目录生成 54 | ) 55 | self.val_path: str = ( 56 | val_path # 验证集路径,可以不填。不填则将在原始文本目录生成 57 | ) 58 | self.config_path: str = config_path # 配置文件路径 59 | self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数 60 | self.max_val_total: int = ( 61 | max_val_total # 验证集最大条数,多于的会被截断并放到训练集中 62 | ) 63 | self.clean: bool = clean # 是否进行数据清洗 64 | 65 | @classmethod 66 | def from_dict(cls, dataset_path: str, data: Dict[str, any]): 67 | """从字典中生成实例""" 68 | 69 | data["transcription_path"] = os.path.join( 70 | dataset_path, data["transcription_path"] 71 | ) 72 | if data["cleaned_path"] == "" or data["cleaned_path"] is None: 73 | data["cleaned_path"] = None 74 | else: 75 | data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"]) 76 | data["train_path"] = os.path.join(dataset_path, data["train_path"]) 77 | data["val_path"] = os.path.join(dataset_path, data["val_path"]) 78 | data["config_path"] = os.path.join(dataset_path, data["config_path"]) 79 | 80 | return cls(**data) 81 | 82 | 83 | class Bert_gen_config: 84 | """bert_gen 配置""" 85 | 86 | def __init__( 87 | self, 88 | config_path: str, 89 | num_processes: int = 2, 90 | device: str = "cuda", 91 | use_multi_device: bool = False, 92 | ): 93 | self.config_path = config_path 94 | self.num_processes = num_processes 95 | self.device = device 96 | self.use_multi_device = use_multi_device 97 | 98 | @classmethod 99 | def from_dict(cls, dataset_path: str, data: Dict[str, any]): 100 | data["config_path"] = os.path.join(dataset_path, data["config_path"]) 101 | 102 | return cls(**data) 103 | 104 | 105 | class Emo_gen_config: 106 | """emo_gen 配置""" 107 | 108 | def __init__( 109 | self, 110 | config_path: str, 111 | num_processes: int = 2, 112 | device: str = "cuda", 113 | use_multi_device: bool = False, 114 | ): 115 | self.config_path = config_path 116 | self.num_processes = num_processes 117 | self.device = device 118 | self.use_multi_device = use_multi_device 119 | 120 | @classmethod 121 | def from_dict(cls, dataset_path: str, data: Dict[str, any]): 122 | data["config_path"] = os.path.join(dataset_path, data["config_path"]) 123 | 124 | return cls(**data) 125 | 126 | 127 | class Train_ms_config: 128 | """训练配置""" 129 | 130 | def __init__( 131 | self, 132 | config_path: str, 133 | env: Dict[str, any], 134 | base: Dict[str, any], 135 | model: str, 136 | num_workers: int, 137 | spec_cache: bool, 138 | keep_ckpts: int, 139 | ): 140 | self.env = env # 需要加载的环境变量 141 | self.base = base # 底模配置 142 | self.model = ( 143 | model # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录 144 | ) 145 | self.config_path = config_path # 配置文件路径 146 | self.num_workers = num_workers # worker数量 147 | self.spec_cache = spec_cache # 是否启用spec缓存 148 | self.keep_ckpts = keep_ckpts # ckpt数量 149 | 150 | @classmethod 151 | def from_dict(cls, dataset_path: str, data: Dict[str, any]): 152 | # data["model"] = os.path.join(dataset_path, data["model"]) 153 | data["config_path"] = os.path.join(dataset_path, data["config_path"]) 154 | 155 | return cls(**data) 156 | 157 | 158 | class Webui_config: 159 | """webui 配置""" 160 | 161 | def __init__( 162 | self, 163 | device: str, 164 | model: str, 165 | config_path: str, 166 | language_identification_library: str, 167 | port: int = 7860, 168 | share: bool = False, 169 | debug: bool = False, 170 | ): 171 | self.device: str = device 172 | self.model: str = model # 端口号 173 | self.config_path: str = config_path # 是否公开部署,对外网开放 174 | self.port: int = port # 是否开启debug模式 175 | self.share: bool = share # 模型路径 176 | self.debug: bool = debug # 配置文件路径 177 | self.language_identification_library: str = ( 178 | language_identification_library # 语种识别库 179 | ) 180 | 181 | @classmethod 182 | def from_dict(cls, dataset_path: str, data: Dict[str, any]): 183 | data["config_path"] = os.path.join(dataset_path, data["config_path"]) 184 | data["model"] = os.path.join(dataset_path, data["model"]) 185 | return cls(**data) 186 | 187 | 188 | class Server_config: 189 | def __init__( 190 | self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda" 191 | ): 192 | self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置 193 | self.port: int = port # 端口号 194 | self.device: str = device # 模型默认使用设备 195 | 196 | @classmethod 197 | def from_dict(cls, data: Dict[str, any]): 198 | return cls(**data) 199 | 200 | 201 | class Translate_config: 202 | """翻译api配置""" 203 | 204 | def __init__(self, app_key: str, secret_key: str): 205 | self.app_key = app_key 206 | self.secret_key = secret_key 207 | 208 | @classmethod 209 | def from_dict(cls, data: Dict[str, any]): 210 | return cls(**data) 211 | 212 | 213 | class Config: 214 | def __init__(self, config_path: str): 215 | if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"): 216 | shutil.copy(src="default_config.yml", dst=config_path) 217 | print( 218 | f"已根据默认配置文件default_config.yml生成配置文件{config_path}。请按该配置文件的说明进行配置后重新运行。" 219 | ) 220 | print("如无特殊需求,请勿修改default_config.yml或备份该文件。") 221 | sys.exit(0) 222 | with open(file=config_path, mode="r", encoding="utf-8") as file: 223 | yaml_config: Dict[str, any] = yaml.safe_load(file.read()) 224 | dataset_path: str = yaml_config["dataset_path"] 225 | openi_token: str = yaml_config["openi_token"] 226 | self.dataset_path: str = dataset_path 227 | self.mirror: str = yaml_config["mirror"] 228 | self.openi_token: str = openi_token 229 | self.resample_config: Resample_config = Resample_config.from_dict( 230 | dataset_path, yaml_config["resample"] 231 | ) 232 | self.preprocess_text_config: Preprocess_text_config = ( 233 | Preprocess_text_config.from_dict( 234 | dataset_path, yaml_config["preprocess_text"] 235 | ) 236 | ) 237 | self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict( 238 | dataset_path, yaml_config["bert_gen"] 239 | ) 240 | self.emo_gen_config: Emo_gen_config = Emo_gen_config.from_dict( 241 | dataset_path, yaml_config["emo_gen"] 242 | ) 243 | self.train_ms_config: Train_ms_config = Train_ms_config.from_dict( 244 | dataset_path, yaml_config["train_ms"] 245 | ) 246 | self.webui_config: Webui_config = Webui_config.from_dict( 247 | dataset_path, yaml_config["webui"] 248 | ) 249 | self.server_config: Server_config = Server_config.from_dict( 250 | yaml_config["server"] 251 | ) 252 | self.translate_config: Translate_config = Translate_config.from_dict( 253 | yaml_config["translate"] 254 | ) 255 | 256 | 257 | parser = argparse.ArgumentParser() 258 | # 为避免与以前的config.json起冲突,将其更名如下 259 | parser.add_argument("-y", "--yml_config", type=str, default="config.yml") 260 | args, _ = parser.parse_known_args() 261 | config = Config(args.yml_config) 262 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | # 全局配置 2 | # 对于希望在同一时间使用多个配置文件的情况,例如两个GPU同时跑两个训练集:通过环境变量指定配置文件,不指定则默认为./config.yml 3 | 4 | # 拟提供通用路径配置,统一存放数据,避免数据放得很乱 5 | # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径 6 | # 不填或者填空则路径为相对于项目根目录的路径 7 | dataset_path: "Data/" 8 | 9 | # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token 10 | mirror: "" 11 | openi_token: "" # openi token 12 | 13 | # resample 音频重采样配置 14 | # 注意, “:” 后需要加空格 15 | resample: 16 | # 目标重采样率 17 | sampling_rate: 44100 18 | # 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样 19 | # 请填入相对于datasetPath的相对路径 20 | in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir 21 | # 音频文件重采样后输出路径 22 | out_dir: "audios/wavs" 23 | 24 | 25 | # preprocess_text 数据集预处理相关配置 26 | # 注意, “:” 后需要加空格 27 | preprocess_text: 28 | # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。 29 | transcription_path: "filelists/你的数据集文本.list" 30 | # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成 31 | cleaned_path: "" 32 | # 训练集路径 33 | train_path: "filelists/train.list" 34 | # 验证集路径 35 | val_path: "filelists/val.list" 36 | # 配置文件路径 37 | config_path: "config.json" 38 | # 每个语言的验证集条数 39 | val_per_lang: 4 40 | # 验证集最大条数,多于的会被截断并放到训练集中 41 | max_val_total: 12 42 | # 是否进行数据清洗 43 | clean: true 44 | 45 | 46 | # bert_gen 相关配置 47 | # 注意, “:” 后需要加空格 48 | bert_gen: 49 | # 训练数据集配置文件路径 50 | config_path: "config.json" 51 | # 并行数 52 | num_processes: 4 53 | # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理 54 | # 该选项同时决定了get_bert_feature的默认设备 55 | device: "cuda" 56 | # 使用多卡推理 57 | use_multi_device: false 58 | 59 | # emo_gen 相关配置 60 | # 注意, “:” 后需要加空格 61 | emo_gen: 62 | # 训练数据集配置文件路径 63 | config_path: "config.json" 64 | # 并行数 65 | num_processes: 4 66 | # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理 67 | device: "cuda" 68 | # 使用多卡推理 69 | use_multi_device: false 70 | 71 | # train 训练配置 72 | # 注意, “:” 后需要加空格 73 | train_ms: 74 | env: 75 | MASTER_ADDR: "localhost" 76 | MASTER_PORT: 10086 77 | WORLD_SIZE: 1 78 | LOCAL_RANK: 0 79 | RANK: 0 80 | # 可以填写任意名的环境变量 81 | # THE_ENV_VAR_YOU_NEED_TO_USE: "1234567" 82 | # 底模设置 83 | base: 84 | use_base_model: false 85 | repo_id: "Stardust_minus/Bert-VITS2" 86 | model_image: "Bert-VITS2_2.3底模" # openi网页的模型名 87 | # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下 88 | model: "models" 89 | # 配置文件路径 90 | config_path: "configs/config.json" 91 | # 训练使用的worker,不建议超过CPU核心数 92 | num_workers: 16 93 | # 关闭此项可以节约接近70%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。 94 | spec_cache: False 95 | # 保存的检查点数量,多于此数目的权重会被删除来节省空间。 96 | keep_ckpts: 8 97 | 98 | 99 | # webui webui配置 100 | # 注意, “:” 后需要加空格 101 | webui: 102 | # 推理设备 103 | device: "cuda" 104 | # 模型路径 105 | model: "models/G_8000.pth" 106 | # 配置文件路径 107 | config_path: "configs/config.json" 108 | # 端口号 109 | port: 7860 110 | # 是否公开部署,对外网开放 111 | share: false 112 | # 是否开启debug模式 113 | debug: false 114 | # 语种识别库,可选langid, fastlid 115 | language_identification_library: "langid" 116 | 117 | 118 | # server-fastapi配置 119 | # 注意, “:” 后需要加空格 120 | # 注意,本配置下的所有配置均为相对于根目录的路径 121 | server: 122 | # 端口号 123 | port: 5000 124 | # 模型默认使用设备:但是当前并没有实现这个配置。 125 | device: "cuda" 126 | # 需要加载的所有模型的配置,可以填多个模型,也可以不填模型,等网页成功后手动加载模型 127 | # 不加载模型的配置格式:删除默认给的两个模型配置,给models赋值 [ ],也就是空列表。参考模型2的speakers 即 models: [ ] 128 | # 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。 129 | # 也可以不填模型,等网页加载成功后手动填写models。 130 | models: 131 | - # 模型的路径 132 | model: "" 133 | # 模型config.json的路径 134 | config: "" 135 | # 模型使用设备,若填写则会覆盖默认配置 136 | device: "cuda" 137 | # 模型默认使用的语言 138 | language: "ZH" 139 | # 模型人物默认参数 140 | # 不必填写所有人物,不填的使用默认值 141 | # 暂时不用填写,当前尚未实现按人区分配置 142 | speakers: 143 | - speaker: "科比" 144 | sdp_ratio: 0.2 145 | noise_scale: 0.6 146 | noise_scale_w: 0.8 147 | length_scale: 1 148 | - speaker: "五条悟" 149 | sdp_ratio: 0.3 150 | noise_scale: 0.7 151 | noise_scale_w: 0.8 152 | length_scale: 0.5 153 | - speaker: "安倍晋三" 154 | sdp_ratio: 0.2 155 | noise_scale: 0.6 156 | noise_scale_w: 0.8 157 | length_scale: 1.2 158 | - # 模型的路径 159 | model: "" 160 | # 模型config.json的路径 161 | config: "" 162 | # 模型使用设备,若填写则会覆盖默认配置 163 | device: "cpu" 164 | # 模型默认使用的语言 165 | language: "JP" 166 | # 模型人物默认参数 167 | # 不必填写所有人物,不填的使用默认值 168 | speakers: [ ] # 也可以不填 169 | 170 | # 百度翻译开放平台 api配置 171 | # api接入文档 https://api.fanyi.baidu.com/doc/21 172 | # 请不要在github等网站公开分享你的app id 与 key 173 | translate: 174 | # 你的APPID 175 | "app_key": "" 176 | # 你的密钥 177 | "secret_key": "" 178 | -------------------------------------------------------------------------------- /css/custom.css: -------------------------------------------------------------------------------- 1 | 2 | #yml_code { 3 | height: 600px; 4 | flex-grow: inherit; 5 | overflow-y: auto; 6 | } 7 | 8 | #json_code { 9 | height: 600px; 10 | flex-grow: inherit; 11 | overflow-y: auto; 12 | } 13 | 14 | #gpu_code { 15 | height: 300px; 16 | flex-grow: inherit; 17 | overflow-y: auto; 18 | } 19 | -------------------------------------------------------------------------------- /default_config.yml: -------------------------------------------------------------------------------- 1 | # 全局配置 2 | # 对于希望在同一时间使用多个配置文件的情况,例如两个GPU同时跑两个训练集:通过环境变量指定配置文件,不指定则默认为./config.yml 3 | 4 | # 拟提供通用路径配置,统一存放数据,避免数据放得很乱 5 | # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径 6 | # 不填或者填空则路径为相对于项目根目录的路径 7 | dataset_path: "Data/" 8 | 9 | # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token 10 | mirror: "" 11 | openi_token: "" # openi token 12 | 13 | # resample 音频重采样配置 14 | # 注意, “:” 后需要加空格 15 | resample: 16 | # 目标重采样率 17 | sampling_rate: 44100 18 | # 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样 19 | # 请填入相对于datasetPath的相对路径 20 | in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir 21 | # 音频文件重采样后输出路径 22 | out_dir: "audios/wavs" 23 | 24 | 25 | # preprocess_text 数据集预处理相关配置 26 | # 注意, “:” 后需要加空格 27 | preprocess_text: 28 | # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。 29 | transcription_path: "filelists/你的数据集文本.list" 30 | # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成 31 | cleaned_path: "" 32 | # 训练集路径 33 | train_path: "filelists/train.list" 34 | # 验证集路径 35 | val_path: "filelists/val.list" 36 | # 配置文件路径 37 | config_path: "config.json" 38 | # 每个语言的验证集条数 39 | val_per_lang: 4 40 | # 验证集最大条数,多于的会被截断并放到训练集中 41 | max_val_total: 12 42 | # 是否进行数据清洗 43 | clean: true 44 | 45 | 46 | # bert_gen 相关配置 47 | # 注意, “:” 后需要加空格 48 | bert_gen: 49 | # 训练数据集配置文件路径 50 | config_path: "config.json" 51 | # 并行数 52 | num_processes: 4 53 | # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理 54 | # 该选项同时决定了get_bert_feature的默认设备 55 | device: "cuda" 56 | # 使用多卡推理 57 | use_multi_device: false 58 | 59 | # emo_gen 相关配置 60 | # 注意, “:” 后需要加空格 61 | emo_gen: 62 | # 训练数据集配置文件路径 63 | config_path: "config.json" 64 | # 并行数 65 | num_processes: 4 66 | # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理 67 | device: "cuda" 68 | # 使用多卡推理 69 | use_multi_device: false 70 | 71 | # train 训练配置 72 | # 注意, “:” 后需要加空格 73 | train_ms: 74 | env: 75 | MASTER_ADDR: "localhost" 76 | MASTER_PORT: 10086 77 | WORLD_SIZE: 1 78 | LOCAL_RANK: 0 79 | RANK: 0 80 | # 可以填写任意名的环境变量 81 | # THE_ENV_VAR_YOU_NEED_TO_USE: "1234567" 82 | # 底模设置 83 | base: 84 | use_base_model: false 85 | repo_id: "Stardust_minus/Bert-VITS2" 86 | model_image: "Bert-VITS2_2.3底模" # openi网页的模型名 87 | # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下 88 | model: "models" 89 | # 配置文件路径 90 | config_path: "configs/config.json" 91 | # 训练使用的worker,不建议超过CPU核心数 92 | num_workers: 16 93 | # 关闭此项可以节约接近70%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。 94 | spec_cache: False 95 | # 保存的检查点数量,多于此数目的权重会被删除来节省空间。 96 | keep_ckpts: 8 97 | 98 | 99 | # webui webui配置 100 | # 注意, “:” 后需要加空格 101 | webui: 102 | # 推理设备 103 | device: "cuda" 104 | # 模型路径 105 | model: "models/G_8000.pth" 106 | # 配置文件路径 107 | config_path: "configs/config.json" 108 | # 端口号 109 | port: 7860 110 | # 是否公开部署,对外网开放 111 | share: false 112 | # 是否开启debug模式 113 | debug: false 114 | # 语种识别库,可选langid, fastlid 115 | language_identification_library: "langid" 116 | 117 | 118 | # server-fastapi配置 119 | # 注意, “:” 后需要加空格 120 | # 注意,本配置下的所有配置均为相对于根目录的路径 121 | server: 122 | # 端口号 123 | port: 5000 124 | # 模型默认使用设备:但是当前并没有实现这个配置。 125 | device: "cuda" 126 | # 需要加载的所有模型的配置,可以填多个模型,也可以不填模型,等网页成功后手动加载模型 127 | # 不加载模型的配置格式:删除默认给的两个模型配置,给models赋值 [ ],也就是空列表。参考模型2的speakers 即 models: [ ] 128 | # 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。 129 | # 也可以不填模型,等网页加载成功后手动填写models。 130 | models: 131 | - # 模型的路径 132 | model: "" 133 | # 模型config.json的路径 134 | config: "" 135 | # 模型使用设备,若填写则会覆盖默认配置 136 | device: "cuda" 137 | # 模型默认使用的语言 138 | language: "ZH" 139 | # 模型人物默认参数 140 | # 不必填写所有人物,不填的使用默认值 141 | # 暂时不用填写,当前尚未实现按人区分配置 142 | speakers: 143 | - speaker: "科比" 144 | sdp_ratio: 0.2 145 | noise_scale: 0.6 146 | noise_scale_w: 0.8 147 | length_scale: 1 148 | - speaker: "五条悟" 149 | sdp_ratio: 0.3 150 | noise_scale: 0.7 151 | noise_scale_w: 0.8 152 | length_scale: 0.5 153 | - speaker: "安倍晋三" 154 | sdp_ratio: 0.2 155 | noise_scale: 0.6 156 | noise_scale_w: 0.8 157 | length_scale: 1.2 158 | - # 模型的路径 159 | model: "" 160 | # 模型config.json的路径 161 | config: "" 162 | # 模型使用设备,若填写则会覆盖默认配置 163 | device: "cpu" 164 | # 模型默认使用的语言 165 | language: "JP" 166 | # 模型人物默认参数 167 | # 不必填写所有人物,不填的使用默认值 168 | speakers: [ ] # 也可以不填 169 | 170 | # 百度翻译开放平台 api配置 171 | # api接入文档 https://api.fanyi.baidu.com/doc/21 172 | # 请不要在github等网站公开分享你的app id 与 key 173 | translate: 174 | # 你的APPID 175 | "app_key": "" 176 | # 你的密钥 177 | "secret_key": "" 178 | -------------------------------------------------------------------------------- /docs/011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/011.png -------------------------------------------------------------------------------- /docs/SSB00050007.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/SSB00050007.wav -------------------------------------------------------------------------------- /docs/SSB0005_50_月光.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/SSB0005_50_月光.wav -------------------------------------------------------------------------------- /docs/SSB0005_50_根据.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/SSB0005_50_根据.wav -------------------------------------------------------------------------------- /docs/gentel_truth0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/gentel_truth0.wav -------------------------------------------------------------------------------- /docs/gentle_girl_月光.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/gentle_girl_月光.wav -------------------------------------------------------------------------------- /docs/gentle_girl_根据.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/gentle_girl_根据.wav -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Speaker Information 7 | 8 | 9 | 10 |
11 | 12 |

Speaker Information

13 |
14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
微调speaker name语音数量/个语音总时长/min性别
gentle_girl1939.12
SSB0005503.86
38 |
39 | 40 | 41 |

gentle_girl合成demo

42 |
43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 72 | 73 | 74 | 75 | 81 | 82 | 83 |
参考文本音频
长时间给他进行错误的心理暗示 54 | 58 |
合成文本音频
今夜的月光如此清亮,不做些什么真是浪费。随我一同去月下漫步吧,不许拒绝 67 | 71 |
根据我们上面的描述,我们的目标是希望获取一组相对最优的参数来作为模型的初始化参数 76 | 80 |
84 |
85 | 86 | 87 |

SSB005_50合成demo

88 |
89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 118 | 119 | 120 | 121 | 127 | 128 | 129 |
参考文本音频
广州大学女尸案嫌疑人辩称死者为女友。 100 | 104 |
合成文本音频
今夜的月光如此清亮,不做些什么真是浪费。随我一同去月下漫步吧,不许拒绝 113 | 117 |
根据我们上面的描述,我们的目标是希望获取一组相对最优的参数来作为模型的初始化参数 122 | 126 |
130 |
131 |
132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /docs/合成文本.txt: -------------------------------------------------------------------------------- 1 | 根据我们上面的描述,我们的目标是希望获取一组相对最优的参数来作为模型的初始化参数 2 | 今夜的月光如此清亮,不做些什么真是浪费。随我一同去月下漫步吧,不许拒绝 -------------------------------------------------------------------------------- /docs/语音学习群.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/docs/语音学习群.png -------------------------------------------------------------------------------- /emotional/clap-htsat-fused/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | -------------------------------------------------------------------------------- /emotional/clap-htsat-fused/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: apache-2.0 3 | --- 4 | # Model card for CLAP 5 | 6 | Model card for CLAP: Contrastive Language-Audio Pretraining 7 | 8 | ![clap_image](https://s3.amazonaws.com/moonup/production/uploads/1678811100805-62441d1d9fdefb55a0b7d12c.png) 9 | 10 | 11 | # Table of Contents 12 | 13 | 0. [TL;DR](#TL;DR) 14 | 1. [Model Details](#model-details) 15 | 2. [Usage](#usage) 16 | 3. [Uses](#uses) 17 | 4. [Citation](#citation) 18 | 19 | # TL;DR 20 | 21 | The abstract of the paper states that: 22 | 23 | > Contrastive learning has shown remarkable success in the field of multimodal representation learning. In this paper, we propose a pipeline of contrastive language-audio pretraining to develop an audio representation by combining audio data with natural language descriptions. To accomplish this target, we first release LAION-Audio-630K, a large collection of 633,526 audio-text pairs from different data sources. Second, we construct a contrastive language-audio pretraining model by considering different audio encoders and text encoders. We incorporate the feature fusion mechanism and keyword-to-caption augmentation into the model design to further enable the model to process audio inputs of variable lengths and enhance the performance. Third, we perform comprehensive experiments to evaluate our model across three tasks: text-to-audio retrieval, zero-shot audio classification, and supervised audio classification. The results demonstrate that our model achieves superior performance in text-to-audio retrieval task. In audio classification tasks, the model achieves state-of-the-art performance in the zero-shot setting and is able to obtain performance comparable to models' results in the non-zero-shot setting. LAION-Audio-630K and the proposed model are both available to the public. 24 | 25 | 26 | # Usage 27 | 28 | You can use this model for zero shot audio classification or extracting audio and/or textual features. 29 | 30 | # Uses 31 | 32 | ## Perform zero-shot audio classification 33 | 34 | ### Using `pipeline` 35 | 36 | ```python 37 | from datasets import load_dataset 38 | from transformers import pipeline 39 | 40 | dataset = load_dataset("ashraq/esc50") 41 | audio = dataset["train"]["audio"][-1]["array"] 42 | 43 | audio_classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-fused") 44 | output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"]) 45 | print(output) 46 | >>> [{"score": 0.999, "label": "Sound of a dog"}, {"score": 0.001, "label": "Sound of vaccum cleaner"}] 47 | ``` 48 | 49 | ## Run the model: 50 | 51 | You can also get the audio and text embeddings using `ClapModel` 52 | 53 | ### Run the model on CPU: 54 | 55 | ```python 56 | from datasets import load_dataset 57 | from transformers import ClapModel, ClapProcessor 58 | 59 | librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 60 | audio_sample = librispeech_dummy[0] 61 | 62 | model = ClapModel.from_pretrained("laion/clap-htsat-fused") 63 | processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused") 64 | 65 | inputs = processor(audios=audio_sample["audio"]["array"], return_tensors="pt") 66 | audio_embed = model.get_audio_features(**inputs) 67 | ``` 68 | 69 | ### Run the model on GPU: 70 | 71 | ```python 72 | from datasets import load_dataset 73 | from transformers import ClapModel, ClapProcessor 74 | 75 | librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 76 | audio_sample = librispeech_dummy[0] 77 | 78 | model = ClapModel.from_pretrained("laion/clap-htsat-fused").to(0) 79 | processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused") 80 | 81 | inputs = processor(audios=audio_sample["audio"]["array"], return_tensors="pt").to(0) 82 | audio_embed = model.get_audio_features(**inputs) 83 | ``` 84 | 85 | 86 | # Citation 87 | 88 | If you are using this model for your work, please consider citing the original paper: 89 | ``` 90 | @misc{https://doi.org/10.48550/arxiv.2211.06687, 91 | doi = {10.48550/ARXIV.2211.06687}, 92 | 93 | url = {https://arxiv.org/abs/2211.06687}, 94 | 95 | author = {Wu, Yusong and Chen, Ke and Zhang, Tianyu and Hui, Yuchen and Berg-Kirkpatrick, Taylor and Dubnov, Shlomo}, 96 | 97 | keywords = {Sound (cs.SD), Audio and Speech Processing (eess.AS), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering}, 98 | 99 | title = {Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation}, 100 | 101 | publisher = {arXiv}, 102 | 103 | year = {2022}, 104 | 105 | copyright = {Creative Commons Attribution 4.0 International} 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /emotional/clap-htsat-fused/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_commit_hash": null, 3 | "architectures": [ 4 | "ClapModel" 5 | ], 6 | "audio_config": { 7 | "_name_or_path": "", 8 | "add_cross_attention": false, 9 | "aff_block_r": 4, 10 | "architectures": null, 11 | "attention_probs_dropout_prob": 0.0, 12 | "bad_words_ids": null, 13 | "begin_suppress_tokens": null, 14 | "bos_token_id": null, 15 | "chunk_size_feed_forward": 0, 16 | "cross_attention_hidden_size": null, 17 | "decoder_start_token_id": null, 18 | "depths": [ 19 | 2, 20 | 2, 21 | 6, 22 | 2 23 | ], 24 | "diversity_penalty": 0.0, 25 | "do_sample": false, 26 | "drop_path_rate": 0.0, 27 | "early_stopping": false, 28 | "enable_fusion": true, 29 | "enable_patch_fusion": true, 30 | "enable_patch_layer_norm": true, 31 | "encoder_no_repeat_ngram_size": 0, 32 | "eos_token_id": null, 33 | "exponential_decay_length_penalty": null, 34 | "finetuning_task": null, 35 | "flatten_patch_embeds": true, 36 | "forced_bos_token_id": null, 37 | "forced_eos_token_id": null, 38 | "fusion_num_hidden_layers": 2, 39 | "fusion_type": null, 40 | "hidden_act": "gelu", 41 | "hidden_dropout_prob": 0.1, 42 | "hidden_size": 768, 43 | "id2label": { 44 | "0": "LABEL_0", 45 | "1": "LABEL_1" 46 | }, 47 | "initializer_factor": 1.0, 48 | "is_decoder": false, 49 | "is_encoder_decoder": false, 50 | "label2id": { 51 | "LABEL_0": 0, 52 | "LABEL_1": 1 53 | }, 54 | "layer_norm_eps": 1e-05, 55 | "length_penalty": 1.0, 56 | "max_length": 20, 57 | "min_length": 0, 58 | "mlp_ratio": 4.0, 59 | "model_type": "clap_audio_model", 60 | "no_repeat_ngram_size": 0, 61 | "num_attention_heads": [ 62 | 4, 63 | 8, 64 | 16, 65 | 32 66 | ], 67 | "num_beam_groups": 1, 68 | "num_beams": 1, 69 | "num_classes": 527, 70 | "num_hidden_layers": 4, 71 | "num_mel_bins": 64, 72 | "num_return_sequences": 1, 73 | "output_attentions": false, 74 | "output_hidden_states": false, 75 | "output_scores": false, 76 | "pad_token_id": null, 77 | "patch_embed_input_channels": 1, 78 | "patch_embeds_hidden_size": 96, 79 | "patch_size": 4, 80 | "patch_stride": [ 81 | 4, 82 | 4 83 | ], 84 | "prefix": null, 85 | "problem_type": null, 86 | "projection_dim": 512, 87 | "projection_hidden_act": "relu", 88 | "projection_hidden_size": 768, 89 | "pruned_heads": {}, 90 | "qkv_bias": true, 91 | "remove_invalid_values": false, 92 | "repetition_penalty": 1.0, 93 | "return_dict": true, 94 | "return_dict_in_generate": false, 95 | "sep_token_id": null, 96 | "spec_size": 256, 97 | "suppress_tokens": null, 98 | "task_specific_params": null, 99 | "temperature": 1.0, 100 | "tf_legacy_loss": false, 101 | "tie_encoder_decoder": false, 102 | "tie_word_embeddings": true, 103 | "tokenizer_class": null, 104 | "top_k": 50, 105 | "top_p": 1.0, 106 | "torch_dtype": null, 107 | "torchscript": false, 108 | "transformers_version": "4.27.0.dev0", 109 | "typical_p": 1.0, 110 | "use_bfloat16": false, 111 | "window_size": 8 112 | }, 113 | "hidden_size": 768, 114 | "initializer_factor": 1.0, 115 | "logit_scale_init_value": 14.285714285714285, 116 | "model_type": "clap", 117 | "num_hidden_layers": 16, 118 | "projection_dim": 512, 119 | "projection_hidden_act": "relu", 120 | "text_config": { 121 | "_name_or_path": "", 122 | "add_cross_attention": false, 123 | "architectures": null, 124 | "attention_probs_dropout_prob": 0.1, 125 | "bad_words_ids": null, 126 | "begin_suppress_tokens": null, 127 | "bos_token_id": 0, 128 | "chunk_size_feed_forward": 0, 129 | "classifier_dropout": null, 130 | "cross_attention_hidden_size": null, 131 | "decoder_start_token_id": null, 132 | "diversity_penalty": 0.0, 133 | "do_sample": false, 134 | "early_stopping": false, 135 | "encoder_no_repeat_ngram_size": 0, 136 | "eos_token_id": 2, 137 | "exponential_decay_length_penalty": null, 138 | "finetuning_task": null, 139 | "forced_bos_token_id": null, 140 | "forced_eos_token_id": null, 141 | "fusion_hidden_size": 768, 142 | "fusion_num_hidden_layers": 2, 143 | "hidden_act": "gelu", 144 | "hidden_dropout_prob": 0.1, 145 | "hidden_size": 768, 146 | "id2label": { 147 | "0": "LABEL_0", 148 | "1": "LABEL_1" 149 | }, 150 | "initializer_factor": 1.0, 151 | "initializer_range": 0.02, 152 | "intermediate_size": 3072, 153 | "is_decoder": false, 154 | "is_encoder_decoder": false, 155 | "label2id": { 156 | "LABEL_0": 0, 157 | "LABEL_1": 1 158 | }, 159 | "layer_norm_eps": 1e-12, 160 | "length_penalty": 1.0, 161 | "max_length": 20, 162 | "max_position_embeddings": 514, 163 | "min_length": 0, 164 | "model_type": "clap_text_model", 165 | "no_repeat_ngram_size": 0, 166 | "num_attention_heads": 12, 167 | "num_beam_groups": 1, 168 | "num_beams": 1, 169 | "num_hidden_layers": 12, 170 | "num_return_sequences": 1, 171 | "output_attentions": false, 172 | "output_hidden_states": false, 173 | "output_scores": false, 174 | "pad_token_id": 1, 175 | "position_embedding_type": "absolute", 176 | "prefix": null, 177 | "problem_type": null, 178 | "projection_dim": 512, 179 | "projection_hidden_act": "relu", 180 | "projection_hidden_size": 768, 181 | "pruned_heads": {}, 182 | "remove_invalid_values": false, 183 | "repetition_penalty": 1.0, 184 | "return_dict": true, 185 | "return_dict_in_generate": false, 186 | "sep_token_id": null, 187 | "suppress_tokens": null, 188 | "task_specific_params": null, 189 | "temperature": 1.0, 190 | "tf_legacy_loss": false, 191 | "tie_encoder_decoder": false, 192 | "tie_word_embeddings": true, 193 | "tokenizer_class": null, 194 | "top_k": 50, 195 | "top_p": 1.0, 196 | "torch_dtype": null, 197 | "torchscript": false, 198 | "transformers_version": "4.27.0.dev0", 199 | "type_vocab_size": 1, 200 | "typical_p": 1.0, 201 | "use_bfloat16": false, 202 | "use_cache": true, 203 | "vocab_size": 50265 204 | }, 205 | "torch_dtype": "float32", 206 | "transformers_version": null 207 | } 208 | -------------------------------------------------------------------------------- /emotional/clap-htsat-fused/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chunk_length_s": 10, 3 | "feature_extractor_type": "ClapFeatureExtractor", 4 | "feature_size": 64, 5 | "fft_window_size": 1024, 6 | "frequency_max": 14000, 7 | "frequency_min": 50, 8 | "hop_length": 480, 9 | "max_length_s": 10, 10 | "n_fft": 1024, 11 | "nb_frequency_bins": 513, 12 | "nb_max_frames": 1000, 13 | "nb_max_samples": 480000, 14 | "padding": "repeatpad", 15 | "padding_side": "right", 16 | "padding_value": 0.0, 17 | "processor_class": "ClapProcessor", 18 | "return_attention_mask": false, 19 | "sampling_rate": 48000, 20 | "top_db": null, 21 | "truncation": "fusion" 22 | } 23 | -------------------------------------------------------------------------------- /emotional/clap-htsat-fused/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": "", 3 | "cls_token": "", 4 | "eos_token": "", 5 | "mask_token": { 6 | "content": "", 7 | "lstrip": true, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "pad_token": "", 13 | "sep_token": "", 14 | "unk_token": "" 15 | } 16 | -------------------------------------------------------------------------------- /emotional/clap-htsat-fused/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": "", 4 | "cls_token": "", 5 | "eos_token": "", 6 | "errors": "replace", 7 | "mask_token": "", 8 | "model_max_length": 512, 9 | "pad_token": "", 10 | "processor_class": "ClapProcessor", 11 | "sep_token": "", 12 | "special_tokens_map_file": null, 13 | "tokenizer_class": "RobertaTokenizer", 14 | "trim_offsets": true, 15 | "unk_token": "" 16 | } 17 | -------------------------------------------------------------------------------- /emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bin.* filter=lfs diff=lfs merge=lfs -text 5 | *.bz2 filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.model filter=lfs diff=lfs merge=lfs -text 12 | *.msgpack filter=lfs diff=lfs merge=lfs -text 13 | *.onnx filter=lfs diff=lfs merge=lfs -text 14 | *.ot filter=lfs diff=lfs merge=lfs -text 15 | *.parquet filter=lfs diff=lfs merge=lfs -text 16 | *.pb filter=lfs diff=lfs merge=lfs -text 17 | *.pt filter=lfs diff=lfs merge=lfs -text 18 | *.pth filter=lfs diff=lfs merge=lfs -text 19 | *.rar filter=lfs diff=lfs merge=lfs -text 20 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 21 | *.tar.* filter=lfs diff=lfs merge=lfs -text 22 | *.tflite filter=lfs diff=lfs merge=lfs -text 23 | *.tgz filter=lfs diff=lfs merge=lfs -text 24 | *.wasm filter=lfs diff=lfs merge=lfs -text 25 | *.xz filter=lfs diff=lfs merge=lfs -text 26 | *.zip filter=lfs diff=lfs merge=lfs -text 27 | *.zstandard filter=lfs diff=lfs merge=lfs -text 28 | *tfevents* filter=lfs diff=lfs merge=lfs -text 29 | -------------------------------------------------------------------------------- /emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: en 3 | datasets: 4 | - msp-podcast 5 | inference: true 6 | tags: 7 | - speech 8 | - audio 9 | - wav2vec2 10 | - audio-classification 11 | - emotion-recognition 12 | license: cc-by-nc-sa-4.0 13 | pipeline_tag: audio-classification 14 | --- 15 | 16 | # Model for Dimensional Speech Emotion Recognition based on Wav2vec 2.0 17 | 18 | The model expects a raw audio signal as input and outputs predictions for arousal, dominance and valence in a range of approximately 0...1. In addition, it also provides the pooled states of the last transformer layer. The model was created by fine-tuning [ 19 | Wav2Vec2-Large-Robust](https://huggingface.co/facebook/wav2vec2-large-robust) on [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) (v1.7). The model was pruned from 24 to 12 transformer layers before fine-tuning. An [ONNX](https://onnx.ai/") export of the model is available from [doi:10.5281/zenodo.6221127](https://zenodo.org/record/6221127). Further details are given in the associated [paper](https://arxiv.org/abs/2203.07378) and [tutorial](https://github.com/audeering/w2v2-how-to). 20 | 21 | # Usage 22 | 23 | ```python 24 | import numpy as np 25 | import torch 26 | import torch.nn as nn 27 | from transformers import Wav2Vec2Processor 28 | from transformers.models.wav2vec2.modeling_wav2vec2 import ( 29 | Wav2Vec2Model, 30 | Wav2Vec2PreTrainedModel, 31 | ) 32 | 33 | 34 | class RegressionHead(nn.Module): 35 | r"""Classification head.""" 36 | 37 | def __init__(self, config): 38 | 39 | super().__init__() 40 | 41 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 42 | self.dropout = nn.Dropout(config.final_dropout) 43 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 44 | 45 | def forward(self, features, **kwargs): 46 | 47 | x = features 48 | x = self.dropout(x) 49 | x = self.dense(x) 50 | x = torch.tanh(x) 51 | x = self.dropout(x) 52 | x = self.out_proj(x) 53 | 54 | return x 55 | 56 | 57 | class EmotionModel(Wav2Vec2PreTrainedModel): 58 | r"""Speech emotion classifier.""" 59 | 60 | def __init__(self, config): 61 | 62 | super().__init__(config) 63 | 64 | self.config = config 65 | self.wav2vec2 = Wav2Vec2Model(config) 66 | self.classifier = RegressionHead(config) 67 | self.init_weights() 68 | 69 | def forward( 70 | self, 71 | input_values, 72 | ): 73 | 74 | outputs = self.wav2vec2(input_values) 75 | hidden_states = outputs[0] 76 | hidden_states = torch.mean(hidden_states, dim=1) 77 | logits = self.classifier(hidden_states) 78 | 79 | return hidden_states, logits 80 | 81 | 82 | 83 | # load model from hub 84 | device = 'cpu' 85 | model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim' 86 | processor = Wav2Vec2Processor.from_pretrained(model_name) 87 | model = EmotionModel.from_pretrained(model_name) 88 | 89 | # dummy signal 90 | sampling_rate = 16000 91 | signal = np.zeros((1, sampling_rate), dtype=np.float32) 92 | 93 | 94 | def process_func( 95 | x: np.ndarray, 96 | sampling_rate: int, 97 | embeddings: bool = False, 98 | ) -> np.ndarray: 99 | r"""Predict emotions or extract embeddings from raw audio signal.""" 100 | 101 | # run through processor to normalize signal 102 | # always returns a batch, so we just get the first entry 103 | # then we put it on the device 104 | y = processor(x, sampling_rate=sampling_rate) 105 | y = y['input_values'][0] 106 | y = y.reshape(1, -1) 107 | y = torch.from_numpy(y).to(device) 108 | 109 | # run through model 110 | with torch.no_grad(): 111 | y = model(y)[0 if embeddings else 1] 112 | 113 | # convert to numpy 114 | y = y.detach().cpu().numpy() 115 | 116 | return y 117 | 118 | 119 | print(process_func(signal, sampling_rate)) 120 | # Arousal dominance valence 121 | # [[0.5460754 0.6062266 0.40431657]] 122 | 123 | print(process_func(signal, sampling_rate, embeddings=True)) 124 | # Pooled hidden states of last transformer layer 125 | # [[-0.00752167 0.0065819 -0.00746342 ... 0.00663632 0.00848748 126 | # 0.00599211]] 127 | ``` 128 | -------------------------------------------------------------------------------- /emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "torch", 3 | "activation_dropout": 0.1, 4 | "adapter_kernel_size": 3, 5 | "adapter_stride": 2, 6 | "add_adapter": false, 7 | "apply_spec_augment": true, 8 | "architectures": [ 9 | "Wav2Vec2ForSpeechClassification" 10 | ], 11 | "attention_dropout": 0.1, 12 | "bos_token_id": 1, 13 | "classifier_proj_size": 256, 14 | "codevector_dim": 768, 15 | "contrastive_logits_temperature": 0.1, 16 | "conv_bias": true, 17 | "conv_dim": [ 18 | 512, 19 | 512, 20 | 512, 21 | 512, 22 | 512, 23 | 512, 24 | 512 25 | ], 26 | "conv_kernel": [ 27 | 10, 28 | 3, 29 | 3, 30 | 3, 31 | 3, 32 | 2, 33 | 2 34 | ], 35 | "conv_stride": [ 36 | 5, 37 | 2, 38 | 2, 39 | 2, 40 | 2, 41 | 2, 42 | 2 43 | ], 44 | "ctc_loss_reduction": "sum", 45 | "ctc_zero_infinity": false, 46 | "diversity_loss_weight": 0.1, 47 | "do_stable_layer_norm": true, 48 | "eos_token_id": 2, 49 | "feat_extract_activation": "gelu", 50 | "feat_extract_dropout": 0.0, 51 | "feat_extract_norm": "layer", 52 | "feat_proj_dropout": 0.1, 53 | "feat_quantizer_dropout": 0.0, 54 | "final_dropout": 0.1, 55 | "finetuning_task": "wav2vec2_reg", 56 | "gradient_checkpointing": false, 57 | "hidden_act": "gelu", 58 | "hidden_dropout": 0.1, 59 | "hidden_dropout_prob": 0.1, 60 | "hidden_size": 1024, 61 | "id2label": { 62 | "0": "arousal", 63 | "1": "dominance", 64 | "2": "valence" 65 | }, 66 | "initializer_range": 0.02, 67 | "intermediate_size": 4096, 68 | "label2id": { 69 | "arousal": 0, 70 | "dominance": 1, 71 | "valence": 2 72 | }, 73 | "layer_norm_eps": 1e-05, 74 | "layerdrop": 0.1, 75 | "mask_feature_length": 10, 76 | "mask_feature_min_masks": 0, 77 | "mask_feature_prob": 0.0, 78 | "mask_time_length": 10, 79 | "mask_time_min_masks": 2, 80 | "mask_time_prob": 0.05, 81 | "model_type": "wav2vec2", 82 | "num_adapter_layers": 3, 83 | "num_attention_heads": 16, 84 | "num_codevector_groups": 2, 85 | "num_codevectors_per_group": 320, 86 | "num_conv_pos_embedding_groups": 16, 87 | "num_conv_pos_embeddings": 128, 88 | "num_feat_extract_layers": 7, 89 | "num_hidden_layers": 12, 90 | "num_negatives": 100, 91 | "output_hidden_size": 1024, 92 | "pad_token_id": 0, 93 | "pooling_mode": "mean", 94 | "problem_type": "regression", 95 | "proj_codevector_dim": 768, 96 | "tdnn_dilation": [ 97 | 1, 98 | 2, 99 | 3, 100 | 1, 101 | 1 102 | ], 103 | "tdnn_dim": [ 104 | 512, 105 | 512, 106 | 512, 107 | 512, 108 | 1500 109 | ], 110 | "tdnn_kernel": [ 111 | 5, 112 | 3, 113 | 3, 114 | 1, 115 | 1 116 | ], 117 | "torch_dtype": "float32", 118 | "transformers_version": "4.17.0.dev0", 119 | "use_weighted_layer_sum": false, 120 | "vocab_size": null, 121 | "xvector_output_dim": 512 122 | } 123 | -------------------------------------------------------------------------------- /emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_normalize": true, 3 | "feature_extractor_type": "Wav2Vec2FeatureExtractor", 4 | "feature_size": 1, 5 | "padding_side": "right", 6 | "padding_value": 0.0, 7 | "return_attention_mask": true, 8 | "sampling_rate": 16000 9 | } 10 | -------------------------------------------------------------------------------- /emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim/vocab.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | from onnx_modules import export_onnx 2 | import os 3 | 4 | if __name__ == "__main__": 5 | export_path = "BertVits2.2PT" 6 | model_path = "model\\G_0.pth" 7 | config_path = "model\\config.json" 8 | novq = False 9 | dev = False 10 | Extra = "chinese" # japanese or chinese 11 | if not os.path.exists("onnx"): 12 | os.makedirs("onnx") 13 | if not os.path.exists(f"onnx/{export_path}"): 14 | os.makedirs(f"onnx/{export_path}") 15 | export_onnx(export_path, model_path, config_path, novq, dev, Extra) 16 | -------------------------------------------------------------------------------- /filelists/sample.list: -------------------------------------------------------------------------------- 1 | Example: 2 | {wav_path}|{speaker_name}|{language}|{text} 3 | 派蒙_1.wav|派蒙|ZH|前面的区域,以后再来探索吧! 4 | -------------------------------------------------------------------------------- /for_deploy/infer.py: -------------------------------------------------------------------------------- 1 | """ 2 | 版本管理、兼容推理及模型加载实现。 3 | 版本说明: 4 | 1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号 5 | 2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号" 6 | 特殊版本说明: 7 | 1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复 8 | 2.2:当前版本 9 | """ 10 | 11 | import torch 12 | import commons 13 | from text import cleaned_text_to_sequence 14 | from text.cleaner import clean_text 15 | import utils 16 | import numpy as np 17 | 18 | from models import SynthesizerTrn 19 | from text.symbols import symbols 20 | 21 | from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn 22 | from oldVersion.V210.text import symbols as V210symbols 23 | from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn 24 | from oldVersion.V200.text import symbols as V200symbols 25 | from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn 26 | from oldVersion.V111.text import symbols as V111symbols 27 | from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn 28 | from oldVersion.V110.text import symbols as V110symbols 29 | from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn 30 | from oldVersion.V101.text import symbols as V101symbols 31 | 32 | from oldVersion import V111, V110, V101, V200, V210 33 | 34 | # 当前版本信息 35 | latest_version = "2.2" 36 | 37 | # 版本兼容 38 | SynthesizerTrnMap = { 39 | "2.1": V210SynthesizerTrn, 40 | "2.0.2-fix": V200SynthesizerTrn, 41 | "2.0.1": V200SynthesizerTrn, 42 | "2.0": V200SynthesizerTrn, 43 | "1.1.1-fix": V111SynthesizerTrn, 44 | "1.1.1": V111SynthesizerTrn, 45 | "1.1": V110SynthesizerTrn, 46 | "1.1.0": V110SynthesizerTrn, 47 | "1.0.1": V101SynthesizerTrn, 48 | "1.0": V101SynthesizerTrn, 49 | "1.0.0": V101SynthesizerTrn, 50 | } 51 | 52 | symbolsMap = { 53 | "2.1": V210symbols, 54 | "2.0.2-fix": V200symbols, 55 | "2.0.1": V200symbols, 56 | "2.0": V200symbols, 57 | "1.1.1-fix": V111symbols, 58 | "1.1.1": V111symbols, 59 | "1.1": V110symbols, 60 | "1.1.0": V110symbols, 61 | "1.0.1": V101symbols, 62 | "1.0": V101symbols, 63 | "1.0.0": V101symbols, 64 | } 65 | 66 | 67 | # def get_emo_(reference_audio, emotion, sid): 68 | # emo = ( 69 | # torch.from_numpy(get_emo(reference_audio)) 70 | # if reference_audio and emotion == -1 71 | # else torch.FloatTensor( 72 | # np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy") 73 | # ) 74 | # ) 75 | # return emo 76 | 77 | 78 | def get_net_g(model_path: str, version: str, device: str, hps): 79 | if version != latest_version: 80 | net_g = SynthesizerTrnMap[version]( 81 | len(symbolsMap[version]), 82 | hps.data.filter_length // 2 + 1, 83 | hps.train.segment_size // hps.data.hop_length, 84 | n_speakers=hps.data.n_speakers, 85 | **hps.model, 86 | ).to(device) 87 | else: 88 | # 当前版本模型 net_g 89 | net_g = SynthesizerTrn( 90 | len(symbols), 91 | hps.data.filter_length // 2 + 1, 92 | hps.train.segment_size // hps.data.hop_length, 93 | n_speakers=hps.data.n_speakers, 94 | **hps.model, 95 | ).to(device) 96 | _ = net_g.eval() 97 | _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True) 98 | return net_g 99 | 100 | 101 | def get_text(text, language_str, bert, hps, device): 102 | # 在此处实现当前版本的get_text 103 | norm_text, phone, tone, word2ph = clean_text(text, language_str) 104 | phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) 105 | 106 | if hps.data.add_blank: 107 | phone = commons.intersperse(phone, 0) 108 | tone = commons.intersperse(tone, 0) 109 | language = commons.intersperse(language, 0) 110 | for i in range(len(word2ph)): 111 | word2ph[i] = word2ph[i] * 2 112 | word2ph[0] += 1 113 | # bert_ori = get_bert(norm_text, word2ph, language_str, device) 114 | bert_ori = bert[language_str].get_bert_feature(norm_text, word2ph, device) 115 | del word2ph 116 | assert bert_ori.shape[-1] == len(phone), phone 117 | 118 | if language_str == "ZH": 119 | bert = bert_ori 120 | ja_bert = torch.randn(1024, len(phone)) 121 | en_bert = torch.randn(1024, len(phone)) 122 | elif language_str == "JP": 123 | bert = torch.randn(1024, len(phone)) 124 | ja_bert = bert_ori 125 | en_bert = torch.randn(1024, len(phone)) 126 | elif language_str == "EN": 127 | bert = torch.randn(1024, len(phone)) 128 | ja_bert = torch.randn(1024, len(phone)) 129 | en_bert = bert_ori 130 | else: 131 | raise ValueError("language_str should be ZH, JP or EN") 132 | 133 | assert bert.shape[-1] == len( 134 | phone 135 | ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" 136 | 137 | phone = torch.LongTensor(phone) 138 | tone = torch.LongTensor(tone) 139 | language = torch.LongTensor(language) 140 | return bert, ja_bert, en_bert, phone, tone, language 141 | 142 | 143 | def infer( 144 | text, 145 | emotion, 146 | sdp_ratio, 147 | noise_scale, 148 | noise_scale_w, 149 | length_scale, 150 | sid, 151 | language, 152 | hps, 153 | net_g, 154 | device, 155 | bert=None, 156 | clap=None, 157 | reference_audio=None, 158 | skip_start=False, 159 | skip_end=False, 160 | ): 161 | # 2.2版本参数位置变了 162 | # 2.1 参数新增 emotion reference_audio skip_start skip_end 163 | inferMap_V3 = { 164 | "2.1": V210.infer, 165 | } 166 | # 支持中日英三语版本 167 | inferMap_V2 = { 168 | "2.0.2-fix": V200.infer, 169 | "2.0.1": V200.infer, 170 | "2.0": V200.infer, 171 | "1.1.1-fix": V111.infer_fix, 172 | "1.1.1": V111.infer, 173 | "1.1": V110.infer, 174 | "1.1.0": V110.infer, 175 | } 176 | # 仅支持中文版本 177 | # 在测试中,并未发现两个版本的模型不能互相通用 178 | inferMap_V1 = { 179 | "1.0.1": V101.infer, 180 | "1.0": V101.infer, 181 | "1.0.0": V101.infer, 182 | } 183 | version = hps.version if hasattr(hps, "version") else latest_version 184 | # 非当前版本,根据版本号选择合适的infer 185 | if version != latest_version: 186 | if version in inferMap_V3.keys(): 187 | return inferMap_V3[version]( 188 | text, 189 | sdp_ratio, 190 | noise_scale, 191 | noise_scale_w, 192 | length_scale, 193 | sid, 194 | language, 195 | hps, 196 | net_g, 197 | device, 198 | reference_audio, 199 | emotion, 200 | skip_start, 201 | skip_end, 202 | ) 203 | if version in inferMap_V2.keys(): 204 | return inferMap_V2[version]( 205 | text, 206 | sdp_ratio, 207 | noise_scale, 208 | noise_scale_w, 209 | length_scale, 210 | sid, 211 | language, 212 | hps, 213 | net_g, 214 | device, 215 | ) 216 | if version in inferMap_V1.keys(): 217 | return inferMap_V1[version]( 218 | text, 219 | sdp_ratio, 220 | noise_scale, 221 | noise_scale_w, 222 | length_scale, 223 | sid, 224 | hps, 225 | net_g, 226 | device, 227 | ) 228 | # 在此处实现当前版本的推理 229 | # emo = get_emo_(reference_audio, emotion, sid) 230 | if isinstance(reference_audio, np.ndarray): 231 | emo = clap.get_clap_audio_feature(reference_audio, device) 232 | else: 233 | emo = clap.get_clap_text_feature(emotion, device) 234 | emo = torch.squeeze(emo, dim=1) 235 | 236 | bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( 237 | text, language, bert, hps, device 238 | ) 239 | if skip_start: 240 | phones = phones[3:] 241 | tones = tones[3:] 242 | lang_ids = lang_ids[3:] 243 | bert = bert[:, 3:] 244 | ja_bert = ja_bert[:, 3:] 245 | en_bert = en_bert[:, 3:] 246 | if skip_end: 247 | phones = phones[:-2] 248 | tones = tones[:-2] 249 | lang_ids = lang_ids[:-2] 250 | bert = bert[:, :-2] 251 | ja_bert = ja_bert[:, :-2] 252 | en_bert = en_bert[:, :-2] 253 | with torch.no_grad(): 254 | x_tst = phones.to(device).unsqueeze(0) 255 | tones = tones.to(device).unsqueeze(0) 256 | lang_ids = lang_ids.to(device).unsqueeze(0) 257 | bert = bert.to(device).unsqueeze(0) 258 | ja_bert = ja_bert.to(device).unsqueeze(0) 259 | en_bert = en_bert.to(device).unsqueeze(0) 260 | x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) 261 | emo = emo.to(device).unsqueeze(0) 262 | del phones 263 | speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) 264 | audio = ( 265 | net_g.infer( 266 | x_tst, 267 | x_tst_lengths, 268 | speakers, 269 | tones, 270 | lang_ids, 271 | bert, 272 | ja_bert, 273 | en_bert, 274 | emo, 275 | sdp_ratio=sdp_ratio, 276 | noise_scale=noise_scale, 277 | noise_scale_w=noise_scale_w, 278 | length_scale=length_scale, 279 | )[0][0, 0] 280 | .data.cpu() 281 | .float() 282 | .numpy() 283 | ) 284 | del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo 285 | if torch.cuda.is_available(): 286 | torch.cuda.empty_cache() 287 | return audio 288 | 289 | 290 | def infer_multilang( 291 | text, 292 | sdp_ratio, 293 | noise_scale, 294 | noise_scale_w, 295 | length_scale, 296 | sid, 297 | language, 298 | hps, 299 | net_g, 300 | device, 301 | bert=None, 302 | clap=None, 303 | reference_audio=None, 304 | emotion=None, 305 | skip_start=False, 306 | skip_end=False, 307 | ): 308 | bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], [] 309 | # emo = get_emo_(reference_audio, emotion, sid) 310 | if isinstance(reference_audio, np.ndarray): 311 | emo = clap.get_clap_audio_feature(reference_audio, device) 312 | else: 313 | emo = clap.get_clap_text_feature(emotion, device) 314 | emo = torch.squeeze(emo, dim=1) 315 | for idx, (txt, lang) in enumerate(zip(text, language)): 316 | skip_start = (idx != 0) or (skip_start and idx == 0) 317 | skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1) 318 | ( 319 | temp_bert, 320 | temp_ja_bert, 321 | temp_en_bert, 322 | temp_phones, 323 | temp_tones, 324 | temp_lang_ids, 325 | ) = get_text(txt, lang, bert, hps, device) 326 | if skip_start: 327 | temp_bert = temp_bert[:, 3:] 328 | temp_ja_bert = temp_ja_bert[:, 3:] 329 | temp_en_bert = temp_en_bert[:, 3:] 330 | temp_phones = temp_phones[3:] 331 | temp_tones = temp_tones[3:] 332 | temp_lang_ids = temp_lang_ids[3:] 333 | if skip_end: 334 | temp_bert = temp_bert[:, :-2] 335 | temp_ja_bert = temp_ja_bert[:, :-2] 336 | temp_en_bert = temp_en_bert[:, :-2] 337 | temp_phones = temp_phones[:-2] 338 | temp_tones = temp_tones[:-2] 339 | temp_lang_ids = temp_lang_ids[:-2] 340 | bert.append(temp_bert) 341 | ja_bert.append(temp_ja_bert) 342 | en_bert.append(temp_en_bert) 343 | phones.append(temp_phones) 344 | tones.append(temp_tones) 345 | lang_ids.append(temp_lang_ids) 346 | bert = torch.concatenate(bert, dim=1) 347 | ja_bert = torch.concatenate(ja_bert, dim=1) 348 | en_bert = torch.concatenate(en_bert, dim=1) 349 | phones = torch.concatenate(phones, dim=0) 350 | tones = torch.concatenate(tones, dim=0) 351 | lang_ids = torch.concatenate(lang_ids, dim=0) 352 | with torch.no_grad(): 353 | x_tst = phones.to(device).unsqueeze(0) 354 | tones = tones.to(device).unsqueeze(0) 355 | lang_ids = lang_ids.to(device).unsqueeze(0) 356 | bert = bert.to(device).unsqueeze(0) 357 | ja_bert = ja_bert.to(device).unsqueeze(0) 358 | en_bert = en_bert.to(device).unsqueeze(0) 359 | emo = emo.to(device).unsqueeze(0) 360 | x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) 361 | del phones 362 | speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) 363 | audio = ( 364 | net_g.infer( 365 | x_tst, 366 | x_tst_lengths, 367 | speakers, 368 | tones, 369 | lang_ids, 370 | bert, 371 | ja_bert, 372 | en_bert, 373 | emo, 374 | sdp_ratio=sdp_ratio, 375 | noise_scale=noise_scale, 376 | noise_scale_w=noise_scale_w, 377 | length_scale=length_scale, 378 | )[0][0, 0] 379 | .data.cpu() 380 | .float() 381 | .numpy() 382 | ) 383 | del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo 384 | if torch.cuda.is_available(): 385 | torch.cuda.empty_cache() 386 | return audio 387 | -------------------------------------------------------------------------------- /for_deploy/infer_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from transformers import ( 5 | AutoModelForMaskedLM, 6 | AutoTokenizer, 7 | DebertaV2Model, 8 | DebertaV2Tokenizer, 9 | ClapModel, 10 | ClapProcessor, 11 | ) 12 | 13 | from config import config 14 | from text.japanese import text2sep_kata 15 | 16 | 17 | class BertFeature: 18 | def __init__(self, model_path, language="ZH"): 19 | self.model_path = model_path 20 | self.language = language 21 | self.tokenizer = None 22 | self.model = None 23 | self.device = None 24 | 25 | self._prepare() 26 | 27 | def _get_device(self, device=config.bert_gen_config.device): 28 | if ( 29 | sys.platform == "darwin" 30 | and torch.backends.mps.is_available() 31 | and device == "cpu" 32 | ): 33 | device = "mps" 34 | if not device: 35 | device = "cuda" 36 | return device 37 | 38 | def _prepare(self): 39 | self.device = self._get_device() 40 | 41 | if self.language == "EN": 42 | self.tokenizer = DebertaV2Tokenizer.from_pretrained(self.model_path) 43 | self.model = DebertaV2Model.from_pretrained(self.model_path).to(self.device) 44 | else: 45 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) 46 | self.model = AutoModelForMaskedLM.from_pretrained(self.model_path).to( 47 | self.device 48 | ) 49 | self.model.eval() 50 | 51 | def get_bert_feature(self, text, word2ph): 52 | if self.language == "JP": 53 | text = "".join(text2sep_kata(text)[0]) 54 | with torch.no_grad(): 55 | inputs = self.tokenizer(text, return_tensors="pt") 56 | for i in inputs: 57 | inputs[i] = inputs[i].to(self.device) 58 | res = self.model(**inputs, output_hidden_states=True) 59 | res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() 60 | 61 | word2phone = word2ph 62 | phone_level_feature = [] 63 | for i in range(len(word2phone)): 64 | repeat_feature = res[i].repeat(word2phone[i], 1) 65 | phone_level_feature.append(repeat_feature) 66 | 67 | phone_level_feature = torch.cat(phone_level_feature, dim=0) 68 | 69 | return phone_level_feature.T 70 | 71 | 72 | class ClapFeature: 73 | def __init__(self, model_path): 74 | self.model_path = model_path 75 | self.processor = None 76 | self.model = None 77 | self.device = None 78 | 79 | self._prepare() 80 | 81 | def _get_device(self, device=config.bert_gen_config.device): 82 | if ( 83 | sys.platform == "darwin" 84 | and torch.backends.mps.is_available() 85 | and device == "cpu" 86 | ): 87 | device = "mps" 88 | if not device: 89 | device = "cuda" 90 | return device 91 | 92 | def _prepare(self): 93 | self.device = self._get_device() 94 | 95 | self.processor = ClapProcessor.from_pretrained(self.model_path) 96 | self.model = ClapModel.from_pretrained(self.model_path).to(self.device) 97 | self.model.eval() 98 | 99 | def get_clap_audio_feature(self, audio_data): 100 | with torch.no_grad(): 101 | inputs = self.processor( 102 | audios=audio_data, return_tensors="pt", sampling_rate=48000 103 | ).to(self.device) 104 | emb = self.model.get_audio_features(**inputs) 105 | return emb.T 106 | 107 | def get_clap_text_feature(self, text): 108 | with torch.no_grad(): 109 | inputs = self.processor(text=text, return_tensors="pt").to(self.device) 110 | emb = self.model.get_text_features(**inputs) 111 | return emb.T 112 | -------------------------------------------------------------------------------- /hlf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Color definitions 3 | 4 | export HF_ENDPOINT="https://hf-mirror.com" 5 | RED='\033[0;31m' 6 | GREEN='\033[0;32m' 7 | YELLOW='\033[1;33m' 8 | NC='\033[0m' # No Color 9 | 10 | trap 'printf "${YELLOW}\nDownload interrupted. If you re-run the command, you can resume the download from the breakpoint.\n${NC}"; exit 1' INT 11 | 12 | display_help() { 13 | cat << EOF 14 | Usage: 15 | hfd [--include include_pattern] [--exclude exclude_pattern] [--hf_username username] [--hf_token token] [--tool aria2c|wget] [-x threads] [--dataset] [--local-dir path] 16 | Description: 17 | Downloads a model or dataset from Hugging Face using the provided repo ID. 18 | Parameters: 19 | repo_id The Hugging Face repo ID in the format 'org/repo_name'. 20 | --include (Optional) Flag to specify a string pattern to include files for downloading. 21 | --exclude (Optional) Flag to specify a string pattern to exclude files from downloading. 22 | include/exclude_pattern The pattern to match against filenames, supports wildcard characters. e.g., '--exclude *.safetensor', '--include vae/*'. 23 | --hf_username (Optional) Hugging Face username for authentication. **NOT EMAIL**. 24 | --hf_token (Optional) Hugging Face token for authentication. 25 | --tool (Optional) Download tool to use. Can be aria2c (default) or wget. 26 | -x (Optional) Number of download threads for aria2c. Defaults to 4. 27 | --dataset (Optional) Flag to indicate downloading a dataset. 28 | --local-dir (Optional) Local directory path where the model or dataset will be stored. 29 | Example: 30 | hfd bigscience/bloom-560m --exclude *.safetensors 31 | hfd meta-llama/Llama-2-7b --hf_username myuser --hf_token mytoken -x 4 32 | hfd lavita/medical-qa-shared-task-v1-toy --dataset 33 | EOF 34 | exit 1 35 | } 36 | 37 | MODEL_ID=$1 38 | shift 39 | 40 | # Default values 41 | TOOL="aria2c" 42 | THREADS=4 43 | HF_ENDPOINT=${HF_ENDPOINT:-"https://huggingface.co"} 44 | 45 | while [[ $# -gt 0 ]]; do 46 | case $1 in 47 | --include) INCLUDE_PATTERN="$2"; shift 2 ;; 48 | --exclude) EXCLUDE_PATTERN="$2"; shift 2 ;; 49 | --hf_username) HF_USERNAME="$2"; shift 2 ;; 50 | --hf_token) HF_TOKEN="$2"; shift 2 ;; 51 | --tool) TOOL="$2"; shift 2 ;; 52 | -x) THREADS="$2"; shift 2 ;; 53 | --dataset) DATASET=1; shift ;; 54 | --local-dir) LOCAL_DIR="$2"; shift 2 ;; 55 | *) shift ;; 56 | esac 57 | done 58 | 59 | # Check if aria2, wget, curl, git, and git-lfs are installed 60 | check_command() { 61 | if ! command -v $1 &>/dev/null; then 62 | echo -e "${RED}$1 is not installed. Please install it first.${NC}" 63 | exit 1 64 | fi 65 | } 66 | 67 | # Mark current repo safe when using shared file system like samba or nfs 68 | ensure_ownership() { 69 | if git status 2>&1 | grep "fatal: detected dubious ownership in repository at" > /dev/null; then 70 | git config --global --add safe.directory "${PWD}" 71 | printf "${YELLOW}Detected dubious ownership in repository, mark ${PWD} safe using git, edit ~/.gitconfig if you want to reverse this.\n${NC}" 72 | fi 73 | } 74 | 75 | [[ "$TOOL" == "aria2c" ]] && check_command aria2c 76 | [[ "$TOOL" == "wget" ]] && check_command wget 77 | check_command curl; check_command git; check_command git-lfs 78 | 79 | [[ -z "$MODEL_ID" || "$MODEL_ID" =~ ^-h ]] && display_help 80 | 81 | if [[ -z "$LOCAL_DIR" ]]; then 82 | LOCAL_DIR="${MODEL_ID#*/}" 83 | fi 84 | 85 | if [[ "$DATASET" == 1 ]]; then 86 | MODEL_ID="datasets/$MODEL_ID" 87 | fi 88 | echo "Downloading to $LOCAL_DIR" 89 | 90 | if [ -d "$LOCAL_DIR/.git" ]; then 91 | printf "${YELLOW}%s exists, Skip Clone.\n${NC}" "$LOCAL_DIR" 92 | cd "$LOCAL_DIR" && ensure_ownership && GIT_LFS_SKIP_SMUDGE=1 git pull || { printf "${RED}Git pull failed.${NC}\n"; exit 1; } 93 | else 94 | REPO_URL="$HF_ENDPOINT/$MODEL_ID" 95 | GIT_REFS_URL="${REPO_URL}/info/refs?service=git-upload-pack" 96 | echo "Testing GIT_REFS_URL: $GIT_REFS_URL" 97 | response=$(curl -s -o /dev/null -w "%{http_code}" "$GIT_REFS_URL") 98 | if [ "$response" == "401" ] || [ "$response" == "403" ]; then 99 | if [[ -z "$HF_USERNAME" || -z "$HF_TOKEN" ]]; then 100 | printf "${RED}HTTP Status Code: $response.\nThe repository requires authentication, but --hf_username and --hf_token is not passed. Please get token from https://huggingface.co/settings/tokens.\nExiting.\n${NC}" 101 | exit 1 102 | fi 103 | REPO_URL="https://$HF_USERNAME:$HF_TOKEN@${HF_ENDPOINT#https://}/$MODEL_ID" 104 | elif [ "$response" != "200" ]; then 105 | printf "${RED}Unexpected HTTP Status Code: $response\n${NC}" 106 | printf "${YELLOW}Executing debug command: curl -v %s\nOutput:${NC}\n" "$GIT_REFS_URL" 107 | curl -v "$GIT_REFS_URL"; printf "\n${RED}Git clone failed.\n${NC}"; exit 1 108 | fi 109 | echo "GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR" 110 | 111 | GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR && cd "$LOCAL_DIR" || { printf "${RED}Git clone failed.\n${NC}"; exit 1; } 112 | 113 | ensure_ownership 114 | 115 | while IFS= read -r file; do 116 | truncate -s 0 "$file" 117 | done <<< $(git lfs ls-files | cut -d ' ' -f 3-) 118 | fi 119 | 120 | printf "\nStart Downloading lfs files, bash script:\ncd $LOCAL_DIR\n" 121 | files=$(git lfs ls-files | cut -d ' ' -f 3-) 122 | declare -a urls 123 | 124 | while IFS= read -r file; do 125 | url="$HF_ENDPOINT/$MODEL_ID/resolve/main/$file" 126 | file_dir=$(dirname "$file") 127 | mkdir -p "$file_dir" 128 | if [[ "$TOOL" == "wget" ]]; then 129 | download_cmd="wget -c \"$url\" -O \"$file\"" 130 | [[ -n "$HF_TOKEN" ]] && download_cmd="wget --header=\"Authorization: Bearer ${HF_TOKEN}\" -c \"$url\" -O \"$file\"" 131 | else 132 | download_cmd="aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\"" 133 | [[ -n "$HF_TOKEN" ]] && download_cmd="aria2c --header=\"Authorization: Bearer ${HF_TOKEN}\" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\"" 134 | fi 135 | [[ -n "$INCLUDE_PATTERN" && ! "$file" == $INCLUDE_PATTERN ]] && printf "# %s\n" "$download_cmd" && continue 136 | [[ -n "$EXCLUDE_PATTERN" && "$file" == $EXCLUDE_PATTERN ]] && printf "# %s\n" "$download_cmd" && continue 137 | printf "%s\n" "$download_cmd" 138 | urls+=("$url|$file") 139 | done <<< "$files" 140 | 141 | for url_file in "${urls[@]}"; do 142 | IFS='|' read -r url file <<< "$url_file" 143 | printf "${YELLOW}Start downloading ${file}.\n${NC}" 144 | file_dir=$(dirname "$file") 145 | if [[ "$TOOL" == "wget" ]]; then 146 | [[ -n "$HF_TOKEN" ]] && wget --header="Authorization: Bearer ${HF_TOKEN}" -c "$url" -O "$file" || wget -c "$url" -O "$file" 147 | else 148 | [[ -n "$HF_TOKEN" ]] && aria2c --header="Authorization: Bearer ${HF_TOKEN}" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")" || aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")" 149 | fi 150 | [[ $? -eq 0 ]] && printf "Downloaded %s successfully.\n" "$url" || { printf "${RED}Failed to download %s.\n${NC}" "$url"; exit 1; } 151 | done 152 | 153 | printf "${GREEN}Download completed successfully.\n${NC}" -------------------------------------------------------------------------------- /img/yuyu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/img/yuyu.png -------------------------------------------------------------------------------- /img/参数说明.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/img/参数说明.png -------------------------------------------------------------------------------- /img/宵宫.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/img/宵宫.png -------------------------------------------------------------------------------- /img/微信图片_20231010105112.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/img/微信图片_20231010105112.png -------------------------------------------------------------------------------- /img/神里绫华.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/img/神里绫华.png -------------------------------------------------------------------------------- /img/纳西妲.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/img/纳西妲.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from transformers import AutoModel 4 | 5 | 6 | def feature_loss(fmap_r, fmap_g): 7 | loss = 0 8 | for dr, dg in zip(fmap_r, fmap_g): 9 | for rl, gl in zip(dr, dg): 10 | rl = rl.float().detach() 11 | gl = gl.float() 12 | loss += torch.mean(torch.abs(rl - gl)) 13 | 14 | return loss * 2 15 | 16 | 17 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 18 | loss = 0 19 | r_losses = [] 20 | g_losses = [] 21 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 22 | dr = dr.float() 23 | dg = dg.float() 24 | r_loss = torch.mean((1 - dr) ** 2) 25 | g_loss = torch.mean(dg**2) 26 | loss += r_loss + g_loss 27 | r_losses.append(r_loss.item()) 28 | g_losses.append(g_loss.item()) 29 | 30 | return loss, r_losses, g_losses 31 | 32 | 33 | def generator_loss(disc_outputs): 34 | loss = 0 35 | gen_losses = [] 36 | for dg in disc_outputs: 37 | dg = dg.float() 38 | l = torch.mean((1 - dg) ** 2) 39 | gen_losses.append(l) 40 | loss += l 41 | 42 | return loss, gen_losses 43 | 44 | 45 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 46 | """ 47 | z_p, logs_q: [b, h, t_t] 48 | m_p, logs_p: [b, h, t_t] 49 | """ 50 | z_p = z_p.float() 51 | logs_q = logs_q.float() 52 | m_p = m_p.float() 53 | logs_p = logs_p.float() 54 | z_mask = z_mask.float() 55 | 56 | kl = logs_p - logs_q - 0.5 57 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 58 | kl = torch.sum(kl * z_mask) 59 | l = kl / torch.sum(z_mask) 60 | return l 61 | 62 | 63 | class WavLMLoss(torch.nn.Module): 64 | def __init__(self, model, wd, model_sr, slm_sr=16000): 65 | super(WavLMLoss, self).__init__() 66 | self.wavlm = AutoModel.from_pretrained(model) 67 | self.wd = wd 68 | self.resample = torchaudio.transforms.Resample(model_sr, slm_sr) 69 | self.wavlm.eval() 70 | for param in self.wavlm.parameters(): 71 | param.requires_grad = False 72 | 73 | def forward(self, wav, y_rec): 74 | with torch.no_grad(): 75 | wav_16 = self.resample(wav) 76 | wav_embeddings = self.wavlm( 77 | input_values=wav_16, output_hidden_states=True 78 | ).hidden_states 79 | y_rec_16 = self.resample(y_rec) 80 | y_rec_embeddings = self.wavlm( 81 | input_values=y_rec_16.squeeze(), output_hidden_states=True 82 | ).hidden_states 83 | 84 | floss = 0 85 | for er, eg in zip(wav_embeddings, y_rec_embeddings): 86 | floss += torch.mean(torch.abs(er - eg)) 87 | 88 | return floss.mean() 89 | 90 | def generator(self, y_rec): 91 | y_rec_16 = self.resample(y_rec) 92 | y_rec_embeddings = self.wavlm( 93 | input_values=y_rec_16, output_hidden_states=True 94 | ).hidden_states 95 | y_rec_embeddings = ( 96 | torch.stack(y_rec_embeddings, dim=1) 97 | .transpose(-1, -2) 98 | .flatten(start_dim=1, end_dim=2) 99 | ) 100 | y_df_hat_g = self.wd(y_rec_embeddings) 101 | loss_gen = torch.mean((1 - y_df_hat_g) ** 2) 102 | 103 | return loss_gen 104 | 105 | def discriminator(self, wav, y_rec): 106 | with torch.no_grad(): 107 | wav_16 = self.resample(wav) 108 | wav_embeddings = self.wavlm( 109 | input_values=wav_16, output_hidden_states=True 110 | ).hidden_states 111 | y_rec_16 = self.resample(y_rec) 112 | y_rec_embeddings = self.wavlm( 113 | input_values=y_rec_16, output_hidden_states=True 114 | ).hidden_states 115 | 116 | y_embeddings = ( 117 | torch.stack(wav_embeddings, dim=1) 118 | .transpose(-1, -2) 119 | .flatten(start_dim=1, end_dim=2) 120 | ) 121 | y_rec_embeddings = ( 122 | torch.stack(y_rec_embeddings, dim=1) 123 | .transpose(-1, -2) 124 | .flatten(start_dim=1, end_dim=2) 125 | ) 126 | 127 | y_d_rs = self.wd(y_embeddings) 128 | y_d_gs = self.wd(y_rec_embeddings) 129 | 130 | y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs 131 | 132 | r_loss = torch.mean((1 - y_df_hat_r) ** 2) 133 | g_loss = torch.mean((y_df_hat_g) ** 2) 134 | 135 | loss_disc_f = r_loss + g_loss 136 | 137 | return loss_disc_f.mean() 138 | 139 | def discriminator_forward(self, wav): 140 | with torch.no_grad(): 141 | wav_16 = self.resample(wav) 142 | wav_embeddings = self.wavlm( 143 | input_values=wav_16, output_hidden_states=True 144 | ).hidden_states 145 | y_embeddings = ( 146 | torch.stack(wav_embeddings, dim=1) 147 | .transpose(-1, -2) 148 | .flatten(start_dim=1, end_dim=2) 149 | ) 150 | 151 | y_d_rs = self.wd(y_embeddings) 152 | 153 | return y_d_rs 154 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import librosa 4 | 5 | import warnings 6 | 7 | # warnings.simplefilter(action='ignore', category=FutureWarning) 8 | warnings.filterwarnings(action="ignore") 9 | MAX_WAV_VALUE = 32768.0 10 | 11 | """ 12 | librosa 0.10.2.post1 13 | 14 | """ 15 | 16 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 17 | """ 18 | PARAMS 19 | ------ 20 | C: compression factor 21 | """ 22 | return torch.log(torch.clamp(x, min=clip_val) * C) 23 | 24 | 25 | def dynamic_range_decompression_torch(x, C=1): 26 | """ 27 | PARAMS 28 | ------ 29 | C: compression factor used to compress 30 | """ 31 | return torch.exp(x) / C 32 | 33 | 34 | def spectral_normalize_torch(magnitudes): 35 | output = dynamic_range_compression_torch(magnitudes) 36 | return output 37 | 38 | 39 | def spectral_de_normalize_torch(magnitudes): 40 | output = dynamic_range_decompression_torch(magnitudes) 41 | return output 42 | 43 | 44 | mel_basis = {} 45 | hann_window = {} 46 | 47 | 48 | import torch 49 | import librosa 50 | 51 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 52 | global mel_basis 53 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 54 | fmax_dtype_device = str(fmax) + "_" + dtype_device 55 | 56 | # If mel_basis is not cached, generate and cache it 57 | if fmax_dtype_device not in mel_basis: 58 | mel = librosa.filters.mel( 59 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 60 | ) 61 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 62 | 63 | # Apply the Mel filterbank to the spectrogram 64 | mel_spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 65 | 66 | # Apply spectral normalization (assuming you have this function) 67 | mel_spec = spectral_normalize_torch(mel_spec) 68 | 69 | return mel_spec 70 | 71 | 72 | def mel_spectrogram_torch( 73 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 74 | ): 75 | if torch.min(y) < -1.0: 76 | print("min value is ", torch.min(y)) 77 | if torch.max(y) > 1.0: 78 | print("max value is ", torch.max(y)) 79 | 80 | global mel_basis, hann_window 81 | dtype_device = str(y.dtype) + "_" + str(y.device) 82 | fmax_dtype_device = str(fmax) + "_" + dtype_device 83 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 84 | 85 | if fmax_dtype_device not in mel_basis: 86 | mel = librosa.filters.mel( 87 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 88 | ) 89 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 90 | 91 | if wnsize_dtype_device not in hann_window: 92 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 93 | 94 | # Pad the input to match the expected STFT behavior 95 | y = torch.nn.functional.pad( 96 | y.unsqueeze(1), 97 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 98 | mode="reflect", 99 | ) 100 | y = y.squeeze(1) 101 | 102 | # Perform the Short-Time Fourier Transform (STFT) 103 | spec = torch.stft( 104 | y, 105 | n_fft, 106 | hop_length=hop_size, 107 | win_length=win_size, 108 | window=hann_window[wnsize_dtype_device], 109 | center=center, 110 | pad_mode="reflect", 111 | normalized=False, 112 | onesided=True, 113 | return_complex=False, 114 | ) 115 | 116 | # Compute magnitude spectrogram 117 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 118 | 119 | # Apply the mel filterbank 120 | mel_spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 121 | mel_spec = spectral_normalize_torch(mel_spec) 122 | 123 | return mel_spec 124 | 125 | 126 | if __name__ =="__main__": 127 | 128 | 129 | audio_path = "随机片段01.wav" 130 | y, sr = librosa.load(audio_path,sr=22050) 131 | y= torch.FloatTensor(y).unsqueeze(0) 132 | print(y.shape) # [1,T] 133 | mel_spech = mel_spectrogram_torch(y=y, 134 | n_fft=1024, 135 | num_mels=80, 136 | sampling_rate=22050, 137 | hop_size=256, 138 | win_size=1024, 139 | fmax=None, 140 | fmin=0) 141 | print(mel_spech.shape) 142 | 143 | from matplotlib import pyplot as plt 144 | plt.figure() 145 | plt.imshow(mel_spech[0]) 146 | plt.savefig('对数梅尔谱图.png') 147 | 148 | 149 | pass -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | from numpy import zeros, int32, float32 2 | from torch import from_numpy 3 | 4 | from .core import maximum_path_jit 5 | 6 | 7 | def maximum_path(neg_cent, mask): 8 | device = neg_cent.device 9 | dtype = neg_cent.dtype 10 | neg_cent = neg_cent.data.cpu().numpy().astype(float32) 11 | path = zeros(neg_cent.shape, dtype=int32) 12 | 13 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) 14 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) 15 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max) 16 | return from_numpy(path).to(device=device, dtype=dtype) 17 | -------------------------------------------------------------------------------- /monotonic_align/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/monotonic_align/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /monotonic_align/__pycache__/core.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/monotonic_align/__pycache__/core.cpython-310.pyc -------------------------------------------------------------------------------- /monotonic_align/core.py: -------------------------------------------------------------------------------- 1 | import numba 2 | 3 | 4 | @numba.jit( 5 | numba.void( 6 | numba.int32[:, :, ::1], 7 | numba.float32[:, :, ::1], 8 | numba.int32[::1], 9 | numba.int32[::1], 10 | ), 11 | nopython=True, 12 | nogil=True, 13 | ) 14 | def maximum_path_jit(paths, values, t_ys, t_xs): 15 | b = paths.shape[0] 16 | max_neg_val = -1e9 17 | for i in range(int(b)): 18 | path = paths[i] 19 | value = values[i] 20 | t_y = t_ys[i] 21 | t_x = t_xs[i] 22 | 23 | v_prev = v_cur = 0.0 24 | index = t_x - 1 25 | 26 | for y in range(t_y): 27 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 28 | if x == y: 29 | v_cur = max_neg_val 30 | else: 31 | v_cur = value[y - 1, x] 32 | if x == 0: 33 | if y == 0: 34 | v_prev = 0.0 35 | else: 36 | v_prev = max_neg_val 37 | else: 38 | v_prev = value[y - 1, x - 1] 39 | value[y, x] += max(v_prev, v_cur) 40 | 41 | for y in range(t_y - 1, -1, -1): 42 | path[y, index] = 1 43 | if index != 0 and ( 44 | index == y or value[y - 1, index] < value[y - 1, index - 1] 45 | ): 46 | index = index - 1 47 | -------------------------------------------------------------------------------- /onnx_infer.py: -------------------------------------------------------------------------------- 1 | from onnx_modules.V220_OnnxInference import OnnxInferenceSession 2 | import numpy as np 3 | 4 | Session = OnnxInferenceSession( 5 | { 6 | "enc": "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx", 7 | "emb_g": "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx", 8 | "dp": "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx", 9 | "sdp": "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx", 10 | "flow": "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx", 11 | "dec": "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx", 12 | }, 13 | Providers=["CPUExecutionProvider"], 14 | ) 15 | 16 | # 这里的输入和原版是一样的,只需要在原版预处理结果出来之后加上.numpy()即可 17 | x = np.array( 18 | [ 19 | 0, 20 | 97, 21 | 0, 22 | 8, 23 | 0, 24 | 78, 25 | 0, 26 | 8, 27 | 0, 28 | 76, 29 | 0, 30 | 37, 31 | 0, 32 | 40, 33 | 0, 34 | 97, 35 | 0, 36 | 8, 37 | 0, 38 | 23, 39 | 0, 40 | 8, 41 | 0, 42 | 74, 43 | 0, 44 | 26, 45 | 0, 46 | 104, 47 | 0, 48 | ] 49 | ) 50 | tone = np.zeros_like(x) 51 | language = np.zeros_like(x) 52 | sid = np.array([0]) 53 | bert = np.random.randn(x.shape[0], 1024) 54 | ja_bert = np.random.randn(x.shape[0], 1024) 55 | en_bert = np.random.randn(x.shape[0], 1024) 56 | emo = np.random.randn(512, 1) 57 | 58 | audio = Session(x, tone, language, bert, ja_bert, en_bert, emo, sid) 59 | 60 | print(audio) 61 | -------------------------------------------------------------------------------- /preprocess_text.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from random import shuffle 4 | from typing import Optional 5 | import os 6 | 7 | from tqdm import tqdm 8 | import click 9 | from text.cleaner import clean_text 10 | from config import config 11 | #from infer import latest_version 12 | 13 | preprocess_text_config = config.preprocess_text_config 14 | 15 | 16 | @click.command() 17 | @click.option( 18 | "--transcription-path", 19 | default=preprocess_text_config.transcription_path, 20 | type=click.Path(exists=True, file_okay=True, dir_okay=False), 21 | ) 22 | @click.option("--cleaned-path", default=preprocess_text_config.cleaned_path) 23 | @click.option("--train-path", default=preprocess_text_config.train_path) 24 | @click.option("--val-path", default=preprocess_text_config.val_path) 25 | @click.option( 26 | "--config-path", 27 | default=preprocess_text_config.config_path, 28 | type=click.Path(exists=True, file_okay=True, dir_okay=False), 29 | ) 30 | @click.option("--val-per-lang", default=preprocess_text_config.val_per_lang) 31 | @click.option("--max-val-total", default=preprocess_text_config.max_val_total) 32 | @click.option("--clean/--no-clean", default=preprocess_text_config.clean) 33 | @click.option("-y", "--yml_config") 34 | def preprocess( 35 | transcription_path: str, 36 | cleaned_path: Optional[str], 37 | train_path: str, 38 | val_path: str, 39 | config_path: str, 40 | val_per_lang: int, 41 | max_val_total: int, 42 | clean: bool, 43 | yml_config: str, # 这个不要删 44 | ): 45 | if cleaned_path == "" or cleaned_path is None: 46 | cleaned_path = transcription_path + ".cleaned" 47 | 48 | if clean: 49 | with open(cleaned_path, "w", encoding="utf-8") as out_file: 50 | with open(transcription_path, "r", encoding="utf-8") as trans_file: 51 | lines = trans_file.readlines() 52 | # print(lines, ' ', len(lines)) 53 | if len(lines) != 0: 54 | for line in tqdm(lines): 55 | try: 56 | utt, spk, language, text = line.strip().split("|") 57 | norm_text, phones, tones, word2ph = clean_text( 58 | text, language 59 | ) 60 | out_file.write( 61 | "{}|{}|{}|{}|{}|{}|{}\n".format( 62 | utt, 63 | spk, 64 | language, 65 | norm_text, 66 | " ".join(phones), 67 | " ".join([str(i) for i in tones]), 68 | " ".join([str(i) for i in word2ph]), 69 | ) 70 | ) 71 | except Exception as e: 72 | print(line) 73 | print(f"生成训练集和验证集时发生错误!, 详细信息:\n{e}") 74 | 75 | transcription_path = cleaned_path 76 | spk_utt_map = defaultdict(list) 77 | spk_id_map = {} 78 | current_sid = 0 79 | 80 | with open(transcription_path, "r", encoding="utf-8") as f: 81 | audioPaths = set() 82 | countSame = 0 83 | countNotFound = 0 84 | for line in f.readlines(): 85 | utt, spk, language, text, phones, tones, word2ph = line.strip().split("|") 86 | if utt in audioPaths: 87 | # 过滤数据集错误:相同的音频匹配多个文本,导致后续bert出问题 88 | print(f"重复音频文本:{line}") 89 | countSame += 1 90 | continue 91 | if not os.path.isfile(utt): 92 | # 过滤数据集错误:不存在对应音频 93 | print(f"没有找到对应的音频:{utt}") 94 | countNotFound += 1 95 | continue 96 | audioPaths.add(utt) 97 | spk_utt_map[language].append(line) 98 | if spk not in spk_id_map.keys(): 99 | spk_id_map[spk] = current_sid 100 | current_sid += 1 101 | print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}") 102 | 103 | train_list = [] 104 | val_list = [] 105 | 106 | for spk, utts in spk_utt_map.items(): 107 | shuffle(utts) 108 | val_list += utts[:val_per_lang] 109 | train_list += utts[val_per_lang:] 110 | 111 | shuffle(val_list) 112 | if len(val_list) > max_val_total: 113 | train_list += val_list[max_val_total:] 114 | val_list = val_list[:max_val_total] 115 | 116 | with open(train_path, "w", encoding="utf-8") as f: 117 | for line in train_list: 118 | f.write(line) 119 | 120 | with open(val_path, "w", encoding="utf-8") as f: 121 | for line in val_list: 122 | f.write(line) 123 | 124 | json_config = json.load(open(config_path, encoding="utf-8")) 125 | json_config["data"]["spk2id"] = spk_id_map 126 | json_config["data"]["n_speakers"] = len(spk_id_map) 127 | # # 新增写入:写入训练版本、数据集路径 128 | # json_config["version"] = latest_version 129 | json_config["data"]["training_files"] = os.path.normpath(train_path).replace( 130 | "\\", "/" 131 | ) 132 | json_config["data"]["validation_files"] = os.path.normpath(val_path).replace( 133 | "\\", "/" 134 | ) 135 | with open(config_path, "w", encoding="utf-8") as f: 136 | json.dump(json_config, f, indent=2, ensure_ascii=False) 137 | print("训练集和验证集生成完成!") 138 | 139 | 140 | if __name__ == "__main__": 141 | preprocess() 142 | 143 | 144 | 145 | """ 146 | cp configs/config.json A5_finetuned_trainingout/SSB0005_50 147 | python preprocess_text.py \ 148 | --transcription-path A5_finetuned_trainingout/SSB0005_50/filelists/script.txt \ 149 | --cleaned-path A5_finetuned_trainingout/SSB0005_50/filelists/script.txt.cleaned \ 150 | --train-path A5_finetuned_trainingout/SSB0005_50/filelists/script.txt.cleaned.train \ 151 | --val-path A5_finetuned_trainingout/SSB0005_50/filelists/script.txt.cleaned.val \ 152 | --config-path A5_finetuned_trainingout/SSB0005_50/config.json 153 | 154 | 155 | """ -------------------------------------------------------------------------------- /re_matching.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def extract_language_and_text_updated(speaker, dialogue): 5 | # 使用正则表达式匹配<语言>标签和其后的文本 6 | pattern_language_text = r"<(\S+?)>([^<]+)" 7 | matches = re.findall(pattern_language_text, dialogue, re.DOTALL) 8 | speaker = speaker[1:-1] 9 | # 清理文本:去除两边的空白字符 10 | matches_cleaned = [(lang.upper(), text.strip()) for lang, text in matches] 11 | matches_cleaned.append(speaker) 12 | return matches_cleaned 13 | 14 | 15 | def validate_text(input_text): 16 | # 验证说话人的正则表达式 17 | pattern_speaker = r"(\[\S+?\])((?:\s*<\S+?>[^<\[\]]+?)+)" 18 | 19 | # 使用re.DOTALL标志使.匹配包括换行符在内的所有字符 20 | matches = re.findall(pattern_speaker, input_text, re.DOTALL) 21 | 22 | # 对每个匹配到的说话人内容进行进一步验证 23 | for _, dialogue in matches: 24 | language_text_matches = extract_language_and_text_updated(_, dialogue) 25 | if not language_text_matches: 26 | return ( 27 | False, 28 | "Error: Invalid format detected in dialogue content. Please check your input.", 29 | ) 30 | 31 | # 如果输入的文本中没有找到任何匹配项 32 | if not matches: 33 | return ( 34 | False, 35 | "Error: No valid speaker format detected. Please check your input.", 36 | ) 37 | 38 | return True, "Input is valid." 39 | 40 | 41 | def text_matching(text: str) -> list: 42 | speaker_pattern = r"(\[\S+?\])(.+?)(?=\[\S+?\]|$)" 43 | matches = re.findall(speaker_pattern, text, re.DOTALL) 44 | result = [] 45 | for speaker, dialogue in matches: 46 | result.append(extract_language_and_text_updated(speaker, dialogue)) 47 | return result 48 | 49 | 50 | def cut_para(text): 51 | splitted_para = re.split("[\n]", text) # 按段分 52 | splitted_para = [ 53 | sentence.strip() for sentence in splitted_para if sentence.strip() 54 | ] # 删除空字符串 55 | return splitted_para 56 | 57 | 58 | def cut_sent(para): 59 | para = re.sub("([。!;?\?])([^”’])", r"\1\n\2", para) # 单字符断句符 60 | para = re.sub("(\.{6})([^”’])", r"\1\n\2", para) # 英文省略号 61 | para = re.sub("(\…{2})([^”’])", r"\1\n\2", para) # 中文省略号 62 | para = re.sub("([。!?\?][”’])([^,。!?\?])", r"\1\n\2", para) 63 | para = para.rstrip() # 段尾如果有多余的\n就去掉它 64 | return para.split("\n") 65 | 66 | 67 | if __name__ == "__main__": 68 | text = """ 69 | [说话人1] 70 | [说话人2]你好吗?元気ですか?こんにちは,世界。你好吗? 71 | [说话人3]谢谢。どういたしまして。 72 | """ 73 | text_matching(text) 74 | # 测试函数 75 | test_text = """ 76 | [说话人1]你好,こんにちは!こんにちは,世界。 77 | [说话人2]你好吗? 78 | """ 79 | text_matching(test_text) 80 | res = validate_text(test_text) 81 | print(res) 82 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiofiles==23.2.1 3 | altair==5.4.1 4 | AMFM_decompy==1.0.11 5 | annotated-types==0.7.0 6 | anyio==4.4.0 7 | attrs==24.2.0 8 | audioread==3.0.1 9 | av==13.0.0 10 | babel==2.16.0 11 | bibtexparser==2.0.0b7 12 | certifi==2024.8.30 13 | cffi==1.17.1 14 | charset-normalizer==3.3.2 15 | click==8.1.7 16 | clldutils==3.22.2 17 | cmudict==1.0.30 18 | cn2an==0.5.22 19 | colorama==0.4.6 20 | colorlog==6.8.2 21 | contourpy==1.3.0 22 | csvw==3.3.1 23 | cycler==0.12.1 24 | Cython==3.0.11 25 | decorator==5.1.1 26 | Deprecated==1.2.14 27 | Distance==0.1.3 28 | dlinfo==1.2.1 29 | docopt==0.6.2 30 | einops==0.8.0 31 | einx==0.3.0 32 | exceptiongroup==1.2.2 33 | fastapi==0.114.1 34 | ffmpy==0.4.0 35 | filelock==3.16.0 36 | fonttools==4.53.1 37 | frozendict==2.4.4 38 | fsspec==2024.9.0 39 | fugashi==1.3.2 40 | g2p-en==2.1.0 41 | GPUtil==1.4.0 42 | gradio==3.50.2 43 | gradio_client==0.6.1 44 | grpcio==1.66.1 45 | h11==0.14.0 46 | httpcore==1.0.5 47 | httpx==0.27.2 48 | huggingface-hub==0.24.6 49 | idna==3.8 50 | importlib_metadata==8.5.0 51 | importlib_resources==6.4.5 52 | inflect==7.4.0 53 | isodate==0.6.1 54 | jaconv==0.4.0 55 | jieba==0.42.1 56 | Jinja2==3.1.4 57 | joblib==1.4.2 58 | jsonschema==4.23.0 59 | jsonschema-specifications==2023.12.1 60 | kiwisolver==1.4.7 61 | langid==1.1.6 62 | language-tags==1.2.0 63 | librosa==0.9.2 64 | llvmlite==0.43.0 65 | loguru==0.7.2 66 | lxml==5.3.0 67 | Markdown==3.7 68 | MarkupSafe==2.1.5 69 | matplotlib==3.9.2 70 | mecab-python3==1.0.9 71 | more-itertools==10.5.0 72 | mpmath==1.3.0 73 | narwhals==1.7.0 74 | networkx==3.3 75 | nltk==3.9.1 76 | num2words==0.5.13 77 | numba==0.60.0 78 | numpy==1.26.4 79 | nvidia-cublas-cu12==12.1.3.1 80 | nvidia-cuda-cupti-cu12==12.1.105 81 | nvidia-cuda-nvrtc-cu12==12.1.105 82 | nvidia-cuda-runtime-cu12==12.1.105 83 | nvidia-cudnn-cu12==9.1.0.70 84 | nvidia-cufft-cu12==11.0.2.54 85 | nvidia-curand-cu12==10.3.2.106 86 | nvidia-cusolver-cu12==11.4.5.107 87 | nvidia-cusparse-cu12==12.1.0.106 88 | nvidia-nccl-cu12==2.20.5 89 | nvidia-nvjitlink-cu12==12.6.68 90 | nvidia-nvtx-cu12==12.1.105 91 | openi==2.0.2.post3 92 | orjson==3.10.7 93 | packaging==24.1 94 | pandas==2.2.2 95 | phonemizer==3.3.0 96 | pillow==10.4.0 97 | platformdirs==4.3.2 98 | pooch==1.8.2 99 | proces==0.1.7 100 | protobuf==5.28.1 101 | psutil==6.0.0 102 | pwinput==1.0.3 103 | pycparser==2.22 104 | pydantic==2.9.1 105 | pydantic_core==2.23.3 106 | pydub==0.25.1 107 | pykakasi==2.3.0 108 | pylatexenc==2.10 109 | pynini==2.1.6 110 | pyopenjtalk-prebuilt==0.3.0 111 | pyparsing==3.1.4 112 | pypinyin==0.52.0 113 | python-dateutil==2.9.0.post0 114 | python-multipart==0.0.9 115 | pytz==2024.2 116 | PyYAML==6.0.2 117 | rdflib==7.0.0 118 | referencing==0.35.1 119 | regex==2024.9.11 120 | requests==2.32.3 121 | resampy==0.4.3 122 | rfc3986==1.5.0 123 | rpds-py==0.20.0 124 | safetensors==0.4.5 125 | scikit-learn==1.5.2 126 | scipy==1.14.1 127 | segments==2.2.1 128 | semantic-version==2.10.0 129 | sentencepiece==0.2.0 130 | six==1.16.0 131 | sniffio==1.3.1 132 | soundfile==0.12.1 133 | starlette==0.38.5 134 | sympy==1.13.2 135 | tabulate==0.9.0 136 | tensorboard==2.17.1 137 | tensorboard-data-server==0.7.2 138 | threadpoolctl==3.5.0 139 | tokenizers==0.19.1 140 | torch==2.4.1 141 | torchaudio==2.4.1 142 | tqdm==4.66.5 143 | transformers==4.44.2 144 | triton==3.0.0 145 | typeguard==4.3.0 146 | typing_extensions==4.12.2 147 | tzdata==2024.1 148 | Unidecode==1.3.8 149 | unidic-lite==1.0.8 150 | uritemplate==4.1.1 151 | urllib3==2.2.2 152 | uvicorn==0.30.6 153 | vector-quantize-pytorch==1.17.3 154 | websockets==11.0.3 155 | Werkzeug==3.0.4 156 | WeTextProcessing==1.0.4.1 157 | wrapt==1.16.0 158 | zipp==3.20.1 159 | -------------------------------------------------------------------------------- /resample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import librosa 4 | from multiprocessing import Pool, cpu_count 5 | 6 | import soundfile 7 | from tqdm import tqdm 8 | 9 | from config import config 10 | 11 | 12 | def process(item): 13 | spkdir, wav_name, args = item 14 | wav_path = os.path.join(args.in_dir, spkdir, wav_name) 15 | if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"): 16 | wav, sr = librosa.load(wav_path, sr=args.sr) 17 | soundfile.write(os.path.join(args.out_dir, spkdir, wav_name), wav, sr) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--sr", 24 | type=int, 25 | default=config.resample_config.sampling_rate, 26 | help="sampling rate", 27 | ) 28 | parser.add_argument( 29 | "--in_dir", 30 | type=str, 31 | default=config.resample_config.in_dir, 32 | help="path to source dir", 33 | ) 34 | parser.add_argument( 35 | "--out_dir", 36 | type=str, 37 | default=config.resample_config.out_dir, 38 | help="path to target dir", 39 | ) 40 | parser.add_argument( 41 | "--processes", 42 | type=int, 43 | default=0, 44 | help="cpu_processes", 45 | ) 46 | args, _ = parser.parse_known_args() 47 | # autodl 无卡模式会识别出46个cpu 48 | if args.processes == 0: 49 | processes = cpu_count() - 2 if cpu_count() > 4 else 1 50 | else: 51 | processes = args.processes 52 | pool = Pool(processes=processes) 53 | 54 | tasks = [] 55 | 56 | for dirpath, _, filenames in os.walk(args.in_dir): 57 | # 子级目录 58 | spk_dir = os.path.relpath(dirpath, args.in_dir) 59 | spk_dir_out = os.path.join(args.out_dir, spk_dir) 60 | if not os.path.isdir(spk_dir_out): 61 | os.makedirs(spk_dir_out, exist_ok=True) 62 | for filename in filenames: 63 | if filename.lower().endswith(".wav"): 64 | twople = (spk_dir, filename, args) 65 | tasks.append(twople) 66 | 67 | for _ in tqdm( 68 | pool.imap_unordered(process, tasks), 69 | ): 70 | pass 71 | 72 | pool.close() 73 | pool.join() 74 | 75 | print("音频重采样完毕!") 76 | -------------------------------------------------------------------------------- /resample_legacy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import librosa 4 | from multiprocessing import Pool, cpu_count 5 | 6 | import soundfile 7 | from tqdm import tqdm 8 | 9 | from config import config 10 | 11 | 12 | def process(item): 13 | wav_name, args = item 14 | wav_path = os.path.join(args.in_dir, wav_name) 15 | if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"): 16 | wav, sr = librosa.load(wav_path, sr=args.sr) 17 | soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--sr", 24 | type=int, 25 | default=config.resample_config.sampling_rate, 26 | help="sampling rate", 27 | ) 28 | parser.add_argument( 29 | "--in_dir", 30 | type=str, 31 | default=config.resample_config.in_dir, 32 | help="path to source dir", 33 | ) 34 | parser.add_argument( 35 | "--out_dir", 36 | type=str, 37 | default=config.resample_config.out_dir, 38 | help="path to target dir", 39 | ) 40 | parser.add_argument( 41 | "--processes", 42 | type=int, 43 | default=0, 44 | help="cpu_processes", 45 | ) 46 | args, _ = parser.parse_known_args() 47 | # autodl 无卡模式会识别出46个cpu 48 | if args.processes == 0: 49 | processes = cpu_count() - 2 if cpu_count() > 4 else 1 50 | else: 51 | processes = args.processes 52 | pool = Pool(processes=processes) 53 | 54 | tasks = [] 55 | 56 | for dirpath, _, filenames in os.walk(args.in_dir): 57 | if not os.path.isdir(args.out_dir): 58 | os.makedirs(args.out_dir, exist_ok=True) 59 | for filename in filenames: 60 | if filename.lower().endswith(".wav"): 61 | tasks.append((filename, args)) 62 | 63 | for _ in tqdm( 64 | pool.imap_unordered(process, tasks), 65 | ): 66 | pass 67 | 68 | pool.close() 69 | pool.join() 70 | 71 | print("音频重采样完毕!") 72 | -------------------------------------------------------------------------------- /spec_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from multiprocessing import Pool 4 | from mel_processing import spectrogram_torch, mel_spectrogram_torch 5 | from utils import load_wav_to_torch 6 | 7 | 8 | class AudioProcessor: 9 | def __init__( 10 | self, 11 | max_wav_value, 12 | use_mel_spec_posterior, 13 | filter_length, 14 | n_mel_channels, 15 | sampling_rate, 16 | hop_length, 17 | win_length, 18 | mel_fmin, 19 | mel_fmax, 20 | ): 21 | self.max_wav_value = max_wav_value 22 | self.use_mel_spec_posterior = use_mel_spec_posterior 23 | self.filter_length = filter_length 24 | self.n_mel_channels = n_mel_channels 25 | self.sampling_rate = sampling_rate 26 | self.hop_length = hop_length 27 | self.win_length = win_length 28 | self.mel_fmin = mel_fmin 29 | self.mel_fmax = mel_fmax 30 | 31 | def process_audio(self, filename): 32 | audio, sampling_rate = load_wav_to_torch(filename) 33 | audio_norm = audio / self.max_wav_value 34 | audio_norm = audio_norm.unsqueeze(0) 35 | spec_filename = filename.replace(".wav", ".spec.pt") 36 | if self.use_mel_spec_posterior: 37 | spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") 38 | try: 39 | spec = torch.load(spec_filename) 40 | except: 41 | if self.use_mel_spec_posterior: 42 | spec = mel_spectrogram_torch( 43 | audio_norm, 44 | self.filter_length, 45 | self.n_mel_channels, 46 | self.sampling_rate, 47 | self.hop_length, 48 | self.win_length, 49 | self.mel_fmin, 50 | self.mel_fmax, 51 | center=False, 52 | ) 53 | else: 54 | spec = spectrogram_torch( 55 | audio_norm, 56 | self.filter_length, 57 | self.sampling_rate, 58 | self.hop_length, 59 | self.win_length, 60 | center=False, 61 | ) 62 | spec = torch.squeeze(spec, 0) 63 | torch.save(spec, spec_filename) 64 | return spec, audio_norm 65 | 66 | 67 | # 使用示例 68 | processor = AudioProcessor( 69 | max_wav_value=32768.0, 70 | use_mel_spec_posterior=False, 71 | filter_length=2048, 72 | n_mel_channels=128, 73 | sampling_rate=44100, 74 | hop_length=512, 75 | win_length=2048, 76 | mel_fmin=0.0, 77 | mel_fmax="null", 78 | ) 79 | 80 | import click 81 | @click.command() 82 | @click.option('--script', type=str, help='Your txt.cleaned.train') 83 | def melspec_gen(script): 84 | 85 | with open(script, "r") as f: 86 | filepaths = [line.split("|")[0] for line in f] # 取每一行的第一部分作为audiopath 87 | 88 | # 使用多进程处理 89 | with Pool(processes=32) as pool: # 使用4个进程 90 | with tqdm(total=len(filepaths)) as pbar: 91 | for i, _ in enumerate(pool.imap_unordered(processor.process_audio, filepaths)): 92 | pbar.update() 93 | 94 | pass 95 | 96 | 97 | if __name__ == "__main__": 98 | melspec_gen() 99 | 100 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | from text.symbols import * 2 | 3 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 4 | 5 | 6 | def cleaned_text_to_sequence(cleaned_text, tones, language): 7 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 8 | Args: 9 | text: string to convert to a sequence 10 | Returns: 11 | List of integers corresponding to the symbols in the text 12 | """ 13 | phones = [_symbol_to_id[symbol] for symbol in cleaned_text] 14 | tone_start = language_tone_start_map[language] 15 | tones = [i + tone_start for i in tones] 16 | lang_id = language_id_map[language] 17 | lang_ids = [lang_id for i in phones] 18 | return phones, tones, lang_ids 19 | 20 | 21 | def get_bert(norm_text, word2ph, language, device, style_text=None, style_weight=0.7): 22 | 23 | bert = None 24 | 25 | try: 26 | # 你要加载BERT模型的代码 27 | # 比如:model = load_bert_model() 28 | if language == "ZH": 29 | from .chinese_bert import get_bert_feature as zh_bert 30 | bert = zh_bert(norm_text, word2ph, device, style_text, style_weight) 31 | if language == "EN": 32 | from .english_bert_mock import get_bert_feature as en_bert 33 | bert = en_bert(norm_text, word2ph, device, style_text, style_weight) 34 | if language == "JP": 35 | from .japanese_bert import get_bert_feature as jp_bert 36 | bert = jp_bert(norm_text, word2ph, device, style_text, style_weight) 37 | 38 | except Exception as e: 39 | print("加载到BERT模型异常") 40 | # 打印具体的异常信息(可选) 41 | print(f"错误详情: {str(e)}") 42 | 43 | 44 | 45 | return bert 46 | 47 | -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/bert_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/bert_utils.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/chinese.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/chinese.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/chinese_bert.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/chinese_bert.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaner.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/cleaner.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/english.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/english.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/english_bert_mock.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/english_bert_mock.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/japanese.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/japanese.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/symbols.cpython-310.pyc -------------------------------------------------------------------------------- /text/__pycache__/tone_sandhi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/__pycache__/tone_sandhi.cpython-310.pyc -------------------------------------------------------------------------------- /text/bert_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from huggingface_hub import hf_hub_download 4 | 5 | from config import config 6 | 7 | 8 | MIRROR: str = config.mirror 9 | 10 | 11 | def _check_bert(repo_id, files, local_path): 12 | for file in files: 13 | if not Path(local_path).joinpath(file).exists(): 14 | if MIRROR.lower() == "openi": 15 | import openi 16 | 17 | openi.model.download_model( 18 | "Stardust_minus/Bert-VITS2", repo_id.split("/")[-1], "./bert" 19 | ) 20 | else: 21 | hf_hub_download( 22 | repo_id, file, local_dir=local_path, local_dir_use_symlinks=False 23 | ) 24 | -------------------------------------------------------------------------------- /text/chinese.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from pypinyin import lazy_pinyin, Style 5 | 6 | from text.symbols import punctuation 7 | from text.tone_sandhi import ToneSandhi 8 | 9 | try: 10 | from tn.chinese.normalizer import Normalizer 11 | 12 | normalizer = Normalizer().normalize 13 | except ImportError: 14 | import cn2an 15 | 16 | print("tn.chinese.normalizer not found, use cn2an normalizer") 17 | normalizer = lambda x: cn2an.transform(x, "an2cn") 18 | 19 | current_file_path = os.path.dirname(__file__) 20 | pinyin_to_symbol_map = { 21 | line.split("\t")[0]: line.strip().split("\t")[1] 22 | for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() 23 | } 24 | 25 | import jieba.posseg as psg 26 | 27 | 28 | rep_map = { 29 | ":": ",", 30 | ";": ",", 31 | ",": ",", 32 | "。": ".", 33 | "!": "!", 34 | "?": "?", 35 | "\n": ".", 36 | "·": ",", 37 | "、": ",", 38 | "...": "…", 39 | "$": ".", 40 | "“": "'", 41 | "”": "'", 42 | '"': "'", 43 | "‘": "'", 44 | "’": "'", 45 | "(": "'", 46 | ")": "'", 47 | "(": "'", 48 | ")": "'", 49 | "《": "'", 50 | "》": "'", 51 | "【": "'", 52 | "】": "'", 53 | "[": "'", 54 | "]": "'", 55 | "—": "-", 56 | "~": "-", 57 | "~": "-", 58 | "「": "'", 59 | "」": "'", 60 | } 61 | 62 | tone_modifier = ToneSandhi() 63 | 64 | 65 | def replace_punctuation(text): 66 | text = text.replace("嗯", "恩").replace("呣", "母") 67 | pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) 68 | 69 | replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) 70 | 71 | replaced_text = re.sub( 72 | r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text 73 | ) 74 | 75 | return replaced_text 76 | 77 | 78 | def g2p(text): 79 | pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) 80 | sentences = [i for i in re.split(pattern, text) if i.strip() != ""] 81 | phones, tones, word2ph = _g2p(sentences) 82 | assert sum(word2ph) == len(phones) 83 | assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch. 84 | phones = ["_"] + phones + ["_"] 85 | tones = [0] + tones + [0] 86 | word2ph = [1] + word2ph + [1] 87 | return phones, tones, word2ph 88 | 89 | 90 | def _get_initials_finals(word): 91 | initials = [] 92 | finals = [] 93 | orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) 94 | orig_finals = lazy_pinyin( 95 | word, neutral_tone_with_five=True, style=Style.FINALS_TONE3 96 | ) 97 | for c, v in zip(orig_initials, orig_finals): 98 | initials.append(c) 99 | finals.append(v) 100 | return initials, finals 101 | 102 | 103 | def _g2p(segments): 104 | phones_list = [] 105 | tones_list = [] 106 | word2ph = [] 107 | for seg in segments: 108 | # Replace all English words in the sentence 109 | seg = re.sub("[a-zA-Z]+", "", seg) 110 | seg_cut = psg.lcut(seg) 111 | initials = [] 112 | finals = [] 113 | seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) 114 | for word, pos in seg_cut: 115 | if pos == "eng": 116 | continue 117 | sub_initials, sub_finals = _get_initials_finals(word) 118 | sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) 119 | initials.append(sub_initials) 120 | finals.append(sub_finals) 121 | 122 | # assert len(sub_initials) == len(sub_finals) == len(word) 123 | initials = sum(initials, []) 124 | finals = sum(finals, []) 125 | # 126 | for c, v in zip(initials, finals): 127 | raw_pinyin = c + v 128 | # NOTE: post process for pypinyin outputs 129 | # we discriminate i, ii and iii 130 | if c == v: 131 | assert c in punctuation 132 | phone = [c] 133 | tone = "0" 134 | word2ph.append(1) 135 | else: 136 | v_without_tone = v[:-1] 137 | tone = v[-1] 138 | 139 | pinyin = c + v_without_tone 140 | assert tone in "12345" 141 | 142 | if c: 143 | # 多音节 144 | v_rep_map = { 145 | "uei": "ui", 146 | "iou": "iu", 147 | "uen": "un", 148 | } 149 | if v_without_tone in v_rep_map.keys(): 150 | pinyin = c + v_rep_map[v_without_tone] 151 | else: 152 | # 单音节 153 | pinyin_rep_map = { 154 | "ing": "ying", 155 | "i": "yi", 156 | "in": "yin", 157 | "u": "wu", 158 | } 159 | if pinyin in pinyin_rep_map.keys(): 160 | pinyin = pinyin_rep_map[pinyin] 161 | else: 162 | single_rep_map = { 163 | "v": "yu", 164 | "e": "e", 165 | "i": "y", 166 | "u": "w", 167 | } 168 | if pinyin[0] in single_rep_map.keys(): 169 | pinyin = single_rep_map[pinyin[0]] + pinyin[1:] 170 | 171 | assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) 172 | phone = pinyin_to_symbol_map[pinyin].split(" ") 173 | word2ph.append(len(phone)) 174 | 175 | phones_list += phone 176 | tones_list += [int(tone)] * len(phone) 177 | return phones_list, tones_list, word2ph 178 | 179 | 180 | def text_normalize(text): 181 | text = normalizer(text) 182 | text = replace_punctuation(text) 183 | return text 184 | 185 | 186 | def get_bert_feature(text, word2ph): 187 | from text import chinese_bert 188 | 189 | return chinese_bert.get_bert_feature(text, word2ph) 190 | 191 | 192 | if __name__ == "__main__": 193 | from text.chinese_bert import get_bert_feature 194 | 195 | text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" 196 | text = text_normalize(text) 197 | print(text) 198 | phones, tones, word2ph = g2p(text) 199 | bert = get_bert_feature(text, word2ph) 200 | 201 | print(phones, tones, word2ph, bert.shape) 202 | 203 | 204 | # # 示例用法 205 | # text = "这是一个示例文本:,你好!这是一个测试...." 206 | # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 207 | -------------------------------------------------------------------------------- /text/chinese_bert.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from transformers import AutoModelForMaskedLM, AutoTokenizer 5 | 6 | from config import config 7 | 8 | LOCAL_PATH = "./bert/chinese-roberta-wwm-ext-large" 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH) 11 | 12 | models = dict() 13 | 14 | 15 | def get_bert_feature( 16 | text, 17 | word2ph, 18 | device=config.bert_gen_config.device, 19 | style_text=None, 20 | style_weight=0.7, 21 | ): 22 | if ( 23 | sys.platform == "darwin" 24 | and torch.backends.mps.is_available() 25 | and device == "cpu" 26 | ): 27 | device = "mps" 28 | if not device: 29 | device = "cuda" 30 | if device not in models.keys(): 31 | models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device) 32 | with torch.no_grad(): 33 | inputs = tokenizer(text, return_tensors="pt") 34 | for i in inputs: 35 | inputs[i] = inputs[i].to(device) 36 | res = models[device](**inputs, output_hidden_states=True) 37 | res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() 38 | if style_text: 39 | style_inputs = tokenizer(style_text, return_tensors="pt") 40 | for i in style_inputs: 41 | style_inputs[i] = style_inputs[i].to(device) 42 | style_res = models[device](**style_inputs, output_hidden_states=True) 43 | style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() 44 | style_res_mean = style_res.mean(0) 45 | assert len(word2ph) == len(text) + 2 46 | word2phone = word2ph 47 | phone_level_feature = [] 48 | for i in range(len(word2phone)): 49 | if style_text: 50 | repeat_feature = ( 51 | res[i].repeat(word2phone[i], 1) * (1 - style_weight) 52 | + style_res_mean.repeat(word2phone[i], 1) * style_weight 53 | ) 54 | else: 55 | repeat_feature = res[i].repeat(word2phone[i], 1) 56 | phone_level_feature.append(repeat_feature) 57 | 58 | phone_level_feature = torch.cat(phone_level_feature, dim=0) 59 | 60 | return phone_level_feature.T 61 | 62 | 63 | if __name__ == "__main__": 64 | word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征 65 | word2phone = [ 66 | 1, 67 | 2, 68 | 1, 69 | 2, 70 | 2, 71 | 1, 72 | 2, 73 | 2, 74 | 1, 75 | 2, 76 | 2, 77 | 1, 78 | 2, 79 | 2, 80 | 2, 81 | 2, 82 | 2, 83 | 1, 84 | 1, 85 | 2, 86 | 2, 87 | 1, 88 | 2, 89 | 2, 90 | 2, 91 | 2, 92 | 1, 93 | 2, 94 | 2, 95 | 2, 96 | 2, 97 | 2, 98 | 1, 99 | 2, 100 | 2, 101 | 2, 102 | 2, 103 | 1, 104 | ] 105 | 106 | # 计算总帧数 107 | total_frames = sum(word2phone) 108 | print(word_level_feature.shape) 109 | print(word2phone) 110 | phone_level_feature = [] 111 | for i in range(len(word2phone)): 112 | print(word_level_feature[i].shape) 113 | 114 | # 对每个词重复word2phone[i]次 115 | repeat_feature = word_level_feature[i].repeat(word2phone[i], 1) 116 | phone_level_feature.append(repeat_feature) 117 | 118 | phone_level_feature = torch.cat(phone_level_feature, dim=0) 119 | print(phone_level_feature.shape) # torch.Size([36, 1024]) 120 | -------------------------------------------------------------------------------- /text/cleaner.py: -------------------------------------------------------------------------------- 1 | # from text import chinese, japanese, english, cleaned_text_to_sequence 2 | # language_module_map = {"ZH": chinese, "JP": japanese, "EN": english} 3 | 4 | from text import chinese, cleaned_text_to_sequence 5 | language_module_map = {"ZH": chinese} 6 | 7 | 8 | def clean_text(text, language): 9 | language_module = language_module_map[language] 10 | norm_text = language_module.text_normalize(text) 11 | phones, tones, word2ph = language_module.g2p(norm_text) 12 | return norm_text, phones, tones, word2ph 13 | 14 | 15 | def clean_text_bert(text, language): 16 | language_module = language_module_map[language] 17 | norm_text = language_module.text_normalize(text) 18 | phones, tones, word2ph = language_module.g2p(norm_text) 19 | bert = language_module.get_bert_feature(norm_text, word2ph) 20 | return phones, tones, bert 21 | 22 | 23 | def text_to_sequence(text, language): 24 | norm_text, phones, tones, word2ph = clean_text(text, language) 25 | return cleaned_text_to_sequence(phones, tones, language) 26 | 27 | 28 | if __name__ == "__main__": 29 | pass 30 | -------------------------------------------------------------------------------- /text/cmudict_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/text/cmudict_cache.pickle -------------------------------------------------------------------------------- /text/english_bert_mock.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from transformers import DebertaV2Model, DebertaV2Tokenizer 5 | 6 | from config import config 7 | 8 | 9 | LOCAL_PATH = "./bert/deberta-v3-large" 10 | 11 | tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH) 12 | 13 | models = dict() 14 | 15 | 16 | def get_bert_feature( 17 | text, 18 | word2ph, 19 | device=config.bert_gen_config.device, 20 | style_text=None, 21 | style_weight=0.7, 22 | ): 23 | if ( 24 | sys.platform == "darwin" 25 | and torch.backends.mps.is_available() 26 | and device == "cpu" 27 | ): 28 | device = "mps" 29 | if not device: 30 | device = "cuda" 31 | if device not in models.keys(): 32 | models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device) 33 | with torch.no_grad(): 34 | inputs = tokenizer(text, return_tensors="pt") 35 | for i in inputs: 36 | inputs[i] = inputs[i].to(device) 37 | res = models[device](**inputs, output_hidden_states=True) 38 | res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() 39 | if style_text: 40 | style_inputs = tokenizer(style_text, return_tensors="pt") 41 | for i in style_inputs: 42 | style_inputs[i] = style_inputs[i].to(device) 43 | style_res = models[device](**style_inputs, output_hidden_states=True) 44 | style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() 45 | style_res_mean = style_res.mean(0) 46 | assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph)) 47 | word2phone = word2ph 48 | phone_level_feature = [] 49 | for i in range(len(word2phone)): 50 | if style_text: 51 | repeat_feature = ( 52 | res[i].repeat(word2phone[i], 1) * (1 - style_weight) 53 | + style_res_mean.repeat(word2phone[i], 1) * style_weight 54 | ) 55 | else: 56 | repeat_feature = res[i].repeat(word2phone[i], 1) 57 | phone_level_feature.append(repeat_feature) 58 | 59 | phone_level_feature = torch.cat(phone_level_feature, dim=0) 60 | 61 | return phone_level_feature.T 62 | -------------------------------------------------------------------------------- /text/japanese_bert.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from transformers import AutoModelForMaskedLM, AutoTokenizer 5 | 6 | from config import config 7 | from text.japanese import text2sep_kata 8 | 9 | LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm" 10 | 11 | tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH) 12 | 13 | models = dict() 14 | 15 | 16 | def get_bert_feature( 17 | text, 18 | word2ph, 19 | device=config.bert_gen_config.device, 20 | style_text=None, 21 | style_weight=0.7, 22 | ): 23 | text = "".join(text2sep_kata(text)[0]) 24 | if style_text: 25 | style_text = "".join(text2sep_kata(style_text)[0]) 26 | if ( 27 | sys.platform == "darwin" 28 | and torch.backends.mps.is_available() 29 | and device == "cpu" 30 | ): 31 | device = "mps" 32 | if not device: 33 | device = "cuda" 34 | if device not in models.keys(): 35 | models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device) 36 | with torch.no_grad(): 37 | inputs = tokenizer(text, return_tensors="pt") 38 | for i in inputs: 39 | inputs[i] = inputs[i].to(device) 40 | res = models[device](**inputs, output_hidden_states=True) 41 | res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() 42 | if style_text: 43 | style_inputs = tokenizer(style_text, return_tensors="pt") 44 | for i in style_inputs: 45 | style_inputs[i] = style_inputs[i].to(device) 46 | style_res = models[device](**style_inputs, output_hidden_states=True) 47 | style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() 48 | style_res_mean = style_res.mean(0) 49 | 50 | assert len(word2ph) == len(text) + 2 51 | word2phone = word2ph 52 | phone_level_feature = [] 53 | for i in range(len(word2phone)): 54 | if style_text: 55 | repeat_feature = ( 56 | res[i].repeat(word2phone[i], 1) * (1 - style_weight) 57 | + style_res_mean.repeat(word2phone[i], 1) * style_weight 58 | ) 59 | else: 60 | repeat_feature = res[i].repeat(word2phone[i], 1) 61 | phone_level_feature.append(repeat_feature) 62 | 63 | phone_level_feature = torch.cat(phone_level_feature, dim=0) 64 | 65 | return phone_level_feature.T 66 | -------------------------------------------------------------------------------- /text/opencpop-strict.txt: -------------------------------------------------------------------------------- 1 | a AA a 2 | ai AA ai 3 | an AA an 4 | ang AA ang 5 | ao AA ao 6 | ba b a 7 | bai b ai 8 | ban b an 9 | bang b ang 10 | bao b ao 11 | bei b ei 12 | ben b en 13 | beng b eng 14 | bi b i 15 | bian b ian 16 | biao b iao 17 | bie b ie 18 | bin b in 19 | bing b ing 20 | bo b o 21 | bu b u 22 | ca c a 23 | cai c ai 24 | can c an 25 | cang c ang 26 | cao c ao 27 | ce c e 28 | cei c ei 29 | cen c en 30 | ceng c eng 31 | cha ch a 32 | chai ch ai 33 | chan ch an 34 | chang ch ang 35 | chao ch ao 36 | che ch e 37 | chen ch en 38 | cheng ch eng 39 | chi ch ir 40 | chong ch ong 41 | chou ch ou 42 | chu ch u 43 | chua ch ua 44 | chuai ch uai 45 | chuan ch uan 46 | chuang ch uang 47 | chui ch ui 48 | chun ch un 49 | chuo ch uo 50 | ci c i0 51 | cong c ong 52 | cou c ou 53 | cu c u 54 | cuan c uan 55 | cui c ui 56 | cun c un 57 | cuo c uo 58 | da d a 59 | dai d ai 60 | dan d an 61 | dang d ang 62 | dao d ao 63 | de d e 64 | dei d ei 65 | den d en 66 | deng d eng 67 | di d i 68 | dia d ia 69 | dian d ian 70 | diao d iao 71 | die d ie 72 | ding d ing 73 | diu d iu 74 | dong d ong 75 | dou d ou 76 | du d u 77 | duan d uan 78 | dui d ui 79 | dun d un 80 | duo d uo 81 | e EE e 82 | ei EE ei 83 | en EE en 84 | eng EE eng 85 | er EE er 86 | fa f a 87 | fan f an 88 | fang f ang 89 | fei f ei 90 | fen f en 91 | feng f eng 92 | fo f o 93 | fou f ou 94 | fu f u 95 | ga g a 96 | gai g ai 97 | gan g an 98 | gang g ang 99 | gao g ao 100 | ge g e 101 | gei g ei 102 | gen g en 103 | geng g eng 104 | gong g ong 105 | gou g ou 106 | gu g u 107 | gua g ua 108 | guai g uai 109 | guan g uan 110 | guang g uang 111 | gui g ui 112 | gun g un 113 | guo g uo 114 | ha h a 115 | hai h ai 116 | han h an 117 | hang h ang 118 | hao h ao 119 | he h e 120 | hei h ei 121 | hen h en 122 | heng h eng 123 | hong h ong 124 | hou h ou 125 | hu h u 126 | hua h ua 127 | huai h uai 128 | huan h uan 129 | huang h uang 130 | hui h ui 131 | hun h un 132 | huo h uo 133 | ji j i 134 | jia j ia 135 | jian j ian 136 | jiang j iang 137 | jiao j iao 138 | jie j ie 139 | jin j in 140 | jing j ing 141 | jiong j iong 142 | jiu j iu 143 | ju j v 144 | jv j v 145 | juan j van 146 | jvan j van 147 | jue j ve 148 | jve j ve 149 | jun j vn 150 | jvn j vn 151 | ka k a 152 | kai k ai 153 | kan k an 154 | kang k ang 155 | kao k ao 156 | ke k e 157 | kei k ei 158 | ken k en 159 | keng k eng 160 | kong k ong 161 | kou k ou 162 | ku k u 163 | kua k ua 164 | kuai k uai 165 | kuan k uan 166 | kuang k uang 167 | kui k ui 168 | kun k un 169 | kuo k uo 170 | la l a 171 | lai l ai 172 | lan l an 173 | lang l ang 174 | lao l ao 175 | le l e 176 | lei l ei 177 | leng l eng 178 | li l i 179 | lia l ia 180 | lian l ian 181 | liang l iang 182 | liao l iao 183 | lie l ie 184 | lin l in 185 | ling l ing 186 | liu l iu 187 | lo l o 188 | long l ong 189 | lou l ou 190 | lu l u 191 | luan l uan 192 | lun l un 193 | luo l uo 194 | lv l v 195 | lve l ve 196 | ma m a 197 | mai m ai 198 | man m an 199 | mang m ang 200 | mao m ao 201 | me m e 202 | mei m ei 203 | men m en 204 | meng m eng 205 | mi m i 206 | mian m ian 207 | miao m iao 208 | mie m ie 209 | min m in 210 | ming m ing 211 | miu m iu 212 | mo m o 213 | mou m ou 214 | mu m u 215 | na n a 216 | nai n ai 217 | nan n an 218 | nang n ang 219 | nao n ao 220 | ne n e 221 | nei n ei 222 | nen n en 223 | neng n eng 224 | ni n i 225 | nian n ian 226 | niang n iang 227 | niao n iao 228 | nie n ie 229 | nin n in 230 | ning n ing 231 | niu n iu 232 | nong n ong 233 | nou n ou 234 | nu n u 235 | nuan n uan 236 | nun n un 237 | nuo n uo 238 | nv n v 239 | nve n ve 240 | o OO o 241 | ou OO ou 242 | pa p a 243 | pai p ai 244 | pan p an 245 | pang p ang 246 | pao p ao 247 | pei p ei 248 | pen p en 249 | peng p eng 250 | pi p i 251 | pian p ian 252 | piao p iao 253 | pie p ie 254 | pin p in 255 | ping p ing 256 | po p o 257 | pou p ou 258 | pu p u 259 | qi q i 260 | qia q ia 261 | qian q ian 262 | qiang q iang 263 | qiao q iao 264 | qie q ie 265 | qin q in 266 | qing q ing 267 | qiong q iong 268 | qiu q iu 269 | qu q v 270 | qv q v 271 | quan q van 272 | qvan q van 273 | que q ve 274 | qve q ve 275 | qun q vn 276 | qvn q vn 277 | ran r an 278 | rang r ang 279 | rao r ao 280 | re r e 281 | ren r en 282 | reng r eng 283 | ri r ir 284 | rong r ong 285 | rou r ou 286 | ru r u 287 | rua r ua 288 | ruan r uan 289 | rui r ui 290 | run r un 291 | ruo r uo 292 | sa s a 293 | sai s ai 294 | san s an 295 | sang s ang 296 | sao s ao 297 | se s e 298 | sen s en 299 | seng s eng 300 | sha sh a 301 | shai sh ai 302 | shan sh an 303 | shang sh ang 304 | shao sh ao 305 | she sh e 306 | shei sh ei 307 | shen sh en 308 | sheng sh eng 309 | shi sh ir 310 | shou sh ou 311 | shu sh u 312 | shua sh ua 313 | shuai sh uai 314 | shuan sh uan 315 | shuang sh uang 316 | shui sh ui 317 | shun sh un 318 | shuo sh uo 319 | si s i0 320 | song s ong 321 | sou s ou 322 | su s u 323 | suan s uan 324 | sui s ui 325 | sun s un 326 | suo s uo 327 | ta t a 328 | tai t ai 329 | tan t an 330 | tang t ang 331 | tao t ao 332 | te t e 333 | tei t ei 334 | teng t eng 335 | ti t i 336 | tian t ian 337 | tiao t iao 338 | tie t ie 339 | ting t ing 340 | tong t ong 341 | tou t ou 342 | tu t u 343 | tuan t uan 344 | tui t ui 345 | tun t un 346 | tuo t uo 347 | wa w a 348 | wai w ai 349 | wan w an 350 | wang w ang 351 | wei w ei 352 | wen w en 353 | weng w eng 354 | wo w o 355 | wu w u 356 | xi x i 357 | xia x ia 358 | xian x ian 359 | xiang x iang 360 | xiao x iao 361 | xie x ie 362 | xin x in 363 | xing x ing 364 | xiong x iong 365 | xiu x iu 366 | xu x v 367 | xv x v 368 | xuan x van 369 | xvan x van 370 | xue x ve 371 | xve x ve 372 | xun x vn 373 | xvn x vn 374 | ya y a 375 | yan y En 376 | yang y ang 377 | yao y ao 378 | ye y E 379 | yi y i 380 | yin y in 381 | ying y ing 382 | yo y o 383 | yong y ong 384 | you y ou 385 | yu y v 386 | yv y v 387 | yuan y van 388 | yvan y van 389 | yue y ve 390 | yve y ve 391 | yun y vn 392 | yvn y vn 393 | za z a 394 | zai z ai 395 | zan z an 396 | zang z ang 397 | zao z ao 398 | ze z e 399 | zei z ei 400 | zen z en 401 | zeng z eng 402 | zha zh a 403 | zhai zh ai 404 | zhan zh an 405 | zhang zh ang 406 | zhao zh ao 407 | zhe zh e 408 | zhei zh ei 409 | zhen zh en 410 | zheng zh eng 411 | zhi zh ir 412 | zhong zh ong 413 | zhou zh ou 414 | zhu zh u 415 | zhua zh ua 416 | zhuai zh uai 417 | zhuan zh uan 418 | zhuang zh uang 419 | zhui zh ui 420 | zhun zh un 421 | zhuo zh uo 422 | zi z i0 423 | zong z ong 424 | zou z ou 425 | zu z u 426 | zuan z uan 427 | zui z ui 428 | zun z un 429 | zuo z uo 430 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | punctuation = ["!", "?", "…", ",", ".", "'", "-"] 2 | pu_symbols = punctuation + ["SP", "UNK"] 3 | pad = "_" 4 | 5 | # chinese 6 | zh_symbols = [ 7 | "E", 8 | "En", 9 | "a", 10 | "ai", 11 | "an", 12 | "ang", 13 | "ao", 14 | "b", 15 | "c", 16 | "ch", 17 | "d", 18 | "e", 19 | "ei", 20 | "en", 21 | "eng", 22 | "er", 23 | "f", 24 | "g", 25 | "h", 26 | "i", 27 | "i0", 28 | "ia", 29 | "ian", 30 | "iang", 31 | "iao", 32 | "ie", 33 | "in", 34 | "ing", 35 | "iong", 36 | "ir", 37 | "iu", 38 | "j", 39 | "k", 40 | "l", 41 | "m", 42 | "n", 43 | "o", 44 | "ong", 45 | "ou", 46 | "p", 47 | "q", 48 | "r", 49 | "s", 50 | "sh", 51 | "t", 52 | "u", 53 | "ua", 54 | "uai", 55 | "uan", 56 | "uang", 57 | "ui", 58 | "un", 59 | "uo", 60 | "v", 61 | "van", 62 | "ve", 63 | "vn", 64 | "w", 65 | "x", 66 | "y", 67 | "z", 68 | "zh", 69 | "AA", 70 | "EE", 71 | "OO", 72 | ] 73 | num_zh_tones = 6 74 | 75 | # japanese 76 | ja_symbols = [ 77 | "N", 78 | "a", 79 | "a:", 80 | "b", 81 | "by", 82 | "ch", 83 | "d", 84 | "dy", 85 | "e", 86 | "e:", 87 | "f", 88 | "g", 89 | "gy", 90 | "h", 91 | "hy", 92 | "i", 93 | "i:", 94 | "j", 95 | "k", 96 | "ky", 97 | "m", 98 | "my", 99 | "n", 100 | "ny", 101 | "o", 102 | "o:", 103 | "p", 104 | "py", 105 | "q", 106 | "r", 107 | "ry", 108 | "s", 109 | "sh", 110 | "t", 111 | "ts", 112 | "ty", 113 | "u", 114 | "u:", 115 | "w", 116 | "y", 117 | "z", 118 | "zy", 119 | ] 120 | num_ja_tones = 2 121 | 122 | # English 123 | en_symbols = [ 124 | "aa", 125 | "ae", 126 | "ah", 127 | "ao", 128 | "aw", 129 | "ay", 130 | "b", 131 | "ch", 132 | "d", 133 | "dh", 134 | "eh", 135 | "er", 136 | "ey", 137 | "f", 138 | "g", 139 | "hh", 140 | "ih", 141 | "iy", 142 | "jh", 143 | "k", 144 | "l", 145 | "m", 146 | "n", 147 | "ng", 148 | "ow", 149 | "oy", 150 | "p", 151 | "r", 152 | "s", 153 | "sh", 154 | "t", 155 | "th", 156 | "uh", 157 | "uw", 158 | "V", 159 | "w", 160 | "y", 161 | "z", 162 | "zh", 163 | ] 164 | num_en_tones = 4 165 | 166 | # combine all symbols 167 | normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols)) 168 | symbols = [pad] + normal_symbols + pu_symbols 169 | sil_phonemes_ids = [symbols.index(i) for i in pu_symbols] 170 | 171 | # combine all tones 172 | num_tones = num_zh_tones + num_ja_tones + num_en_tones 173 | 174 | # language maps 175 | language_id_map = {"ZH": 0, "JP": 1, "EN": 2} 176 | num_languages = len(language_id_map.keys()) 177 | 178 | language_tone_start_map = { 179 | "ZH": 0, 180 | "JP": num_zh_tones, 181 | "EN": num_zh_tones + num_ja_tones, 182 | } 183 | 184 | if __name__ == "__main__": 185 | a = set(zh_symbols) 186 | b = set(en_symbols) 187 | print(sorted(a & b)) 188 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具包 3 | """ 4 | -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/tools/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tools/__pycache__/log.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ywh-my/Bert-VITS2-FixBug/cb32db628b6a2ce60873eef140c571e0e767092a/tools/__pycache__/log.cpython-310.pyc -------------------------------------------------------------------------------- /tools/classify_language.py: -------------------------------------------------------------------------------- 1 | import regex as re 2 | 3 | try: 4 | from config import config 5 | 6 | LANGUAGE_IDENTIFICATION_LIBRARY = ( 7 | config.webui_config.language_identification_library 8 | ) 9 | except: 10 | LANGUAGE_IDENTIFICATION_LIBRARY = "langid" 11 | 12 | module = LANGUAGE_IDENTIFICATION_LIBRARY.lower() 13 | 14 | langid_languages = [ 15 | "af", 16 | "am", 17 | "an", 18 | "ar", 19 | "as", 20 | "az", 21 | "be", 22 | "bg", 23 | "bn", 24 | "br", 25 | "bs", 26 | "ca", 27 | "cs", 28 | "cy", 29 | "da", 30 | "de", 31 | "dz", 32 | "el", 33 | "en", 34 | "eo", 35 | "es", 36 | "et", 37 | "eu", 38 | "fa", 39 | "fi", 40 | "fo", 41 | "fr", 42 | "ga", 43 | "gl", 44 | "gu", 45 | "he", 46 | "hi", 47 | "hr", 48 | "ht", 49 | "hu", 50 | "hy", 51 | "id", 52 | "is", 53 | "it", 54 | "ja", 55 | "jv", 56 | "ka", 57 | "kk", 58 | "km", 59 | "kn", 60 | "ko", 61 | "ku", 62 | "ky", 63 | "la", 64 | "lb", 65 | "lo", 66 | "lt", 67 | "lv", 68 | "mg", 69 | "mk", 70 | "ml", 71 | "mn", 72 | "mr", 73 | "ms", 74 | "mt", 75 | "nb", 76 | "ne", 77 | "nl", 78 | "nn", 79 | "no", 80 | "oc", 81 | "or", 82 | "pa", 83 | "pl", 84 | "ps", 85 | "pt", 86 | "qu", 87 | "ro", 88 | "ru", 89 | "rw", 90 | "se", 91 | "si", 92 | "sk", 93 | "sl", 94 | "sq", 95 | "sr", 96 | "sv", 97 | "sw", 98 | "ta", 99 | "te", 100 | "th", 101 | "tl", 102 | "tr", 103 | "ug", 104 | "uk", 105 | "ur", 106 | "vi", 107 | "vo", 108 | "wa", 109 | "xh", 110 | "zh", 111 | "zu", 112 | ] 113 | 114 | 115 | def classify_language(text: str, target_languages: list = None) -> str: 116 | if module == "fastlid" or module == "fasttext": 117 | from fastlid import fastlid, supported_langs 118 | 119 | classifier = fastlid 120 | if target_languages != None: 121 | target_languages = [ 122 | lang for lang in target_languages if lang in supported_langs 123 | ] 124 | fastlid.set_languages = target_languages 125 | elif module == "langid": 126 | import langid 127 | 128 | classifier = langid.classify 129 | if target_languages != None: 130 | target_languages = [ 131 | lang for lang in target_languages if lang in langid_languages 132 | ] 133 | langid.set_languages(target_languages) 134 | else: 135 | raise ValueError(f"Wrong module {module}") 136 | 137 | lang = classifier(text)[0] 138 | 139 | return lang 140 | 141 | 142 | def classify_zh_ja(text: str) -> str: 143 | for idx, char in enumerate(text): 144 | unicode_val = ord(char) 145 | 146 | # 检测日语字符 147 | if 0x3040 <= unicode_val <= 0x309F or 0x30A0 <= unicode_val <= 0x30FF: 148 | return "ja" 149 | 150 | # 检测汉字字符 151 | if 0x4E00 <= unicode_val <= 0x9FFF: 152 | # 检查周围的字符 153 | next_char = text[idx + 1] if idx + 1 < len(text) else None 154 | 155 | if next_char and ( 156 | 0x3040 <= ord(next_char) <= 0x309F or 0x30A0 <= ord(next_char) <= 0x30FF 157 | ): 158 | return "ja" 159 | 160 | return "zh" 161 | 162 | 163 | def split_alpha_nonalpha(text, mode=1): 164 | if mode == 1: 165 | pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\d\s])(?=[\p{Latin}])|(?<=[\p{Latin}\s])(?=[\u4e00-\u9fff\u3040-\u30FF\d])" 166 | elif mode == 2: 167 | pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\s])(?=[\p{Latin}\d])|(?<=[\p{Latin}\d\s])(?=[\u4e00-\u9fff\u3040-\u30FF])" 168 | else: 169 | raise ValueError("Invalid mode. Supported modes are 1 and 2.") 170 | 171 | return re.split(pattern, text) 172 | 173 | 174 | if __name__ == "__main__": 175 | text = "这是一个测试文本" 176 | print(classify_language(text)) 177 | print(classify_zh_ja(text)) # "zh" 178 | 179 | text = "これはテストテキストです" 180 | print(classify_language(text)) 181 | print(classify_zh_ja(text)) # "ja" 182 | 183 | text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days" 184 | 185 | print(split_alpha_nonalpha(text, mode=1)) 186 | # output: ['vits', '和', 'Bert-VITS', '2是', 'tts', '模型。花费3', 'days.花费3天。Take 3 days'] 187 | 188 | print(split_alpha_nonalpha(text, mode=2)) 189 | # output: ['vits', '和', 'Bert-VITS2', '是', 'tts', '模型。花费', '3days.花费', '3', '天。Take 3 days'] 190 | 191 | text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days" 192 | print(split_alpha_nonalpha(text, mode=1)) 193 | # output: ['vits ', '和 ', 'Bert-VITS', '2 ', '是 ', 'tts ', '模型。花费3', 'days.花费3天。Take ', '3 ', 'days'] 194 | 195 | text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days" 196 | print(split_alpha_nonalpha(text, mode=2)) 197 | # output: ['vits ', '和 ', 'Bert-VITS2 ', '是 ', 'tts ', '模型。花费', '3days.花费', '3', '天。Take ', '3 ', 'days'] 198 | -------------------------------------------------------------------------------- /tools/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | logger封装 3 | """ 4 | 5 | from loguru import logger 6 | import sys 7 | 8 | 9 | # 移除所有默认的处理器 10 | logger.remove() 11 | 12 | # 自定义格式并添加到标准输出 13 | log_format = ( 14 | "{time:MM-DD HH:mm:ss} {level:<9}| {file}:{line} | {message}" 15 | ) 16 | 17 | logger.add(sys.stdout, format=log_format, backtrace=True, diagnose=True) 18 | -------------------------------------------------------------------------------- /tools/sentence.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import regex as re 4 | 5 | from tools.classify_language import classify_language, split_alpha_nonalpha 6 | 7 | 8 | def check_is_none(item) -> bool: 9 | """none -> True, not none -> False""" 10 | return ( 11 | item is None 12 | or (isinstance(item, str) and str(item).isspace()) 13 | or str(item) == "" 14 | ) 15 | 16 | 17 | def markup_language(text: str, target_languages: list = None) -> str: 18 | pattern = ( 19 | r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`" 20 | r"\!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」" 21 | r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+" 22 | ) 23 | sentences = re.split(pattern, text) 24 | 25 | pre_lang = "" 26 | p = 0 27 | 28 | if target_languages is not None: 29 | sorted_target_languages = sorted(target_languages) 30 | if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]: 31 | new_sentences = [] 32 | for sentence in sentences: 33 | new_sentences.extend(split_alpha_nonalpha(sentence)) 34 | sentences = new_sentences 35 | 36 | for sentence in sentences: 37 | if check_is_none(sentence): 38 | continue 39 | 40 | lang = classify_language(sentence, target_languages) 41 | 42 | if pre_lang == "": 43 | text = text[:p] + text[p:].replace( 44 | sentence, f"[{lang.upper()}]{sentence}", 1 45 | ) 46 | p += len(f"[{lang.upper()}]") 47 | elif pre_lang != lang: 48 | text = text[:p] + text[p:].replace( 49 | sentence, f"[{pre_lang.upper()}][{lang.upper()}]{sentence}", 1 50 | ) 51 | p += len(f"[{pre_lang.upper()}][{lang.upper()}]") 52 | pre_lang = lang 53 | p += text[p:].index(sentence) + len(sentence) 54 | text += f"[{pre_lang.upper()}]" 55 | 56 | return text 57 | 58 | 59 | def split_by_language(text: str, target_languages: list = None) -> list: 60 | pattern = ( 61 | r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`" 62 | r"\!?\。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」" 63 | r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+" 64 | ) 65 | sentences = re.split(pattern, text) 66 | 67 | pre_lang = "" 68 | start = 0 69 | end = 0 70 | sentences_list = [] 71 | 72 | if target_languages is not None: 73 | sorted_target_languages = sorted(target_languages) 74 | if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]: 75 | new_sentences = [] 76 | for sentence in sentences: 77 | new_sentences.extend(split_alpha_nonalpha(sentence)) 78 | sentences = new_sentences 79 | 80 | for sentence in sentences: 81 | if check_is_none(sentence): 82 | continue 83 | 84 | lang = classify_language(sentence, target_languages) 85 | 86 | end += text[end:].index(sentence) 87 | if pre_lang != "" and pre_lang != lang: 88 | sentences_list.append((text[start:end], pre_lang)) 89 | start = end 90 | end += len(sentence) 91 | pre_lang = lang 92 | sentences_list.append((text[start:], pre_lang)) 93 | 94 | return sentences_list 95 | 96 | 97 | def sentence_split(text: str, max: int) -> list: 98 | pattern = r"[!(),—+\-.:;??。,、;:]+" 99 | sentences = re.split(pattern, text) 100 | discarded_chars = re.findall(pattern, text) 101 | 102 | sentences_list, count, p = [], 0, 0 103 | 104 | # 按被分割的符号遍历 105 | for i, discarded_chars in enumerate(discarded_chars): 106 | count += len(sentences[i]) + len(discarded_chars) 107 | if count >= max: 108 | sentences_list.append(text[p : p + count].strip()) 109 | p += count 110 | count = 0 111 | 112 | # 加入最后剩余的文本 113 | if p < len(text): 114 | sentences_list.append(text[p:]) 115 | 116 | return sentences_list 117 | 118 | 119 | def sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None): 120 | # 如果该speaker只支持一种语言 121 | if speaker_lang is not None and len(speaker_lang) == 1: 122 | if lang.upper() not in ["AUTO", "MIX"] and lang.lower() != speaker_lang[0]: 123 | logging.debug( 124 | f'lang "{lang}" is not in speaker_lang {speaker_lang},automatically set lang={speaker_lang[0]}' 125 | ) 126 | lang = speaker_lang[0] 127 | 128 | sentences_list = [] 129 | if lang.upper() != "MIX": 130 | if max <= 0: 131 | sentences_list.append( 132 | markup_language(text, speaker_lang) 133 | if lang.upper() == "AUTO" 134 | else f"[{lang.upper()}]{text}[{lang.upper()}]" 135 | ) 136 | else: 137 | for i in sentence_split(text, max): 138 | if check_is_none(i): 139 | continue 140 | sentences_list.append( 141 | markup_language(i, speaker_lang) 142 | if lang.upper() == "AUTO" 143 | else f"[{lang.upper()}]{i}[{lang.upper()}]" 144 | ) 145 | else: 146 | sentences_list.append(text) 147 | 148 | for i in sentences_list: 149 | logging.debug(i) 150 | 151 | return sentences_list 152 | 153 | 154 | if __name__ == "__main__": 155 | text = "这几天心里颇不宁静。今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。我悄悄地披了大衫,带上门出去。" 156 | print(markup_language(text, target_languages=None)) 157 | print(sentence_split(text, max=50)) 158 | print(sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None)) 159 | 160 | text = "你好,这是一段用来测试自动标注的文本。こんにちは,これは自動ラベリングのテスト用テキストです.Hello, this is a piece of text to test autotagging.你好!今天我们要介绍VITS项目,其重点是使用了GAN Duration predictor和transformer flow,并且接入了Bert模型来提升韵律。Bert embedding会在稍后介绍。" 161 | print(split_by_language(text, ["zh", "ja", "en"])) 162 | 163 | text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days" 164 | 165 | print(split_by_language(text, ["zh", "ja", "en"])) 166 | # output: [('vits', 'en'), ('和', 'ja'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')] 167 | 168 | print(split_by_language(text, ["zh", "en"])) 169 | # output: [('vits', 'en'), ('和', 'zh'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')] 170 | 171 | text = "vits 和 Bert-VITS2 是 tts 模型。花费 3 days. 花费 3天。Take 3 days" 172 | print(split_by_language(text, ["zh", "en"])) 173 | # output: [('vits ', 'en'), ('和 ', 'zh'), ('Bert-VITS2 ', 'en'), ('是 ', 'zh'), ('tts ', 'en'), ('模型。花费 ', 'zh'), ('3 days. ', 'en'), ('花费 3天。', 'zh'), ('Take 3 days', 'en')] 174 | -------------------------------------------------------------------------------- /tools/translate.py: -------------------------------------------------------------------------------- 1 | """ 2 | 翻译api 3 | """ 4 | 5 | from config import config 6 | 7 | import random 8 | import hashlib 9 | import requests 10 | 11 | 12 | def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""): 13 | """ 14 | :param Sentence: 待翻译语句 15 | :param from_Language: 待翻译语句语言 16 | :param to_Language: 目标语言 17 | :return: 翻译后语句 出错时返回None 18 | 19 | 常见语言代码:中文 zh 英语 en 日语 jp 20 | """ 21 | appid = config.translate_config.app_key 22 | key = config.translate_config.secret_key 23 | if appid == "" or key == "": 24 | return "请开发者在config.yml中配置app_key与secret_key" 25 | url = "https://fanyi-api.baidu.com/api/trans/vip/translate" 26 | texts = Sentence.splitlines() 27 | outTexts = [] 28 | for t in texts: 29 | if t != "": 30 | # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113 31 | salt = str(random.randint(1, 100000)) 32 | signString = appid + t + salt + key 33 | hs = hashlib.md5() 34 | hs.update(signString.encode("utf-8")) 35 | signString = hs.hexdigest() 36 | if from_Language == "": 37 | from_Language = "auto" 38 | headers = {"Content-Type": "application/x-www-form-urlencoded"} 39 | payload = { 40 | "q": t, 41 | "from": from_Language, 42 | "to": to_Language, 43 | "appid": appid, 44 | "salt": salt, 45 | "sign": signString, 46 | } 47 | # 发送请求 48 | try: 49 | response = requests.post( 50 | url=url, data=payload, headers=headers, timeout=3 51 | ) 52 | response = response.json() 53 | if "trans_result" in response.keys(): 54 | result = response["trans_result"][0] 55 | if "dst" in result.keys(): 56 | dst = result["dst"] 57 | outTexts.append(dst) 58 | except Exception: 59 | return Sentence 60 | else: 61 | outTexts.append(t) 62 | return "\n".join(outTexts) 63 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | ( 80 | outputs[inside_interval_mask], 81 | logabsdet[inside_interval_mask], 82 | ) = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * ( 163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 164 | ) + input_heights * (input_delta - input_derivatives) 165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) 168 | c = -input_delta * (inputs - input_cumheights) 169 | 170 | discriminant = b.pow(2) - 4 * a * c 171 | assert (discriminant >= 0).all() 172 | 173 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 174 | outputs = root * input_bin_widths + input_cumwidths 175 | 176 | theta_one_minus_theta = root * (1 - root) 177 | denominator = input_delta + ( 178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 179 | * theta_one_minus_theta 180 | ) 181 | derivative_numerator = input_delta.pow(2) * ( 182 | input_derivatives_plus_one * root.pow(2) 183 | + 2 * input_delta * theta_one_minus_theta 184 | + input_derivatives * (1 - root).pow(2) 185 | ) 186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 187 | 188 | return outputs, -logabsdet 189 | else: 190 | theta = (inputs - input_cumwidths) / input_bin_widths 191 | theta_one_minus_theta = theta * (1 - theta) 192 | 193 | numerator = input_heights * ( 194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 195 | ) 196 | denominator = input_delta + ( 197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 198 | * theta_one_minus_theta 199 | ) 200 | outputs = input_cumheights + numerator / denominator 201 | 202 | derivative_numerator = input_delta.pow(2) * ( 203 | input_derivatives_plus_one * theta.pow(2) 204 | + 2 * input_delta * theta_one_minus_theta 205 | + input_derivatives * (1 - theta).pow(2) 206 | ) 207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 208 | 209 | return outputs, logabsdet 210 | -------------------------------------------------------------------------------- /微调用语音时长计算.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | 3 | def calculate_total_duration(txt_file): 4 | total_duration = 0.0 5 | 6 | with open(txt_file, 'r', encoding='utf-8') as f: 7 | for line in f: 8 | # 分割每行内容 9 | parts = line.strip().split('|') 10 | audio_path = parts[0] # 获取音频文件路径 11 | 12 | try: 13 | # 使用librosa加载音频文件 14 | y, sr = librosa.load(audio_path, sr=None) 15 | duration = librosa.get_duration(y=y, sr=sr) # 计算时长 16 | total_duration += duration 17 | except Exception as e: 18 | print(f"Error loading {audio_path}: {e}") 19 | 20 | return total_duration 21 | 22 | # 示例用法 23 | txt_file = 'A5_finetuned_trainingout/SSB0005_50/filelists/script.txt.cleaned.train' # 替换为你的txt文件路径 24 | total_duration = calculate_total_duration(txt_file) 25 | print(f"Total duration: {total_duration:.2f} seconds, {total_duration / 60 :.2f} min ") -------------------------------------------------------------------------------- /语音数据复制.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | 5 | def copy_random_wav_files(src_dir, dest_dir, N): 6 | # 确保目标目录存在 7 | os.makedirs(dest_dir, exist_ok=True) 8 | 9 | # 获取所有.wav文件 10 | wav_files = [f for f in os.listdir(src_dir) if f.endswith('.wav')] 11 | 12 | # 随机选择N个文件 13 | selected_files = random.sample(wav_files, min(N, len(wav_files))) 14 | 15 | # 复制选定的文件到目标目录 16 | for file in selected_files: 17 | shutil.copy(os.path.join(src_dir, file), os.path.join(dest_dir, file)) 18 | 19 | # 示例用法 20 | N = 50 21 | copy_random_wav_files('/data/data-aishell3/train/wav/SSB0273', f'A2_prepared_audios/SSB0273_{N}', N=N) --------------------------------------------------------------------------------