├── .dockerignore ├── .gitignore ├── Dockerfile ├── EmotiVoice_UserAgreement_易魔声用户协议.pdf ├── HTTP_API_TtsDemo ├── README.md └── apidemo │ ├── TtsDemo.py │ └── utils │ └── AuthV3Util.py ├── LICENSE ├── README.md ├── README.zh.md ├── README_小白安装教程.md ├── ROADMAP.md ├── assets └── audio │ ├── emotivoice_intro_cn.wav │ └── emotivoice_intro_en.wav ├── cn2an ├── an2cn.py └── conf.py ├── cog.yaml ├── config ├── joint │ ├── config.py │ └── config.yaml └── template.py ├── data ├── DataBaker │ ├── README.md │ └── src │ │ ├── step0_download.sh │ │ ├── step1_clean_raw_data.py │ │ └── step2_get_phoneme.py ├── LJspeech │ ├── README.md │ └── src │ │ ├── step0_download.sh │ │ ├── step1_clean_raw_data.py │ │ └── step2_get_phoneme.py ├── inference │ └── text └── youdao │ └── text │ ├── README.md │ ├── emotion │ ├── energy │ ├── pitch │ ├── speaker2 │ ├── speed │ └── tokenlist ├── demo_page.py ├── demo_page_databaker.py ├── frontend.py ├── frontend_cn.py ├── frontend_en.py ├── inference_am_vocoder_exp.py ├── inference_am_vocoder_joint.py ├── inference_tts.py ├── lexicon └── librispeech-lexicon.txt ├── mel_process.py ├── mfa ├── step1_create_dataset.py ├── step2_prepare_data.py ├── step3_prepare_special_tokens.py ├── step4_convert_text_to_phn.py ├── step5_prepare_alignment.py ├── step7_gen_alignment_from_textgrid.py ├── step8_make_data_list.py └── step9_datalist_from_mfa.py ├── models ├── hifigan │ ├── dataset.py │ ├── env.py │ ├── get_random_segments.py │ ├── get_vocoder.py │ ├── models.py │ └── pretrained_discriminator.py └── prompt_tts_modified │ ├── audio_processing.py │ ├── feats.py │ ├── jets.py │ ├── loss.py │ ├── model_open_source.py │ ├── modules │ ├── alignment.py │ ├── encoder.py │ ├── initialize.py │ └── variance.py │ ├── prompt_dataset.py │ ├── scheduler.py │ ├── simbert.py │ ├── stft.py │ ├── style_encoder.py │ └── tacotron_stft.py ├── openaiapi.py ├── plot_image.py ├── predict.py ├── prepare_for_training.py ├── requirements.openaiapi.txt ├── requirements.txt ├── setup.py ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py └── train_am_vocoder_joint.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | WangZeJun/ 3 | *.pyc 4 | .vscode/ 5 | __pycache__/ 6 | .idea/ 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM ubuntu:22.04 3 | 4 | # install app dependencies 5 | RUN apt-get update && apt-get install -y python3 python3-pip libsndfile1 6 | RUN python3 -m pip install torch==1.11.0 torchaudio numpy numba scipy transformers==4.26.1 soundfile yacs 7 | RUN python3 -m pip install pypinyin jieba 8 | 9 | # install app 10 | RUN mkdir /EmotiVoice 11 | COPY . /EmotiVoice/ 12 | 13 | # final configuration 14 | EXPOSE 8501 15 | RUN python3 -m pip install streamlit g2p_en 16 | WORKDIR /EmotiVoice 17 | RUN python3 frontend_en.py 18 | CMD streamlit run demo_page.py --server.port 8501 19 | -------------------------------------------------------------------------------- /EmotiVoice_UserAgreement_易魔声用户协议.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/netease-youdao/EmotiVoice/bc2de8c9eb1121237958ef154cb171e7faefc769/EmotiVoice_UserAgreement_易魔声用户协议.pdf -------------------------------------------------------------------------------- /HTTP_API_TtsDemo/README.md: -------------------------------------------------------------------------------- 1 | # 说明 2 | 项目为有道智云paas接口的python语言调用示例。您可以通过执行项目中的main函数快速调用有道智云相关api服务。 3 | 4 | # 运行环境 5 | 1. python 3.6版本及以上。 6 | 7 | # 运行方式 8 | 1. 在执行前您需要根据代码中的 中文提示 填写相关接口参数,具体参数取值可以访问 [智云官网](https://ai.youdao.com) 文档获取。 9 | 2. 同时您需要获取智云相关 应用ID应用密钥 信息。具体获取方式可以访问 [入门指南](https://ai.youdao.com/doc.s#guide) 获取帮助。 10 | 11 | # 注意事项 12 | 1. 项目中的代码有些仅作展示及参考,生产环境中请根据业务的实际情况进行修改。 13 | 2. 项目中接口返回的数据仅在控制台输出,实际使用中请根据实际情况进行解析。 -------------------------------------------------------------------------------- /HTTP_API_TtsDemo/apidemo/TtsDemo.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | from utils.AuthV3Util import addAuthParams 4 | 5 | # 您的应用ID 6 | APP_KEY = '' 7 | # 您的应用密钥 8 | APP_SECRET = '' 9 | 10 | # 合成音频保存路径, 例windows路径:PATH = "C:\\tts\\media.mp3" 11 | PATH = 'EmotiVoice-8051.mp3' 12 | 13 | 14 | def createRequest(): 15 | ''' 16 | note: 将下列变量替换为需要请求的参数 17 | ''' 18 | q = 'Emoti-Voice - a Multi-Voice and Prompt-Controlled T-T-S Engine,大家好' 19 | voiceName = 'Maria Kasper' # 'Cori Samuel' 20 | format = 'mp3' 21 | 22 | data = {'q': q, 'voiceName': voiceName, 'format': format} 23 | 24 | addAuthParams(APP_KEY, APP_SECRET, data) 25 | 26 | header = {'Content-Type': 'application/x-www-form-urlencoded'} 27 | res = doCall('https://openapi.youdao.com/ttsapi', header, data, 'post') 28 | saveFile(res) 29 | 30 | 31 | def doCall(url, header, params, method): 32 | if 'get' == method: 33 | return requests.get(url, params) 34 | elif 'post' == method: 35 | return requests.post(url, params, header) 36 | 37 | 38 | def saveFile(res): 39 | contentType = res.headers['Content-Type'] 40 | if 'audio' in contentType: 41 | fo = open(PATH, 'wb') 42 | fo.write(res.content) 43 | fo.close() 44 | print('save file path: ' + PATH) 45 | else: 46 | print(str(res.content, 'utf-8')) 47 | 48 | # 网易有道智云语音合成服务api调用demo 49 | # api接口: https://openapi.youdao.com/ttsapi 50 | if __name__ == '__main__': 51 | createRequest() 52 | -------------------------------------------------------------------------------- /HTTP_API_TtsDemo/apidemo/utils/AuthV3Util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import time 3 | import uuid 4 | 5 | ''' 6 | 添加鉴权相关参数 - 7 | appKey : 应用ID 8 | salt : 随机值 9 | curtime : 当前时间戳(秒) 10 | signType : 签名版本 11 | sign : 请求签名 12 | 13 | @param appKey 您的应用ID 14 | @param appSecret 您的应用密钥 15 | @param paramsMap 请求参数表 16 | ''' 17 | def addAuthParams(appKey, appSecret, params): 18 | q = params.get('q') 19 | if q is None: 20 | q = params.get('img') 21 | salt = str(uuid.uuid1()) 22 | curtime = str(int(time.time())) 23 | sign = calculateSign(appKey, appSecret, q, salt, curtime) 24 | params['appKey'] = appKey 25 | params['salt'] = salt 26 | params['curtime'] = curtime 27 | params['signType'] = 'v3' 28 | params['sign'] = sign 29 | 30 | ''' 31 | 计算鉴权签名 - 32 | 计算方式 : sign = sha256(appKey + input(q) + salt + curtime + appSecret) 33 | @param appKey 您的应用ID 34 | @param appSecret 您的应用密钥 35 | @param q 请求内容 36 | @param salt 随机值 37 | @param curtime 当前时间戳(秒) 38 | @return 鉴权签名sign 39 | ''' 40 | def calculateSign(appKey, appSecret, q, salt, curtime): 41 | strSrc = appKey + getInput(q) + salt + curtime + appSecret 42 | return encrypt(strSrc) 43 | 44 | 45 | def encrypt(strSrc): 46 | hash_algorithm = hashlib.sha256() 47 | hash_algorithm.update(strSrc.encode('utf-8')) 48 | return hash_algorithm.hexdigest() 49 | 50 | 51 | def getInput(input): 52 | if input is None: 53 | return input 54 | inputLen = len(input) 55 | return input if inputLen <= 20 else input[0:10] + str(inputLen) + input[inputLen - 10:inputLen] 56 | -------------------------------------------------------------------------------- /README.zh.md: -------------------------------------------------------------------------------- 1 | README: EN | 中文 2 | 3 | 4 |
5 |

EmotiVoice易魔声 😊: 多音色提示控制TTS

6 |
7 | 8 |
9 | 10 |      11 | 12 |      13 | 14 |      15 |
16 |
17 | 18 | **EmotiVoice**是一个强大的开源TTS引擎,**完全免费**,支持中英文双语,包含2000多种不同的音色,以及特色的**情感合成**功能,支持合成包含快乐、兴奋、悲伤、愤怒等广泛情感的语音。 19 | 20 | EmotiVoice提供一个易于使用的web界面,还有用于批量生成结果的脚本接口。 21 | 22 | 以下是EmotiVoice生成的几个示例: 23 | 24 | - [Chinese audio sample](https://github.com/netease-youdao/EmotiVoice/assets/3909232/6426d7c1-d620-4bfc-ba03-cd7fc046a4fb) 25 | 26 | - [English audio sample](https://github.com/netease-youdao/EmotiVoice/assets/3909232/8f272eba-49db-493b-b479-2d9e5a419e26) 27 | 28 | - [Fun Chinese English audio sample](https://github.com/netease-youdao/EmotiVoice/assets/3909232/a0709012-c3ef-4182-bb0e-b7a2ba386f1c) 29 | 30 | ## 热闻速递 31 | 32 | - [x] 类OpenAI TTS的API已经支持调语速功能,感谢 [@john9405](https://github.com/john9405). [#90](https://github.com/netease-youdao/EmotiVoice/pull/90) [#67](https://github.com/netease-youdao/EmotiVoice/issues/67) [#77](https://github.com/netease-youdao/EmotiVoice/issues/77) 33 | - [x] [Mac版一键安装包](https://github.com/netease-youdao/EmotiVoice/releases/download/v0.3/emotivoice-1.0.0-arm64.dmg) 已于2023年12月28日发布,**强烈推荐尽快下载使用,免费好用!** 34 | - [x] [易魔声 HTTP API](https://github.com/netease-youdao/EmotiVoice/wiki/HTTP-API) 已于2023年12月6日发布上线。更易上手(无需任何安装配置),更快更稳定,单账户提供**超过 13,000 次免费调用**。此外,用户还可以使用[智云](https://ai.youdao.com/)提供的其它迷人的声音。 35 | - [x] [用你自己的数据定制音色](https://github.com/netease-youdao/EmotiVoice/wiki/Voice-Cloning-with-your-personal-data)已于2023年12月13日发布上线,同时提供了两个教程示例:[DataBaker Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/DataBaker) [LJSpeech Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/LJspeech)。 36 | 37 | ## 开发中的特性 38 | 39 | - [ ] 更多语言支持,例如日韩 [#19](https://github.com/netease-youdao/EmotiVoice/issues/19) [#22](https://github.com/netease-youdao/EmotiVoice/issues/22) 40 | 41 | 易魔声倾听社区需求并积极响应,期待您的反馈! 42 | 43 | ## 快速入门 44 | 45 | ### EmotiVoice Docker镜像 46 | 47 | 尝试EmotiVoice最简单的方法是运行docker镜像。你需要一台带有NVidia GPU的机器。先按照[Linux](https://www.server-world.info/en/note?os=Ubuntu_22.04&p=nvidia&f=2)和[Windows WSL2](https://zhuanlan.zhihu.com/p/653173679)平台的说明安装NVidia容器工具包。然后可以直接运行EmotiVoice镜像: 48 | 49 | ```sh 50 | docker run -dp 127.0.0.1:8501:8501 syq163/emoti-voice:latest 51 | ``` 52 | 53 | Docker镜像更新于2024年1月4号。如果你使用了老的版本,推荐运行如下命令进行更新: 54 | ```sh 55 | docker pull syq163/emoti-voice:latest 56 | docker run -dp 127.0.0.1:8501:8501 -p 127.0.0.1:8000:8000 syq163/emoti-voice:latest 57 | ``` 58 | 59 | 现在打开浏览器,导航到 http://localhost:8501 ,就可以体验EmotiVoice强大的TTS功能。从2024年的docker镜像版本开始,通过http://localhost:8000/可以使用类OpenAI TTS的API功能。 60 | 61 | ### 完整安装 62 | 63 | ```sh 64 | conda create -n EmotiVoice python=3.8 -y 65 | conda activate EmotiVoice 66 | pip install torch torchaudio 67 | pip install numpy numba scipy transformers soundfile yacs g2p_en jieba pypinyin pypinyin_dict 68 | python -m nltk.downloader "averaged_perceptron_tagger_eng" 69 | ``` 70 | 71 | ### 准备模型文件 72 | 73 | 强烈推荐用户参考[如何下载预训练模型文件](https://github.com/netease-youdao/EmotiVoice/wiki/Pretrained-models)的维基页面,尤其遇到问题时。 74 | 75 | ```sh 76 | git lfs install 77 | git lfs clone https://huggingface.co/WangZeJun/simbert-base-chinese WangZeJun/simbert-base-chinese 78 | ``` 79 | 80 | 或者你可以运行: 81 | ```sh 82 | git clone https://www.modelscope.cn/syq163/WangZeJun.git 83 | ``` 84 | 85 | ### 推理 86 | 87 | 1. 通过简单运行如下命令来下载[预训练模型](https://drive.google.com/drive/folders/1y6Xwj_GG9ulsAonca_unSGbJ4lxbNymM?usp=sharing): 88 | 89 | ```sh 90 | git clone https://www.modelscope.cn/syq163/outputs.git 91 | ``` 92 | 93 | 2. 推理输入文本格式是:`|||`. 94 | 95 | - 例如: `8051|非常开心| uo3 sp1 l ai2 sp0 d ao4 sp1 b ei3 sp0 j ing1 sp3 q ing1 sp0 h ua2 sp0 d a4 sp0 x ve2 |我来到北京,清华大学`. 96 | 4. 其中的音素(phonemes)可以这样得到:`python frontend.py data/my_text.txt > data/my_text_for_tts.txt`. 97 | 98 | 5. 然后运行: 99 | ```sh 100 | TEXT=data/inference/text 101 | python inference_am_vocoder_joint.py \ 102 | --logdir prompt_tts_open_source_joint \ 103 | --config_folder config/joint \ 104 | --checkpoint g_00140000 \ 105 | --test_file $TEXT 106 | ``` 107 | 合成的语音结果在:`outputs/prompt_tts_open_source_joint/test_audio`. 108 | 109 | 6. 或者你可以直接使用交互的网页界面: 110 | ```sh 111 | pip install streamlit 112 | streamlit run demo_page.py 113 | ``` 114 | 115 | ### 类OpenAI TTS的API 116 | 117 | 非常感谢 @lewangdev 的相关该工作 [#60](../../issues/60)。通过运行如下命令来完成配置: 118 | 119 | ```sh 120 | pip install fastapi pydub uvicorn[standard] pyrubberband 121 | uvicorn openaiapi:app --reload 122 | ``` 123 | 124 | ### Wiki页面 125 | 126 | 如果遇到问题,或者想获取更多详情,请参考 [wiki](https://github.com/netease-youdao/EmotiVoice/wiki) 页面。 127 | 128 | ## 训练 129 | 130 | [用你自己的数据定制音色](https://github.com/netease-youdao/EmotiVoice/wiki/Voice-Cloning-with-your-personal-data)已于2023年12月13日发布上线。 131 | 132 | ## 路线图和未来的工作 133 | 134 | - 我们未来的计划可以在 [ROADMAP](./ROADMAP.md) 文件中找到。 135 | 136 | - 当前的实现侧重于通过提示控制情绪/风格。它只使用音高、速度、能量和情感作为风格因素,而不使用性别。但是将其更改为样式、音色控制并不复杂,类似于PromptTTS的原始闭源实现。 137 | 138 | ## 微信群 139 | 140 | 欢迎扫描下方左侧二维码加入微信群。商业合作扫描右侧个人二维码。 141 | 142 | qr 143 |      144 | qr 145 | 146 | ## 致谢 147 | 148 | - [PromptTTS](https://speechresearch.github.io/prompttts/). PromptTTS论文是本工作的重要基础。 149 | - [LibriTTS](https://www.openslr.org/60/). 训练使用了LibriTTS开放数据集。 150 | - [HiFiTTS](https://www.openslr.org/109/). 训练使用了HiFi TTS开放数据集。 151 | - [ESPnet](https://github.com/espnet/espnet). 152 | - [WeTTS](https://github.com/wenet-e2e/wetts) 153 | - [HiFi-GAN](https://github.com/jik876/hifi-gan) 154 | - [Transformers](https://github.com/huggingface/transformers) 155 | - [tacotron](https://github.com/keithito/tacotron) 156 | - [KAN-TTS](https://github.com/alibaba-damo-academy/KAN-TTS) 157 | - [StyleTTS](https://github.com/yl4579/StyleTTS) 158 | - [Simbert](https://github.com/ZhuiyiTechnology/simbert) 159 | - [cn2an](https://github.com/Ailln/cn2an). 易魔声集成了cn2an来处理数字。 160 | 161 | ## 许可 162 | 163 | EmotiVoice是根据Apache-2.0许可证提供的 - 有关详细信息,请参阅[许可证文件](./LICENSE)。 164 | 165 | 交互的网页是根据[用户协议](./EmotiVoice_UserAgreement_易魔声用户协议.pdf)提供的。 166 | -------------------------------------------------------------------------------- /README_小白安装教程.md: -------------------------------------------------------------------------------- 1 | ## 小白安装教程 2 | 3 | #### 环境条件:设备有GPU、已经安装cuda 4 | 5 | 说明:这是针对Linux环境安装的教程,其他系统可作为参考。 6 | 7 | #### 1、创建并进入conda环境 8 | 9 | ``` 10 | conda create -n EmotiVoice python=3.8 11 | conda init 12 | conda activate EmotiVoice 13 | ``` 14 | 15 | 如果你不想使用conda环境,也可以省略该步骤,但要保证python版本为3.8 16 | 17 | 18 | #### 2、安装git-lfs 19 | 20 | 如果是Ubuntu则执行 21 | 22 | ``` 23 | sudo apt update 24 | sudo apt install git 25 | sudo apt-get install git-lfs 26 | ``` 27 | 28 | CentOS则执行 29 | 30 | ``` 31 | sudo yum update 32 | sudo yum install git 33 | sudo yum install git-lfs 34 | ``` 35 | 36 | 37 | 38 | #### 3、克隆仓库 39 | 40 | ``` 41 | git lfs install 42 | git lfs clone https://github.com/netease-youdao/EmotiVoice.git 43 | ``` 44 | 45 | 46 | 47 | #### 4、安装依赖 48 | 49 | ``` 50 | pip install torch torchaudio 51 | pip install numpy numba scipy transformers soundfile yacs g2p_en jieba pypinyin pypinyin_dict 52 | python -m nltk.downloader "averaged_perceptron_tagger_eng" 53 | ``` 54 | 55 | 56 | 57 | 58 | 59 | #### 5、下载预训练模型文件 60 | 61 | (1)首先进入项目文件夹 62 | 63 | ``` 64 | cd EmotiVoice 65 | ``` 66 | 67 | (2)执行下面命令 68 | 69 | ``` 70 | git lfs clone https://huggingface.co/WangZeJun/simbert-base-chinese WangZeJun/simbert-base-chinese 71 | ``` 72 | 73 | 或者 74 | 75 | ``` 76 | git clone https://www.modelscope.cn/syq163/WangZeJun.git 77 | ``` 78 | 79 | 上面两种下载方式二选一即可。 80 | 81 | (3)第三步下载ckpt模型 82 | 83 | ``` 84 | git clone https://www.modelscope.cn/syq163/outputs.git 85 | ``` 86 | 87 | 上面步骤完成后,项目文件夹内会多 `WangZeJun` 和 `outputs` 文件夹,下面是项目文件结构 88 | 89 | ``` 90 | ├── Dockerfile 91 | ├── EmotiVoice_UserAgreement_易魔声用户协议.pdf 92 | ├── demo_page.py 93 | ├── frontend.py 94 | ├── frontend_cn.py 95 | ├── frontend_en.py 96 | ├── WangZeJun 97 | │ └── simbert-base-chinese 98 | │ ├── README.md 99 | │ ├── config.json 100 | │ ├── pytorch_model.bin 101 | │ └── vocab.txt 102 | ├── outputs 103 | │ ├── README.md 104 | │ ├── configuration.json 105 | │ ├── prompt_tts_open_source_joint 106 | │ │ └── ckpt 107 | │ │ ├── do_00140000 108 | │ │ └── g_00140000 109 | │ └── style_encoder 110 | │ └── ckpt 111 | │ └── checkpoint_163431 112 | ``` 113 | 114 | 115 | 116 | #### 6、运行UI交互界面 117 | 118 | (1)安装streamlit 119 | 120 | ``` 121 | pip install streamlit 122 | ``` 123 | 124 | (2)启动 125 | 126 | 打开运行后显示的server地址,如何正常显示页面则部署完成。 127 | 128 | ``` 129 | streamlit run demo_page.py --server.port 6006 --logger.level debug 130 | ``` 131 | 132 | 133 | 134 | #### 7、启动API服务 135 | 136 | 安装依赖 137 | 138 | ``` 139 | pip install fastapi pydub uvicorn[standard] pyrubberband 140 | ``` 141 | 142 | 在6006端口启动服务(端口可根据自己的需求修改) 143 | 144 | ``` 145 | uvicorn openaiapi:app --reload --port 6006 146 | ``` 147 | 148 | 接口文档地址:你的服务地址+`/docs` 149 | 150 |   151 | 152 | #### 8、遇到错误 153 | 154 | **(1) 运行UI界面后,打开页面一直显示 "Please wait..." 或者显示一片空白** 155 | 156 | 原因: 157 | 158 | 这个错误可能是由于CORS(跨域资源共享)保护配置错误。 159 | 160 | 解决方法: 161 | 162 | 在启动时加上一个 `server.enableCORS=false` 参数,即使用下面命令启动程序 163 | 164 | ``` 165 | streamlit run demo_page.py --server.port 6006 --logger.level debug --server.enableCORS=false 166 | ``` 167 | 168 | 如果通过临时禁用 CORS 保护解决了问题,建议重新启用 CORS 保护并设置正确的 URL 和端口。 169 | 170 |   171 | 172 | **(2) 运行报错 raise BadZipFile("File is not a zip file") zipfile.BadZipFile: File is not a zip file** 173 | 174 | 原因: 175 | 176 | 这可能是由于缺少 `averaged_perceptron_tagger` 这个`nltk`中用于词性标注的一个包,它包含了一个基于平均感知器算法的词性标注器。如果你在代码中使用了这个标注器,但是没有预先下载对应的数据包,就会遇到错误,提示你缺少`averaged_perceptron_tagger.zip`文件。当然也有可能是缺少 `cmudict` CMU 发音词典数据包文件。 177 | 178 | 正常来说,初次运行程序NLTK会自动下载使用的相关数据包,debug模式下运行会显示如下信息 179 | 180 | ``` 181 | [nltk_data] Downloading package averaged_perceptron_tagger to 182 | [nltk_data] /root/nltk_data... 183 | [nltk_data] Unzipping taggers/averaged_perceptron_tagger.zip. 184 | [nltk_data] Downloading package cmudict to /root/nltk_data... 185 | [nltk_data] Unzipping corpora/cmudict.zip. 186 | ``` 187 | 188 | 可能由于网络(需科学上网)等原因,没能自动下载成功,因此缺少相关文件导致加载报错。 189 | 190 | 191 | 192 | 解决方法:重新下载缺少的数据包文件 193 | 194 | 195 | 196 | 1)方法一 197 | 198 | 创建一个 download.py文件,在其中编写如下代码 199 | 200 | ``` 201 | import nltk 202 | print(nltk.data.path) 203 | nltk.download('averaged_perceptron_tagger') 204 | nltk.download('cmudict') 205 | ``` 206 | 207 | 保存并运行 208 | 209 | ``` 210 | python download.py 211 | ``` 212 | 213 | 这将显示其文件索引位置,并自动下载 缺少的 `averaged_perceptron_tagger.zip`和 `cmudict.zip` 文件到/root/nltk_data目录下的子目录,下载完成后查看根目录下是否有`nltk_data`文件夹,并将其中的压缩包都解压。 214 | 215 |   216 | 217 | 2)方法二 218 | 219 | 如果通过上面代码还是无法正常下载数据包 ,也可以打开以下地址手动搜索并下载压缩包文件(需科学上网) 220 | 221 | ``` 222 | https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/index.xml 223 | ``` 224 | 225 | 其中下面是`averaged_perceptron_tagger.zip` 和`cmudict.zip` 数据包文件的下载地址 226 | 227 | ``` 228 | https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/taggers/averaged_perceptron_tagger.zip 229 | https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/cmudict.zip 230 | ``` 231 | 232 | 然后将该压缩包文件上传至(1)运行`python download.py`时打印显示的文件索引位置,如 `/root/nltk_data` 或者 `/root/miniconda3/envs/EmotiVoice/nltk_data` 等类似目录下,如果没有则创建一个,然后将zip压缩包解压。 233 | 234 |   235 | 236 | 解压后nltk_data目录结构应该是下面这样 237 | 238 | ``` 239 | ├── nltk_data 240 | │ ├── corpora 241 | │ │ ├── cmudict 242 | │ │ │ ├── README 243 | │ │ │ └── cmudict 244 | │ │ └── cmudict.zip 245 | │ └── taggers 246 | │ ├── averaged_perceptron_tagger 247 | │ │ └── averaged_perceptron_tagger.pickle 248 | │ └── averaged_perceptron_tagger.zip 249 | ``` 250 | 251 |   252 | 253 | **(3) 报错 AttributeError: 'NoneType' object has no attribute 'seek'.** 254 | 255 | 原因:未找到模型文件 256 | 257 | 解决方法:大概率是你未下载模型文件或者存放路径不正确,查看自己下载的模型文件是否存在,即outputs文件夹存放路径和里面的模型文件是否正确,正确结构可参考 [第五步](#step5) 中的项目结构。 258 | 259 |   260 | 261 | **(4) 运行API服务出错 ImportError: cannot import name 'Doc' from 'typing_extensions'** 262 | 263 | 原因:typing_extensions 版本问题 264 | 265 | 解决方法: 266 | 267 | 尝试将`typing_extensions`升级至最新版本,如果已经是最新版本,则适当降低版本,以下版本在`fastapi V0.104.1`测试正常。 268 | 269 | ``` 270 | pip install typing_extensions==4.8.0 --force-reinstall 271 | ``` 272 | 273 |   274 | 275 | **(5) 请求文本转语音接口时报错 500 Internal Server Error ,FileNotFoundError: [Errno 2] No such file or directory: 'ffmpeg'** 276 | 277 | 原因:未安装ffmpeg 278 | 279 | 解决方法: 280 | 281 | 执行以下命令进行安装,如果是Ubuntu执行 282 | 283 | ``` 284 | sudo apt update 285 | sudo apt install ffmpeg 286 | ``` 287 | 288 | CentOS则执行 289 | 290 | ``` 291 | sudo yum install epel-release 292 | sudo yum install ffmpeg 293 | ``` 294 | 295 | 安装完成后,你可以在终端中运行以下命令来验证"ffmpeg"是否成功安装: 296 | 297 | ``` 298 | ffmpeg -version 299 | ``` 300 | 301 | 如果安装成功,你将看到"ffmpeg"的版本信息。 302 | 303 |   304 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # EmotiVoice Roadmap 2 | 3 | This roadmap is for EmotiVoice (易魔声), a project driven by the community. We value your feedback and suggestions on our future direction. 4 | 5 | Please visit https://github.com/netease-youdao/EmotiVoice/issues on GitHub to submit your proposals. 6 | If you are interested, feel free to volunteer for any tasks, even if they are not listed. 7 | 8 | The plan is to finish 0.2 to 0.4 in Q4 2023. 9 | 10 | ## EmotiVoice 0.4 11 | 12 | - [ ] Updated model with potentially improved quality. 13 | - [ ] First version of desktop application. 14 | - [ ] Support longer text. 15 | 16 | ## EmotiVoice 0.3 (2023.12.13) 17 | 18 | - [x] Release [The EmotiVoice HTTP API](https://github.com/netease-youdao/EmotiVoice/wiki/HTTP-API) provided by [Zhiyun](https://mp.weixin.qq.com/s/_Fbj4TI4ifC6N7NFOUrqKQ). 19 | - [x] Release [Voice Cloning with your personal data](https://github.com/netease-youdao/EmotiVoice/wiki/Voice-Cloning-with-your-personal-data) along with [DataBaker Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/DataBaker) and [LJSpeech Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/LJspeech). 20 | - [x] Documentation: wiki page for hardware requirements. [#30](../../issues/30) 21 | 22 | ## EmotiVoice 0.2 (2023.11.17) 23 | 24 | - [x] Support mixed Chinese and English input text. [#28](../../issues/28) 25 | - [x] Resolve bugs related to certain modal particles, to make it more robust. [#18](../../issues/18) 26 | - [x] Documentation: voice list wiki page 27 | - [x] Documentation: this roadmap. 28 | 29 | ## EmotiVoice 0.1 (2023.11.10) first public version 30 | 31 | - [x] We offer a pretrained model with over 2000 voices, supporting both Chinese and English languages. 32 | - [x] You can perform inference using the command line interface. We also offer a user-friendly web demo for easy usage. 33 | - [x] For convenient deployment, we offer a Docker image. 34 | 35 | -------------------------------------------------------------------------------- /assets/audio/emotivoice_intro_cn.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/netease-youdao/EmotiVoice/bc2de8c9eb1121237958ef154cb171e7faefc769/assets/audio/emotivoice_intro_cn.wav -------------------------------------------------------------------------------- /assets/audio/emotivoice_intro_en.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/netease-youdao/EmotiVoice/bc2de8c9eb1121237958ef154cb171e7faefc769/assets/audio/emotivoice_intro_en.wav -------------------------------------------------------------------------------- /cn2an/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/Ailln/cn2an. 3 | """ 4 | 5 | NUMBER_CN2AN = { 6 | "零": 0, 7 | "〇": 0, 8 | "一": 1, 9 | "壹": 1, 10 | "幺": 1, 11 | "二": 2, 12 | "贰": 2, 13 | "两": 2, 14 | "三": 3, 15 | "叁": 3, 16 | "四": 4, 17 | "肆": 4, 18 | "五": 5, 19 | "伍": 5, 20 | "六": 6, 21 | "陆": 6, 22 | "七": 7, 23 | "柒": 7, 24 | "八": 8, 25 | "捌": 8, 26 | "九": 9, 27 | "玖": 9, 28 | } 29 | UNIT_CN2AN = { 30 | "十": 10, 31 | "拾": 10, 32 | "百": 100, 33 | "佰": 100, 34 | "千": 1000, 35 | "仟": 1000, 36 | "万": 10000, 37 | "亿": 100000000, 38 | } 39 | UNIT_LOW_AN2CN = { 40 | 10: "十", 41 | 100: "百", 42 | 1000: "千", 43 | 10000: "万", 44 | 100000000: "亿", 45 | } 46 | NUMBER_LOW_AN2CN = { 47 | 0: "零", 48 | 1: "一", 49 | 2: "二", 50 | 3: "三", 51 | 4: "四", 52 | 5: "五", 53 | 6: "六", 54 | 7: "七", 55 | 8: "八", 56 | 9: "九", 57 | } 58 | NUMBER_UP_AN2CN = { 59 | 0: "零", 60 | 1: "壹", 61 | 2: "贰", 62 | 3: "叁", 63 | 4: "肆", 64 | 5: "伍", 65 | 6: "陆", 66 | 7: "柒", 67 | 8: "捌", 68 | 9: "玖", 69 | } 70 | UNIT_LOW_ORDER_AN2CN = [ 71 | "", 72 | "十", 73 | "百", 74 | "千", 75 | "万", 76 | "十", 77 | "百", 78 | "千", 79 | "亿", 80 | "十", 81 | "百", 82 | "千", 83 | "万", 84 | "十", 85 | "百", 86 | "千", 87 | ] 88 | UNIT_UP_ORDER_AN2CN = [ 89 | "", 90 | "拾", 91 | "佰", 92 | "仟", 93 | "万", 94 | "拾", 95 | "佰", 96 | "仟", 97 | "亿", 98 | "拾", 99 | "佰", 100 | "仟", 101 | "万", 102 | "拾", 103 | "佰", 104 | "仟", 105 | ] 106 | STRICT_CN_NUMBER = { 107 | "零": "零", 108 | "一": "一壹", 109 | "二": "二贰", 110 | "三": "三叁", 111 | "四": "四肆", 112 | "五": "五伍", 113 | "六": "六陆", 114 | "七": "七柒", 115 | "八": "八捌", 116 | "九": "九玖", 117 | "十": "十拾", 118 | "百": "百佰", 119 | "千": "千仟", 120 | "万": "万", 121 | "亿": "亿", 122 | } 123 | NORMAL_CN_NUMBER = { 124 | "零": "零〇", 125 | "一": "一壹幺", 126 | "二": "二贰两", 127 | "三": "三叁仨", 128 | "四": "四肆", 129 | "五": "五伍", 130 | "六": "六陆", 131 | "七": "七柒", 132 | "八": "八捌", 133 | "九": "九玖", 134 | "十": "十拾", 135 | "百": "百佰", 136 | "千": "千仟", 137 | "万": "万", 138 | "亿": "亿", 139 | } 140 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | # a list of ubuntu apt packages to install 8 | # system_packages: 9 | # - "libgl1-mesa-glx" 10 | # - "libglib2.0-0" 11 | 12 | python_version: "3.8" 13 | python_packages: 14 | - "torch==2.0.1" 15 | - "torchaudio==2.0.2" 16 | - "g2p-en==2.1.0" 17 | - "jieba==0.42.1" 18 | - "numba==0.58.1" 19 | - "numpy==1.24.4" 20 | - "pypinyin==0.49.0" 21 | - "scipy==1.10.1" 22 | - "soundfile==0.12.1" 23 | - "transformers==4.26.1" 24 | - "yacs==0.1.8" 25 | 26 | run: 27 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 28 | 29 | # predict.py defines how predictions are run on your model 30 | predict: "predict.py:Predictor" 31 | -------------------------------------------------------------------------------- /config/joint/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | # with thanks to arjun-234 in https://github.com/netease-youdao/EmotiVoice/pull/38. 18 | def get_labels_length(file_path): 19 | """ 20 | Return labels and their count in a file. 21 | 22 | Args: 23 | file_path (str): The path to the file containing the labels. 24 | 25 | Returns: 26 | list: labels; int: The number of labels in the file. 27 | """ 28 | with open(file_path, encoding = "UTF-8") as f: 29 | tokens = [t.strip() for t in f.readlines()] 30 | return tokens, len(tokens) 31 | 32 | class Config: 33 | #### PATH #### 34 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__")) 35 | DATA_DIR = ROOT_DIR + "/data/youdao/" 36 | train_data_path = DATA_DIR + "train_am/datalist.jsonl" 37 | valid_data_path = DATA_DIR + "valid_am/datalist.jsonl" 38 | output_directory = ROOT_DIR + "/outputs" 39 | speaker2id_path = DATA_DIR + "text/speaker2" 40 | emotion2id_path = DATA_DIR + "text/emotion" 41 | pitch2id_path = DATA_DIR + "text/pitch" 42 | energy2id_path = DATA_DIR + "text/energy" 43 | speed2id_path = DATA_DIR + "text/speed" 44 | bert_path = 'WangZeJun/simbert-base-chinese' 45 | token_list_path = DATA_DIR + "text/tokenlist" 46 | style_encoder_ckpt = ROOT_DIR + "/outputs/style_encoder/ckpt/checkpoint_163431" 47 | tmp_dir = ROOT_DIR + "/tmp" 48 | model_config_path = ROOT_DIR + "/config/joint/config.yaml" 49 | 50 | #### Model #### 51 | bert_hidden_size = 768 52 | style_dim = 128 53 | downsample_ratio = 1 # Whole Model 54 | 55 | #### Text #### 56 | tokens, n_symbols = get_labels_length(token_list_path) 57 | sep = " " 58 | 59 | #### Speaker #### 60 | speakers, speaker_n_labels = get_labels_length(speaker2id_path) 61 | 62 | #### Emotion #### 63 | emotions, emotion_n_labels = get_labels_length(emotion2id_path) 64 | 65 | #### Speed #### 66 | speeds, speed_n_labels = get_labels_length(speed2id_path) 67 | 68 | #### Pitch #### 69 | pitchs, pitch_n_labels = get_labels_length(pitch2id_path) 70 | 71 | #### Energy #### 72 | energys, energy_n_labels = get_labels_length(energy2id_path) 73 | 74 | #### Train #### 75 | # epochs = 10 76 | lr = 1e-3 77 | lr_warmup_steps = 4000 78 | kl_warmup_steps = 60_000 79 | grad_clip_thresh = 1.0 80 | batch_size = 16 81 | train_steps = 10_000_000 82 | opt_level = "O1" 83 | seed = 1234 84 | iters_per_validation= 1000 85 | iters_per_checkpoint= 10000 86 | 87 | 88 | #### Audio #### 89 | sampling_rate = 16_000 90 | max_db = 1 91 | min_db = 0 92 | trim = True 93 | 94 | #### Stft #### 95 | filter_length = 1024 96 | hop_length = 256 97 | win_length = 1024 98 | window = "hann" 99 | 100 | #### Mel #### 101 | n_mel_channels = 80 102 | mel_fmin = 0 103 | mel_fmax = 8000 104 | 105 | #### Pitch #### 106 | pitch_min = 80 107 | pitch_max = 400 108 | pitch_stats = [225.089, 53.78] 109 | 110 | #### Energy #### 111 | energy_stats = [30.610, 21.78] 112 | 113 | 114 | #### Infernce #### 115 | gta = False 116 | -------------------------------------------------------------------------------- /config/joint/config.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | 5 | sr: 16000 # sr 6 | n_fft: 1024 # FFT size (samples). 7 | hop_length: 256 # Hop size (samples). 12.5ms 8 | win_length: 1024 # Window length (samples). 50ms 9 | # If set to null it will be the same as fft_size. 10 | window: "hann" # Window function. 11 | 12 | fmin: 0 # Minimum frequency of Mel basis. 13 | fmax: null # Maximum frequency of Mel basis. 14 | n_mels: 80 # The number of mel basis. 15 | 16 | pitch_min: 80 # Minimum f0 in linear domain for pitch extraction. 17 | pitch_max: 400 # Maximum f0 in linear domain for pitch extraction. 18 | 19 | segment_size: 32 20 | 21 | 22 | cut_sil: True 23 | 24 | shuffle: True 25 | 26 | pretrained_am: "" # absolute path 27 | pretrained_vocoder: "" # absolute path 28 | pretrained_discriminator: "" # absolute path 29 | 30 | max_db: 1 31 | min_db: 0 32 | 33 | ########################################################### 34 | # MODEL SETTING # 35 | ########################################################### 36 | model: 37 | speaker_embed_dim: 384 38 | bert_embedding: 768 39 | #### encoder #### 40 | lang_embed_dim: 0 41 | encoder_n_layers: 4 42 | encoder_n_heads: 8 43 | encoder_n_hidden: 384 44 | encoder_p_dropout: 0.2 45 | encoder_kernel_size_conv_mod: 3 46 | encoder_kernel_size_depthwise: 7 47 | #### decoder #### 48 | decoder_n_layers: 4 49 | decoder_n_heads: 8 50 | decoder_n_hidden: 384 51 | decoder_p_dropout: 0.2 52 | decoder_kernel_size_conv_mod: 3 53 | decoder_kernel_size_depthwise: 31 54 | #### prosodic #### 55 | bottleneck_size_p: 4 56 | bottleneck_size_u: 256 57 | ref_enc_filters: [32, 32, 64, 64, 128, 128] 58 | ref_enc_size: 3 59 | ref_enc_strides: [1, 2, 1, 2, 1] 60 | ref_enc_pad: [1, 1] 61 | ref_enc_gru_size: 32 62 | ref_attention_dropout: 0.2 63 | token_num: 32 64 | predictor_kernel_size: 5 65 | stop_prosodic_gradient: False 66 | ref_p_dropout: 0.1 67 | ref_n_heads: 4 68 | #### variance #### 69 | variance_n_hidden: 384 70 | variance_n_layers: 3 71 | variance_kernel_size: 3 72 | variance_p_dropout: 0.1 73 | variance_embed_kernel_size: 9 74 | variance_embde_p_dropout: 0.0 75 | stop_pitch_gradient: False 76 | stop_duration_gradient: False 77 | duration_p_dropout: 0.5 78 | duration_n_layers: 2 79 | duration_kernel_size: 3 80 | #### postnet #### 81 | postnet_layers: 0 82 | postnet_chans: 256 83 | postnet_filts: 5 84 | use_batch_norm: True 85 | postnet_dropout_rate: 0.5 86 | #### generator #### 87 | resblock: "1" 88 | upsample_rates: [8,8,2,2] 89 | upsample_kernel_sizes: [16,16,4,4] 90 | initial_channel: 80 91 | upsample_initial_channel: 512 92 | resblock_kernel_sizes: [3,7,11] 93 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 94 | r: 1 95 | ########################################################### 96 | # OPTIMIZER SETTING # 97 | ########################################################### 98 | optimizer: 99 | lr: 1.25e-5 100 | betas: [0.5, 0.9] 101 | eps: 1.0e-9 102 | weight_decay: 0.0 103 | scheduler: 104 | gamma: 0.999875 105 | -------------------------------------------------------------------------------- /config/template.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | # with thanks to arjun-234 in https://github.com/netease-youdao/EmotiVoice/pull/38. 18 | def get_labels_length(file_path): 19 | """ 20 | Return labels and their count in a file. 21 | 22 | Args: 23 | file_path (str): The path to the file containing the labels. 24 | 25 | Returns: 26 | list: labels; int: The number of labels in the file. 27 | """ 28 | with open(file_path, encoding = "UTF-8") as f: 29 | tokens = [t.strip() for t in f.readlines()] 30 | return tokens, len(tokens) 31 | 32 | class Config: 33 | #### PATH #### 34 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__")) 35 | DATA_DIR = ROOT_DIR + "/" 36 | # Change datalist.jsonl to datalist_mfa.jsonl if you have run MFA 37 | train_data_path = DATA_DIR + "/train/datalist.jsonl" 38 | valid_data_path = DATA_DIR + "/valid/datalist.jsonl" 39 | output_directory = ROOT_DIR + "/" 40 | speaker2id_path = ROOT_DIR + "//speaker" 41 | emotion2id_path = ROOT_DIR + "//emotion" 42 | pitch2id_path = ROOT_DIR + "//pitch" 43 | energy2id_path = ROOT_DIR + "//energy" 44 | speed2id_path = ROOT_DIR + "//speed" 45 | bert_path = 'WangZeJun/simbert-base-chinese' 46 | token_list_path = ROOT_DIR + "//tokenlist" 47 | style_encoder_ckpt = ROOT_DIR + "/outputs/style_encoder/ckpt/checkpoint_163431" 48 | tmp_dir = output_directory + "/tmp" 49 | model_config_path = ROOT_DIR + "/config/joint/config.yaml" 50 | 51 | #### Model #### 52 | bert_hidden_size = 768 53 | style_dim = 128 54 | downsample_ratio = 1 # Whole Model 55 | 56 | #### Text #### 57 | tokens, n_symbols = get_labels_length(token_list_path) 58 | sep = " " 59 | 60 | #### Speaker #### 61 | speakers, speaker_n_labels = get_labels_length(speaker2id_path) 62 | 63 | #### Emotion #### 64 | emotions, emotion_n_labels = get_labels_length(emotion2id_path) 65 | 66 | #### Speed #### 67 | speeds, speed_n_labels = get_labels_length(speed2id_path) 68 | 69 | #### Pitch #### 70 | pitchs, pitch_n_labels = get_labels_length(pitch2id_path) 71 | 72 | #### Energy #### 73 | energys, energy_n_labels = get_labels_length(energy2id_path) 74 | 75 | #### Train #### 76 | # epochs = 10 77 | lr = 1e-3 78 | lr_warmup_steps = 4000 79 | kl_warmup_steps = 60_000 80 | grad_clip_thresh = 1.0 81 | batch_size = 8 82 | train_steps = 10_000_000 83 | opt_level = "O1" 84 | seed = 1234 85 | iters_per_validation= 1000 86 | iters_per_checkpoint= 5000 87 | 88 | 89 | #### Audio #### 90 | sampling_rate = 16_000 91 | max_db = 1 92 | min_db = 0 93 | trim = True 94 | 95 | #### Stft #### 96 | filter_length = 1024 97 | hop_length = 256 98 | win_length = 1024 99 | window = "hann" 100 | 101 | #### Mel #### 102 | n_mel_channels = 80 103 | mel_fmin = 0 104 | mel_fmax = 8000 105 | 106 | #### Pitch #### 107 | pitch_min = 80 108 | pitch_max = 400 109 | pitch_stats = [225.089, 53.78] 110 | 111 | #### Energy #### 112 | energy_stats = [30.610, 21.78] 113 | 114 | 115 | #### Infernce #### 116 | gta = False 117 | -------------------------------------------------------------------------------- /data/DataBaker/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 😊 DataBaker Recipe 4 | 5 | This is the recipe of Chinese single female speaker TTS model with DataBaker corpus. 6 | 7 | ## Guide For Finetuning 8 | - [Environments Installation](#environments-installation) 9 | - [Step0 Download Data](#step0-download-data) 10 | - [Step1 Preprocess Data](#step1-preprocess-data) 11 | - [Step2 Run MFA (Optional)](#step2-run-mfa-optional-since-we-already-have-labeled-prosody) 12 | - [Step3 Prepare for training](#step3-prepare-for-training) 13 | - [Step4 Start training](#step4-finetune-your-model) 14 | - [Step5 Inference](#step5-inference) 15 | 16 | ### Environments Installation 17 | 18 | create conda enviroment 19 | ```bash 20 | conda create -n EmotiVoice python=3.8 -y 21 | conda activate EmotiVoice 22 | ``` 23 | then run: 24 | ```bash 25 | pip install EmotiVoice[train] 26 | # or 27 | git clone https://github.com/netease-youdao/EmotiVoice 28 | pip install -e .[train] 29 | ``` 30 | Additionally, it is important to prepare the pre-trained models as mentioned in the [pretrained models](https://github.com/netease-youdao/EmotiVoice/wiki/Pretrained-models). 31 | 32 | ### Step0 Download Data 33 | 34 | ```bash 35 | mkdir data/DataBaker/raw 36 | 37 | # download 38 | # please download the data from https://en.data-baker.com/datasets/freeDatasets/, and place the extracted BZNSYP folder under data/DataBaker/raw 39 | ``` 40 | 41 | ### Step1 Preprocess Data 42 | 43 | For this recipe, since DataBaker has already provided phoneme labels, we will simply utilize that information. 44 | 45 | ```bash 46 | # format data 47 | python data/DataBaker/src/step1_clean_raw_data.py \ 48 | --data_dir data/DataBaker 49 | 50 | # get phoneme 51 | python data/DataBaker/src/step2_get_phoneme.py \ 52 | --data_dir data/DataBaker 53 | ``` 54 | 55 | If you have prepared your own data with only text labels, you can obtain phonemes using the Text-to-Speech (TTS) frontend. For example, you can run the following command: `python data/DataBaker/src/step2_get_phoneme.py --data_dir data/DataBaker --generate_phoneme True`. However, please note that in this specific DataBaker's recipe, you should omit this command. 56 | 57 | 58 | 59 | ### Step2 Run MFA (Optional, since we already have labeled prosody) 60 | 61 | Please be aware that in this particular DataBaker's recipe, **you should skip this step**. Nonetheless, if you have already prepared your own data with only text labels, the following commands might assist you: 62 | 63 | ```bash 64 | # MFA environment install 65 | conda install -c conda-forge kaldi sox librosa biopython praatio tqdm requests colorama pyyaml pynini openfst baumwelch ngram postgresql -y 66 | pip install pgvector hdbscan montreal-forced-aligner 67 | 68 | # MFA Step1 69 | python mfa/step1_create_dataset.py \ 70 | --data_dir data/DataBaker 71 | 72 | # MFA Step2 73 | python mfa/step2_prepare_data.py \ 74 | --dataset_dir data/DataBaker/mfa \ 75 | --wav data/DataBaker/mfa/wav.txt \ 76 | --speaker data/DataBaker/mfa/speaker.txt \ 77 | --text data/DataBaker/mfa/text.txt 78 | 79 | # MFA Step3 80 | python mfa/step3_prepare_special_tokens.py \ 81 | --special_tokens data/DataBaker/mfa/special_token.txt 82 | 83 | # MFA Step4 84 | python mfa/step4_convert_text_to_phn.py \ 85 | --text data/DataBaker/mfa/text.txt \ 86 | --special_tokens data/DataBaker/mfa/special_token.txt \ 87 | --output data/DataBaker/mfa/text.txt 88 | 89 | # MFA Step5 90 | python mfa/step5_prepare_alignment.py \ 91 | --wav data/DataBaker/mfa/wav.txt \ 92 | --speaker data/DataBaker/mfa/speaker.txt \ 93 | --text data/DataBaker/mfa/text.txt \ 94 | --special_tokens data/DataBaker/mfa/special_token.txt \ 95 | --pronounciation_dict data/DataBaker/mfa/mfa_pronounciation_dict.txt \ 96 | --output_dir data/DataBaker/mfa/lab 97 | 98 | # MFA Step6 99 | mfa validate \ 100 | --overwrite \ 101 | --clean \ 102 | --single_speaker \ 103 | data/DataBaker/mfa/lab \ 104 | data/DataBaker/mfa/mfa_pronounciation_dict.txt 105 | 106 | mfa train \ 107 | --overwrite \ 108 | --clean \ 109 | --single_speaker \ 110 | data/DataBaker/mfa/lab \ 111 | data/DataBaker/mfa/mfa_pronounciation_dict.txt \ 112 | data/DataBaker/mfa/mfa/mfa_model.zip \ 113 | data/DataBaker/mfa/TextGrid 114 | 115 | mfa align \ 116 | --single_speaker \ 117 | data/DataBaker/mfa/lab \ 118 | data/DataBaker/mfa/mfa_pronounciation_dict.txt \ 119 | data/DataBaker/mfa/mfa/mfa_model.zip \ 120 | data/DataBaker/mfa/TextGrid 121 | 122 | # MFA Step7 123 | python mfa/step7_gen_alignment_from_textgrid.py \ 124 | --wav data/DataBaker/mfa/wav.txt \ 125 | --speaker data/DataBaker/mfa/speaker.txt \ 126 | --text data/DataBaker/mfa/text.txt \ 127 | --special_tokens data/DataBaker/mfa/special_token.txt \ 128 | --text_grid data/DataBaker/mfa/TextGrid \ 129 | --aligned_wav data/DataBaker/mfa/aligned_wav.txt \ 130 | --aligned_speaker data/DataBaker/mfa/aligned_speaker.txt \ 131 | --duration data/DataBaker/mfa/duration.txt \ 132 | --aligned_text data/DataBaker/mfa/aligned_text.txt \ 133 | --reassign_sp True 134 | 135 | # MFA Step8 136 | python mfa/step8_make_data_list.py \ 137 | --wav data/DataBaker/mfa/aligned_wav.txt \ 138 | --speaker data/DataBaker/mfa/aligned_speaker.txt \ 139 | --text data/DataBaker/mfa/aligned_text.txt \ 140 | --duration data/DataBaker/mfa/duration.txt \ 141 | --datalist_path data/DataBaker/mfa/datalist.jsonl 142 | 143 | # MFA Step9 144 | python mfa/step9_datalist_from_mfa.py \ 145 | --data_dir data/DataBaker 146 | ``` 147 | 148 | ### Step3 Prepare for training 149 | 150 | ```bash 151 | python prepare_for_training.py --data_dir data/DataBaker --exp_dir exp/DataBaker 152 | ``` 153 | __Please check and change the training and valid file path in the `exp/DataBaker/config/config.py`, especially:__ 154 | - `model_config_path`: corresponing model config file 155 | - `DATA_DIR`: data dir 156 | - `train_data_path` and `valid_data_path`: training file and valid file. Change to `datalist_mfa.jsonl` if you run Step2 157 | - `batch_size` 158 | 159 | ### Step4 Finetune Your Model 160 | 161 | ```bash 162 | torchrun \ 163 | --nproc_per_node=1 \ 164 | --master_port 8008 \ 165 | train_am_vocoder_joint.py \ 166 | --config_folder exp/DataBaker/config \ 167 | --load_pretrained_model True 168 | ``` 169 | 170 | Training tips: 171 | 172 | - You can run tensorboad by 173 | ``` 174 | tensorboard --logdir=exp/DataBaker 175 | ``` 176 | - The model checkpoints are saved at `exp/DataBaker/ckpt`. 177 | - The bert features are extracted in the first epoch and saved in `exp/DataBaker/tmp/` folder, you can change the path in `exp/DataBaker/config/config.py`. 178 | 179 | 180 | ### Step5 Inference 181 | 182 | 183 | ```bash 184 | TEXT=data/inference/text 185 | python inference_am_vocoder_exp.py \ 186 | --config_folder exp/DataBaker/config \ 187 | --checkpoint g_00010000 \ 188 | --test_file $TEXT 189 | ``` 190 | __Please change the speaker names in the `data/inference/text`__ 191 | 192 | the synthesized speech is under `exp/DataBaker/test_audio`. -------------------------------------------------------------------------------- /data/DataBaker/src/step0_download.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # please download the data from https://en.data-baker.com/datasets/freeDatasets/, and place the extracted BZNSYP folder under data/DataBaker/raw 4 | 5 | -------------------------------------------------------------------------------- /data/DataBaker/src/step1_clean_raw_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/wenet-e2e/wetts. 3 | """ 4 | 5 | import os 6 | import argparse 7 | import soundfile as sf 8 | import librosa 9 | import jsonlines 10 | from tqdm import tqdm 11 | import re 12 | 13 | def main(args): 14 | 15 | ROOT_DIR=os.path.abspath(args.data_dir) 16 | RAW_DIR=f"{ROOT_DIR}/raw" 17 | WAV_DIR=f"{ROOT_DIR}/wavs" 18 | TEXT_DIR=f"{ROOT_DIR}/text" 19 | 20 | os.makedirs(WAV_DIR, exist_ok=True) 21 | os.makedirs(TEXT_DIR, exist_ok=True) 22 | 23 | 24 | with open(f"{RAW_DIR}/BZNSYP/ProsodyLabeling/000001-010000.txt", encoding="utf-8") as f, \ 25 | jsonlines.open(f"{TEXT_DIR}/data.jsonl", "w") as fout1: 26 | 27 | lines = f.readlines() 28 | for i in tqdm(range(0, len(lines), 2)): 29 | key = lines[i][:6] 30 | 31 | ### Text 32 | content_org = lines[i][7:].strip() 33 | content = re.sub("[。,、“”?:……!( )—;]", "", content_org) 34 | content_org = re.sub("#\d", "", content_org) 35 | 36 | chars = [] 37 | prosody = {} 38 | j = 0 39 | while j < len(content): 40 | if content[j] == "#": 41 | prosody[len(chars) - 1] = content[j : j + 2] 42 | j += 2 43 | else: 44 | chars.append(content[j]) 45 | j += 1 46 | 47 | if key == "005107": 48 | lines[i + 1] = lines[i + 1].replace(" ng1", " en1") 49 | if key == "002365": 50 | continue 51 | 52 | syllable = lines[i + 1].strip().split() 53 | s_index = 0 54 | phones = [] 55 | phone = [] 56 | for k, char in enumerate(chars): 57 | # 儿化音处理 58 | er_flag = False 59 | if char == "儿" and (s_index == len(syllable) or syllable[s_index][0:2] != "er"): 60 | er_flag = True 61 | else: 62 | phones.append(syllable[s_index]) 63 | #phones.extend(lexicon[syllable[s_index]]) 64 | s_index += 1 65 | 66 | 67 | if k in prosody: 68 | if er_flag: 69 | phones[-1] = prosody[k] 70 | else: 71 | phones.append(prosody[k]) 72 | else: 73 | phones.append("#0") 74 | 75 | ### Wav 76 | path = f"{RAW_DIR}/BZNSYP/Wave/{key}.wav" 77 | wav_path = f"{WAV_DIR}/{key}.wav" 78 | y, sr = sf.read(path) 79 | y_16=librosa.resample(y, orig_sr=sr, target_sr=16_000) 80 | sf.write(wav_path, y_16, 16_000) 81 | 82 | fout1.write({ 83 | "key":key, 84 | "wav_path":wav_path, 85 | "speaker":"BZNSYP", 86 | "text":[""] + phones[:-1] + [""], 87 | "original_text":content_org, 88 | }) 89 | 90 | 91 | return 92 | 93 | 94 | if __name__ == "__main__": 95 | p = argparse.ArgumentParser() 96 | p.add_argument('--data_dir', type=str, required=True) 97 | args = p.parse_args() 98 | 99 | main(args) 100 | -------------------------------------------------------------------------------- /data/DataBaker/src/step2_get_phoneme.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import jsonlines 18 | import json 19 | from tqdm import tqdm 20 | from multiprocessing.pool import ThreadPool 21 | from functools import partial 22 | 23 | import re 24 | import sys 25 | DIR=os.path.dirname(os.path.abspath("__file__")) 26 | sys.path.append(DIR) 27 | 28 | from frontend_cn import split_py, tn_chinese 29 | from frontend_en import read_lexicon, G2p 30 | from frontend import contains_chinese, re_digits, g2p_cn 31 | 32 | # re_english_word = re.compile('([a-z\-\.\']+|\d+[\d\.]*)', re.I) 33 | re_english_word = re.compile('([^\u4e00-\u9fa5]+|[ \u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09\u4e00-\u9fa5]+)', re.I) 34 | 35 | def g2p_cn_en(text, g2p, lexicon): 36 | # Our policy dictates that if the text contains Chinese, digits are to be converted into Chinese. 37 | text=tn_chinese(text) 38 | parts = re_english_word.split(text) 39 | parts=list(filter(None, parts)) 40 | tts_text = [""] 41 | chartype = '' 42 | text_contains_chinese = contains_chinese(text) 43 | for part in parts: 44 | if part == ' ' or part == '': continue 45 | if re_digits.match(part) and (text_contains_chinese or chartype == '') or contains_chinese(part): 46 | if chartype == 'en': 47 | tts_text.append('eng_cn_sp') 48 | phoneme = g2p_cn(part).split()[1:-1] 49 | chartype = 'cn' 50 | elif re_english_word.match(part): 51 | if chartype == 'cn': 52 | if "sp" in tts_text[-1]: 53 | "" 54 | else: 55 | tts_text.append('cn_eng_sp') 56 | phoneme = get_eng_phoneme(part, g2p, lexicon).split() 57 | if not phoneme : 58 | # tts_text.pop() 59 | continue 60 | else: 61 | chartype = 'en' 62 | else: 63 | continue 64 | tts_text.extend( phoneme ) 65 | 66 | tts_text=" ".join(tts_text).split() 67 | if "sp" in tts_text[-1]: 68 | tts_text.pop() 69 | tts_text.append("") 70 | 71 | return " ".join(tts_text) 72 | 73 | def get_eng_phoneme(text, g2p, lexicon): 74 | """ 75 | english g2p 76 | """ 77 | filters = {",", " ", "'"} 78 | phones = [] 79 | words = list(filter(lambda x: x not in {"", " "}, re.split(r"([,;.\-\?\!\s+])", text))) 80 | 81 | for w in words: 82 | if w.lower() in lexicon: 83 | 84 | for ph in lexicon[w.lower()]: 85 | if ph not in filters: 86 | phones += ["[" + ph + "]"] 87 | 88 | if "sp" not in phones[-1]: 89 | phones += ["engsp1"] 90 | else: 91 | phone=g2p(w) 92 | if not phone: 93 | continue 94 | 95 | if phone[0].isalnum(): 96 | 97 | for ph in phone: 98 | if ph not in filters: 99 | phones += ["[" + ph + "]"] 100 | if ph == " " and "sp" not in phones[-1]: 101 | phones += ["engsp1"] 102 | elif phone == " ": 103 | continue 104 | elif phones: 105 | phones.pop() # pop engsp1 106 | phones.append("engsp4") 107 | if phones and "engsp" in phones[-1]: 108 | phones.pop() 109 | 110 | 111 | return " ".join(phones) 112 | 113 | 114 | def onetime(resource, sample): 115 | 116 | text=sample["text"] 117 | # del sample["original_text"] 118 | 119 | phoneme = get_phoneme(text, resource["g2p"]).split() 120 | 121 | sample["text"]=phoneme 122 | # sample["original_text"]=text 123 | sample["prompt"]=sample["original_text"] 124 | 125 | return sample 126 | 127 | def onetime2(resource, sample): 128 | 129 | text=sample["original_text"] 130 | del sample["original_text"] 131 | try: 132 | phoneme = g2p_cn_en(text, resource["g2p_en"], resource["lexicon"]).split()#g2p_cn_eng_mix(text, resource["g2p_en"], resource["lexicon"]).split() 133 | except: 134 | print("Warning!!! phoneme get error! " + \ 135 | "Please check text") 136 | print("Text is: ", text) 137 | return "" 138 | 139 | if not phoneme: 140 | return "" 141 | 142 | sample["text"]=phoneme 143 | sample["original_text"]=text 144 | sample["prompt"]=sample["original_text"] 145 | 146 | return sample 147 | 148 | def get_phoneme(text, g2p): 149 | special_tokens = {"#0":"sp0", "#1":"sp1", "#2":"sp2", "#3":"sp3", "#4":"sp4", "":""} 150 | phones = [] 151 | 152 | for ph in text: 153 | if ph not in special_tokens: 154 | phs = g2p(ph) 155 | phones.extend([ph for ph in phs if ph]) 156 | else: 157 | phones.append(special_tokens[ph]) 158 | 159 | return " ".join(phones) 160 | 161 | 162 | 163 | def main(args): 164 | 165 | ROOT_DIR=args.data_dir 166 | TRAIN_DIR=f"{ROOT_DIR}/train" 167 | VALID_DIR=f"{ROOT_DIR}/valid" 168 | TEXT_DIR=f"{ROOT_DIR}/text" 169 | 170 | os.makedirs(TRAIN_DIR, exist_ok=True) 171 | os.makedirs(VALID_DIR, exist_ok=True) 172 | 173 | lexicon = read_lexicon(f"{DIR}/lexicon/librispeech-lexicon.txt") 174 | 175 | g2p = G2p() 176 | 177 | resource={ 178 | "g2p":split_py, 179 | "g2p_en":g2p, 180 | "lexicon":lexicon, 181 | } 182 | 183 | with jsonlines.open(f"{TEXT_DIR}/data.jsonl") as f: 184 | data = list(f) 185 | 186 | new_data=[] 187 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl", "w") as f: 188 | for sample in tqdm(data): 189 | if not args.generate_phoneme: 190 | sample = onetime(resource, sample) 191 | else: 192 | sample = onetime2(resource, sample) 193 | if not sample: 194 | continue 195 | f.write(sample) 196 | new_data.append(sample) 197 | 198 | with jsonlines.open(f"{TRAIN_DIR}/datalist.jsonl", "w") as f: 199 | for sample in tqdm(new_data[:-3]): 200 | f.write(sample) 201 | 202 | with jsonlines.open(f"{VALID_DIR}/datalist.jsonl", "w") as f: 203 | for sample in tqdm(data[-3:]): 204 | f.write(sample) 205 | 206 | 207 | return 208 | 209 | if __name__ == "__main__": 210 | 211 | p = argparse.ArgumentParser() 212 | p.add_argument('--data_dir', type=str, required=True) 213 | p.add_argument('--generate_phoneme', type=bool, default=False) 214 | args = p.parse_args() 215 | 216 | main(args) -------------------------------------------------------------------------------- /data/LJspeech/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 😊 LJSpeech Recipe 4 | 5 | This is the recipe of English single female speaker TTS model with LJSpeech corpus. 6 | 7 | ## Guide For Finetuning 8 | - [Environments Installation](#environments-installation) 9 | - [Step0 Download Data](#step0-download-data) 10 | - [Step1 Preprocess Data](#step1-preprocess-data) 11 | - [Step2 Run MFA (Optional, but Recommended)](#step2-run-mfa-optional-but-recommended) 12 | - [Step3 Prepare for training](#step3-prepare-for-training) 13 | - [Step4 Start training](#step4-finetune-your-model) 14 | - [Step5 Inference](#step5-inference) 15 | 16 | Run EmotiVoice Finetuning on Google Colab Notebook! [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1dDAyjoYGcDGwYpHI3Oj2_OIV-7DIdx2L?usp=sharing) 17 | 18 | ### Environments Installation 19 | 20 | create conda enviroment 21 | ```bash 22 | conda create -n EmotiVoice python=3.8 -y 23 | conda activate EmotiVoice 24 | ``` 25 | then run: 26 | ```bash 27 | pip install EmotiVoice[train] 28 | # or 29 | git clone https://github.com/netease-youdao/EmotiVoice 30 | pip install -e .[train] 31 | ``` 32 | Additionally, it is important to prepare the pre-trained models as mentioned in the [pretrained models](https://github.com/netease-youdao/EmotiVoice/wiki/Pretrained-models). 33 | 34 | ### Step0 Download Data 35 | 36 | ```bash 37 | mkdir data/LJspeech/raw 38 | 39 | # download 40 | wget -P data/LJspeech/raw http://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 41 | # extract 42 | tar -xjf data/LJspeech/raw/LJSpeech-1.1.tar.bz2 -C data/LJspeech/raw 43 | ``` 44 | 45 | ### Step1 Preprocess Data 46 | 47 | ```bash 48 | # format data 49 | python data/LJspeech/src/step1_clean_raw_data.py \ 50 | --data_dir data/LJspeech 51 | 52 | # get phoneme 53 | python data/LJspeech/src/step2_get_phoneme.py \ 54 | --data_dir data/LJspeech 55 | ``` 56 | 57 | ### Step2 Run MFA (Optional, but Recommended!) 58 | 59 | ```bash 60 | # MFA environment install 61 | conda install -c conda-forge kaldi sox librosa biopython praatio tqdm requests colorama pyyaml pynini openfst baumwelch ngram postgresql -y 62 | pip install pgvector hdbscan montreal-forced-aligner 63 | 64 | # MFA Step1 65 | python mfa/step1_create_dataset.py \ 66 | --data_dir data/LJspeech 67 | 68 | # MFA Step2 69 | python mfa/step2_prepare_data.py \ 70 | --dataset_dir data/LJspeech/mfa \ 71 | --wav data/LJspeech/mfa/wav.txt \ 72 | --speaker data/LJspeech/mfa/speaker.txt \ 73 | --text data/LJspeech/mfa/text.txt 74 | 75 | # MFA Step3 76 | python mfa/step3_prepare_special_tokens.py \ 77 | --special_tokens data/LJspeech/mfa/special_token.txt 78 | 79 | # MFA Step4 80 | python mfa/step4_convert_text_to_phn.py \ 81 | --text data/LJspeech/mfa/text.txt \ 82 | --special_tokens data/LJspeech/mfa/special_token.txt \ 83 | --output data/LJspeech/mfa/text.txt 84 | 85 | # MFA Step5 86 | python mfa/step5_prepare_alignment.py \ 87 | --wav data/LJspeech/mfa/wav.txt \ 88 | --speaker data/LJspeech/mfa/speaker.txt \ 89 | --text data/LJspeech/mfa/text.txt \ 90 | --special_tokens data/LJspeech/mfa/special_token.txt \ 91 | --pronounciation_dict data/LJspeech/mfa/mfa_pronounciation_dict.txt \ 92 | --output_dir data/LJspeech/mfa/lab 93 | 94 | # MFA Step6 95 | mfa validate \ 96 | --overwrite \ 97 | --clean \ 98 | --single_speaker \ 99 | data/LJspeech/mfa/lab \ 100 | data/LJspeech/mfa/mfa_pronounciation_dict.txt 101 | 102 | mfa train \ 103 | --overwrite \ 104 | --clean \ 105 | --single_speaker \ 106 | data/LJspeech/mfa/lab \ 107 | data/LJspeech/mfa/mfa_pronounciation_dict.txt \ 108 | data/LJspeech/mfa/mfa/mfa_model.zip \ 109 | data/LJspeech/mfa/TextGrid 110 | 111 | mfa align \ 112 | --single_speaker \ 113 | data/LJspeech/mfa/lab \ 114 | data/LJspeech/mfa/mfa_pronounciation_dict.txt \ 115 | data/LJspeech/mfa/mfa/mfa_model.zip \ 116 | data/LJspeech/mfa/TextGrid 117 | 118 | # MFA Step7 119 | python mfa/step7_gen_alignment_from_textgrid.py \ 120 | --wav data/LJspeech/mfa/wav.txt \ 121 | --speaker data/LJspeech/mfa/speaker.txt \ 122 | --text data/LJspeech/mfa/text.txt \ 123 | --special_tokens data/LJspeech/mfa/special_token.txt \ 124 | --text_grid data/LJspeech/mfa/TextGrid \ 125 | --aligned_wav data/LJspeech/mfa/aligned_wav.txt \ 126 | --aligned_speaker data/LJspeech/mfa/aligned_speaker.txt \ 127 | --duration data/LJspeech/mfa/duration.txt \ 128 | --aligned_text data/LJspeech/mfa/aligned_text.txt \ 129 | --reassign_sp True 130 | 131 | # MFA Step8 132 | python mfa/step8_make_data_list.py \ 133 | --wav data/LJspeech/mfa/aligned_wav.txt \ 134 | --speaker data/LJspeech/mfa/aligned_speaker.txt \ 135 | --text data/LJspeech/mfa/aligned_text.txt \ 136 | --duration data/LJspeech/mfa/duration.txt \ 137 | --datalist_path data/LJspeech/mfa/datalist.jsonl 138 | 139 | # MFA Step9 140 | python mfa/step9_datalist_from_mfa.py \ 141 | --data_dir data/LJspeech 142 | ``` 143 | 144 | ### Step3 Prepare for training 145 | 146 | ```bash 147 | python prepare_for_training.py --data_dir data/LJspeech --exp_dir exp/LJspeech 148 | ``` 149 | __Please check and change the training and valid file path in the `exp/LJspeech/config/config.py`, especially:__ 150 | - `model_config_path`: corresponing model config file 151 | - `DATA_DIR`: data dir 152 | - `train_data_path` and `valid_data_path`: training file and valid file. Change to `datalist_mfa.jsonl` if you run Step2 153 | - `batch_size` 154 | 155 | ### Step4 Finetune Your Model 156 | 157 | ```bash 158 | torchrun \ 159 | --nproc_per_node=1 \ 160 | --master_port 8008 \ 161 | train_am_vocoder_joint.py \ 162 | --config_folder exp/LJspeech/config \ 163 | --load_pretrained_model True 164 | ``` 165 | 166 | Training tips: 167 | 168 | - You can run tensorboad by 169 | ``` 170 | tensorboard --logdir=exp/LJspeech 171 | ``` 172 | - The model checkpoints are saved at `exp/LJspeech/ckpt`. 173 | - The bert features are extracted in the first epoch and saved in `exp/LJspeech/tmp/` folder, you can change the path in `exp/LJspeech/config/config.py`. 174 | 175 | 176 | ### Step5 Inference 177 | 178 | 179 | ```bash 180 | TEXT=data/inference/text 181 | python inference_am_vocoder_exp.py \ 182 | --config_folder exp/LJspeech/config \ 183 | --checkpoint g_00010000 \ 184 | --test_file $TEXT 185 | ``` 186 | __Please change the speaker name in the `data/inference/text`__ 187 | 188 | the synthesized speech is under `exp/LJspeech/test_audio`. -------------------------------------------------------------------------------- /data/LJspeech/src/step0_download.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | wget http://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 5 | 6 | tar -xjf LJSpeech-1.1.tar.bz2 7 | -------------------------------------------------------------------------------- /data/LJspeech/src/step1_clean_raw_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | import soundfile as sf 18 | import librosa 19 | import jsonlines 20 | from tqdm import tqdm 21 | 22 | def main(args): 23 | 24 | ROOT_DIR=os.path.abspath(args.data_dir) 25 | RAW_DIR=f"{ROOT_DIR}/raw" 26 | WAV_DIR=f"{ROOT_DIR}/wavs" 27 | TEXT_DIR=f"{ROOT_DIR}/text" 28 | 29 | os.makedirs(WAV_DIR, exist_ok=True) 30 | os.makedirs(TEXT_DIR, exist_ok=True) 31 | 32 | with open(f"{RAW_DIR}/LJSpeech-1.1/metadata.csv") as f, \ 33 | jsonlines.open(f"{TEXT_DIR}/data.jsonl", "w") as fout1: 34 | # open(f"{TEXT_DIR}/text_raw", "w") as fout2: 35 | for line in tqdm(f): 36 | #### Text #### 37 | line = line.strip().split("|") 38 | name = line[0] 39 | text=line[1] 40 | 41 | #### Wav ##### 42 | path = f"{RAW_DIR}/LJSpeech-1.1/wavs/{name}.wav" 43 | wav_path = f"{WAV_DIR}/{name}.wav" 44 | y, sr = sf.read(path) 45 | y_16=librosa.resample(y, orig_sr=sr, target_sr=16_000) 46 | sf.write(wav_path, y_16, 16_000) 47 | 48 | #### Write #### 49 | fout1.write({ 50 | "key":name, 51 | "wav_path":wav_path, 52 | "speaker":"LJ", 53 | "original_text":text 54 | }) 55 | # fout2.write(text+"\n") 56 | 57 | 58 | 59 | 60 | return 61 | 62 | 63 | if __name__ == "__main__": 64 | p = argparse.ArgumentParser() 65 | p.add_argument('--data_dir', type=str, required=True) 66 | args = p.parse_args() 67 | 68 | main(args) 69 | -------------------------------------------------------------------------------- /data/LJspeech/src/step2_get_phoneme.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import jsonlines 18 | import json 19 | from tqdm import tqdm 20 | from multiprocessing.pool import ThreadPool 21 | from functools import partial 22 | import re 23 | import sys 24 | DIR=os.path.dirname(os.path.abspath("__file__")) 25 | sys.path.append(DIR) 26 | 27 | from frontend_en import read_lexicon, G2p 28 | 29 | 30 | def onetime(resource, sample): 31 | 32 | text=sample["original_text"] 33 | del sample["original_text"] 34 | 35 | phoneme = get_phoneme(text, resource["g2p"], resource["lexicon"]).split() 36 | 37 | sample["text"]=phoneme 38 | sample["original_text"]=text 39 | sample["prompt"]=text 40 | 41 | return sample 42 | 43 | def get_phoneme(text, g2p, lexicon): 44 | filters = {",", " ", "'"} 45 | phones = [] 46 | words = list(filter(lambda x: x not in {"", " "}, re.split(r"([,;.\-\?\!\s+])", text))) 47 | 48 | for w in words: 49 | if w.lower() in lexicon: 50 | 51 | for ph in lexicon[w.lower()]: 52 | if ph not in filters: 53 | phones += ["[" + ph + "]"] 54 | 55 | if "sp" not in phones[-1]: 56 | phones += ["engsp1"] 57 | else: 58 | phone=g2p(w) 59 | if not phone: 60 | continue 61 | 62 | if phone[0].isalnum(): 63 | 64 | for ph in phone: 65 | if ph not in filters: 66 | phones += ["[" + ph + "]"] 67 | if ph == " " and "sp" not in phones[-1]: 68 | phones += ["engsp1"] 69 | elif phone == " ": 70 | continue 71 | elif phones: 72 | phones.pop() # pop engsp1 73 | phones.append("engsp4") 74 | if phones and "engsp" in phones[-1]: 75 | phones.pop() 76 | 77 | mark = "." if text[-1] != "?" else "?" 78 | phones = [""] + phones + [mark, ""] 79 | return " ".join(phones) 80 | 81 | 82 | 83 | def main(args): 84 | 85 | ROOT_DIR=args.data_dir 86 | TRAIN_DIR=f"{ROOT_DIR}/train" 87 | VALID_DIR=f"{ROOT_DIR}/valid" 88 | TEXT_DIR=f"{ROOT_DIR}/text" 89 | 90 | os.makedirs(TRAIN_DIR, exist_ok=True) 91 | os.makedirs(VALID_DIR, exist_ok=True) 92 | 93 | lexicon = read_lexicon(f"{DIR}/lexicon/librispeech-lexicon.txt") 94 | 95 | g2p = G2p() 96 | 97 | resource={ 98 | "g2p":g2p, 99 | "lexicon":lexicon, 100 | } 101 | 102 | with jsonlines.open(f"{TEXT_DIR}/data.jsonl") as f: 103 | data = list(f) 104 | 105 | new_data=[] 106 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl", "w") as f: 107 | for sample in tqdm(data): 108 | sample = onetime(resource, sample) 109 | f.write(sample) 110 | new_data.append(sample) 111 | 112 | with jsonlines.open(f"{TRAIN_DIR}/datalist.jsonl", "w") as f: 113 | for sample in tqdm(new_data[:-3]): 114 | f.write(sample) 115 | 116 | with jsonlines.open(f"{VALID_DIR}/datalist.jsonl", "w") as f: 117 | for sample in tqdm(data[-3:]): 118 | f.write(sample) 119 | 120 | 121 | return 122 | 123 | if __name__ == "__main__": 124 | 125 | p = argparse.ArgumentParser() 126 | p.add_argument('--data_dir', type=str, required=True) 127 | args = p.parse_args() 128 | 129 | main(args) -------------------------------------------------------------------------------- /data/inference/text: -------------------------------------------------------------------------------- 1 | 8051|Happy| [IH0] [M] [AA1] [T] engsp4 [V] [OY1] [S] engsp4 [AH0] engsp1 [M] [AH1] [L] [T] [IY0] engsp4 [V] [OY1] [S] engsp1 [AE1] [N] [D] engsp1 [P] [R] [AA1] [M] [P] [T] engsp4 [K] [AH0] [N] [T] [R] [OW1] [L] [D] engsp1 [T] [IY1] engsp4 [T] [IY1] engsp4 [EH1] [S] engsp1 [EH1] [N] [JH] [AH0] [N] . |Emoti-Voice - a Multi-Voice and Prompt-Controlled T-T-S Engine 2 | 8051|哭唧唧| uo3 sp1 l ai2 sp0 d ao4 sp1 b ei3 sp0 j ing1 sp3 q ing1 sp0 h ua2 sp0 d a4 sp0 x ve2 |我来到北京,清华大学 3 | 11614|第一章| d i4 sp0 i1 sp0 zh ang1 |第一章 4 | 9017|在昏暗狭小的房间内,我父亲躺在窗前的地板上,全身素白,显得身子特别长。| z ai4 sp1 h uen1 sp0 an4 sp1 x ia2 sp0 x iao3 sp0 d e5 sp1 f ang2 sp0 j ian1 sp0 n ei4 sp3 uo3 sp1 f u4 sp0 q in1 sp1 t ang3 sp0 z ai4 sp1 ch uang1 sp0 q ian2 sp0 d e5 sp1 d i4 sp0 b an3 sp0 sh ang4 sp3 q van2 sp0 sh en1 sp1 s u4 sp0 b ai2 sp3 x ian3 sp0 d e5 sp1 sh en1 sp0 z ii5 sp1 t e4 sp0 b ie2 sp0 ch ang2 |在昏暗狭小的房间内,我父亲躺在窗前的地板上,全身素白,显得身子特别长。 5 | 6097|他光着双脚,脚趾头怪模怪样地向外翻着,一双亲切的手平静地放在胸前,手指头也是弯曲的。| t a1 sp1 g uang1 sp0 zh e5 sp1 sh uang1 sp0 j iao3 sp3 j iao2 sp0 zh iii3 sp0 t ou5 sp1 g uai4 sp0 m u2 sp1 g uai4 sp0 iang4 sp0 d e5 sp1 x iang4 sp0 uai4 sp1 f an1 sp0 zh e5 sp3 i4 sp0 sh uang1 sp1 q in1 sp0 q ie4 sp0 d e5 sp0 sh ou3 sp2 p ing2 sp0 j ing4 sp0 d e5 sp1 f ang4 sp0 z ai4 sp1 x iong1 sp0 q ian2 sp3 sh ou2 sp0 zh iii3 sp0 t ou5 sp1 ie3 sp0 sh iii4 sp1 uan1 sp0 q v1 sp0 d e5 |他光着双脚,脚趾头怪模怪样地向外翻着,一双亲切的手平静地放在胸前,手指头也是弯曲的。 6 | 6671|他双目紧闭,可以看见铜钱在上面留下的黑色圆圈;和善的面孔乌青发黑,龇牙咧嘴,挺吓人的。| t a1 sp1 sh uang1 sp0 m u4 sp1 j in3 sp0 b i4 sp3 k e2 sp0 i3 sp1 k an4 sp0 j ian4 sp1 t ong2 sp0 q ian2 sp1 z ai4 sp0 sh ang4 sp0 m ian4 sp1 l iou2 sp0 x ia4 sp0 d e5 sp1 h ei1 sp0 s e4 sp1 van2 sp0 q van1 sp3 h e2 sp0 sh an4 sp0 d e5 sp1 m ian4 sp0 k ong3 sp2 u1 sp0 q ing1 sp1 f a1 sp0 h ei1 sp3 z ii1 sp0 ia2 sp0 l ie2 sp0 z uei3 sp3 t ing3 sp1 x ia4 sp0 r en2 sp0 d e5 |他双目紧闭,可以看见铜钱在上面留下的黑色圆圈;和善的面孔乌青发黑,龇牙咧嘴,挺吓人的。 7 | 6670|母亲半光着上身,穿一条红裙子,跪在地上,正在用那把我常用来锯西瓜皮的小黑梳子,将父亲那又长又软的头发从前额向脑后梳去。| m u3 sp0 q in1 sp2 b an4 sp0 g uang1 sp0 zh e5 sp1 sh ang4 sp0 sh en1 sp3 ch uan1 sp0 i4 sp0 t iao2 sp1 h ong2 sp0 q vn2 sp0 z ii5 sp3 g uei4 sp0 z ai4 sp1 d i4 sp0 sh ang5 sp3 zh eng4 sp0 z ai4 sp1 iong4 sp0 n a4 sp1 b a2 sp0 uo3 sp1 ch ang2 sp0 iong4 sp0 l ai2 sp1 j v4 sp1 x i1 sp0 g ua1 sp0 p i2 sp0 d e5 sp1 x iao3 sp0 h ei1 sp1 sh u1 sp0 z ii5 sp3 j iang1 sp1 f u4 sp0 q in1 sp1 n a4 sp1 iou4 sp0 ch ang2 sp1 iou4 sp0 r uan3 sp0 d e5 sp1 t ou2 sp0 f a4 sp3 c ong2 sp1 q ian2 sp0 e2 sp1 x iang4 sp1 n ao3 sp0 h ou4 sp1 sh u1 sp0 q v4 |母亲半光着上身,穿一条红裙子,跪在地上,正在用那把我常用来锯西瓜皮的小黑梳子,将父亲那又长又软的头发从前额向脑后梳去。 8 | 9136|母亲一直在诉说着什么,声音嘶哑而低沉,她那双浅灰色的眼睛已经浮肿,仿佛融化了似的,眼泪大滴大滴地直往下落。| m u3 sp0 q in1 sp1 i4 sp0 zh iii2 sp1 z ai4 sp1 s u4 sp0 sh uo1 sp0 zh e5 sp1 sh en2 sp0 m e5 sp3 sh eng1 sp0 in1 sp1 s ii1 sp0 ia3 sp1 er2 sp1 d i1 sp0 ch en2 sp3 t a1 sp1 n a4 sp0 sh uang1 sp1 q ian3 sp0 h uei1 sp0 s e4 sp0 d e5 sp1 ian3 sp0 j ing5 sp2 i3 sp0 j ing1 sp1 f u2 sp0 zh ong3 sp3 f ang3 sp0 f u2 sp1 r ong2 sp0 h ua4 sp0 l e5 sp1 sh iii4 sp0 d e5 sp3 ian3 sp0 l ei4 sp1 d a4 sp0 d i1 sp1 d a4 sp0 d i1 sp0 d e5 sp1 zh iii2 sp0 uang3 sp0 x ia4 sp0 l uo4 |母亲一直在诉说着什么,声音嘶哑而低沉,她那双浅灰色的眼睛已经浮肿,仿佛融化了似的,眼泪大滴大滴地直往下落。 9 | 11697|外婆拽着我的手;她长得圆滚滚的,大脑袋、大眼睛和一只滑稽可笑的松弛的鼻子。| uai4 sp0 p o2 sp1 zh uai4 sp0 zh e5 sp1 uo3 sp0 d e5 sp0 sh ou3 sp3 t a1 sp1 zh ang3 sp0 d e5 sp1 van2 sp0 g uen2 sp0 g uen3 sp0 d e5 sp3 d a4 sp0 n ao3 sp0 d ai5 sp3 d a4 sp0 ian3 sp0 j ing5 sp3 h e2 sp1 i4 sp0 zh iii1 sp1 h ua2 sp0 j i1 sp1 k e3 sp0 x iao4 sp0 d e5 sp1 s ong1 sp0 ch iii2 sp0 d e5 sp1 b i2 sp0 z ii5 |外婆拽着我的手;她长得圆滚滚的,大脑袋、大眼睛和一只滑稽可笑的松弛的鼻子。 10 | 92|她穿一身黑衣服,身上软乎乎的,特别好玩。她也在哭,但哭得有些特别,和母亲的哭声交相呼应。她全身都在颤抖,而且老是把我往父亲跟前推。我扭动身子,直往她身后躲;我感到害怕,浑身不自在。| t a1 sp0 ch uan1 sp1 i4 sp0 sh en1 sp1 h ei1 sp0 i1 sp0 f u2 sp3 sh en1 sp0 sh ang4 sp1 r uan3 sp0 h u1 sp0 h u1 sp0 d e5 sp3 t e4 sp0 b ie2 sp1 h ao3 sp0 uan2 sp3 t a1 sp0 ie3 sp1 z ai4 sp0 k u1 sp3 d an4 sp1 k u1 sp0 d e5 sp1 iou3 sp0 x ie1 sp1 t e4 sp0 b ie2 sp3 h e2 sp1 m u3 sp0 q in1 sp0 d e5 sp1 k u1 sp0 sh eng1 sp1 j iao1 sp0 x iang1 sp1 h u1 sp0 ing4 sp3 t a1 sp1 q van2 sp0 sh en1 sp1 d ou1 sp0 z ai4 sp1 ch an4 sp0 d ou3 sp3 er2 sp0 q ie3 sp1 l ao3 sp0 sh iii4 sp1 b a2 sp0 uo3 sp1 uang3 sp1 f u4 sp0 q in1 sp1 g en1 sp0 q ian5 sp0 t uei1 sp3 uo3 sp1 n iou3 sp0 d ong4 sp1 sh en1 sp0 z ii5 sp3 zh iii2 sp0 uang3 sp0 t a1 sp1 sh en1 sp0 h ou4 sp0 d uo3 sp3 uo3 sp1 g an3 sp0 d ao4 sp1 h ai4 sp0 p a4 sp3 h uen2 sp0 sh en1 sp1 b u2 sp0 z ii4 sp0 z ai5 |她穿一身黑衣服,身上软乎乎的,特别好玩。她也在哭,但哭得有些特别,和母亲的哭声交相呼应。她全身都在颤抖,而且老是把我往父亲跟前推。我扭动身子,直往她身后躲;我感到害怕,浑身不自在。 11 | 12787|我还从没有见过大人们哭,而且不明白外婆老说的那些话的意思:“跟你爹告个别吧,以后你再也看不到他啦,他死了,乖孩子,还不到年纪,不是时候啊……”| uo3 sp0 h ai2 sp1 c ong2 sp0 m ei2 sp0 iou3 sp1 j ian4 sp0 g uo4 sp1 d a4 sp0 r en2 sp0 m en5 sp1 k u1 sp3 er2 sp0 q ie3 sp1 b u4 sp0 m ing2 sp0 b ai2 sp1 uai4 sp0 p o2 sp1 l ao3 sp0 sh uo1 sp0 d e5 sp1 n a4 sp0 x ie1 sp0 h ua4 sp0 d e5 sp1 i4 sp0 s ii5 sp3 g en1 sp0 n i3 sp0 d ie1 sp1 g ao4 sp0 g e4 sp0 b ie2 sp0 b a5 sp3 i3 sp0 h ou4 sp3 n i3 sp1 z ai4 sp0 ie3 sp1 k an4 sp0 b u2 sp0 d ao4 sp1 t a1 sp0 l a5 sp3 t a1 sp0 s ii3 sp0 l e5 sp3 g uai1 sp0 h ai2 sp0 z ii5 sp3 h ai2 sp0 b u2 sp0 d ao4 sp1 n ian2 sp0 j i4 sp3 b u2 sp0 sh iii4 sp1 sh iii2 sp0 h ou5 sp0 a5 |我还从没有见过大人们哭,而且不明白外婆老说的那些话的意思:“跟你爹告个别吧,以后你再也看不到他啦,他死了,乖孩子,还不到年纪,不是时候啊……” 12 | 1006|我得过一场大病,这时刚刚能下地。生病期间一这一点我记得很清楚——父亲照看我时显得很高兴,后来他突然就不见了,换成了外婆这个怪里怪气的人。| uo3 sp1 d e2 sp0 g uo4 sp1 i4 sp0 ch ang3 sp1 d a4 sp0 b ing4 sp3 zh e4 sp0 sh iii2 sp1 g ang1 sp0 g ang1 sp1 n eng2 sp0 x ia4 sp0 d i4 sp3 sh eng1 sp0 b ing4 sp1 q i1 sp0 j ian1 sp3 i2 sp0 zh e4 sp0 i4 sp0 d ian3 sp1 uo3 sp1 j i4 sp0 d e5 sp1 h en3 sp0 q ing1 sp0 ch u5 sp3 f u4 sp0 q in1 sp1 zh ao4 sp0 k an4 sp1 uo3 sp0 sh iii2 sp2 x ian3 sp0 d e5 sp1 h en3 sp0 g ao1 sp0 x ing4 sp3 h ou4 sp0 l ai2 sp3 t a1 sp1 t u1 sp0 r an2 sp1 j iou4 sp1 b u2 sp0 j ian4 sp0 l e5 sp3 h uan4 sp0 ch eng2 sp0 l e5 sp1 uai4 sp0 p o2 sp1 zh e4 sp0 g e4 sp1 g uai4 sp0 l i3 sp1 g uai4 sp0 q i4 sp0 d e5 sp0 r en2 |我得过一场大病,这时刚刚能下地。生病期间一这一点我记得很清楚——父亲照看我时显得很高兴,后来他突然就不见了,换成了外婆这个怪里怪气的人。 13 | -------------------------------------------------------------------------------- /data/youdao/text/emotion: -------------------------------------------------------------------------------- 1 | 普通 2 | 生气 3 | 开心 4 | 惊讶 5 | 悲伤 6 | 厌恶 7 | 恐惧 -------------------------------------------------------------------------------- /data/youdao/text/energy: -------------------------------------------------------------------------------- 1 | 音量普通 2 | 音量很高 3 | 音量很低 -------------------------------------------------------------------------------- /data/youdao/text/pitch: -------------------------------------------------------------------------------- 1 | 音调普通 2 | 音调很高 3 | 音调很低 -------------------------------------------------------------------------------- /data/youdao/text/speed: -------------------------------------------------------------------------------- 1 | 语速普通 2 | 语速很快 3 | 语速很慢 -------------------------------------------------------------------------------- /data/youdao/text/tokenlist: -------------------------------------------------------------------------------- 1 | _ 2 | 3 | [AA0] 4 | [AA1] 5 | [AA2] 6 | [AE0] 7 | [AE1] 8 | [AE2] 9 | [AH0] 10 | [AH1] 11 | [AH2] 12 | [AO0] 13 | [AO1] 14 | [AO2] 15 | [AW0] 16 | [AW1] 17 | [AW2] 18 | [AY0] 19 | [AY1] 20 | [AY2] 21 | [B] 22 | [CH] 23 | [DH] 24 | [D] 25 | [EH0] 26 | [EH1] 27 | [EH2] 28 | [ER0] 29 | [ER1] 30 | [ER2] 31 | [EY0] 32 | [EY1] 33 | [EY2] 34 | [F] 35 | [G] 36 | [HH] 37 | [IH0] 38 | [IH1] 39 | [IH2] 40 | [IY0] 41 | [IY1] 42 | [IY2] 43 | [JH] 44 | [K] 45 | [L] 46 | [M] 47 | [NG] 48 | [N] 49 | [OW0] 50 | [OW1] 51 | [OW2] 52 | [OY0] 53 | [OY1] 54 | [OY2] 55 | [P] 56 | [R] 57 | [SH] 58 | [S] 59 | [TH] 60 | [T] 61 | [UH0] 62 | [UH1] 63 | [UH2] 64 | [UW0] 65 | [UW1] 66 | [UW2] 67 | [V] 68 | [W] 69 | [Y] 70 | [ZH] 71 | [Z] 72 | a1 73 | a2 74 | a3 75 | a4 76 | a5 77 | ai1 78 | ai2 79 | ai3 80 | ai4 81 | ai5 82 | air1 83 | air2 84 | air4 85 | air5 86 | an1 87 | an2 88 | an3 89 | an4 90 | an5 91 | ang1 92 | ang2 93 | ang3 94 | ang4 95 | ang5 96 | angr1 97 | angr2 98 | angr4 99 | anr1 100 | anr2 101 | anr3 102 | anr4 103 | ao1 104 | ao2 105 | ao3 106 | ao4 107 | ao5 108 | aor1 109 | aor2 110 | aor3 111 | aor4 112 | ar1 113 | ar2 114 | ar3 115 | ar4 116 | ar5 117 | arr4 118 | b 119 | c 120 | ch 121 | cn_eng_sp 122 | d 123 | e1 124 | e2 125 | e3 126 | e4 127 | e5 128 | ei1 129 | ei2 130 | ei3 131 | ei4 132 | ei5 133 | eir1 134 | eir4 135 | en1 136 | en2 137 | en3 138 | en4 139 | en5 140 | eng1 141 | eng2 142 | eng3 143 | eng4 144 | eng5 145 | eng_cn_sp 146 | engr1 147 | engr3 148 | engr4 149 | engsp1 150 | engsp2 151 | engsp4 152 | enr1 153 | enr2 154 | enr3 155 | enr4 156 | enr5 157 | er1 158 | er2 159 | er3 160 | er4 161 | er5 162 | f 163 | g 164 | h 165 | i1 166 | i2 167 | i3 168 | i4 169 | i5 170 | ia1 171 | ia2 172 | ia3 173 | ia4 174 | ia5 175 | ian1 176 | ian2 177 | ian3 178 | ian4 179 | ian5 180 | iang1 181 | iang2 182 | iang3 183 | iang4 184 | iang5 185 | iangr2 186 | iangr4 187 | ianr1 188 | ianr2 189 | ianr3 190 | ianr4 191 | ianr5 192 | iao1 193 | iao2 194 | iao3 195 | iao4 196 | iao5 197 | iaor2 198 | iaor3 199 | iaor4 200 | iar2 201 | iar3 202 | iar4 203 | ie1 204 | ie2 205 | ie3 206 | ie4 207 | ie5 208 | ier4 209 | ii1 210 | ii2 211 | ii3 212 | ii4 213 | ii5 214 | iii1 215 | iii2 216 | iii3 217 | iii4 218 | iii5 219 | iiir2 220 | iiir3 221 | iiir4 222 | iir2 223 | iir3 224 | iir4 225 | in1 226 | in2 227 | in3 228 | in4 229 | in5 230 | ing1 231 | ing2 232 | ing3 233 | ing4 234 | ing5 235 | ingr1 236 | ingr2 237 | ingr3 238 | ingr4 239 | inr1 240 | inr4 241 | iong1 242 | iong2 243 | iong3 244 | iong4 245 | iong5 246 | iou1 247 | iou2 248 | iou3 249 | iou4 250 | iou5 251 | iour2 252 | iour3 253 | iour4 254 | ir1 255 | ir2 256 | ir3 257 | ir4 258 | irr1 259 | j 260 | k 261 | l 262 | m 263 | n 264 | o1 265 | o2 266 | o3 267 | o4 268 | o5 269 | ong1 270 | ong2 271 | ong3 272 | ong4 273 | ong5 274 | ongr2 275 | ongr3 276 | ongr4 277 | or4 278 | ou1 279 | ou2 280 | ou3 281 | ou4 282 | ou5 283 | our1 284 | our2 285 | our3 286 | our4 287 | our5 288 | p 289 | q 290 | r 291 | s 292 | sh 293 | sp0 294 | sp1 295 | sp2 296 | sp3 297 | sp4 298 | t 299 | u1 300 | u2 301 | u3 302 | u4 303 | u5 304 | ua1 305 | ua2 306 | ua3 307 | ua4 308 | ua5 309 | uai1 310 | uai2 311 | uai3 312 | uai4 313 | uai5 314 | uair4 315 | uan1 316 | uan2 317 | uan3 318 | uan4 319 | uan5 320 | uang1 321 | uang2 322 | uang3 323 | uang4 324 | uang5 325 | uanr1 326 | uanr2 327 | uanr3 328 | uanr4 329 | uanr5 330 | uar1 331 | uar2 332 | uar3 333 | uar4 334 | uei1 335 | uei2 336 | uei3 337 | uei4 338 | uei5 339 | ueir1 340 | ueir2 341 | ueir3 342 | ueir4 343 | uen1 344 | uen2 345 | uen3 346 | uen4 347 | uen5 348 | ueng1 349 | ueng3 350 | ueng4 351 | uenr1 352 | uenr2 353 | uenr3 354 | uenr4 355 | uo1 356 | uo2 357 | uo3 358 | uo4 359 | uo5 360 | uor1 361 | uor2 362 | uor3 363 | uor4 364 | uor5 365 | ur1 366 | ur2 367 | ur3 368 | ur4 369 | ur5 370 | v1 371 | v2 372 | v3 373 | v4 374 | v5 375 | van1 376 | van2 377 | van3 378 | van4 379 | van5 380 | vanr1 381 | vanr2 382 | vanr3 383 | vanr4 384 | ve1 385 | ve2 386 | ve3 387 | ve4 388 | ve5 389 | ver2 390 | vn1 391 | vn2 392 | vn3 393 | vn4 394 | vn5 395 | vr2 396 | vr3 397 | vr4 398 | vr5 399 | x 400 | y 401 | z 402 | zh 403 | engsp0 404 | ? 405 | . 406 | spn 407 | ue2 408 | ! 409 | err1 410 | [LAUGH] 411 | rr 412 | ier2 413 | or1 414 | ueng2 415 | ir5 416 | iar1 417 | iour1 418 | uncased15 419 | uncased16 420 | uncased17 421 | uncased18 422 | uncased19 423 | uncased20 424 | uncased21 425 | uncased22 426 | uncased23 427 | uncased24 428 | uncased25 429 | uncased26 430 | uncased27 431 | uncased28 432 | uncased29 433 | uncased30 434 | uncased31 435 | uncased32 436 | uncased33 437 | uncased34 438 | uncased35 439 | uncased36 440 | uncased37 441 | uncased38 442 | uncased39 443 | uncased40 444 | uncased41 445 | uncased42 446 | uncased43 447 | uncased44 448 | uncased45 449 | uncased46 450 | uncased47 451 | uncased48 452 | uncased49 453 | uncased50 454 | uncased51 455 | uncased52 456 | uncased53 457 | uncased54 458 | uncased55 459 | uncased56 460 | uncased57 461 | uncased58 462 | uncased59 463 | uncased60 464 | uncased61 465 | uncased62 466 | uncased63 467 | uncased64 468 | uncased65 469 | uncased66 470 | uncased67 471 | uncased68 472 | uncased69 473 | uncased70 474 | uncased71 475 | uncased72 476 | uncased73 477 | uncased74 478 | uncased75 479 | uncased76 480 | uncased77 481 | uncased78 482 | uncased79 483 | uncased80 484 | uncased81 485 | uncased82 486 | uncased83 487 | uncased84 488 | uncased85 489 | uncased86 490 | uncased87 491 | uncased88 492 | uncased89 493 | uncased90 494 | uncased91 495 | uncased92 496 | uncased93 497 | uncased94 498 | uncased95 499 | uncased96 500 | uncased97 501 | uncased98 502 | uncased99 503 | -------------------------------------------------------------------------------- /demo_page.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import streamlit as st 16 | import os, glob 17 | import numpy as np 18 | from yacs import config as CONFIG 19 | import torch 20 | import re 21 | 22 | from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p 23 | from config.joint.config import Config 24 | from models.prompt_tts_modified.jets import JETSGenerator 25 | from models.prompt_tts_modified.simbert import StyleEncoder 26 | from transformers import AutoTokenizer 27 | 28 | import base64 29 | from pathlib import Path 30 | 31 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | MAX_WAV_VALUE = 32768.0 33 | 34 | config = Config() 35 | 36 | def create_download_link(): 37 | pdf_path = Path("EmotiVoice_UserAgreement_易魔声用户协议.pdf") 38 | base64_pdf = base64.b64encode(pdf_path.read_bytes()).decode("utf-8") # val looks like b'...' 39 | return f'EmotiVoice_UserAgreement_易魔声用户协议.pdf' 40 | 41 | html=create_download_link() 42 | 43 | st.set_page_config( 44 | page_title="demo page", 45 | page_icon="📕", 46 | ) 47 | st.write("# Text-To-Speech") 48 | st.markdown(f""" 49 | ### How to use: 50 | 51 | - Simply select a **Speaker ID**, type in the **text** you want to convert and the emotion **Prompt**, like a single word or even a sentence. Then click on the **Synthesize** button below to start voice synthesis. 52 | 53 | - You can download the audio by clicking on the vertical three points next to the displayed audio widget. 54 | 55 | - For more information on **'Speaker ID'**, please consult the [EmotiVoice voice wiki page](https://github.com/netease-youdao/EmotiVoice/tree/main/data/youdao/text) 56 | 57 | - This interactive demo page is provided under the {html} file. The audio is synthesized by AI. 音频由AI合成,仅供参考。 58 | 59 | """, unsafe_allow_html=True) 60 | 61 | def scan_checkpoint(cp_dir, prefix, c=8): 62 | pattern = os.path.join(cp_dir, prefix + '?'*c) 63 | cp_list = glob.glob(pattern) 64 | if len(cp_list) == 0: 65 | return None 66 | return sorted(cp_list)[-1] 67 | 68 | @st.cache_resource 69 | def get_models(): 70 | 71 | am_checkpoint_path = scan_checkpoint(f'{config.output_directory}/prompt_tts_open_source_joint/ckpt', 'g_') 72 | 73 | style_encoder_checkpoint_path = scan_checkpoint(f'{config.output_directory}/style_encoder/ckpt', 'checkpoint_', 6)#f'{config.output_directory}/style_encoder/ckpt/checkpoint_163431' 74 | 75 | with open(config.model_config_path, 'r') as fin: 76 | conf = CONFIG.load_cfg(fin) 77 | 78 | conf.n_vocab = config.n_symbols 79 | conf.n_speaker = config.speaker_n_labels 80 | 81 | style_encoder = StyleEncoder(config) 82 | model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu") 83 | model_ckpt = {} 84 | for key, value in model_CKPT['model'].items(): 85 | new_key = key[7:] 86 | model_ckpt[new_key] = value 87 | style_encoder.load_state_dict(model_ckpt, strict=False) 88 | generator = JETSGenerator(conf).to(DEVICE) 89 | 90 | model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE) 91 | generator.load_state_dict(model_CKPT['generator']) 92 | generator.eval() 93 | 94 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path) 95 | 96 | with open(config.token_list_path, 'r') as f: 97 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())} 98 | 99 | with open(config.speaker2id_path, encoding='utf-8') as f: 100 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())} 101 | 102 | 103 | return (style_encoder, generator, tokenizer, token2id, speaker2id) 104 | 105 | def get_style_embedding(prompt, tokenizer, style_encoder): 106 | prompt = tokenizer([prompt], return_tensors="pt") 107 | input_ids = prompt["input_ids"] 108 | token_type_ids = prompt["token_type_ids"] 109 | attention_mask = prompt["attention_mask"] 110 | with torch.no_grad(): 111 | output = style_encoder( 112 | input_ids=input_ids, 113 | token_type_ids=token_type_ids, 114 | attention_mask=attention_mask, 115 | ) 116 | style_embedding = output["pooled_output"].cpu().squeeze().numpy() 117 | return style_embedding 118 | 119 | def tts(name, text, prompt, content, speaker, models): 120 | (style_encoder, generator, tokenizer, token2id, speaker2id)=models 121 | 122 | 123 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder) 124 | content_embedding = get_style_embedding(content, tokenizer, style_encoder) 125 | 126 | speaker = speaker2id[speaker] 127 | 128 | text_int = [token2id[ph] for ph in text.split()] 129 | 130 | sequence = torch.from_numpy(np.array(text_int)).to(DEVICE).long().unsqueeze(0) 131 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE) 132 | style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0) 133 | content_embedding = torch.from_numpy(content_embedding).to(DEVICE).unsqueeze(0) 134 | speaker = torch.from_numpy(np.array([speaker])).to(DEVICE) 135 | 136 | with torch.no_grad(): 137 | 138 | infer_output = generator( 139 | inputs_ling=sequence, 140 | inputs_style_embedding=style_embedding, 141 | input_lengths=sequence_len, 142 | inputs_content_embedding=content_embedding, 143 | inputs_speaker=speaker, 144 | alpha=1.0 145 | ) 146 | 147 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE 148 | audio = audio.cpu().numpy().astype('int16') 149 | 150 | return audio 151 | 152 | speakers = config.speakers 153 | models = get_models() 154 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt") 155 | g2p = G2p() 156 | 157 | def new_line(i): 158 | col1, col2, col3, col4 = st.columns([1.5, 1.5, 3.5, 1.3]) 159 | with col1: 160 | speaker=st.selectbox("Speaker ID (说话人)", speakers, key=f"{i}_speaker") 161 | with col2: 162 | prompt=st.text_input("Prompt (开心/悲伤)", "", key=f"{i}_prompt") 163 | with col3: 164 | content=st.text_input("Text to be synthesized into speech (合成文本)", "合成文本", key=f"{i}_text") 165 | with col4: 166 | lang=st.selectbox("Language (语言)", ["zh_us"], key=f"{i}_lang") 167 | 168 | flag = st.button(f"Synthesize (合成)", key=f"{i}_button1") 169 | if flag: 170 | text = g2p_cn_en(content, g2p, lexicon) 171 | path = tts(i, text, prompt, content, speaker, models) 172 | st.audio(path, sample_rate=config.sampling_rate) 173 | 174 | 175 | 176 | new_line(0) 177 | -------------------------------------------------------------------------------- /demo_page_databaker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import streamlit as st 16 | import os, glob 17 | import numpy as np 18 | from yacs import config as CONFIG 19 | import torch 20 | import re 21 | 22 | from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p 23 | from exp.DataBaker.config.config import Config 24 | from models.prompt_tts_modified.jets import JETSGenerator 25 | from models.prompt_tts_modified.simbert import StyleEncoder 26 | from transformers import AutoTokenizer 27 | 28 | import base64 29 | from pathlib import Path 30 | 31 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | MAX_WAV_VALUE = 32768.0 33 | 34 | config = Config() 35 | 36 | def create_download_link(): 37 | pdf_path = Path("EmotiVoice_UserAgreement_易魔声用户协议.pdf") 38 | base64_pdf = base64.b64encode(pdf_path.read_bytes()).decode("utf-8") # val looks like b'...' 39 | return f'EmotiVoice_UserAgreement_易魔声用户协议.pdf' 40 | 41 | html=create_download_link() 42 | 43 | st.set_page_config( 44 | page_title="demo page", 45 | page_icon="📕", 46 | ) 47 | st.write("# Text-To-Speech") 48 | st.markdown(f""" 49 | ### How to use: 50 | 51 | - Simply select a **Speaker ID**, type in the **text** you want to convert and the emotion **Prompt**, like a single word or even a sentence. Then click on the **Synthesize** button below to start voice synthesis. 52 | 53 | - You can download the audio by clicking on the vertical three points next to the displayed audio widget. 54 | 55 | - For more information on **'Speaker ID'**, please consult the [EmotiVoice voice wiki page](https://github.com/netease-youdao/EmotiVoice/tree/main/data/youdao/text) 56 | 57 | - This interactive demo page is provided under the {html} file. The audio is synthesized by AI. 音频由AI合成,仅供参考。 58 | 59 | """, unsafe_allow_html=True) 60 | 61 | def scan_checkpoint(cp_dir, prefix, c=8): 62 | pattern = os.path.join(cp_dir, prefix + '?'*c) 63 | cp_list = glob.glob(pattern) 64 | if len(cp_list) == 0: 65 | return None 66 | return sorted(cp_list)[-1] 67 | 68 | @st.cache_resource 69 | def get_models(): 70 | 71 | am_checkpoint_path = scan_checkpoint(f'{config.output_directory}/ckpt', 'g_') 72 | 73 | style_encoder_checkpoint_path = config.style_encoder_ckpt 74 | 75 | with open(config.model_config_path, 'r') as fin: 76 | conf = CONFIG.load_cfg(fin) 77 | 78 | conf.n_vocab = config.n_symbols 79 | conf.n_speaker = config.speaker_n_labels 80 | 81 | style_encoder = StyleEncoder(config) 82 | model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu") 83 | model_ckpt = {} 84 | for key, value in model_CKPT['model'].items(): 85 | new_key = key[7:] 86 | model_ckpt[new_key] = value 87 | style_encoder.load_state_dict(model_ckpt, strict=False) 88 | generator = JETSGenerator(conf).to(DEVICE) 89 | 90 | model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE) 91 | generator.load_state_dict(model_CKPT['generator']) 92 | generator.eval() 93 | 94 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path) 95 | 96 | with open(config.token_list_path, 'r') as f: 97 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())} 98 | 99 | with open(config.speaker2id_path, encoding='utf-8') as f: 100 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())} 101 | 102 | 103 | return (style_encoder, generator, tokenizer, token2id, speaker2id) 104 | 105 | def get_style_embedding(prompt, tokenizer, style_encoder): 106 | prompt = tokenizer([prompt], return_tensors="pt") 107 | input_ids = prompt["input_ids"] 108 | token_type_ids = prompt["token_type_ids"] 109 | attention_mask = prompt["attention_mask"] 110 | with torch.no_grad(): 111 | output = style_encoder( 112 | input_ids=input_ids, 113 | token_type_ids=token_type_ids, 114 | attention_mask=attention_mask, 115 | ) 116 | style_embedding = output["pooled_output"].cpu().squeeze().numpy() 117 | return style_embedding 118 | 119 | def tts(name, text, prompt, content, speaker, models): 120 | (style_encoder, generator, tokenizer, token2id, speaker2id)=models 121 | 122 | 123 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder) 124 | content_embedding = get_style_embedding(content, tokenizer, style_encoder) 125 | 126 | speaker = speaker2id[speaker] 127 | 128 | text_int = [token2id[ph] for ph in text.split()] 129 | 130 | sequence = torch.from_numpy(np.array(text_int)).to(DEVICE).long().unsqueeze(0) 131 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE) 132 | style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0) 133 | content_embedding = torch.from_numpy(content_embedding).to(DEVICE).unsqueeze(0) 134 | speaker = torch.from_numpy(np.array([speaker])).to(DEVICE) 135 | 136 | with torch.no_grad(): 137 | 138 | infer_output = generator( 139 | inputs_ling=sequence, 140 | inputs_style_embedding=style_embedding, 141 | input_lengths=sequence_len, 142 | inputs_content_embedding=content_embedding, 143 | inputs_speaker=speaker, 144 | alpha=1.0 145 | ) 146 | 147 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE 148 | audio = audio.cpu().numpy().astype('int16') 149 | 150 | return audio 151 | 152 | speakers = config.speakers 153 | models = get_models() 154 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt") 155 | g2p = G2p() 156 | 157 | def new_line(i): 158 | col1, col2, col3, col4 = st.columns([1.5, 1.5, 3.5, 1.3]) 159 | with col1: 160 | speaker=st.selectbox("Speaker ID (说话人)", speakers, key=f"{i}_speaker") 161 | with col2: 162 | prompt=st.text_input("Prompt (开心/悲伤)", "", key=f"{i}_prompt") 163 | with col3: 164 | content=st.text_input("Text to be synthesized into speech (合成文本)", "合成文本", key=f"{i}_text") 165 | with col4: 166 | lang=st.selectbox("Language (语言)", ["zh_us"], key=f"{i}_lang") 167 | 168 | flag = st.button(f"Synthesize (合成)", key=f"{i}_button1") 169 | if flag: 170 | text = g2p_cn_en(content, g2p, lexicon) 171 | path = tts(i, text, prompt, content, speaker, models) 172 | st.audio(path, sample_rate=config.sampling_rate) 173 | 174 | 175 | 176 | new_line(0) 177 | -------------------------------------------------------------------------------- /frontend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from frontend_cn import g2p_cn, re_digits, tn_chinese 17 | from frontend_en import ROOT_DIR, read_lexicon, G2p, get_eng_phoneme 18 | 19 | # Thanks to GuGCoCo and PatroxGaurab for identifying the issue: 20 | # the results differ between frontend.py and frontend_en.py. Here's a quick fix. 21 | #re_english_word = re.compile('([a-z\-\.\'\s,;\:\!\?]+|\d+[\d\.]*)', re.I) 22 | re_english_word = re.compile('([^\u4e00-\u9fa5]+|[ \u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09\u4e00-\u9fa5]+)', re.I) 23 | def g2p_cn_en(text, g2p, lexicon): 24 | # Our policy dictates that if the text contains Chinese, digits are to be converted into Chinese. 25 | text=tn_chinese(text) 26 | parts = re_english_word.split(text) 27 | parts=list(filter(None, parts)) 28 | tts_text = [""] 29 | chartype = '' 30 | text_contains_chinese = contains_chinese(text) 31 | for part in parts: 32 | if part == ' ' or part == '': continue 33 | if re_digits.match(part) and (text_contains_chinese or chartype == '') or contains_chinese(part): 34 | if chartype == 'en': 35 | tts_text.append('eng_cn_sp') 36 | phoneme = g2p_cn(part).split()[1:-1] 37 | chartype = 'cn' 38 | elif re_english_word.match(part): 39 | if chartype == 'cn': 40 | if "sp" in tts_text[-1]: 41 | "" 42 | else: 43 | tts_text.append('cn_eng_sp') 44 | phoneme = get_eng_phoneme(part, g2p, lexicon, False).split() 45 | if not phoneme : 46 | # tts_text.pop() 47 | continue 48 | else: 49 | chartype = 'en' 50 | else: 51 | continue 52 | tts_text.extend( phoneme ) 53 | 54 | tts_text=" ".join(tts_text).split() 55 | if "sp" in tts_text[-1]: 56 | tts_text.pop() 57 | tts_text.append("") 58 | 59 | return " ".join(tts_text) 60 | 61 | def contains_chinese(text): 62 | pattern = re.compile(r'[\u4e00-\u9fa5]') 63 | match = re.search(pattern, text) 64 | return match is not None 65 | 66 | 67 | if __name__ == "__main__": 68 | import sys 69 | from os.path import isfile 70 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt") 71 | 72 | g2p = G2p() 73 | if len(sys.argv) < 2: 74 | print("Usage: python %s " % sys.argv[0]) 75 | exit() 76 | text_file = sys.argv[1] 77 | if isfile(text_file): 78 | fp = open(text_file, 'r') 79 | for line in fp: 80 | phoneme = g2p_cn_en(line.rstrip(), g2p, lexicon) 81 | print(phoneme) 82 | fp.close() 83 | -------------------------------------------------------------------------------- /frontend_cn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from pypinyin import pinyin, lazy_pinyin, Style 17 | import jieba 18 | import string 19 | from cn2an.an2cn import An2Cn 20 | from pypinyin_dict.phrase_pinyin_data import cc_cedict 21 | cc_cedict.load() 22 | re_special_pinyin = re.compile(r'^(n|ng|m)$') 23 | def split_py(py): 24 | tone = py[-1] 25 | py = py[:-1] 26 | sm = "" 27 | ym = "" 28 | suf_r = "" 29 | if re_special_pinyin.match(py): 30 | py = 'e' + py 31 | if py[-1] == 'r': 32 | suf_r = 'r' 33 | py = py[:-1] 34 | if py == 'zi' or py == 'ci' or py == 'si' or py == 'ri': 35 | sm = py[:1] 36 | ym = "ii" 37 | elif py == 'zhi' or py == 'chi' or py == 'shi': 38 | sm = py[:2] 39 | ym = "iii" 40 | elif py == 'ya' or py == 'yan' or py == 'yang' or py == 'yao' or py == 'ye' or py == 'yong' or py == 'you': 41 | sm = "" 42 | ym = 'i' + py[1:] 43 | elif py == 'yi' or py == 'yin' or py == 'ying': 44 | sm = "" 45 | ym = py[1:] 46 | elif py == 'yu' or py == 'yv' or py == 'yuan' or py == 'yvan' or py == 'yue ' or py == 'yve' or py == 'yun' or py == 'yvn': 47 | sm = "" 48 | ym = 'v' + py[2:] 49 | elif py == 'wu': 50 | sm = "" 51 | ym = "u" 52 | elif py[0] == 'w': 53 | sm = "" 54 | ym = "u" + py[1:] 55 | elif len(py) >= 2 and (py[0] == 'j' or py[0] == 'q' or py[0] == 'x') and py[1] == 'u': 56 | sm = py[0] 57 | ym = 'v' + py[2:] 58 | else: 59 | seg_pos = re.search('a|e|i|o|u|v', py) 60 | sm = py[:seg_pos.start()] 61 | ym = py[seg_pos.start():] 62 | if ym == 'ui': 63 | ym = 'uei' 64 | elif ym == 'iu': 65 | ym = 'iou' 66 | elif ym == 'un': 67 | ym = 'uen' 68 | elif ym == 'ue': 69 | ym = 've' 70 | ym += suf_r + tone 71 | return sm, ym 72 | 73 | 74 | chinese_punctuation_pattern = r'[\u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09]' 75 | 76 | 77 | def has_chinese_punctuation(text): 78 | match = re.search(chinese_punctuation_pattern, text) 79 | return match is not None 80 | def has_english_punctuation(text): 81 | return text in string.punctuation 82 | 83 | # with thanks to KimigaiiWuyi in https://github.com/netease-youdao/EmotiVoice/pull/17. 84 | # Updated on November 20, 2023: EmotiVoice now incorporates cn2an (https://github.com/Ailln/cn2an) for number processing. 85 | re_digits = re.compile('(\d[\d\.]*)') 86 | def number_to_chinese(number): 87 | an2cn = An2Cn() 88 | result = an2cn.an2cn(number) 89 | 90 | return result 91 | 92 | def tn_chinese(text): 93 | parts = re_digits.split(text) 94 | words = [] 95 | for part in parts: 96 | if re_digits.match(part): 97 | words.append(number_to_chinese(part)) 98 | else: 99 | words.append(part) 100 | return ''.join(words) 101 | 102 | def g2p_cn(text): 103 | res_text=[""] 104 | seg_list = jieba.cut(text) 105 | for seg in seg_list: 106 | if seg == " ": continue 107 | seg_tn = tn_chinese(seg) 108 | py =[_py[0] for _py in pinyin(seg_tn, style=Style.TONE3,neutral_tone_with_five=True)] 109 | 110 | if any([has_chinese_punctuation(_py) for _py in py]) or any([has_english_punctuation(_py) for _py in py]): 111 | res_text.pop() 112 | res_text.append("sp3") 113 | else: 114 | 115 | py = [" ".join(split_py(_py)) for _py in py] 116 | 117 | res_text.append(" sp0 ".join(py)) 118 | res_text.append("sp1") 119 | #res_text.pop() 120 | res_text.append("") 121 | return " ".join(res_text) 122 | 123 | if __name__ == "__main__": 124 | import sys 125 | from os.path import isfile 126 | if len(sys.argv) < 2: 127 | print("Usage: python %s " % sys.argv[0]) 128 | exit() 129 | text_file = sys.argv[1] 130 | if isfile(text_file): 131 | fp = open(text_file, 'r') 132 | for line in fp: 133 | phoneme=g2p_cn(line.rstrip()) 134 | print(phoneme) 135 | fp.close() 136 | -------------------------------------------------------------------------------- /frontend_en.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import argparse 17 | from string import punctuation 18 | import numpy as np 19 | 20 | from g2p_en import G2p 21 | 22 | import os 23 | 24 | 25 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__")) 26 | 27 | def read_lexicon(lex_path): 28 | lexicon = {} 29 | with open(lex_path) as f: 30 | for line in f: 31 | temp = re.split(r"\s+", line.strip("\n")) 32 | word = temp[0] 33 | phones = temp[1:] 34 | if word.lower() not in lexicon: 35 | lexicon[word.lower()] = phones 36 | return lexicon 37 | 38 | def get_eng_phoneme(text, g2p, lexicon, pad_sos_eos=True): 39 | """ 40 | english g2p 41 | """ 42 | filters = {",", " ", "'"} 43 | phones = [] 44 | words = list(filter(lambda x: x not in {"", " "}, re.split(r"([,;.\-\?\!\s+])", text))) 45 | 46 | for w in words: 47 | if w.lower() in lexicon: 48 | 49 | for ph in lexicon[w.lower()]: 50 | if ph not in filters: 51 | phones += ["[" + ph + "]"] 52 | 53 | if "sp" not in phones[-1]: 54 | phones += ["engsp1"] 55 | else: 56 | phone=g2p(w) 57 | if not phone: 58 | continue 59 | 60 | if phone[0].isalnum(): 61 | 62 | for ph in phone: 63 | if ph not in filters: 64 | phones += ["[" + ph + "]"] 65 | if ph == " " and "sp" not in phones[-1]: 66 | phones += ["engsp1"] 67 | elif phone == " ": 68 | continue 69 | elif phones: 70 | phones.pop() # pop engsp1 71 | phones.append("engsp4") 72 | if phones and "engsp" in phones[-1]: 73 | phones.pop() 74 | 75 | # mark = "." if text[-1] != "?" else "?" 76 | if pad_sos_eos: 77 | phones = [""] + phones + [""] 78 | return " ".join(phones) 79 | 80 | 81 | if __name__ == "__main__": 82 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt") 83 | g2p = G2p() 84 | phonemes= get_eng_phoneme("Happy New Year", g2p, lexicon) 85 | import sys 86 | from os.path import isfile 87 | if len(sys.argv) < 2: 88 | print("Usage: python %s " % sys.argv[0]) 89 | exit() 90 | text_file = sys.argv[1] 91 | if isfile(text_file): 92 | fp = open(text_file, 'r') 93 | for line in fp: 94 | phoneme=get_eng_phoneme(line.rstrip(), g2p, lexicon) 95 | print(phoneme) 96 | fp.close() -------------------------------------------------------------------------------- /inference_am_vocoder_exp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from models.prompt_tts_modified.jets import JETSGenerator 16 | from models.prompt_tts_modified.simbert import StyleEncoder 17 | from transformers import AutoTokenizer 18 | import os, sys, warnings, torch, glob, argparse 19 | import numpy as np 20 | from models.hifigan.get_vocoder import MAX_WAV_VALUE 21 | import soundfile as sf 22 | from yacs import config as CONFIG 23 | from tqdm import tqdm 24 | 25 | def get_style_embedding(prompt, tokenizer, style_encoder): 26 | prompt = tokenizer([prompt], return_tensors="pt") 27 | input_ids = prompt["input_ids"] 28 | token_type_ids = prompt["token_type_ids"] 29 | attention_mask = prompt["attention_mask"] 30 | 31 | with torch.no_grad(): 32 | output = style_encoder( 33 | input_ids=input_ids, 34 | token_type_ids=token_type_ids, 35 | attention_mask=attention_mask, 36 | ) 37 | style_embedding = output["pooled_output"].cpu().squeeze().numpy() 38 | return style_embedding 39 | 40 | def main(args, config): 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | root_path = os.path.join(config.output_directory) 43 | ckpt_path = os.path.join(root_path, "ckpt") 44 | files = os.listdir(ckpt_path) 45 | 46 | for file in files: 47 | if args.checkpoint: 48 | if file != args.checkpoint: 49 | continue 50 | 51 | checkpoint_path = os.path.join(ckpt_path, file) 52 | 53 | with open(config.model_config_path, 'r') as fin: 54 | conf = CONFIG.load_cfg(fin) 55 | 56 | 57 | conf.n_vocab = config.n_symbols 58 | conf.n_speaker = config.speaker_n_labels 59 | 60 | style_encoder = StyleEncoder(config) 61 | model_CKPT = torch.load(config.style_encoder_ckpt, map_location="cpu") 62 | model_ckpt = {} 63 | for key, value in model_CKPT['model'].items(): 64 | new_key = key[7:] 65 | model_ckpt[new_key] = value 66 | style_encoder.load_state_dict(model_ckpt, strict=False) 67 | 68 | 69 | 70 | generator = JETSGenerator(conf).to(device) 71 | 72 | model_CKPT = torch.load(checkpoint_path, map_location=device) 73 | generator.load_state_dict(model_CKPT['generator']) 74 | generator.eval() 75 | 76 | with open(config.token_list_path, 'r') as f: 77 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())} 78 | 79 | with open(config.speaker2id_path, encoding='utf-8') as f: 80 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())} 81 | 82 | 83 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path) 84 | 85 | text_path = args.test_file 86 | 87 | 88 | if os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"): 89 | r = glob.glob(root_path + "/test_audio/audio/" +f"{file}/*") 90 | for j in r: 91 | os.remove(j) 92 | texts = [] 93 | prompts = [] 94 | speakers = [] 95 | contents = [] 96 | with open(text_path, "r") as f: 97 | for line in f: 98 | line = line.strip().split("|") 99 | speakers.append(line[0]) 100 | prompts.append(line[1]) 101 | texts.append(line[2].split()) 102 | contents.append(line[3]) 103 | 104 | for i, (speaker, prompt, text, content) in enumerate(tqdm(zip(speakers, prompts, texts, contents))): 105 | 106 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder) 107 | content_embedding = get_style_embedding(content, tokenizer, style_encoder) 108 | 109 | if speaker not in speaker2id: 110 | continue 111 | speaker = speaker2id[speaker] 112 | 113 | text_int = [token2id[ph] for ph in text] 114 | 115 | sequence = torch.from_numpy(np.array(text_int)).to(device).long().unsqueeze(0) 116 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(device) 117 | style_embedding = torch.from_numpy(style_embedding).to(device).unsqueeze(0) 118 | content_embedding = torch.from_numpy(content_embedding).to(device).unsqueeze(0) 119 | speaker = torch.from_numpy(np.array([speaker])).to(device) 120 | with torch.no_grad(): 121 | 122 | infer_output = generator( 123 | inputs_ling=sequence, 124 | inputs_style_embedding=style_embedding, 125 | input_lengths=sequence_len, 126 | inputs_content_embedding=content_embedding, 127 | inputs_speaker=speaker, 128 | alpha=1.0 129 | ) 130 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE 131 | audio = audio.cpu().numpy().astype('int16') 132 | if not os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"): 133 | os.makedirs(root_path + "/test_audio/audio/" +f"{file}/", exist_ok=True) 134 | sf.write(file=root_path + "/test_audio/audio/" +f"{file}/{i+1}.wav", data=audio, samplerate=config.sampling_rate) #h.sampling_rate 135 | 136 | 137 | 138 | 139 | 140 | 141 | if __name__ == '__main__': 142 | print("run!") 143 | p = argparse.ArgumentParser() 144 | p.add_argument("-c", "--config_folder", type=str, required=True) 145 | p.add_argument("--checkpoint", type=str, required=False, default='', help='inference specific checkpoint, e.g --checkpoint checkpoint_230000') 146 | p.add_argument('-t', '--test_file', type=str, required=True, help='the absolute path of test file that is going to inference') 147 | 148 | args = p.parse_args() 149 | ################################################## 150 | sys.path.append(os.path.dirname(os.path.abspath("__file__")) + "/" + args.config_folder) 151 | 152 | from config import Config 153 | config = Config() 154 | ################################################## 155 | main(args, config) 156 | 157 | 158 | -------------------------------------------------------------------------------- /inference_am_vocoder_joint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from models.prompt_tts_modified.jets import JETSGenerator 16 | from models.prompt_tts_modified.simbert import StyleEncoder 17 | from transformers import AutoTokenizer 18 | import os, sys, warnings, torch, glob, argparse 19 | import numpy as np 20 | from models.hifigan.get_vocoder import MAX_WAV_VALUE 21 | import soundfile as sf 22 | from yacs import config as CONFIG 23 | from tqdm import tqdm 24 | 25 | def get_style_embedding(prompt, tokenizer, style_encoder): 26 | prompt = tokenizer([prompt], return_tensors="pt") 27 | input_ids = prompt["input_ids"] 28 | token_type_ids = prompt["token_type_ids"] 29 | attention_mask = prompt["attention_mask"] 30 | 31 | with torch.no_grad(): 32 | output = style_encoder( 33 | input_ids=input_ids, 34 | token_type_ids=token_type_ids, 35 | attention_mask=attention_mask, 36 | ) 37 | style_embedding = output["pooled_output"].cpu().squeeze().numpy() 38 | return style_embedding 39 | 40 | def main(args, config): 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | root_path = os.path.join(config.output_directory, args.logdir) 43 | ckpt_path = os.path.join(root_path, "ckpt") 44 | files = os.listdir(ckpt_path) 45 | 46 | for file in files: 47 | if args.checkpoint: 48 | if file != args.checkpoint: 49 | continue 50 | 51 | checkpoint_path = os.path.join(ckpt_path, file) 52 | 53 | with open(config.model_config_path, 'r') as fin: 54 | conf = CONFIG.load_cfg(fin) 55 | 56 | 57 | conf.n_vocab = config.n_symbols 58 | conf.n_speaker = config.speaker_n_labels 59 | 60 | style_encoder = StyleEncoder(config) 61 | model_CKPT = torch.load(config.style_encoder_ckpt, map_location="cpu") 62 | model_ckpt = {} 63 | for key, value in model_CKPT['model'].items(): 64 | new_key = key[7:] 65 | model_ckpt[new_key] = value 66 | style_encoder.load_state_dict(model_ckpt, strict=False) 67 | 68 | 69 | 70 | generator = JETSGenerator(conf).to(device) 71 | 72 | model_CKPT = torch.load(checkpoint_path, map_location=device) 73 | generator.load_state_dict(model_CKPT['generator']) 74 | generator.eval() 75 | 76 | with open(config.token_list_path, 'r') as f: 77 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())} 78 | 79 | with open(config.speaker2id_path, encoding='utf-8') as f: 80 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())} 81 | 82 | 83 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path) 84 | 85 | text_path = args.test_file 86 | 87 | 88 | if os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"): 89 | r = glob.glob(root_path + "/test_audio/audio/" +f"{file}/*") 90 | for j in r: 91 | os.remove(j) 92 | texts = [] 93 | prompts = [] 94 | speakers = [] 95 | contents = [] 96 | with open(text_path, "r") as f: 97 | for line in f: 98 | line = line.strip().split("|") 99 | speakers.append(line[0]) 100 | prompts.append(line[1]) 101 | texts.append(line[2].split()) 102 | contents.append(line[3]) 103 | 104 | for i, (speaker, prompt, text, content) in enumerate(tqdm(zip(speakers, prompts, texts, contents))): 105 | 106 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder) 107 | content_embedding = get_style_embedding(content, tokenizer, style_encoder) 108 | 109 | if speaker not in speaker2id: 110 | continue 111 | speaker = speaker2id[speaker] 112 | 113 | text_int = [token2id[ph] for ph in text] 114 | 115 | sequence = torch.from_numpy(np.array(text_int)).to(device).long().unsqueeze(0) 116 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(device) 117 | style_embedding = torch.from_numpy(style_embedding).to(device).unsqueeze(0) 118 | content_embedding = torch.from_numpy(content_embedding).to(device).unsqueeze(0) 119 | speaker = torch.from_numpy(np.array([speaker])).to(device) 120 | with torch.no_grad(): 121 | 122 | infer_output = generator( 123 | inputs_ling=sequence, 124 | inputs_style_embedding=style_embedding, 125 | input_lengths=sequence_len, 126 | inputs_content_embedding=content_embedding, 127 | inputs_speaker=speaker, 128 | alpha=1.0 129 | ) 130 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE 131 | audio = audio.cpu().numpy().astype('int16') 132 | if not os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"): 133 | os.makedirs(root_path + "/test_audio/audio/" +f"{file}/", exist_ok=True) 134 | sf.write(file=root_path + "/test_audio/audio/" +f"{file}/{i+1}.wav", data=audio, samplerate=config.sampling_rate) #h.sampling_rate 135 | 136 | 137 | 138 | 139 | 140 | 141 | if __name__ == '__main__': 142 | print("run!") 143 | p = argparse.ArgumentParser() 144 | p.add_argument('-d', '--logdir', type=str, required=True) 145 | p.add_argument("-c", "--config_folder", type=str, required=True) 146 | p.add_argument("--checkpoint", type=str, required=False, default='', help='inference specific checkpoint, e.g --checkpoint checkpoint_230000') 147 | p.add_argument('-t', '--test_file', type=str, required=True, help='the absolute path of test file that is going to inference') 148 | 149 | args = p.parse_args() 150 | ################################################## 151 | sys.path.append(os.path.dirname(os.path.abspath("__file__")) + "/" + args.config_folder) 152 | 153 | from config import Config 154 | config = Config() 155 | ################################################## 156 | main(args, config) 157 | 158 | 159 | -------------------------------------------------------------------------------- /mel_process.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | 21 | return torch.log(torch.clamp(x, min=clip_val) * C) 22 | 23 | 24 | def dynamic_range_decompression_torch(x, C=1): 25 | 26 | return torch.exp(x) / C 27 | 28 | 29 | def spectral_normalize_torch(magnitudes): 30 | output = dynamic_range_compression_torch(magnitudes) 31 | return output 32 | 33 | 34 | def spectral_de_normalize_torch(magnitudes): 35 | output = dynamic_range_decompression_torch(magnitudes) 36 | return output 37 | 38 | 39 | mel_basis = {} 40 | hann_window = {} 41 | 42 | 43 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 44 | if torch.min(y) < -1.: 45 | print('min value is ', torch.min(y)) 46 | if torch.max(y) > 1.: 47 | print('max value is ', torch.max(y)) 48 | 49 | global hann_window 50 | dtype_device = str(y.dtype) + '_' + str(y.device) 51 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 52 | if wnsize_dtype_device not in hann_window: 53 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 54 | 55 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 56 | y = y.squeeze(1) 57 | 58 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 59 | center=center, pad_mode='reflect', normalized=False, onesided=True) 60 | 61 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 62 | return spec 63 | 64 | 65 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 66 | global mel_basis 67 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 68 | fmax_dtype_device = str(fmax) + '_' + dtype_device 69 | if fmax_dtype_device not in mel_basis: 70 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 71 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 72 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 73 | spec = spectral_normalize_torch(spec) 74 | return spec 75 | 76 | 77 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 78 | if torch.min(y) < -1.: 79 | print('min value is ', torch.min(y)) 80 | if torch.max(y) > 1.: 81 | print('max value is ', torch.max(y)) 82 | 83 | global mel_basis, hann_window 84 | dtype_device = str(y.dtype) + '_' + str(y.device) 85 | fmax_dtype_device = str(fmax) + '_' + dtype_device 86 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 87 | if fmax_dtype_device not in mel_basis: 88 | mel = librosa_mel_fn( 89 | sr=sampling_rate, 90 | n_fft=n_fft, 91 | n_mels=num_mels, 92 | fmin=fmin, 93 | fmax=fmax) 94 | 95 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 96 | if wnsize_dtype_device not in hann_window: 97 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 98 | 99 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 100 | y = y.squeeze(1) 101 | 102 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 103 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 104 | 105 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 106 | 107 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 108 | spec = spectral_normalize_torch(spec) 109 | 110 | return spec -------------------------------------------------------------------------------- /mfa/step1_create_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from tqdm import tqdm 3 | import jsonlines 4 | import re 5 | import argparse 6 | import os 7 | 8 | def main(args): 9 | ROOT_DIR=os.path.abspath(args.data_dir) 10 | TEXT_DIR=f"{ROOT_DIR}/text" 11 | MFA_DIR=f"{ROOT_DIR}/mfa" 12 | 13 | os.makedirs(MFA_DIR, exist_ok=True) 14 | 15 | 16 | 17 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl") as f1, \ 18 | open(f"{MFA_DIR}/text_sp1-sp4", "w") as f2, \ 19 | open(f"{MFA_DIR}/wav.scp", "w") as f3: 20 | 21 | data=list({f'{sample["key"]}_{sample["speaker"]}':sample for sample in list(f1)}.values()) 22 | 23 | for sample in tqdm(data): 24 | text=[] 25 | for ph in sample["text"]: 26 | if ph[0] == '[': 27 | ph = ph[1:-1] 28 | elif ph == "cn_eng_sp": 29 | ph = "cnengsp" 30 | elif ph == "eng_cn_sp": 31 | ph = "engcnsp" 32 | text.append(ph) 33 | f2.write("{}|{} {}\n".format(re.sub(r" +", "", sample["speaker"]), sample["key"], " ".join(text))) 34 | f3.write("{} {}\n".format(sample["key"], sample["wav_path"])) 35 | 36 | if __name__ == "__main__": 37 | p = argparse.ArgumentParser() 38 | p.add_argument('--data_dir', type=str, required=True) 39 | args = p.parse_args() 40 | 41 | main(args) 42 | -------------------------------------------------------------------------------- /mfa/step2_prepare_data.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import collections 4 | import pathlib 5 | import os 6 | from typing import Iterable 7 | from tqdm import tqdm 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--dataset_dir', 12 | type=str, 13 | help='Path to cath dataset') 14 | parser.add_argument('--wav', type=str, help='Path to export paths of wavs.') 15 | parser.add_argument('--speaker', type=str, help='Path to export speakers.') 16 | parser.add_argument('--text', type=str, help='Path to export text of wavs.') 17 | return parser.parse_args() 18 | 19 | 20 | def save_scp_files(wav_scp_path: os.PathLike, speaker_scp_path: os.PathLike, 21 | text_scp_path: os.PathLike, content: Iterable[str]): 22 | wav_scp_path = pathlib.Path(wav_scp_path) 23 | speaker_scp_path = pathlib.Path(speaker_scp_path) 24 | text_scp_path = pathlib.Path(text_scp_path) 25 | 26 | wav_scp_path.parent.mkdir(parents=True, exist_ok=True) 27 | speaker_scp_path.parent.mkdir(parents=True, exist_ok=True) 28 | text_scp_path.parent.mkdir(parents=True, exist_ok=True) 29 | 30 | with open(wav_scp_path, 'w') as wav_scp_file: 31 | wav_scp_file.writelines([str(line[0]) + '\n' for line in content]) 32 | with open(speaker_scp_path, 'w') as speaker_scp_file: 33 | speaker_scp_file.writelines([line[1] + '\n' for line in content]) 34 | with open(text_scp_path, 'w') as text_scp_file: 35 | text_scp_file.writelines([line[2] + '\n' for line in content]) 36 | 37 | 38 | def main(args): 39 | dataset_dir = pathlib.Path(args.dataset_dir) 40 | 41 | with open(dataset_dir / 42 | 'text_sp1-sp4') as train_set_label_file: 43 | train_set_label = [ 44 | x.strip() for x in train_set_label_file.readlines() 45 | ] 46 | train_set_path={} 47 | with open(dataset_dir / 48 | 'wav.scp') as train_set_path_file: 49 | for line in train_set_path_file: 50 | line = line.strip().split() 51 | train_set_path[line[0]] = line[1] 52 | 53 | samples = collections.defaultdict(list) 54 | 55 | for line in tqdm(train_set_label): 56 | line = line.split() 57 | # sample_name = "_".join(line[0].split("_")[1:]) 58 | sample_name = line[0].split("|")[1] 59 | tokens = " ".join(line[1:]) 60 | speaker = line[0].split("|")[0] 61 | wav_path = train_set_path[sample_name] 62 | if os.path.exists(wav_path): 63 | samples[speaker].append((wav_path, speaker, tokens)) 64 | else: 65 | print(wav_path, "is not existed") 66 | 67 | sample_list = [] 68 | 69 | for speaker in sorted(samples): 70 | sample_list.extend(samples[speaker]) 71 | 72 | save_scp_files(args.wav, args.speaker, args.text, sample_list) 73 | 74 | 75 | if __name__ == "__main__": 76 | main(get_args()) 77 | -------------------------------------------------------------------------------- /mfa/step3_prepare_special_tokens.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--special_tokens', 7 | type=str, 8 | help='Path to special_token.txt') 9 | return parser.parse_args() 10 | 11 | def main(args): 12 | with open(args.special_tokens, "w") as f: 13 | for line in {"sp0", "sp1", "sp2", "sp3", "sp4","engsp1", "engsp2", "engsp3", "engsp4", "", "cn_eng_sp", "eng_cn_sp", "." , "?", "LAUGH"}: 14 | f.write(f"{line}\n") 15 | 16 | if __name__ == '__main__': 17 | main(get_args()) 18 | -------------------------------------------------------------------------------- /mfa/step4_convert_text_to_phn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Tsinghua University. (authors: Jie Chen) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Convert full label pingyin sequences into phoneme sequences according to 15 | lexicon. 16 | """ 17 | 18 | import argparse 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--text', type=str, help='Path to text.txt.') 24 | parser.add_argument('--special_tokens', 25 | type=str, 26 | help='Path to special_token.txt') 27 | parser.add_argument('--output', type=str, help='Path to output file.') 28 | return parser.parse_args() 29 | 30 | 31 | def main(args): 32 | with open(args.special_tokens) as fin: 33 | special_tokens = set([x.strip() for x in fin.readlines()]) 34 | samples = [] 35 | with open(args.text) as fin: 36 | for line in fin: 37 | tokens = [] 38 | word = [] 39 | for ph in line.strip().split(): 40 | if ph in special_tokens: 41 | word = "_".join(word) 42 | 43 | tokens.append(word) 44 | tokens.append(ph) 45 | word = [] 46 | else: 47 | ph = ph #[A] -> A 48 | word.append(ph) 49 | 50 | samples.append(' '.join(tokens)) 51 | with open(args.output, 'w') as fout: 52 | fout.writelines([x + '\n' for x in samples]) 53 | 54 | 55 | if __name__ == '__main__': 56 | main(get_args()) 57 | -------------------------------------------------------------------------------- /mfa/step5_prepare_alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2022 Binbin Zhang(binbzha@qq.com), Jie Chen(unrea1sama@outlook.com) 3 | """Generate lab files from data list for alignment 4 | """ 5 | 6 | import argparse 7 | import pathlib 8 | import random, os 9 | from tqdm import tqdm 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--wav", type=str, help='Path to wav.txt.') 13 | parser.add_argument("--speaker", type=str, help='Path to speaker.txt.') 14 | parser.add_argument( 15 | "--text", 16 | type=str, 17 | help=('Path to text.txt. ', 18 | 'It should only contain phonemes and special tokens.')) 19 | parser.add_argument('--special_tokens', 20 | type=str, 21 | help='Path to special_token.txt.') 22 | parser.add_argument( 23 | '--pronounciation_dict', 24 | type=str, 25 | help='Path to export pronounciation dictionary for MFA.') 26 | parser.add_argument('--output_dir', 27 | type=str, 28 | help='Path to directory for exporting .lab files.') 29 | return parser.parse_args() 30 | 31 | 32 | def main(args): 33 | output_dir = pathlib.Path(args.output_dir) 34 | pronounciation_dict = set() 35 | with open(args.special_tokens) as fin: 36 | special_tokens = set([x.strip() for x in fin.readlines()]) 37 | 38 | num_speaker = 1 39 | with open(args.wav) as f: 40 | index = [i for i in range(len(f.readlines()))] 41 | _mfa_groups = [index[i::num_speaker] for i in range(num_speaker)] 42 | mfa_groups = [] 43 | for i, group in enumerate(_mfa_groups): 44 | mfa_groups.extend([i for _ in range(len(group))]) 45 | 46 | random.shuffle(mfa_groups) 47 | os.system(f"rm -rf {args.output_dir}/*") 48 | with open(args.wav) as fwav, open(args.speaker) as fspeaker, open( 49 | args.text) as ftext: 50 | for wav_path, speaker, text, i in tqdm(zip(fwav, fspeaker, ftext, mfa_groups)): 51 | i = speaker.strip()#str(i) 52 | wav_path, speaker, text = (pathlib.Path(wav_path.strip()), 53 | speaker.strip(), text.strip().split()) 54 | lab_dir = output_dir / i 55 | lab_dir.mkdir(parents=True, exist_ok=True) 56 | 57 | name=wav_path.stem.strip() 58 | 59 | lab_file = output_dir / i / f'{i}_{name}.lab' 60 | wav_file = output_dir / i / f'{i}_{name}.wav' 61 | try: 62 | os.symlink(wav_path, wav_file) 63 | except: 64 | print("ERROR PATH",wav_path) 65 | continue 66 | 67 | 68 | with lab_file.open('w') as fout: 69 | text_no_special_tokens = [ph for ph in text if ph not in special_tokens] 70 | pronounciation_dict |= set(text_no_special_tokens) 71 | fout.writelines([' '.join(text_no_special_tokens)]) 72 | with open(args.pronounciation_dict, 'w') as fout: 73 | fout.writelines([ 74 | '{} {}\n'.format(symbol, " ".join(symbol.split("_"))) for symbol in pronounciation_dict 75 | ]) 76 | 77 | 78 | if __name__ == '__main__': 79 | main(get_args()) 80 | -------------------------------------------------------------------------------- /mfa/step8_make_data_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Tsinghua University(Jie Chen) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import jsonlines 17 | import pathlib 18 | 19 | def read_lists(list_file): 20 | lists = [] 21 | with open(list_file, 'r', encoding='utf8') as fin: 22 | for line in fin: 23 | lists.append(line.strip()) 24 | return lists 25 | 26 | def get_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--wav', type=str, help='Path to wav.txt.') 29 | parser.add_argument('--speaker', type=str, help='Path to speaker.txt.') 30 | parser.add_argument('--text', type=str, help='Path to text.txt.') 31 | parser.add_argument('--duration', type=str, help='Path to duration.txt.') 32 | parser.add_argument('--datalist_path', 33 | type=str, 34 | help='Path to export datalist.jsonl.') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def main(args): 40 | wavs = read_lists(args.wav) 41 | speakers = read_lists(args.speaker) 42 | texts = read_lists(args.text) 43 | durations = read_lists(args.duration) 44 | with jsonlines.open(args.datalist_path, 'w') as fdatalist: 45 | for wav, speaker, text, duration in zip(wavs, speakers, texts, 46 | durations): 47 | key = pathlib.Path(wav).stem 48 | fdatalist.write({ 49 | 'key': key, 50 | 'wav_path': wav, 51 | 'speaker': speaker, 52 | 'text': text.split(), 53 | 'duration': [float(x) for x in duration.split()] 54 | }) 55 | 56 | 57 | if __name__ == '__main__': 58 | main(get_args()) 59 | -------------------------------------------------------------------------------- /mfa/step9_datalist_from_mfa.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import jsonlines 4 | import argparse 5 | import os 6 | 7 | 8 | def main(args): 9 | ROOT_DIR=os.path.abspath(args.data_dir) 10 | TEXT_DIR=f"{ROOT_DIR}/text" 11 | MFA_DIR=f"{ROOT_DIR}/mfa" 12 | TRAIN_DIR=f"{ROOT_DIR}/train" 13 | VALID_DIR=f"{ROOT_DIR}/valid" 14 | 15 | with jsonlines.open(f"{MFA_DIR}/datalist.jsonl") as f: 16 | data = list(f) 17 | 18 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl") as f: 19 | data_ref = {sample["key"]:sample for sample in list(f)} 20 | 21 | new_data = [] 22 | with jsonlines.open(f"{TEXT_DIR}/datalist_mfa.jsonl", "w") as f: 23 | for sample in data: 24 | if "duration" in sample: 25 | del sample["duration"] 26 | 27 | 28 | 29 | # if "emotion" not in sample: 30 | # sample["emotion"]="default" 31 | 32 | 33 | for i, ph in enumerate(sample["text"]): 34 | if ph.isupper(): 35 | sample["text"][i] = "[" + ph + "]" 36 | 37 | if ph =="cnengsp": 38 | sample["text"][i] = "cn_eng_sp" 39 | if ph =="engcnsp": 40 | sample["text"][i] = "eng_cn_sp" 41 | 42 | sample_ref = data_ref[sample["key"]] 43 | 44 | sample["original_text"]=sample_ref["original_text"] 45 | sample["prompt"] = sample_ref["prompt"] 46 | new_data.append(sample) 47 | f.write(sample) 48 | 49 | with jsonlines.open(f"{TRAIN_DIR}/datalist_mfa.jsonl", "w") as f: 50 | for sample in new_data[:-3]: 51 | f.write(sample) 52 | 53 | with jsonlines.open(f"{VALID_DIR}/datalist_mfa.jsonl", "w") as f: 54 | for sample in data[-3:]: 55 | f.write(sample) 56 | 57 | if __name__ == "__main__": 58 | p = argparse.ArgumentParser() 59 | p.add_argument('--data_dir', type=str, required=True) 60 | args = p.parse_args() 61 | 62 | main(args) -------------------------------------------------------------------------------- /models/hifigan/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import jsonlines 17 | from transformers import AutoTokenizer 18 | import os, sys 19 | import numpy as np 20 | from scipy.io.wavfile import read 21 | from torch.nn.utils.rnn import pad_sequence 22 | import copy 23 | from models.prompt_tts_modified.tacotron_stft import TacotronSTFT 24 | 25 | 26 | def get_mel(filename, stft, sampling_rate, trim=False): 27 | 28 | sr, wav = read(filename) 29 | if sr != sampling_rate: 30 | raise ValueError("{} SR doesn't match target {} SR".format(sr, sampling_rate)) 31 | 32 | wav = wav / 32768.0 33 | 34 | wav = torch.FloatTensor(wav.astype(np.float32)) 35 | ### trimming ### 36 | if trim: 37 | frac = 0.005 38 | start = torch.where( 39 | torch.abs(wav)>(torch.abs(wav).max()*frac) 40 | )[0][0] 41 | end = torch.where(torch.abs(wav)>(torch.abs(wav).max()*frac))[0][-1] 42 | ### 50ms silence padding ### 43 | wav = torch.nn.functional.pad(wav[start:end], (sampling_rate//20, sampling_rate//20)) 44 | melspec = stft.mel_spectrogram(wav.unsqueeze(0)) 45 | 46 | return melspec.squeeze(0), wav 47 | 48 | def pad_mel(data, downsample_ratio, max_len ): 49 | batch_size = len(data) 50 | num_mels = data[0].size(0) 51 | padded = torch.zeros((batch_size, num_mels, max_len)) 52 | for i in range(batch_size): 53 | lens = data[i].size(1) 54 | if lens % downsample_ratio!=0: 55 | data[i] = data[i][:,:-(lens % downsample_ratio)] 56 | padded[i, :, :data[i].size(1)] = data[i] 57 | 58 | return padded 59 | 60 | class DatasetTTS(torch.utils.data.Dataset): 61 | def __init__(self, data_path, config): 62 | self.sampling_rate=config.sampling_rate 63 | self.datalist = self.load_files(data_path) 64 | self.stft = TacotronSTFT( 65 | filter_length=config.filter_length, 66 | hop_length=config.hop_length, 67 | win_length=config.win_length, 68 | n_mel_channels=config.n_mel_channels, 69 | sampling_rate=config.sampling_rate, 70 | mel_fmin=config.mel_fmin, 71 | mel_fmax=config.mel_fmax 72 | ) 73 | self.trim = config.trim 74 | self.config=config 75 | 76 | 77 | def load_files(self, data_path): 78 | with jsonlines.open(data_path) as f: 79 | data = list(f) 80 | return data 81 | 82 | 83 | def __len__(self): 84 | return len(self.datalist) 85 | 86 | def __getitem__(self, index): 87 | 88 | uttid = self.datalist[index]["key"] 89 | 90 | 91 | mel, wav = get_mel(self.datalist[index]["wav_path"], self.stft, self.sampling_rate, trim=self.trim) 92 | 93 | return { 94 | "mel": mel, 95 | "uttid": uttid, 96 | "wav": wav, 97 | } 98 | 99 | 100 | def TextMelCollate(self, data): 101 | 102 | # Right zero-pad melspectrogram 103 | mel = [x['mel'] for x in data] 104 | max_target_len = max([x.shape[1] for x in mel]) 105 | 106 | # wav 107 | wav = [x["wav"] for x in data] 108 | 109 | padded_wav = pad_sequence(wav, 110 | batch_first=True, 111 | padding_value=0.0) 112 | padded_mel = pad_mel(mel, self.config.downsample_ratio, max_target_len) 113 | 114 | mel_lens = torch.LongTensor([x.shape[1] for x in mel]) 115 | 116 | res = { 117 | "mel" : padded_mel, 118 | "mel_lens" : mel_lens, 119 | "wav" : padded_wav, 120 | } 121 | return res 122 | 123 | 124 | -------------------------------------------------------------------------------- /models/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/jik876/hifi-gan 3 | """ 4 | 5 | import os 6 | import shutil 7 | 8 | 9 | class AttrDict(dict): 10 | def __init__(self, *args, **kwargs): 11 | super(AttrDict, self).__init__(*args, **kwargs) 12 | self.__dict__ = self 13 | 14 | 15 | def build_env(config, config_name, path): 16 | t_path = os.path.join(path, config_name) 17 | if config != t_path: 18 | os.makedirs(path, exist_ok=True) 19 | shutil.copyfile(config, os.path.join(path, config_name)) 20 | -------------------------------------------------------------------------------- /models/hifigan/get_random_segments.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/espnet/espnet 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def get_random_segments( x: torch.Tensor, x_lengths: torch.Tensor, segment_size: int): 9 | b, d, t = x.size() 10 | max_start_idx = x_lengths - segment_size 11 | max_start_idx = torch.clamp(max_start_idx, min=0) 12 | start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( 13 | dtype=torch.long, 14 | ) 15 | segments = get_segments(x, start_idxs, segment_size) 16 | return segments, start_idxs, segment_size 17 | 18 | 19 | def get_segments( x: torch.Tensor, start_idxs: torch.Tensor, segment_size: int): 20 | b, c, t = x.size() 21 | segments = x.new_zeros(b, c, segment_size) 22 | if t < segment_size: 23 | x = torch.nn.functional.pad(x, (0, segment_size - t), 'constant') 24 | for i, start_idx in enumerate(start_idxs): 25 | segment = x[i, :, start_idx : start_idx + segment_size] 26 | segments[i,:,:segment.size(1)] = segment 27 | return segments 28 | -------------------------------------------------------------------------------- /models/hifigan/get_vocoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os, json, torch 16 | from models.hifigan.env import AttrDict 17 | from models.hifigan.models import Generator 18 | 19 | MAX_WAV_VALUE = 32768.0 20 | 21 | def vocoder(hifi_gan_path, hifi_gan_name): 22 | device = torch.device('cpu') 23 | config_file = os.path.join(os.path.split(hifi_gan_path)[0], 'config.json') 24 | with open(config_file) as f: 25 | data = f.read() 26 | global h 27 | json_config = json.loads(data) 28 | h = AttrDict(json_config) 29 | torch.manual_seed(h.seed) 30 | generator = Generator(h).to(device) 31 | 32 | state_dict_g = torch.load(hifi_gan_path+hifi_gan_name, map_location=device) 33 | 34 | generator.load_state_dict(state_dict_g['generator']) 35 | generator.eval() 36 | generator.remove_weight_norm() 37 | return generator 38 | 39 | def vocoder2(config,hifi_gan_ckpt_path): 40 | device = torch.device('cpu') 41 | global h 42 | generator = Generator(config.model).to(device) 43 | 44 | state_dict_g = torch.load(hifi_gan_ckpt_path, map_location=device) 45 | 46 | generator.load_state_dict(state_dict_g['generator']) 47 | generator.eval() 48 | generator.remove_weight_norm() 49 | return generator 50 | 51 | 52 | def vocoder_inference(vocoder, melspec, max_db, min_db): 53 | with torch.no_grad(): 54 | x = melspec*(max_db-min_db)+min_db 55 | device = torch.device('cpu') 56 | x = torch.FloatTensor(x).to(device) 57 | y_g_hat = vocoder(x) 58 | audio = y_g_hat.squeeze().numpy() 59 | return audio -------------------------------------------------------------------------------- /models/hifigan/pretrained_discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | import torch 17 | from models.hifigan.models import MultiScaleDiscriminator, MultiPeriodDiscriminator 18 | 19 | 20 | 21 | class Discriminator(nn.Module): 22 | def __init__(self, config) -> None: 23 | super().__init__() 24 | 25 | self.msd = MultiScaleDiscriminator() 26 | self.mpd = MultiPeriodDiscriminator() 27 | if config.pretrained_discriminator: 28 | state_dict_do = torch.load(config.pretrained_discriminator,map_location="cpu") 29 | 30 | self.mpd.load_state_dict(state_dict_do['mpd']) 31 | self.msd.load_state_dict(state_dict_do['msd']) 32 | print("pretrained discriminator is loaded") 33 | def forward(self, y, y_hat): 34 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat) 35 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat) 36 | 37 | return y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g -------------------------------------------------------------------------------- /models/prompt_tts_modified/audio_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/espnet/espnet 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | from scipy.signal import get_window 8 | import librosa.util as librosa_util 9 | 10 | 11 | def window_sumsquare(window, 12 | n_frames, 13 | hop_length=200, 14 | win_length=800, 15 | n_fft=800, 16 | dtype=np.float32, 17 | norm=None): 18 | if win_length is None: 19 | win_length = n_fft 20 | 21 | n = n_fft + hop_length * (n_frames - 1) 22 | x = np.zeros(n, dtype=dtype) 23 | 24 | # Compute the squared window at the desired length 25 | win_sq = get_window(window, win_length, fftbins=True) 26 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 27 | win_sq = librosa_util.pad_center(win_sq, n_fft) 28 | 29 | # Fill the envelope 30 | for i in range(n_frames): 31 | sample = i * hop_length 32 | x[sample:min(n, sample+n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 33 | return x 34 | 35 | 36 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 37 | 38 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 39 | angles = angles.astype(np.float32) 40 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 41 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 42 | 43 | for i in range(n_iters): 44 | _, angles = stft_fn.transform(signal) 45 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 46 | return signal 47 | 48 | 49 | 50 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 51 | return torch.log(torch.clamp(x, min=clip_val) * C) 52 | 53 | 54 | def dynamic_range_decompression(x, C=1): 55 | return torch.exp(x) / C 56 | -------------------------------------------------------------------------------- /models/prompt_tts_modified/jets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import numpy as np 18 | from typing import Optional 19 | 20 | from models.prompt_tts_modified.model_open_source import PromptTTS 21 | from models.hifigan.models import Generator as HiFiGANGenerator 22 | 23 | from models.hifigan.get_random_segments import get_random_segments, get_segments 24 | 25 | 26 | class JETSGenerator(nn.Module): 27 | def __init__(self, config) -> None: 28 | 29 | super().__init__() 30 | 31 | self.upsample_factor=int(np.prod(config.model.upsample_rates)) 32 | 33 | self.segment_size = config.segment_size 34 | 35 | self.am = PromptTTS(config) 36 | 37 | self.generator = HiFiGANGenerator(config.model) 38 | 39 | # try: 40 | # model_CKPT = torch.load(config.pretrained_am, map_location="cpu") 41 | # self.am.load_state_dict(model_CKPT['model']) 42 | # state_dict_g = torch.load(config.pretrained_vocoder,map_location="cpu") 43 | # self.generator.load_state_dict(state_dict_g['generator']) 44 | # print("pretrained generator is loaded") 45 | # except: 46 | # print("pretrained generator is not loaded for training") 47 | self.config=config 48 | 49 | 50 | def forward(self, inputs_ling, input_lengths, inputs_speaker, inputs_style_embedding , inputs_content_embedding, mel_targets=None, output_lengths=None, pitch_targets=None, energy_targets=None, alpha=1.0, cut_flag=True): 51 | 52 | outputs = self.am(inputs_ling, input_lengths, inputs_speaker, inputs_style_embedding , inputs_content_embedding, mel_targets , output_lengths , pitch_targets , energy_targets , alpha) 53 | 54 | 55 | if mel_targets is not None and cut_flag: 56 | z_segments, z_start_idxs, segment_size = get_random_segments( 57 | outputs["dec_outputs"].transpose(1,2), 58 | output_lengths, 59 | self.segment_size, 60 | ) 61 | else: 62 | z_segments = outputs["dec_outputs"].transpose(1,2) 63 | z_start_idxs=None 64 | segment_size=self.segment_size 65 | 66 | wav = self.generator(z_segments) 67 | 68 | outputs["wav_predictions"] = wav 69 | outputs["z_start_idxs"]= z_start_idxs 70 | outputs["segment_size"] = segment_size 71 | return outputs 72 | -------------------------------------------------------------------------------- /models/prompt_tts_modified/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/alibaba-damo-academy/KAN-TTS. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | def get_mask_from_lengths(lengths, max_len=None): 11 | batch_size = lengths.shape[0] 12 | if max_len is None: 13 | max_len = torch.max(lengths).item() 14 | 15 | ids = ( 16 | torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) 17 | ) 18 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 19 | 20 | return mask 21 | 22 | class MelReconLoss(torch.nn.Module): 23 | def __init__(self, loss_type="mae"): 24 | super(MelReconLoss, self).__init__() 25 | self.loss_type = loss_type 26 | if loss_type == "mae": 27 | self.criterion = torch.nn.L1Loss(reduction="none") 28 | elif loss_type == "mse": 29 | self.criterion = torch.nn.MSELoss(reduction="none") 30 | else: 31 | raise ValueError("Unknown loss type: {}".format(loss_type)) 32 | 33 | def forward(self, output_lengths, mel_targets, dec_outputs, postnet_outputs=None): 34 | """ 35 | mel_targets: B, C, T 36 | """ 37 | output_masks = get_mask_from_lengths( 38 | output_lengths, max_len=mel_targets.size(1) 39 | ) 40 | output_masks = ~output_masks 41 | valid_outputs = output_masks.sum() 42 | 43 | mel_loss_ = torch.sum( 44 | self.criterion(mel_targets, dec_outputs) * output_masks.unsqueeze(-1) 45 | ) / (valid_outputs * mel_targets.size(-1)) 46 | 47 | if postnet_outputs is not None: 48 | mel_loss = torch.sum( 49 | self.criterion(mel_targets, postnet_outputs) 50 | * output_masks.unsqueeze(-1) 51 | ) / (valid_outputs * mel_targets.size(-1)) 52 | else: 53 | mel_loss = 0.0 54 | 55 | return mel_loss_, mel_loss 56 | 57 | 58 | 59 | class ForwardSumLoss(torch.nn.Module): 60 | 61 | def __init__(self): 62 | super().__init__() 63 | 64 | def forward( 65 | self, 66 | log_p_attn: torch.Tensor, 67 | ilens: torch.Tensor, 68 | olens: torch.Tensor, 69 | blank_prob: float = np.e**-1, 70 | ) -> torch.Tensor: 71 | B = log_p_attn.size(0) 72 | 73 | # a row must be added to the attention matrix to account for 74 | # blank token of CTC loss 75 | # (B,T_feats,T_text+1) 76 | log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob)) 77 | 78 | loss = 0 79 | for bidx in range(B): 80 | # construct target sequnece. 81 | # Every text token is mapped to a unique sequnece number. 82 | target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0) 83 | cur_log_p_attn_pd = log_p_attn_pd[ 84 | bidx, : olens[bidx], : ilens[bidx] + 1 85 | ].unsqueeze( 86 | 1 87 | ) # (T_feats,1,T_text+1) 88 | cur_log_p_attn_pd = F.log_softmax(cur_log_p_attn_pd, dim=-1) 89 | loss += F.ctc_loss( 90 | log_probs=cur_log_p_attn_pd, 91 | targets=target_seq, 92 | input_lengths=olens[bidx : bidx + 1], 93 | target_lengths=ilens[bidx : bidx + 1], 94 | zero_infinity=True, 95 | ) 96 | loss = loss / B 97 | return loss 98 | 99 | class ProsodyReconLoss(torch.nn.Module): 100 | def __init__(self, loss_type="mae"): 101 | super(ProsodyReconLoss, self).__init__() 102 | self.loss_type = loss_type 103 | if loss_type == "mae": 104 | self.criterion = torch.nn.L1Loss(reduction="none") 105 | elif loss_type == "mse": 106 | self.criterion = torch.nn.MSELoss(reduction="none") 107 | else: 108 | raise ValueError("Unknown loss type: {}".format(loss_type)) 109 | 110 | def forward( 111 | self, 112 | input_lengths, 113 | duration_targets, 114 | pitch_targets, 115 | energy_targets, 116 | log_duration_predictions, 117 | pitch_predictions, 118 | energy_predictions, 119 | ): 120 | input_masks = get_mask_from_lengths( 121 | input_lengths, max_len=duration_targets.size(1) 122 | ) 123 | input_masks = ~input_masks 124 | valid_inputs = input_masks.sum() 125 | 126 | dur_loss = ( 127 | torch.sum( 128 | self.criterion( 129 | torch.log(duration_targets.float() + 1), log_duration_predictions 130 | ) 131 | * input_masks 132 | ) 133 | / valid_inputs 134 | ) 135 | pitch_loss = ( 136 | torch.sum(self.criterion(pitch_targets, pitch_predictions) * input_masks) 137 | / valid_inputs 138 | ) 139 | energy_loss = ( 140 | torch.sum(self.criterion(energy_targets, energy_predictions) * input_masks) 141 | / valid_inputs 142 | ) 143 | 144 | return dur_loss, pitch_loss, energy_loss 145 | 146 | 147 | class TTSLoss(torch.nn.Module): 148 | def __init__(self, loss_type="mae") -> None: 149 | super().__init__() 150 | 151 | self.Mel_Loss = MelReconLoss() 152 | self.Prosodu_Loss = ProsodyReconLoss(loss_type) 153 | self.ForwardSum_Loss = ForwardSumLoss() 154 | 155 | def forward(self, outputs): 156 | 157 | dec_outputs = outputs["dec_outputs"] 158 | postnet_outputs = outputs["postnet_outputs"] 159 | log_duration_predictions = outputs["log_duration_predictions"] 160 | pitch_predictions = outputs["pitch_predictions"] 161 | energy_predictions = outputs["energy_predictions"] 162 | duration_targets = outputs["duration_targets"] 163 | pitch_targets = outputs["pitch_targets"] 164 | energy_targets = outputs["energy_targets"] 165 | output_lengths = outputs["output_lengths"] 166 | input_lengths = outputs["input_lengths"] 167 | mel_targets = outputs["mel_targets"].transpose(1,2) 168 | log_p_attn = outputs["log_p_attn"] 169 | bin_loss = outputs["bin_loss"] 170 | 171 | dec_mel_loss, postnet_mel_loss = self.Mel_Loss(output_lengths, mel_targets, dec_outputs, postnet_outputs) 172 | dur_loss, pitch_loss, energy_loss = self.Prosodu_Loss(input_lengths, duration_targets, pitch_targets, energy_targets, log_duration_predictions, pitch_predictions, energy_predictions) 173 | forwardsum_loss = self.ForwardSum_Loss(log_p_attn, input_lengths, output_lengths) 174 | 175 | res = { 176 | "dec_mel_loss": dec_mel_loss, 177 | "postnet_mel_loss": postnet_mel_loss, 178 | "dur_loss": dur_loss, 179 | "pitch_loss": pitch_loss, 180 | "energy_loss": energy_loss, 181 | "forwardsum_loss": forwardsum_loss, 182 | "bin_loss": bin_loss, 183 | } 184 | 185 | return res -------------------------------------------------------------------------------- /models/prompt_tts_modified/modules/alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/espnet/espnet. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from numba import jit 10 | from scipy.stats import betabinom 11 | 12 | 13 | class AlignmentModule(nn.Module): 14 | 15 | def __init__(self, adim, odim, cache_prior=True): 16 | super().__init__() 17 | self.cache_prior = cache_prior 18 | self._cache = {} 19 | 20 | self.t_conv1 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) 21 | self.t_conv2 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) 22 | 23 | self.f_conv1 = nn.Conv1d(odim, adim, kernel_size=3, padding=1) 24 | self.f_conv2 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) 25 | self.f_conv3 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) 26 | 27 | def forward(self, text, feats, text_lengths, feats_lengths, x_masks=None): 28 | 29 | text = text.transpose(1, 2) 30 | text = F.relu(self.t_conv1(text)) 31 | text = self.t_conv2(text) 32 | text = text.transpose(1, 2) 33 | 34 | feats = feats.transpose(1, 2) 35 | feats = F.relu(self.f_conv1(feats)) 36 | feats = F.relu(self.f_conv2(feats)) 37 | feats = self.f_conv3(feats) 38 | feats = feats.transpose(1, 2) 39 | 40 | dist = feats.unsqueeze(2) - text.unsqueeze(1) 41 | dist = torch.norm(dist, p=2, dim=3) 42 | score = -dist 43 | 44 | if x_masks is not None: 45 | x_masks = x_masks.unsqueeze(-2) 46 | score = score.masked_fill(x_masks, -np.inf) 47 | 48 | log_p_attn = F.log_softmax(score, dim=-1) 49 | # add beta-binomial prior 50 | bb_prior = self._generate_prior( 51 | text_lengths, 52 | feats_lengths, 53 | ).to(dtype=log_p_attn.dtype, device=log_p_attn.device) 54 | 55 | log_p_attn = log_p_attn + bb_prior 56 | 57 | return log_p_attn 58 | 59 | def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor: 60 | 61 | B = len(text_lengths) 62 | T_text = text_lengths.max() 63 | T_feats = feats_lengths.max() 64 | 65 | bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf) 66 | for bidx in range(B): 67 | T = feats_lengths[bidx].item() 68 | N = text_lengths[bidx].item() 69 | 70 | key = str(T) + "," + str(N) 71 | if self.cache_prior and key in self._cache: 72 | prob = self._cache[key] 73 | else: 74 | alpha = w * np.arange(1, T + 1, dtype=float) # (T,) 75 | beta = w * np.array([T - t + 1 for t in alpha]) 76 | k = np.arange(N) 77 | batched_k = k[..., None] # (N,1) 78 | prob = betabinom.logpmf(batched_k, N, alpha, beta) # (N,T) 79 | 80 | # store cache 81 | if self.cache_prior and key not in self._cache: 82 | self._cache[key] = prob 83 | 84 | prob = torch.from_numpy(prob).transpose(0, 1) # -> (T,N) 85 | bb_prior[bidx, :T, :N] = prob 86 | 87 | return bb_prior 88 | 89 | 90 | 91 | 92 | @jit(nopython=True) 93 | def _monotonic_alignment_search(log_p_attn): 94 | 95 | T_mel = log_p_attn.shape[0] 96 | T_inp = log_p_attn.shape[1] 97 | Q = np.full((T_inp, T_mel), fill_value=-np.inf) 98 | 99 | log_prob = log_p_attn.transpose(1, 0) # -> (T_inp,T_mel) 100 | # 1. Q <- init first row for all j 101 | for j in range(T_mel): 102 | Q[0, j] = log_prob[0, : j + 1].sum() 103 | 104 | # 2. 105 | for j in range(1, T_mel): 106 | for i in range(1, min(j + 1, T_inp)): 107 | Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j] 108 | 109 | # 3. 110 | A = np.full((T_mel,), fill_value=T_inp - 1) 111 | for j in range(T_mel - 2, -1, -1): # T_mel-2, ..., 0 112 | # 'i' in {A[j+1]-1, A[j+1]} 113 | i_a = A[j + 1] - 1 114 | i_b = A[j + 1] 115 | if i_b == 0: 116 | argmax_i = 0 117 | elif Q[i_a, j] >= Q[i_b, j]: 118 | argmax_i = i_a 119 | else: 120 | argmax_i = i_b 121 | A[j] = argmax_i 122 | return A 123 | 124 | 125 | def viterbi_decode(log_p_attn, text_lengths, feats_lengths): 126 | 127 | B = log_p_attn.size(0) 128 | T_text = log_p_attn.size(2) 129 | device = log_p_attn.device 130 | 131 | bin_loss = 0 132 | ds = torch.zeros((B, T_text), device=device) 133 | for b in range(B): 134 | cur_log_p_attn = log_p_attn[b, : feats_lengths[b], : text_lengths[b]] 135 | viterbi = _monotonic_alignment_search(cur_log_p_attn.detach().cpu().numpy()) 136 | _ds = np.bincount(viterbi) 137 | ds[b, : len(_ds)] = torch.from_numpy(_ds).to(device) 138 | 139 | t_idx = torch.arange(feats_lengths[b]) 140 | bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean() 141 | bin_loss = bin_loss / B 142 | return ds, bin_loss 143 | 144 | 145 | @jit(nopython=True) 146 | def _average_by_duration(ds, xs, text_lengths, feats_lengths): 147 | B = ds.shape[0] 148 | xs_avg = np.zeros_like(ds) 149 | ds = ds.astype(np.int32) 150 | for b in range(B): 151 | t_text = text_lengths[b] 152 | t_feats = feats_lengths[b] 153 | d = ds[b, :t_text] 154 | d_cumsum = d.cumsum() 155 | d_cumsum = [0] + list(d_cumsum) 156 | x = xs[b, :t_feats] 157 | for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])): 158 | if len(x[start:end]) != 0: 159 | xs_avg[b, n] = x[start:end].mean() 160 | else: 161 | xs_avg[b, n] = 0 162 | return xs_avg 163 | 164 | 165 | def average_by_duration(ds, xs, text_lengths, feats_lengths): 166 | 167 | device = ds.device 168 | args = [ds, xs, text_lengths, feats_lengths] 169 | args = [arg.detach().cpu().numpy() for arg in args] 170 | xs_avg = _average_by_duration(*args) 171 | xs_avg = torch.from_numpy(xs_avg).to(device) 172 | return xs_avg 173 | 174 | 175 | class GaussianUpsampling(torch.nn.Module): 176 | 177 | def __init__(self, delta=0.1): 178 | super().__init__() 179 | self.delta = delta 180 | def forward(self, hs, ds, h_masks=None, d_masks=None, alpha=1.0): 181 | 182 | 183 | ds = ds * alpha 184 | 185 | B = ds.size(0) 186 | device = ds.device 187 | if ds.sum() == 0: 188 | # NOTE(kan-bayashi): This case must not be happened in teacher forcing. 189 | # It will be happened in inference with a bad duration predictor. 190 | # So we do not need to care the padded sequence case here. 191 | ds[ds.sum(dim=1).eq(0)] = 1 192 | 193 | if h_masks is None: 194 | mel_lenghs = torch.sum(ds, dim=-1).int() # lengths = [5, 3, 2] 195 | T_feats = mel_lenghs.max().item() # T_feats = 5 196 | else: 197 | T_feats = h_masks.size(-1) 198 | t = torch.arange(0, T_feats).unsqueeze(0).repeat(B,1).to(device).float() 199 | if h_masks is not None: 200 | t = t * h_masks.float() 201 | 202 | c = ds.cumsum(dim=-1) - ds/2 203 | 204 | energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1)) ** 2 205 | 206 | if d_masks is not None: 207 | energy = energy.masked_fill(~(d_masks.unsqueeze(1).repeat(1,T_feats,1)), -float("inf")) 208 | 209 | p_attn = torch.softmax(energy, dim=2) # (B, T_feats, T_text) 210 | hs = torch.matmul(p_attn, hs) 211 | return hs -------------------------------------------------------------------------------- /models/prompt_tts_modified/modules/initialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/espnet/espnet. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | 11 | def initialize(model: torch.nn.Module, init: str): 12 | for p in model.parameters(): 13 | if p.dim() > 1: 14 | if init == "xavier_uniform": 15 | torch.nn.init.xavier_uniform_(p.data) 16 | elif init == "xavier_normal": 17 | torch.nn.init.xavier_normal_(p.data) 18 | elif init == "kaiming_uniform": 19 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") 20 | elif init == "kaiming_normal": 21 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") 22 | else: 23 | raise ValueError("Unknown initialization: " + init) 24 | # bias init 25 | for p in model.parameters(): 26 | if p.dim() == 1: 27 | p.data.zero_() 28 | 29 | # reset some modules with default init 30 | for m in model.modules(): 31 | if isinstance( 32 | m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm) 33 | ): 34 | m.reset_parameters() 35 | if hasattr(m, "espnet_initialization_fn"): 36 | m.espnet_initialization_fn() 37 | 38 | # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization 39 | if getattr(model, "encoder", None) and getattr( 40 | model.encoder, "reload_pretrained_parameters", None 41 | ): 42 | model.encoder.reload_pretrained_parameters() 43 | if getattr(model, "frontend", None) and getattr( 44 | model.frontend, "reload_pretrained_parameters", None 45 | ): 46 | model.frontend.reload_pretrained_parameters() 47 | if getattr(model, "postencoder", None) and getattr( 48 | model.postencoder, "reload_pretrained_parameters", None 49 | ): 50 | model.postencoder.reload_pretrained_parameters() -------------------------------------------------------------------------------- /models/prompt_tts_modified/modules/variance.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/espnet/espnet. 3 | """ 4 | 5 | import torch 6 | 7 | from models.prompt_tts_modified.modules.encoder import LayerNorm 8 | 9 | class DurationPredictor(torch.nn.Module): 10 | 11 | def __init__( 12 | self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0 13 | ): 14 | 15 | super(DurationPredictor, self).__init__() 16 | self.offset = offset 17 | self.conv = torch.nn.ModuleList() 18 | for idx in range(n_layers): 19 | in_chans = idim if idx == 0 else n_chans 20 | self.conv += [ 21 | torch.nn.Sequential( 22 | torch.nn.Conv1d( 23 | in_chans, 24 | n_chans, 25 | kernel_size, 26 | stride=1, 27 | padding=(kernel_size - 1) // 2, 28 | ), 29 | torch.nn.ReLU(), 30 | LayerNorm(n_chans, dim=1), 31 | torch.nn.Dropout(dropout_rate), 32 | ) 33 | ] 34 | self.linear = torch.nn.Linear(n_chans, 1) 35 | 36 | def _forward(self, xs, x_masks=None, is_inference=False): 37 | 38 | if x_masks is not None: 39 | xs = xs.masked_fill(x_masks, 0.0) 40 | 41 | xs = xs.transpose(1, -1) # (B, idim, Tmax) 42 | for f in self.conv: 43 | xs = f(xs) # (B, C, Tmax) 44 | 45 | # NOTE: calculate in log domain 46 | xs = self.linear(xs.transpose(1, -1)) # (B, Tmax) 47 | if is_inference: 48 | # NOTE: calculate in linear domain 49 | xs = torch.clamp( 50 | torch.round(xs.exp() - self.offset), min=0 51 | ).long() # avoid negative value 52 | 53 | if x_masks is not None: 54 | xs = xs.masked_fill(x_masks, 0.0) 55 | 56 | return xs.squeeze(-1) 57 | 58 | def forward(self, xs, x_masks=None): 59 | 60 | return self._forward(xs, x_masks, False) 61 | 62 | def inference(self, xs, x_masks=None): 63 | 64 | return self._forward(xs, x_masks, True) 65 | 66 | 67 | 68 | class VariancePredictor(torch.nn.Module): 69 | 70 | 71 | def __init__( 72 | self, 73 | idim: int, 74 | n_layers: int = 2, 75 | n_chans: int = 384, 76 | kernel_size: int = 3, 77 | bias: bool = True, 78 | dropout_rate: float = 0.5, 79 | ): 80 | super().__init__() 81 | self.conv = torch.nn.ModuleList() 82 | for idx in range(n_layers): 83 | in_chans = idim if idx == 0 else n_chans 84 | self.conv += [ 85 | torch.nn.Sequential( 86 | torch.nn.Conv1d( 87 | in_chans, 88 | n_chans, 89 | kernel_size, 90 | stride=1, 91 | padding=(kernel_size - 1) // 2, 92 | bias=bias, 93 | ), 94 | torch.nn.ReLU(), 95 | LayerNorm(n_chans, dim=1), 96 | torch.nn.Dropout(dropout_rate), 97 | ) 98 | ] 99 | self.linear = torch.nn.Linear(n_chans, 1) 100 | 101 | def forward(self, xs: torch.Tensor, x_masks: torch.Tensor = None) -> torch.Tensor: 102 | """Calculate forward propagation. 103 | 104 | Args: 105 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 106 | x_masks (ByteTensor): Batch of masks indicating padded part (B, Tmax). 107 | 108 | Returns: 109 | Tensor: Batch of predicted sequences (B, Tmax, 1). 110 | 111 | """ 112 | if x_masks is not None: 113 | xs = xs.masked_fill(x_masks, 0.0) 114 | 115 | xs = xs.transpose(1, -1) # (B, idim, Tmax) 116 | for f in self.conv: 117 | xs = f(xs) # (B, C, Tmax) 118 | 119 | xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1) 120 | 121 | if x_masks is not None: 122 | xs = xs.masked_fill(x_masks, 0.0) 123 | 124 | return xs.squeeze(-1) -------------------------------------------------------------------------------- /models/prompt_tts_modified/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/alibaba-damo-academy/KAN-TTS. 3 | """ 4 | 5 | from torch.optim.lr_scheduler import * 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | class FindLR(_LRScheduler): 9 | 10 | 11 | def __init__(self, optimizer, max_steps, max_lr=10): 12 | self.max_steps = max_steps 13 | self.max_lr = max_lr 14 | super().__init__(optimizer) 15 | 16 | def get_lr(self): 17 | return [ 18 | base_lr 19 | * ((self.max_lr / base_lr) ** (self.last_epoch / (self.max_steps - 1))) 20 | for base_lr in self.base_lrs 21 | ] 22 | 23 | 24 | class NoamLR(_LRScheduler): 25 | def __init__(self, optimizer, warmup_steps): 26 | self.warmup_steps = warmup_steps 27 | super().__init__(optimizer) 28 | 29 | def get_lr(self): 30 | last_epoch = max(1, self.last_epoch) 31 | scale = self.warmup_steps ** 0.5 * min( 32 | last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5) 33 | ) 34 | return [base_lr * scale for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /models/prompt_tts_modified/simbert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from transformers import AutoModel 19 | import numpy as np 20 | 21 | class ClassificationHead(nn.Module): 22 | def __init__(self, hidden_size, num_labels, dropout_rate=0.1) -> None: 23 | super().__init__() 24 | 25 | 26 | self.dropout = nn.Dropout(dropout_rate) 27 | self.classifier = nn.Linear(hidden_size, num_labels) 28 | 29 | def forward(self, pooled_output): 30 | 31 | return self.classifier(self.dropout(pooled_output)) 32 | 33 | class StyleEncoder(nn.Module): 34 | def __init__(self, config) -> None: 35 | super().__init__() 36 | 37 | self.bert = AutoModel.from_pretrained(config.bert_path) 38 | 39 | self.pitch_clf = ClassificationHead(config.bert_hidden_size, config.pitch_n_labels) 40 | self.speed_clf = ClassificationHead(config.bert_hidden_size, config.speed_n_labels) 41 | self.energy_clf = ClassificationHead(config.bert_hidden_size, config.energy_n_labels) 42 | self.emotion_clf = ClassificationHead(config.bert_hidden_size, config.emotion_n_labels) 43 | self.style_embed_proj = nn.Linear(config.bert_hidden_size, config.style_dim) 44 | 45 | 46 | 47 | 48 | def forward(self, input_ids, token_type_ids, attention_mask): 49 | outputs = self.bert( 50 | input_ids, 51 | attention_mask=attention_mask, 52 | token_type_ids=token_type_ids, 53 | ) # return a dict having ['last_hidden_state', 'pooler_output'] 54 | 55 | pooled_output = outputs["pooler_output"] 56 | 57 | pitch_outputs = self.pitch_clf(pooled_output) 58 | speed_outputs = self.speed_clf(pooled_output) 59 | energy_outputs = self.energy_clf(pooled_output) 60 | emotion_outputs = self.emotion_clf(pooled_output) 61 | pred_style_embed = self.style_embed_proj(pooled_output) 62 | 63 | res = { 64 | "pooled_output":pooled_output, 65 | "pitch_outputs":pitch_outputs, 66 | "speed_outputs":speed_outputs, 67 | "energy_outputs":energy_outputs, 68 | "emotion_outputs":emotion_outputs, 69 | # "pred_style_embed":pred_style_embed, 70 | } 71 | 72 | return res 73 | 74 | 75 | 76 | class StylePretrainLoss(nn.Module): 77 | def __init__(self) -> None: 78 | super().__init__() 79 | 80 | self.loss = nn.CrossEntropyLoss() 81 | 82 | def forward(self, inputs, outputs): 83 | 84 | pitch_loss = self.loss(outputs["pitch_outputs"], inputs["pitch"]) 85 | energy_loss = self.loss(outputs["energy_outputs"], inputs["energy"]) 86 | speed_loss = self.loss(outputs["speed_outputs"], inputs["speed"]) 87 | emotion_loss = self.loss(outputs["emotion_outputs"], inputs["emotion"]) 88 | 89 | return { 90 | "pitch_loss" : pitch_loss, 91 | "energy_loss": energy_loss, 92 | "speed_loss" : speed_loss, 93 | "emotion_loss" : emotion_loss, 94 | } 95 | 96 | 97 | class StylePretrainLoss2(StylePretrainLoss): 98 | def __init__(self) -> None: 99 | super().__init__() 100 | 101 | self.loss = nn.CrossEntropyLoss() 102 | 103 | def forward(self, inputs, outputs): 104 | res = super().forward(inputs, outputs) 105 | speaker_loss = self.loss(outputs["speaker_outputs"], inputs["speaker"]) 106 | res["speaker_loss"] = speaker_loss 107 | return res 108 | 109 | def flat_accuracy(preds, labels): 110 | """ 111 | Function to calculate the accuracy of our predictions vs labels 112 | """ 113 | pred_flat = np.argmax(preds, axis=1).flatten() 114 | labels_flat = labels.flatten() 115 | return np.sum(pred_flat == labels_flat) / len(labels_flat) 116 | -------------------------------------------------------------------------------- /models/prompt_tts_modified/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/pseeth/pytorch-stft. 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from scipy.signal import get_window 10 | from librosa.util import pad_center, tiny 11 | from models.prompt_tts_modified.audio_processing import window_sumsquare 12 | 13 | 14 | class STFT(torch.nn.Module): 15 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 16 | window='hann'): 17 | super(STFT, self).__init__() 18 | self.filter_length = filter_length 19 | self.hop_length = hop_length 20 | self.win_length = win_length 21 | self.window = window 22 | self.forward_transform = None 23 | scale = self.filter_length / self.hop_length 24 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 25 | 26 | cutoff = int((self.filter_length / 2 + 1)) 27 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 28 | np.imag(fourier_basis[:cutoff, :])]) 29 | 30 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 31 | inverse_basis = torch.FloatTensor( 32 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 33 | 34 | if window is not None: 35 | assert(filter_length >= win_length) 36 | # get window and zero center pad it to filter_length 37 | fft_window = get_window(window, win_length, fftbins=True) 38 | fft_window = pad_center(data=fft_window, size=filter_length) 39 | fft_window = torch.from_numpy(fft_window).float() 40 | 41 | # window the bases 42 | forward_basis *= fft_window 43 | inverse_basis *= fft_window 44 | 45 | self.register_buffer('forward_basis', forward_basis.float()) 46 | self.register_buffer('inverse_basis', inverse_basis.float()) 47 | 48 | def transform(self, input_data): 49 | num_batches = input_data.size(0) 50 | num_samples = input_data.size(1) 51 | 52 | self.num_samples = num_samples 53 | 54 | # similar to librosa, reflect-pad the input 55 | input_data = input_data.view(num_batches, 1, num_samples) 56 | input_data = F.pad( 57 | input_data.unsqueeze(1), 58 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 59 | mode='reflect') 60 | input_data = input_data.squeeze(1) 61 | 62 | forward_transform = F.conv1d( 63 | input_data, 64 | Variable(self.forward_basis, requires_grad=False), 65 | stride=self.hop_length, 66 | padding=0) 67 | 68 | cutoff = int((self.filter_length / 2) + 1) 69 | real_part = forward_transform[:, :cutoff, :] 70 | imag_part = forward_transform[:, cutoff:, :] 71 | 72 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 73 | phase = torch.autograd.Variable( 74 | torch.atan2(imag_part.data, real_part.data)) 75 | 76 | return magnitude, phase 77 | 78 | def inverse(self, magnitude, phase): 79 | recombine_magnitude_phase = torch.cat( 80 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 81 | 82 | inverse_transform = F.conv_transpose1d( 83 | recombine_magnitude_phase, 84 | Variable(self.inverse_basis, requires_grad=False), 85 | stride=self.hop_length, 86 | padding=0) 87 | 88 | if self.window is not None: 89 | window_sum = window_sumsquare( 90 | self.window, magnitude.size(-1), hop_length=self.hop_length, 91 | win_length=self.win_length, n_fft=self.filter_length, 92 | dtype=np.float32) 93 | # remove modulation effects 94 | approx_nonzero_indices = torch.from_numpy( 95 | np.where(window_sum > tiny(window_sum))[0]) 96 | window_sum = torch.autograd.Variable( 97 | torch.from_numpy(window_sum), requires_grad=False) 98 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 99 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 100 | 101 | # scale by hop ratio 102 | inverse_transform *= float(self.filter_length) / self.hop_length 103 | 104 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 105 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 106 | 107 | return inverse_transform 108 | 109 | def forward(self, input_data): 110 | self.magnitude, self.phase = self.transform(input_data) 111 | reconstruction = self.inverse(self.magnitude, self.phase) 112 | return reconstruction 113 | -------------------------------------------------------------------------------- /models/prompt_tts_modified/style_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/yl4579/StyleTTS. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.utils import spectral_norm 8 | 9 | import math 10 | 11 | class LearnedDownSample(nn.Module): 12 | def __init__(self, layer_type, dim_in): 13 | super().__init__() 14 | self.layer_type = layer_type 15 | 16 | if self.layer_type == 'none': 17 | self.conv = nn.Identity() 18 | elif self.layer_type == 'timepreserve': 19 | self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0))) 20 | elif self.layer_type == 'half': 21 | self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1)) 22 | else: 23 | raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) 24 | 25 | def forward(self, x): 26 | return self.conv(x) 27 | 28 | class LearnedUpSample(nn.Module): 29 | def __init__(self, layer_type, dim_in): 30 | super().__init__() 31 | self.layer_type = layer_type 32 | 33 | if self.layer_type == 'none': 34 | self.conv = nn.Identity() 35 | elif self.layer_type == 'timepreserve': 36 | self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0)) 37 | elif self.layer_type == 'half': 38 | self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1) 39 | else: 40 | raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) 41 | 42 | 43 | def forward(self, x): 44 | return self.conv(x) 45 | 46 | class DownSample(nn.Module): 47 | def __init__(self, layer_type): 48 | super().__init__() 49 | self.layer_type = layer_type 50 | 51 | def forward(self, x): 52 | if self.layer_type == 'none': 53 | return x 54 | elif self.layer_type == 'timepreserve': 55 | return F.avg_pool2d(x, (2, 1)) 56 | elif self.layer_type == 'half': 57 | if x.shape[-1] % 2 != 0: 58 | x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1) 59 | return F.avg_pool2d(x, 2) 60 | else: 61 | raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) 62 | 63 | 64 | class UpSample(nn.Module): 65 | def __init__(self, layer_type): 66 | super().__init__() 67 | self.layer_type = layer_type 68 | 69 | def forward(self, x): 70 | if self.layer_type == 'none': 71 | return x 72 | elif self.layer_type == 'timepreserve': 73 | return F.interpolate(x, scale_factor=(2, 1), mode='nearest') 74 | elif self.layer_type == 'half': 75 | return F.interpolate(x, scale_factor=2, mode='nearest') 76 | else: 77 | raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) 78 | 79 | 80 | class ResBlk(nn.Module): 81 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), 82 | normalize=False, downsample='none'): 83 | super().__init__() 84 | self.actv = actv 85 | self.normalize = normalize 86 | self.downsample = DownSample(downsample) 87 | self.downsample_res = LearnedDownSample(downsample, dim_in) 88 | self.learned_sc = dim_in != dim_out 89 | self._build_weights(dim_in, dim_out) 90 | 91 | def _build_weights(self, dim_in, dim_out): 92 | self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1)) 93 | self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1)) 94 | if self.normalize: 95 | self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) 96 | self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) 97 | if self.learned_sc: 98 | self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)) 99 | 100 | def _shortcut(self, x): 101 | if self.learned_sc: 102 | x = self.conv1x1(x) 103 | if self.downsample: 104 | x = self.downsample(x) 105 | return x 106 | 107 | def _residual(self, x): 108 | if self.normalize: 109 | x = self.norm1(x) 110 | x = self.actv(x) 111 | x = self.conv1(x) 112 | x = self.downsample_res(x) 113 | if self.normalize: 114 | x = self.norm2(x) 115 | x = self.actv(x) 116 | x = self.conv2(x) 117 | return x 118 | 119 | def forward(self, x): 120 | x = self._shortcut(x) + self._residual(x) 121 | return x / math.sqrt(2) # unit variance 122 | 123 | class StyleEncoder(nn.Module): 124 | def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384): 125 | super().__init__() 126 | blocks = [] 127 | blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))] 128 | 129 | repeat_num = 4 130 | for _ in range(repeat_num): 131 | dim_out = min(dim_in*2, max_conv_dim) 132 | blocks += [ResBlk(dim_in, dim_out, downsample='half')] 133 | dim_in = dim_out 134 | 135 | blocks += [nn.LeakyReLU(0.2)] 136 | blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))] 137 | blocks += [nn.AdaptiveAvgPool2d(1)] 138 | blocks += [nn.LeakyReLU(0.2)] 139 | self.shared = nn.Sequential(*blocks) 140 | 141 | self.unshared = nn.Linear(dim_out, style_dim) 142 | 143 | def forward(self, x): 144 | h = self.shared(x) 145 | h = h.view(h.size(0), -1) 146 | s = self.unshared(h) 147 | 148 | return s 149 | 150 | 151 | class CosineSimilarityLoss(nn.Module): 152 | def __init__(self) -> None: 153 | super().__init__() 154 | 155 | self.loss_fn = torch.nn.CosineEmbeddingLoss() 156 | 157 | def forward(self, output1, output2): 158 | B = output1.size(0) 159 | target = torch.ones(B, device=output1.device, requires_grad=False) 160 | loss = self.loss_fn(output1, output2, target) 161 | return loss 162 | 163 | -------------------------------------------------------------------------------- /models/prompt_tts_modified/tacotron_stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/NVIDIA/tacotron2 3 | """ 4 | 5 | import torch 6 | from librosa.filters import mel as librosa_mel_fn 7 | from models.prompt_tts_modified.audio_processing import dynamic_range_compression 8 | from models.prompt_tts_modified.audio_processing import dynamic_range_decompression 9 | from models.prompt_tts_modified.stft import STFT 10 | 11 | 12 | class LinearNorm(torch.nn.Module): 13 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 14 | super(LinearNorm, self).__init__() 15 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 16 | 17 | torch.nn.init.xavier_uniform_( 18 | self.linear_layer.weight, 19 | gain=torch.nn.init.calculate_gain(w_init_gain)) 20 | 21 | def forward(self, x): 22 | return self.linear_layer(x) 23 | 24 | 25 | class ConvNorm(torch.nn.Module): 26 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 27 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 28 | super(ConvNorm, self).__init__() 29 | if padding is None: 30 | assert(kernel_size % 2 == 1) 31 | padding = int(dilation * (kernel_size - 1) / 2) 32 | 33 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 34 | kernel_size=kernel_size, stride=stride, 35 | padding=padding, dilation=dilation, 36 | bias=bias) 37 | 38 | torch.nn.init.xavier_uniform_( 39 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 40 | 41 | def forward(self, signal): 42 | conv_signal = self.conv(signal) 43 | return conv_signal 44 | 45 | 46 | class TacotronSTFT(torch.nn.Module): 47 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 48 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 49 | mel_fmax=8000.0): 50 | super(TacotronSTFT, self).__init__() 51 | self.n_mel_channels = n_mel_channels 52 | self.sampling_rate = sampling_rate 53 | self.stft_fn = STFT(filter_length, hop_length, win_length) 54 | mel_basis = librosa_mel_fn( 55 | sr=sampling_rate, 56 | n_fft=filter_length, 57 | n_mels=n_mel_channels, 58 | fmin=mel_fmin, 59 | fmax=mel_fmax) 60 | mel_basis = torch.from_numpy(mel_basis).float() 61 | self.register_buffer('mel_basis', mel_basis) 62 | 63 | def spectral_normalize(self, magnitudes): 64 | output = dynamic_range_compression(magnitudes) 65 | return output 66 | 67 | def spectral_de_normalize(self, magnitudes): 68 | output = dynamic_range_decompression(magnitudes) 69 | return output 70 | 71 | def mel_spectrogram(self, y): 72 | 73 | assert(torch.min(y.data) >= -1) 74 | assert(torch.max(y.data) <= 1) 75 | 76 | magnitudes, phases = self.stft_fn.transform(y) 77 | magnitudes = magnitudes.data 78 | mel_output = torch.matmul(self.mel_basis, magnitudes) 79 | mel_output = self.spectral_normalize(mel_output) 80 | return mel_output 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /openaiapi.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import io 4 | import torch 5 | import glob 6 | 7 | from fastapi import FastAPI, Response 8 | from pydantic import BaseModel 9 | 10 | from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p 11 | from models.prompt_tts_modified.jets import JETSGenerator 12 | from models.prompt_tts_modified.simbert import StyleEncoder 13 | from transformers import AutoTokenizer 14 | import numpy as np 15 | import soundfile as sf 16 | import pyrubberband as pyrb 17 | from pydub import AudioSegment 18 | from yacs import config as CONFIG 19 | from config.joint.config import Config 20 | 21 | LOGGER = logging.getLogger(__name__) 22 | 23 | DEFAULTS = { 24 | } 25 | 26 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | print(DEVICE) 28 | config = Config() 29 | MAX_WAV_VALUE = 32768.0 30 | 31 | 32 | def get_env(key): 33 | return os.environ.get(key, DEFAULTS.get(key)) 34 | 35 | 36 | def get_int_env(key): 37 | return int(get_env(key)) 38 | 39 | 40 | def get_float_env(key): 41 | return float(get_env(key)) 42 | 43 | 44 | def get_bool_env(key): 45 | return get_env(key).lower() == 'true' 46 | 47 | 48 | def scan_checkpoint(cp_dir, prefix, c=8): 49 | pattern = os.path.join(cp_dir, prefix + '?'*c) 50 | cp_list = glob.glob(pattern) 51 | if len(cp_list) == 0: 52 | return None 53 | return sorted(cp_list)[-1] 54 | 55 | 56 | def get_models(): 57 | 58 | am_checkpoint_path = scan_checkpoint( 59 | f'{config.output_directory}/prompt_tts_open_source_joint/ckpt', 'g_') 60 | 61 | # f'{config.output_directory}/style_encoder/ckpt/checkpoint_163431' 62 | style_encoder_checkpoint_path = scan_checkpoint( 63 | f'{config.output_directory}/style_encoder/ckpt', 'checkpoint_', 6) 64 | 65 | with open(config.model_config_path, 'r') as fin: 66 | conf = CONFIG.load_cfg(fin) 67 | 68 | conf.n_vocab = config.n_symbols 69 | conf.n_speaker = config.speaker_n_labels 70 | 71 | style_encoder = StyleEncoder(config) 72 | model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu") 73 | model_ckpt = {} 74 | for key, value in model_CKPT['model'].items(): 75 | new_key = key[7:] 76 | model_ckpt[new_key] = value 77 | style_encoder.load_state_dict(model_ckpt, strict=False) 78 | generator = JETSGenerator(conf).to(DEVICE) 79 | 80 | model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE) 81 | generator.load_state_dict(model_CKPT['generator']) 82 | generator.eval() 83 | 84 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path) 85 | 86 | with open(config.token_list_path, 'r') as f: 87 | token2id = {t.strip(): idx for idx, t, in enumerate(f.readlines())} 88 | 89 | with open(config.speaker2id_path, encoding='utf-8') as f: 90 | speaker2id = {t.strip(): idx for idx, t in enumerate(f.readlines())} 91 | 92 | return (style_encoder, generator, tokenizer, token2id, speaker2id) 93 | 94 | 95 | def get_style_embedding(prompt, tokenizer, style_encoder): 96 | prompt = tokenizer([prompt], return_tensors="pt") 97 | input_ids = prompt["input_ids"] 98 | token_type_ids = prompt["token_type_ids"] 99 | attention_mask = prompt["attention_mask"] 100 | with torch.no_grad(): 101 | output = style_encoder( 102 | input_ids=input_ids, 103 | token_type_ids=token_type_ids, 104 | attention_mask=attention_mask, 105 | ) 106 | style_embedding = output["pooled_output"].cpu().squeeze().numpy() 107 | return style_embedding 108 | 109 | 110 | def emotivoice_tts(text, prompt, content, speaker, models): 111 | (style_encoder, generator, tokenizer, token2id, speaker2id) = models 112 | 113 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder) 114 | content_embedding = get_style_embedding(content, tokenizer, style_encoder) 115 | 116 | speaker = speaker2id[speaker] 117 | 118 | text_int = [token2id[ph] for ph in text.split()] 119 | 120 | sequence = torch.from_numpy(np.array(text_int)).to( 121 | DEVICE).long().unsqueeze(0) 122 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE) 123 | style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0) 124 | content_embedding = torch.from_numpy( 125 | content_embedding).to(DEVICE).unsqueeze(0) 126 | speaker = torch.from_numpy(np.array([speaker])).to(DEVICE) 127 | 128 | with torch.no_grad(): 129 | 130 | infer_output = generator( 131 | inputs_ling=sequence, 132 | inputs_style_embedding=style_embedding, 133 | input_lengths=sequence_len, 134 | inputs_content_embedding=content_embedding, 135 | inputs_speaker=speaker, 136 | alpha=1.0 137 | ) 138 | 139 | audio = infer_output["wav_predictions"].squeeze() * MAX_WAV_VALUE 140 | audio = audio.cpu().numpy().astype('int16') 141 | 142 | return audio 143 | 144 | 145 | speakers = config.speakers 146 | models = get_models() 147 | app = FastAPI() 148 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt") 149 | g2p = G2p() 150 | 151 | from typing import Optional 152 | class SpeechRequest(BaseModel): 153 | input: str 154 | voice: str = '8051' 155 | prompt: Optional[str] = '' 156 | language: Optional[str] = 'zh_us' 157 | model: Optional[str] = 'emoti-voice' 158 | response_format: Optional[str] = 'mp3' 159 | speed: Optional[float] = 1.0 160 | 161 | 162 | @app.post("/v1/audio/speech") 163 | def text_to_speech(speechRequest: SpeechRequest): 164 | 165 | text = g2p_cn_en(speechRequest.input, g2p, lexicon) 166 | np_audio = emotivoice_tts(text, speechRequest.prompt, 167 | speechRequest.input, speechRequest.voice, 168 | models) 169 | y_stretch = np_audio 170 | if speechRequest.speed != 1.0: 171 | y_stretch = pyrb.time_stretch(np_audio, config.sampling_rate, speechRequest.speed) 172 | wav_buffer = io.BytesIO() 173 | sf.write(file=wav_buffer, data=y_stretch, 174 | samplerate=config.sampling_rate, format='WAV') 175 | buffer = wav_buffer 176 | response_format = speechRequest.response_format 177 | if response_format != 'wav': 178 | wav_audio = AudioSegment.from_wav(wav_buffer) 179 | wav_audio.frame_rate=config.sampling_rate 180 | buffer = io.BytesIO() 181 | wav_audio.export(buffer, format=response_format) 182 | 183 | return Response(content=buffer.getvalue(), 184 | media_type=f"audio/{response_format}") 185 | -------------------------------------------------------------------------------- /plot_image.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch.nn.functional as F 3 | 4 | import os 5 | 6 | def plot_image_sambert(target, melspec, mel_lengths=None, text_lengths=None, save_dir=None, global_step=None, name=None): 7 | # Draw mel_plots 8 | mel_plots, axes = plt.subplots(2,1,figsize=(20,15)) 9 | 10 | T = mel_lengths[-1] 11 | L=100 12 | 13 | 14 | axes[0].imshow(target[-1].detach().cpu()[:,:T], 15 | origin='lower', 16 | aspect='auto') 17 | 18 | axes[1].imshow(melspec[-1].detach().cpu()[:,:T], 19 | origin='lower', 20 | aspect='auto') 21 | for i in range(2): 22 | tmp_dir = save_dir+'/att/'+name+'_'+str(global_step) 23 | if not os.path.exists(tmp_dir): 24 | os.makedirs(tmp_dir) 25 | plt.savefig(tmp_dir+'/'+name+'_'+str(global_step)+'_melspec_%s.png'%i) 26 | 27 | return mel_plots -------------------------------------------------------------------------------- /prepare_for_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, YOUDAO 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import os 17 | import shutil 18 | import argparse 19 | 20 | 21 | def main(args): 22 | from os.path import join 23 | data_dir = args.data_dir 24 | exp_dir = args.exp_dir 25 | os.makedirs(exp_dir, exist_ok=True) 26 | 27 | info_dir = join(exp_dir, 'info') 28 | prepare_info(data_dir, info_dir) 29 | 30 | config_dir = join(exp_dir, 'config') 31 | prepare_config(data_dir, info_dir, exp_dir, config_dir) 32 | 33 | ckpt_dir = join(exp_dir, 'ckpt') 34 | prepare_ckpt(data_dir, info_dir, ckpt_dir) 35 | 36 | 37 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__")) 38 | def prepare_info(data_dir, info_dir): 39 | import jsonlines 40 | print('prepare_info: %s' %info_dir) 41 | os.makedirs(info_dir, exist_ok=True) 42 | 43 | for name in ["emotion", "energy", "pitch", "speed", "tokenlist"]: 44 | shutil.copy(f"{ROOT_DIR}/data/youdao/text/{name}", f"{info_dir}/{name}") 45 | 46 | d_speaker = {} # get all the speakers from datalist 47 | with jsonlines.open(f"{data_dir}/train/datalist.jsonl") as reader: 48 | for obj in reader: 49 | speaker = obj["speaker"] 50 | if not speaker in d_speaker: 51 | d_speaker[speaker] = 1 52 | else: 53 | d_speaker[speaker] += 1 54 | 55 | with open(f"{ROOT_DIR}/data/youdao/text/speaker2") as f, \ 56 | open(f"{info_dir}/speaker", "w") as fout: 57 | 58 | for line in f: 59 | speaker = line.strip() 60 | if speaker in d_speaker: 61 | print('warning: duplicate of speaker [%s] in [%s]' % (speaker, data_dir)) 62 | continue 63 | fout.write(line.strip()+"\n") 64 | 65 | for speaker in sorted(d_speaker.keys()): 66 | fout.write(speaker + "\n") 67 | 68 | 69 | def prepare_config(data_dir, info_dir, exp_dir, config_dir): 70 | print('prepare_config: %s' %config_dir) 71 | os.makedirs(config_dir, exist_ok=True) 72 | 73 | with open(f"{ROOT_DIR}/config/template.py") as f, \ 74 | open(f"{config_dir}/config.py", "w") as fout: 75 | 76 | for line in f: 77 | fout.write(line.replace('', data_dir).replace('', info_dir).replace('', exp_dir)) 78 | 79 | 80 | def prepare_ckpt(data_dir, info_dir, ckpt_dir): 81 | print('prepare_ckpt: %s' %ckpt_dir) 82 | os.makedirs(ckpt_dir, exist_ok=True) 83 | 84 | with open(f"{info_dir}/speaker") as f: 85 | speaker_list=[line.strip() for line in f] 86 | assert len(speaker_list) >= 2014 87 | 88 | gen_ckpt_path = f"{ROOT_DIR}/outputs/prompt_tts_open_source_joint/ckpt/g_00140000" 89 | disc_ckpt_path = f"{ROOT_DIR}/outputs/prompt_tts_open_source_joint/ckpt/do_00140000" 90 | 91 | gen_ckpt = torch.load(gen_ckpt_path, map_location="cpu") 92 | 93 | speaker_embeddings = gen_ckpt["generator"]["am.spk_tokenizer.weight"].clone() 94 | 95 | new_embedding = torch.randn((len(speaker_list)-speaker_embeddings.size(0), speaker_embeddings.size(1))) 96 | 97 | gen_ckpt["generator"]["am.spk_tokenizer.weight"] = torch.cat([speaker_embeddings, new_embedding], dim=0) 98 | 99 | 100 | torch.save(gen_ckpt, f"{ckpt_dir}/pretrained_generator") 101 | shutil.copy(disc_ckpt_path, f"{ckpt_dir}/pretrained_discriminator") 102 | 103 | 104 | 105 | if __name__ == "__main__": 106 | 107 | p = argparse.ArgumentParser() 108 | p.add_argument('--data_dir', type=str, required=True) 109 | p.add_argument('--exp_dir', type=str, required=True) 110 | args = p.parse_args() 111 | 112 | main(args) 113 | -------------------------------------------------------------------------------- /requirements.openaiapi.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | python-multipart 3 | uvicorn[standard] 4 | pydub 5 | pyrubberband 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | numpy 4 | numba 5 | scipy 6 | transformers 7 | soundfile 8 | yacs 9 | g2p_en 10 | jieba 11 | pypinyin 12 | pypinyin_dict 13 | streamlit 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | requirements={ 6 | "infer": [ 7 | "numpy>=1.24.3", 8 | "scipy>=1.10.1", 9 | "torch>=2.1", 10 | "torchaudio", 11 | "soundfile>=0.12.0", 12 | "librosa>=0.10.0", 13 | "scikit-learn", 14 | "numba==0.58.1", 15 | "inflect>=5.6.0", 16 | "tqdm>=4.64.1", 17 | "pyyaml>=6.0", 18 | "transformers==4.26.1", 19 | "yacs", 20 | "g2p_en", 21 | "jieba", 22 | "pypinyin", 23 | "streamlit", 24 | "pandas>=1.4,<2.0", 25 | ], 26 | "openai": [ 27 | "fastapi", 28 | "python-multipart", 29 | "uvicorn[standard]", 30 | "pydub", 31 | ], 32 | "train": [ 33 | "jsonlines", 34 | "praatio", 35 | "pyworld", 36 | "flake8", 37 | "flake8-bugbear", 38 | "flake8-comprehensions", 39 | "flake8-executable", 40 | "flake8-pyi", 41 | "mccabe", 42 | "pycodestyle", 43 | "pyflakes", 44 | "tensorboard", 45 | "einops", 46 | "matplotlib", 47 | ] 48 | } 49 | 50 | infer_requires = requirements["infer"] 51 | openai_requires = requirements["infer"] + requirements["openai"] 52 | train_requires = requirements["infer"] + requirements["train"] 53 | 54 | VERSION = '0.2.0' 55 | 56 | with open("README.md", "r", encoding="utf-8") as readme_file: 57 | README = readme_file.read() 58 | 59 | 60 | setup( 61 | name="EmotiVoice", 62 | version=VERSION, 63 | url="https://github.com/netease-youdao/EmotiVoice", 64 | author="Huaxuan Wang", 65 | author_email="wanghx04@rd.netease.com", 66 | description="EmotiVoice 😊: a Multi-Voice and Prompt-Controlled TTS Engine", 67 | long_description=README, 68 | long_description_content_type="text/markdown", 69 | license="Apache Software License", 70 | # package 71 | packages=find_packages(), 72 | project_urls={ 73 | "Documentation": "https://github.com/netease-youdao/EmotiVoice/wiki", 74 | "Tracker": "https://github.com/netease-youdao/EmotiVoice/issues", 75 | "Repository": "https://github.com/netease-youdao/EmotiVoice", 76 | }, 77 | install_requires=infer_requires, 78 | extras_require={ 79 | "train": train_requires, 80 | "openai": openai_requires, 81 | }, 82 | python_requires=">=3.8.0", 83 | classifiers=[ 84 | "Programming Language :: Python", 85 | "Programming Language :: Python :: 3", 86 | "Programming Language :: Python :: 3.8", 87 | "Programming Language :: Python :: 3.9", 88 | "Programming Language :: Python :: 3.10", 89 | "Programming Language :: Python :: 3.11", 90 | "Development Status :: 3 - Alpha", 91 | "Intended Audience :: Science/Research", 92 | "Operating System :: POSIX :: Linux", 93 | "License :: OSI Approved :: Apache Software License", 94 | "Topic :: Software Development :: Libraries :: Python Modules", 95 | "Topic :: Multimedia :: Sound/Audio :: Speech", 96 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 97 | ], 98 | ) -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/keithito/tacotron 3 | """ 4 | 5 | import re 6 | from text import cleaners 7 | from text.symbols import symbols 8 | 9 | 10 | # Mappings from symbol to numeric ID and vice versa: 11 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 12 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 13 | 14 | # Regular expression matching text enclosed in curly braces: 15 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 16 | 17 | 18 | def text_to_sequence(text, cleaner_names): 19 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 20 | 21 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 22 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 23 | 24 | Args: 25 | text: string to convert to a sequence 26 | cleaner_names: names of the cleaner functions to run the text through 27 | 28 | Returns: 29 | List of integers corresponding to the symbols in the text 30 | """ 31 | sequence = [] 32 | 33 | # Check for curly braces and treat their contents as ARPAbet: 34 | while len(text): 35 | m = _curly_re.match(text) 36 | 37 | if not m: 38 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 39 | break 40 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 41 | 42 | sequence += _arpabet_to_sequence(m.group(2)) 43 | text = m.group(3) 44 | 45 | 46 | return sequence 47 | 48 | 49 | def sequence_to_text(sequence): 50 | """Converts a sequence of IDs back to a string""" 51 | result = "" 52 | for symbol_id in sequence: 53 | if symbol_id in _id_to_symbol: 54 | s = _id_to_symbol[symbol_id] 55 | # Enclose ARPAbet back in curly braces: 56 | if len(s) > 1 and s[0] == "@": 57 | s = "{%s}" % s[1:] 58 | result += s 59 | return result.replace("}{", " ") 60 | 61 | 62 | def _clean_text(text, cleaner_names): 63 | for name in cleaner_names: 64 | cleaner = getattr(cleaners, name) 65 | if not cleaner: 66 | raise Exception("Unknown cleaner: %s" % name) 67 | text = cleaner(text) 68 | return text 69 | 70 | 71 | def _symbols_to_sequence(symbols): 72 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 73 | 74 | 75 | def _arpabet_to_sequence(text): 76 | return _symbols_to_sequence(["@" + s for s in text.split()]) 77 | 78 | 79 | def _should_keep_symbol(s): 80 | return s in _symbol_to_id and s != "_" and s != "~" 81 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/keithito/tacotron 3 | """ 4 | 5 | ''' 6 | Cleaners are transformations that run over the input text at both training and eval time. 7 | 8 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 9 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 10 | 1. "english_cleaners" for English text 11 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 12 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 13 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 14 | the symbols in symbols.py to match your data). 15 | ''' 16 | 17 | 18 | # Regular expression matching whitespace: 19 | import re 20 | from unidecode import unidecode 21 | from .numbers import normalize_numbers 22 | _whitespace_re = re.compile(r'\s+') 23 | 24 | # List of (regular expression, replacement) pairs for abbreviations: 25 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 26 | ('mrs', 'misess'), 27 | ('mr', 'mister'), 28 | ('dr', 'doctor'), 29 | ('st', 'saint'), 30 | ('co', 'company'), 31 | ('jr', 'junior'), 32 | ('maj', 'major'), 33 | ('gen', 'general'), 34 | ('drs', 'doctors'), 35 | ('rev', 'reverend'), 36 | ('lt', 'lieutenant'), 37 | ('hon', 'honorable'), 38 | ('sgt', 'sergeant'), 39 | ('capt', 'captain'), 40 | ('esq', 'esquire'), 41 | ('ltd', 'limited'), 42 | ('col', 'colonel'), 43 | ('ft', 'fort'), 44 | ]] 45 | 46 | 47 | def expand_abbreviations(text): 48 | for regex, replacement in _abbreviations: 49 | text = re.sub(regex, replacement, text) 50 | return text 51 | 52 | 53 | def expand_numbers(text): 54 | return normalize_numbers(text) 55 | 56 | 57 | def lowercase(text): 58 | return text.lower() 59 | 60 | 61 | def collapse_whitespace(text): 62 | return re.sub(_whitespace_re, ' ', text) 63 | 64 | 65 | def convert_to_ascii(text): 66 | return unidecode(text) 67 | 68 | 69 | def basic_cleaners(text): 70 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 71 | text = lowercase(text) 72 | text = collapse_whitespace(text) 73 | return text 74 | 75 | 76 | def transliteration_cleaners(text): 77 | '''Pipeline for non-English text that transliterates to ASCII.''' 78 | text = convert_to_ascii(text) 79 | text = lowercase(text) 80 | text = collapse_whitespace(text) 81 | return text 82 | 83 | 84 | def english_cleaners(text): 85 | '''Pipeline for English text, including number and abbreviation expansion.''' 86 | text = convert_to_ascii(text) 87 | text = lowercase(text) 88 | text = expand_numbers(text) 89 | text = expand_abbreviations(text) 90 | text = collapse_whitespace(text) 91 | return text 92 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/keithito/tacotron 3 | """ 4 | 5 | import re 6 | 7 | 8 | valid_symbols = [ 9 | "AA", 10 | "AA0", 11 | "AA1", 12 | "AA2", 13 | "AE", 14 | "AE0", 15 | "AE1", 16 | "AE2", 17 | "AH", 18 | "AH0", 19 | "AH1", 20 | "AH2", 21 | "AO", 22 | "AO0", 23 | "AO1", 24 | "AO2", 25 | "AW", 26 | "AW0", 27 | "AW1", 28 | "AW2", 29 | "AY", 30 | "AY0", 31 | "AY1", 32 | "AY2", 33 | "B", 34 | "CH", 35 | "D", 36 | "DH", 37 | "EH", 38 | "EH0", 39 | "EH1", 40 | "EH2", 41 | "ER", 42 | "ER0", 43 | "ER1", 44 | "ER2", 45 | "EY", 46 | "EY0", 47 | "EY1", 48 | "EY2", 49 | "F", 50 | "G", 51 | "HH", 52 | "IH", 53 | "IH0", 54 | "IH1", 55 | "IH2", 56 | "IY", 57 | "IY0", 58 | "IY1", 59 | "IY2", 60 | "JH", 61 | "K", 62 | "L", 63 | "M", 64 | "N", 65 | "NG", 66 | "OW", 67 | "OW0", 68 | "OW1", 69 | "OW2", 70 | "OY", 71 | "OY0", 72 | "OY1", 73 | "OY2", 74 | "P", 75 | "R", 76 | "S", 77 | "SH", 78 | "T", 79 | "TH", 80 | "UH", 81 | "UH0", 82 | "UH1", 83 | "UH2", 84 | "UW", 85 | "UW0", 86 | "UW1", 87 | "UW2", 88 | "V", 89 | "W", 90 | "Y", 91 | "Z", 92 | "ZH", 93 | ] 94 | 95 | _valid_symbol_set = set(valid_symbols) 96 | 97 | 98 | class CMUDict: 99 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 100 | 101 | def __init__(self, file_or_path, keep_ambiguous=True): 102 | if isinstance(file_or_path, str): 103 | with open(file_or_path, encoding="latin-1") as f: 104 | entries = _parse_cmudict(f) 105 | else: 106 | entries = _parse_cmudict(file_or_path) 107 | if not keep_ambiguous: 108 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 109 | self._entries = entries 110 | 111 | def __len__(self): 112 | return len(self._entries) 113 | 114 | def lookup(self, word): 115 | """Returns list of ARPAbet pronunciations of the given word.""" 116 | return self._entries.get(word.upper()) 117 | 118 | 119 | _alt_re = re.compile(r"\([0-9]+\)") 120 | 121 | 122 | def _parse_cmudict(file): 123 | cmudict = {} 124 | for line in file: 125 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 126 | parts = line.split(" ") 127 | word = re.sub(_alt_re, "", parts[0]) 128 | pronunciation = _get_pronunciation(parts[1]) 129 | if pronunciation: 130 | if word in cmudict: 131 | cmudict[word].append(pronunciation) 132 | else: 133 | cmudict[word] = [pronunciation] 134 | return cmudict 135 | 136 | 137 | def _get_pronunciation(s): 138 | parts = s.strip().split(" ") 139 | for part in parts: 140 | if part not in _valid_symbol_set: 141 | return None 142 | return " ".join(parts) 143 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/keithito/tacotron 3 | """ 4 | 5 | import inflect 6 | import re 7 | 8 | 9 | _inflect = inflect.engine() 10 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 11 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 12 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 13 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 14 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 15 | _number_re = re.compile(r"[0-9]+") 16 | 17 | 18 | def _remove_commas(m): 19 | return m.group(1).replace(",", "") 20 | 21 | 22 | def _expand_decimal_point(m): 23 | return m.group(1).replace(".", " point ") 24 | 25 | 26 | def _expand_dollars(m): 27 | match = m.group(1) 28 | parts = match.split(".") 29 | if len(parts) > 2: 30 | return match + " dollars" # Unexpected format 31 | dollars = int(parts[0]) if parts[0] else 0 32 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 33 | if dollars and cents: 34 | dollar_unit = "dollar" if dollars == 1 else "dollars" 35 | cent_unit = "cent" if cents == 1 else "cents" 36 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 37 | elif dollars: 38 | dollar_unit = "dollar" if dollars == 1 else "dollars" 39 | return "%s %s" % (dollars, dollar_unit) 40 | elif cents: 41 | cent_unit = "cent" if cents == 1 else "cents" 42 | return "%s %s" % (cents, cent_unit) 43 | else: 44 | return "zero dollars" 45 | 46 | 47 | def _expand_ordinal(m): 48 | return _inflect.number_to_words(m.group(0)) 49 | 50 | 51 | def _expand_number(m): 52 | num = int(m.group(0)) 53 | if num > 1000 and num < 3000: 54 | if num == 2000: 55 | return "two thousand" 56 | elif num > 2000 and num < 2010: 57 | return "two thousand " + _inflect.number_to_words(num % 100) 58 | elif num % 100 == 0: 59 | return _inflect.number_to_words(num // 100) + " hundred" 60 | else: 61 | return _inflect.number_to_words( 62 | num, andword="", zero="oh", group=2 63 | ).replace(", ", " ") 64 | else: 65 | return _inflect.number_to_words(num, andword="") 66 | 67 | 68 | def normalize_numbers(text): 69 | text = re.sub(_comma_number_re, _remove_commas, text) 70 | text = re.sub(_pounds_re, r"\1 pounds", text) 71 | text = re.sub(_dollars_re, _expand_dollars, text) 72 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 73 | text = re.sub(_ordinal_re, _expand_ordinal, text) 74 | text = re.sub(_number_re, _expand_number, text) 75 | return text 76 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/keithito/tacotron 3 | """ 4 | 5 | """ 6 | Defines the set of symbols used in text input to the model. 7 | 8 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 9 | 10 | from text import cmudict 11 | 12 | _pad = "_" 13 | _punctuation = "!'(),.:;? " 14 | _special = "-" 15 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 16 | _silences = ["@sp", "@spn", "@sil"] 17 | 18 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 19 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 20 | 21 | 22 | # Export all symbols: 23 | symbols = ( 24 | [_pad] 25 | + list(_special) 26 | + list(_punctuation) 27 | + list(_letters) 28 | + _arpabet 29 | + _silences 30 | ) 31 | --------------------------------------------------------------------------------