├── .dockerignore
├── .gitignore
├── Dockerfile
├── EmotiVoice_UserAgreement_易魔声用户协议.pdf
├── HTTP_API_TtsDemo
├── README.md
└── apidemo
│ ├── TtsDemo.py
│ └── utils
│ └── AuthV3Util.py
├── LICENSE
├── README.md
├── README.zh.md
├── README_小白安装教程.md
├── ROADMAP.md
├── assets
└── audio
│ ├── emotivoice_intro_cn.wav
│ └── emotivoice_intro_en.wav
├── cn2an
├── an2cn.py
└── conf.py
├── cog.yaml
├── config
├── joint
│ ├── config.py
│ └── config.yaml
└── template.py
├── data
├── DataBaker
│ ├── README.md
│ └── src
│ │ ├── step0_download.sh
│ │ ├── step1_clean_raw_data.py
│ │ └── step2_get_phoneme.py
├── LJspeech
│ ├── README.md
│ └── src
│ │ ├── step0_download.sh
│ │ ├── step1_clean_raw_data.py
│ │ └── step2_get_phoneme.py
├── inference
│ └── text
└── youdao
│ └── text
│ ├── README.md
│ ├── emotion
│ ├── energy
│ ├── pitch
│ ├── speaker2
│ ├── speed
│ └── tokenlist
├── demo_page.py
├── demo_page_databaker.py
├── frontend.py
├── frontend_cn.py
├── frontend_en.py
├── inference_am_vocoder_exp.py
├── inference_am_vocoder_joint.py
├── inference_tts.py
├── lexicon
└── librispeech-lexicon.txt
├── mel_process.py
├── mfa
├── step1_create_dataset.py
├── step2_prepare_data.py
├── step3_prepare_special_tokens.py
├── step4_convert_text_to_phn.py
├── step5_prepare_alignment.py
├── step7_gen_alignment_from_textgrid.py
├── step8_make_data_list.py
└── step9_datalist_from_mfa.py
├── models
├── hifigan
│ ├── dataset.py
│ ├── env.py
│ ├── get_random_segments.py
│ ├── get_vocoder.py
│ ├── models.py
│ └── pretrained_discriminator.py
└── prompt_tts_modified
│ ├── audio_processing.py
│ ├── feats.py
│ ├── jets.py
│ ├── loss.py
│ ├── model_open_source.py
│ ├── modules
│ ├── alignment.py
│ ├── encoder.py
│ ├── initialize.py
│ └── variance.py
│ ├── prompt_dataset.py
│ ├── scheduler.py
│ ├── simbert.py
│ ├── stft.py
│ ├── style_encoder.py
│ └── tacotron_stft.py
├── openaiapi.py
├── plot_image.py
├── predict.py
├── prepare_for_training.py
├── requirements.openaiapi.txt
├── requirements.txt
├── setup.py
├── text
├── __init__.py
├── cleaners.py
├── cmudict.py
├── numbers.py
└── symbols.py
└── train_am_vocoder_joint.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | # The .dockerignore file excludes files from the container build process.
2 | #
3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file
4 |
5 | # Exclude Git files
6 | .git
7 | .github
8 | .gitignore
9 |
10 | # Exclude Python cache files
11 | __pycache__
12 | .mypy_cache
13 | .pytest_cache
14 | .ruff_cache
15 |
16 | # Exclude Python virtual environment
17 | /venv
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/
2 | WangZeJun/
3 | *.pyc
4 | .vscode/
5 | __pycache__/
6 | .idea/
7 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:1
2 | FROM ubuntu:22.04
3 |
4 | # install app dependencies
5 | RUN apt-get update && apt-get install -y python3 python3-pip libsndfile1
6 | RUN python3 -m pip install torch==1.11.0 torchaudio numpy numba scipy transformers==4.26.1 soundfile yacs
7 | RUN python3 -m pip install pypinyin jieba
8 |
9 | # install app
10 | RUN mkdir /EmotiVoice
11 | COPY . /EmotiVoice/
12 |
13 | # final configuration
14 | EXPOSE 8501
15 | RUN python3 -m pip install streamlit g2p_en
16 | WORKDIR /EmotiVoice
17 | RUN python3 frontend_en.py
18 | CMD streamlit run demo_page.py --server.port 8501
19 |
--------------------------------------------------------------------------------
/EmotiVoice_UserAgreement_易魔声用户协议.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/netease-youdao/EmotiVoice/bc2de8c9eb1121237958ef154cb171e7faefc769/EmotiVoice_UserAgreement_易魔声用户协议.pdf
--------------------------------------------------------------------------------
/HTTP_API_TtsDemo/README.md:
--------------------------------------------------------------------------------
1 | # 说明
2 | 项目为有道智云paas接口的python语言调用示例。您可以通过执行项目中的main函数快速调用有道智云相关api服务。
3 |
4 | # 运行环境
5 | 1. python 3.6版本及以上。
6 |
7 | # 运行方式
8 | 1. 在执行前您需要根据代码中的 中文提示 填写相关接口参数,具体参数取值可以访问 [智云官网](https://ai.youdao.com) 文档获取。
9 | 2. 同时您需要获取智云相关 应用ID 及 应用密钥 信息。具体获取方式可以访问 [入门指南](https://ai.youdao.com/doc.s#guide) 获取帮助。
10 |
11 | # 注意事项
12 | 1. 项目中的代码有些仅作展示及参考,生产环境中请根据业务的实际情况进行修改。
13 | 2. 项目中接口返回的数据仅在控制台输出,实际使用中请根据实际情况进行解析。
--------------------------------------------------------------------------------
/HTTP_API_TtsDemo/apidemo/TtsDemo.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | from utils.AuthV3Util import addAuthParams
4 |
5 | # 您的应用ID
6 | APP_KEY = ''
7 | # 您的应用密钥
8 | APP_SECRET = ''
9 |
10 | # 合成音频保存路径, 例windows路径:PATH = "C:\\tts\\media.mp3"
11 | PATH = 'EmotiVoice-8051.mp3'
12 |
13 |
14 | def createRequest():
15 | '''
16 | note: 将下列变量替换为需要请求的参数
17 | '''
18 | q = 'Emoti-Voice - a Multi-Voice and Prompt-Controlled T-T-S Engine,大家好'
19 | voiceName = 'Maria Kasper' # 'Cori Samuel'
20 | format = 'mp3'
21 |
22 | data = {'q': q, 'voiceName': voiceName, 'format': format}
23 |
24 | addAuthParams(APP_KEY, APP_SECRET, data)
25 |
26 | header = {'Content-Type': 'application/x-www-form-urlencoded'}
27 | res = doCall('https://openapi.youdao.com/ttsapi', header, data, 'post')
28 | saveFile(res)
29 |
30 |
31 | def doCall(url, header, params, method):
32 | if 'get' == method:
33 | return requests.get(url, params)
34 | elif 'post' == method:
35 | return requests.post(url, params, header)
36 |
37 |
38 | def saveFile(res):
39 | contentType = res.headers['Content-Type']
40 | if 'audio' in contentType:
41 | fo = open(PATH, 'wb')
42 | fo.write(res.content)
43 | fo.close()
44 | print('save file path: ' + PATH)
45 | else:
46 | print(str(res.content, 'utf-8'))
47 |
48 | # 网易有道智云语音合成服务api调用demo
49 | # api接口: https://openapi.youdao.com/ttsapi
50 | if __name__ == '__main__':
51 | createRequest()
52 |
--------------------------------------------------------------------------------
/HTTP_API_TtsDemo/apidemo/utils/AuthV3Util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import time
3 | import uuid
4 |
5 | '''
6 | 添加鉴权相关参数 -
7 | appKey : 应用ID
8 | salt : 随机值
9 | curtime : 当前时间戳(秒)
10 | signType : 签名版本
11 | sign : 请求签名
12 |
13 | @param appKey 您的应用ID
14 | @param appSecret 您的应用密钥
15 | @param paramsMap 请求参数表
16 | '''
17 | def addAuthParams(appKey, appSecret, params):
18 | q = params.get('q')
19 | if q is None:
20 | q = params.get('img')
21 | salt = str(uuid.uuid1())
22 | curtime = str(int(time.time()))
23 | sign = calculateSign(appKey, appSecret, q, salt, curtime)
24 | params['appKey'] = appKey
25 | params['salt'] = salt
26 | params['curtime'] = curtime
27 | params['signType'] = 'v3'
28 | params['sign'] = sign
29 |
30 | '''
31 | 计算鉴权签名 -
32 | 计算方式 : sign = sha256(appKey + input(q) + salt + curtime + appSecret)
33 | @param appKey 您的应用ID
34 | @param appSecret 您的应用密钥
35 | @param q 请求内容
36 | @param salt 随机值
37 | @param curtime 当前时间戳(秒)
38 | @return 鉴权签名sign
39 | '''
40 | def calculateSign(appKey, appSecret, q, salt, curtime):
41 | strSrc = appKey + getInput(q) + salt + curtime + appSecret
42 | return encrypt(strSrc)
43 |
44 |
45 | def encrypt(strSrc):
46 | hash_algorithm = hashlib.sha256()
47 | hash_algorithm.update(strSrc.encode('utf-8'))
48 | return hash_algorithm.hexdigest()
49 |
50 |
51 | def getInput(input):
52 | if input is None:
53 | return input
54 | inputLen = len(input)
55 | return input if inputLen <= 20 else input[0:10] + str(inputLen) + input[inputLen - 10:inputLen]
56 |
--------------------------------------------------------------------------------
/README.zh.md:
--------------------------------------------------------------------------------
1 | README: EN | 中文
2 |
3 |
4 |
5 |
EmotiVoice易魔声 😊: 多音色提示控制TTS
6 |
7 |
8 |
9 |

10 |
11 |

12 |
13 |

14 |
15 |
16 |
17 |
18 | **EmotiVoice**是一个强大的开源TTS引擎,**完全免费**,支持中英文双语,包含2000多种不同的音色,以及特色的**情感合成**功能,支持合成包含快乐、兴奋、悲伤、愤怒等广泛情感的语音。
19 |
20 | EmotiVoice提供一个易于使用的web界面,还有用于批量生成结果的脚本接口。
21 |
22 | 以下是EmotiVoice生成的几个示例:
23 |
24 | - [Chinese audio sample](https://github.com/netease-youdao/EmotiVoice/assets/3909232/6426d7c1-d620-4bfc-ba03-cd7fc046a4fb)
25 |
26 | - [English audio sample](https://github.com/netease-youdao/EmotiVoice/assets/3909232/8f272eba-49db-493b-b479-2d9e5a419e26)
27 |
28 | - [Fun Chinese English audio sample](https://github.com/netease-youdao/EmotiVoice/assets/3909232/a0709012-c3ef-4182-bb0e-b7a2ba386f1c)
29 |
30 | ## 热闻速递
31 |
32 | - [x] 类OpenAI TTS的API已经支持调语速功能,感谢 [@john9405](https://github.com/john9405). [#90](https://github.com/netease-youdao/EmotiVoice/pull/90) [#67](https://github.com/netease-youdao/EmotiVoice/issues/67) [#77](https://github.com/netease-youdao/EmotiVoice/issues/77)
33 | - [x] [Mac版一键安装包](https://github.com/netease-youdao/EmotiVoice/releases/download/v0.3/emotivoice-1.0.0-arm64.dmg) 已于2023年12月28日发布,**强烈推荐尽快下载使用,免费好用!**
34 | - [x] [易魔声 HTTP API](https://github.com/netease-youdao/EmotiVoice/wiki/HTTP-API) 已于2023年12月6日发布上线。更易上手(无需任何安装配置),更快更稳定,单账户提供**超过 13,000 次免费调用**。此外,用户还可以使用[智云](https://ai.youdao.com/)提供的其它迷人的声音。
35 | - [x] [用你自己的数据定制音色](https://github.com/netease-youdao/EmotiVoice/wiki/Voice-Cloning-with-your-personal-data)已于2023年12月13日发布上线,同时提供了两个教程示例:[DataBaker Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/DataBaker) [LJSpeech Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/LJspeech)。
36 |
37 | ## 开发中的特性
38 |
39 | - [ ] 更多语言支持,例如日韩 [#19](https://github.com/netease-youdao/EmotiVoice/issues/19) [#22](https://github.com/netease-youdao/EmotiVoice/issues/22)
40 |
41 | 易魔声倾听社区需求并积极响应,期待您的反馈!
42 |
43 | ## 快速入门
44 |
45 | ### EmotiVoice Docker镜像
46 |
47 | 尝试EmotiVoice最简单的方法是运行docker镜像。你需要一台带有NVidia GPU的机器。先按照[Linux](https://www.server-world.info/en/note?os=Ubuntu_22.04&p=nvidia&f=2)和[Windows WSL2](https://zhuanlan.zhihu.com/p/653173679)平台的说明安装NVidia容器工具包。然后可以直接运行EmotiVoice镜像:
48 |
49 | ```sh
50 | docker run -dp 127.0.0.1:8501:8501 syq163/emoti-voice:latest
51 | ```
52 |
53 | Docker镜像更新于2024年1月4号。如果你使用了老的版本,推荐运行如下命令进行更新:
54 | ```sh
55 | docker pull syq163/emoti-voice:latest
56 | docker run -dp 127.0.0.1:8501:8501 -p 127.0.0.1:8000:8000 syq163/emoti-voice:latest
57 | ```
58 |
59 | 现在打开浏览器,导航到 http://localhost:8501 ,就可以体验EmotiVoice强大的TTS功能。从2024年的docker镜像版本开始,通过http://localhost:8000/可以使用类OpenAI TTS的API功能。
60 |
61 | ### 完整安装
62 |
63 | ```sh
64 | conda create -n EmotiVoice python=3.8 -y
65 | conda activate EmotiVoice
66 | pip install torch torchaudio
67 | pip install numpy numba scipy transformers soundfile yacs g2p_en jieba pypinyin pypinyin_dict
68 | python -m nltk.downloader "averaged_perceptron_tagger_eng"
69 | ```
70 |
71 | ### 准备模型文件
72 |
73 | 强烈推荐用户参考[如何下载预训练模型文件](https://github.com/netease-youdao/EmotiVoice/wiki/Pretrained-models)的维基页面,尤其遇到问题时。
74 |
75 | ```sh
76 | git lfs install
77 | git lfs clone https://huggingface.co/WangZeJun/simbert-base-chinese WangZeJun/simbert-base-chinese
78 | ```
79 |
80 | 或者你可以运行:
81 | ```sh
82 | git clone https://www.modelscope.cn/syq163/WangZeJun.git
83 | ```
84 |
85 | ### 推理
86 |
87 | 1. 通过简单运行如下命令来下载[预训练模型](https://drive.google.com/drive/folders/1y6Xwj_GG9ulsAonca_unSGbJ4lxbNymM?usp=sharing):
88 |
89 | ```sh
90 | git clone https://www.modelscope.cn/syq163/outputs.git
91 | ```
92 |
93 | 2. 推理输入文本格式是:`|||`.
94 |
95 | - 例如: `8051|非常开心| uo3 sp1 l ai2 sp0 d ao4 sp1 b ei3 sp0 j ing1 sp3 q ing1 sp0 h ua2 sp0 d a4 sp0 x ve2 |我来到北京,清华大学`.
96 | 4. 其中的音素(phonemes)可以这样得到:`python frontend.py data/my_text.txt > data/my_text_for_tts.txt`.
97 |
98 | 5. 然后运行:
99 | ```sh
100 | TEXT=data/inference/text
101 | python inference_am_vocoder_joint.py \
102 | --logdir prompt_tts_open_source_joint \
103 | --config_folder config/joint \
104 | --checkpoint g_00140000 \
105 | --test_file $TEXT
106 | ```
107 | 合成的语音结果在:`outputs/prompt_tts_open_source_joint/test_audio`.
108 |
109 | 6. 或者你可以直接使用交互的网页界面:
110 | ```sh
111 | pip install streamlit
112 | streamlit run demo_page.py
113 | ```
114 |
115 | ### 类OpenAI TTS的API
116 |
117 | 非常感谢 @lewangdev 的相关该工作 [#60](../../issues/60)。通过运行如下命令来完成配置:
118 |
119 | ```sh
120 | pip install fastapi pydub uvicorn[standard] pyrubberband
121 | uvicorn openaiapi:app --reload
122 | ```
123 |
124 | ### Wiki页面
125 |
126 | 如果遇到问题,或者想获取更多详情,请参考 [wiki](https://github.com/netease-youdao/EmotiVoice/wiki) 页面。
127 |
128 | ## 训练
129 |
130 | [用你自己的数据定制音色](https://github.com/netease-youdao/EmotiVoice/wiki/Voice-Cloning-with-your-personal-data)已于2023年12月13日发布上线。
131 |
132 | ## 路线图和未来的工作
133 |
134 | - 我们未来的计划可以在 [ROADMAP](./ROADMAP.md) 文件中找到。
135 |
136 | - 当前的实现侧重于通过提示控制情绪/风格。它只使用音高、速度、能量和情感作为风格因素,而不使用性别。但是将其更改为样式、音色控制并不复杂,类似于PromptTTS的原始闭源实现。
137 |
138 | ## 微信群
139 |
140 | 欢迎扫描下方左侧二维码加入微信群。商业合作扫描右侧个人二维码。
141 |
142 |
143 |
144 |
145 |
146 | ## 致谢
147 |
148 | - [PromptTTS](https://speechresearch.github.io/prompttts/). PromptTTS论文是本工作的重要基础。
149 | - [LibriTTS](https://www.openslr.org/60/). 训练使用了LibriTTS开放数据集。
150 | - [HiFiTTS](https://www.openslr.org/109/). 训练使用了HiFi TTS开放数据集。
151 | - [ESPnet](https://github.com/espnet/espnet).
152 | - [WeTTS](https://github.com/wenet-e2e/wetts)
153 | - [HiFi-GAN](https://github.com/jik876/hifi-gan)
154 | - [Transformers](https://github.com/huggingface/transformers)
155 | - [tacotron](https://github.com/keithito/tacotron)
156 | - [KAN-TTS](https://github.com/alibaba-damo-academy/KAN-TTS)
157 | - [StyleTTS](https://github.com/yl4579/StyleTTS)
158 | - [Simbert](https://github.com/ZhuiyiTechnology/simbert)
159 | - [cn2an](https://github.com/Ailln/cn2an). 易魔声集成了cn2an来处理数字。
160 |
161 | ## 许可
162 |
163 | EmotiVoice是根据Apache-2.0许可证提供的 - 有关详细信息,请参阅[许可证文件](./LICENSE)。
164 |
165 | 交互的网页是根据[用户协议](./EmotiVoice_UserAgreement_易魔声用户协议.pdf)提供的。
166 |
--------------------------------------------------------------------------------
/README_小白安装教程.md:
--------------------------------------------------------------------------------
1 | ## 小白安装教程
2 |
3 | #### 环境条件:设备有GPU、已经安装cuda
4 |
5 | 说明:这是针对Linux环境安装的教程,其他系统可作为参考。
6 |
7 | #### 1、创建并进入conda环境
8 |
9 | ```
10 | conda create -n EmotiVoice python=3.8
11 | conda init
12 | conda activate EmotiVoice
13 | ```
14 |
15 | 如果你不想使用conda环境,也可以省略该步骤,但要保证python版本为3.8
16 |
17 |
18 | #### 2、安装git-lfs
19 |
20 | 如果是Ubuntu则执行
21 |
22 | ```
23 | sudo apt update
24 | sudo apt install git
25 | sudo apt-get install git-lfs
26 | ```
27 |
28 | CentOS则执行
29 |
30 | ```
31 | sudo yum update
32 | sudo yum install git
33 | sudo yum install git-lfs
34 | ```
35 |
36 |
37 |
38 | #### 3、克隆仓库
39 |
40 | ```
41 | git lfs install
42 | git lfs clone https://github.com/netease-youdao/EmotiVoice.git
43 | ```
44 |
45 |
46 |
47 | #### 4、安装依赖
48 |
49 | ```
50 | pip install torch torchaudio
51 | pip install numpy numba scipy transformers soundfile yacs g2p_en jieba pypinyin pypinyin_dict
52 | python -m nltk.downloader "averaged_perceptron_tagger_eng"
53 | ```
54 |
55 |
56 |
57 |
58 |
59 | #### 5、下载预训练模型文件
60 |
61 | (1)首先进入项目文件夹
62 |
63 | ```
64 | cd EmotiVoice
65 | ```
66 |
67 | (2)执行下面命令
68 |
69 | ```
70 | git lfs clone https://huggingface.co/WangZeJun/simbert-base-chinese WangZeJun/simbert-base-chinese
71 | ```
72 |
73 | 或者
74 |
75 | ```
76 | git clone https://www.modelscope.cn/syq163/WangZeJun.git
77 | ```
78 |
79 | 上面两种下载方式二选一即可。
80 |
81 | (3)第三步下载ckpt模型
82 |
83 | ```
84 | git clone https://www.modelscope.cn/syq163/outputs.git
85 | ```
86 |
87 | 上面步骤完成后,项目文件夹内会多 `WangZeJun` 和 `outputs` 文件夹,下面是项目文件结构
88 |
89 | ```
90 | ├── Dockerfile
91 | ├── EmotiVoice_UserAgreement_易魔声用户协议.pdf
92 | ├── demo_page.py
93 | ├── frontend.py
94 | ├── frontend_cn.py
95 | ├── frontend_en.py
96 | ├── WangZeJun
97 | │ └── simbert-base-chinese
98 | │ ├── README.md
99 | │ ├── config.json
100 | │ ├── pytorch_model.bin
101 | │ └── vocab.txt
102 | ├── outputs
103 | │ ├── README.md
104 | │ ├── configuration.json
105 | │ ├── prompt_tts_open_source_joint
106 | │ │ └── ckpt
107 | │ │ ├── do_00140000
108 | │ │ └── g_00140000
109 | │ └── style_encoder
110 | │ └── ckpt
111 | │ └── checkpoint_163431
112 | ```
113 |
114 |
115 |
116 | #### 6、运行UI交互界面
117 |
118 | (1)安装streamlit
119 |
120 | ```
121 | pip install streamlit
122 | ```
123 |
124 | (2)启动
125 |
126 | 打开运行后显示的server地址,如何正常显示页面则部署完成。
127 |
128 | ```
129 | streamlit run demo_page.py --server.port 6006 --logger.level debug
130 | ```
131 |
132 |
133 |
134 | #### 7、启动API服务
135 |
136 | 安装依赖
137 |
138 | ```
139 | pip install fastapi pydub uvicorn[standard] pyrubberband
140 | ```
141 |
142 | 在6006端口启动服务(端口可根据自己的需求修改)
143 |
144 | ```
145 | uvicorn openaiapi:app --reload --port 6006
146 | ```
147 |
148 | 接口文档地址:你的服务地址+`/docs`
149 |
150 |
151 |
152 | #### 8、遇到错误
153 |
154 | **(1) 运行UI界面后,打开页面一直显示 "Please wait..." 或者显示一片空白**
155 |
156 | 原因:
157 |
158 | 这个错误可能是由于CORS(跨域资源共享)保护配置错误。
159 |
160 | 解决方法:
161 |
162 | 在启动时加上一个 `server.enableCORS=false` 参数,即使用下面命令启动程序
163 |
164 | ```
165 | streamlit run demo_page.py --server.port 6006 --logger.level debug --server.enableCORS=false
166 | ```
167 |
168 | 如果通过临时禁用 CORS 保护解决了问题,建议重新启用 CORS 保护并设置正确的 URL 和端口。
169 |
170 |
171 |
172 | **(2) 运行报错 raise BadZipFile("File is not a zip file") zipfile.BadZipFile: File is not a zip file**
173 |
174 | 原因:
175 |
176 | 这可能是由于缺少 `averaged_perceptron_tagger` 这个`nltk`中用于词性标注的一个包,它包含了一个基于平均感知器算法的词性标注器。如果你在代码中使用了这个标注器,但是没有预先下载对应的数据包,就会遇到错误,提示你缺少`averaged_perceptron_tagger.zip`文件。当然也有可能是缺少 `cmudict` CMU 发音词典数据包文件。
177 |
178 | 正常来说,初次运行程序NLTK会自动下载使用的相关数据包,debug模式下运行会显示如下信息
179 |
180 | ```
181 | [nltk_data] Downloading package averaged_perceptron_tagger to
182 | [nltk_data] /root/nltk_data...
183 | [nltk_data] Unzipping taggers/averaged_perceptron_tagger.zip.
184 | [nltk_data] Downloading package cmudict to /root/nltk_data...
185 | [nltk_data] Unzipping corpora/cmudict.zip.
186 | ```
187 |
188 | 可能由于网络(需科学上网)等原因,没能自动下载成功,因此缺少相关文件导致加载报错。
189 |
190 |
191 |
192 | 解决方法:重新下载缺少的数据包文件
193 |
194 |
195 |
196 | 1)方法一
197 |
198 | 创建一个 download.py文件,在其中编写如下代码
199 |
200 | ```
201 | import nltk
202 | print(nltk.data.path)
203 | nltk.download('averaged_perceptron_tagger')
204 | nltk.download('cmudict')
205 | ```
206 |
207 | 保存并运行
208 |
209 | ```
210 | python download.py
211 | ```
212 |
213 | 这将显示其文件索引位置,并自动下载 缺少的 `averaged_perceptron_tagger.zip`和 `cmudict.zip` 文件到/root/nltk_data目录下的子目录,下载完成后查看根目录下是否有`nltk_data`文件夹,并将其中的压缩包都解压。
214 |
215 |
216 |
217 | 2)方法二
218 |
219 | 如果通过上面代码还是无法正常下载数据包 ,也可以打开以下地址手动搜索并下载压缩包文件(需科学上网)
220 |
221 | ```
222 | https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/index.xml
223 | ```
224 |
225 | 其中下面是`averaged_perceptron_tagger.zip` 和`cmudict.zip` 数据包文件的下载地址
226 |
227 | ```
228 | https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/taggers/averaged_perceptron_tagger.zip
229 | https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/cmudict.zip
230 | ```
231 |
232 | 然后将该压缩包文件上传至(1)运行`python download.py`时打印显示的文件索引位置,如 `/root/nltk_data` 或者 `/root/miniconda3/envs/EmotiVoice/nltk_data` 等类似目录下,如果没有则创建一个,然后将zip压缩包解压。
233 |
234 |
235 |
236 | 解压后nltk_data目录结构应该是下面这样
237 |
238 | ```
239 | ├── nltk_data
240 | │ ├── corpora
241 | │ │ ├── cmudict
242 | │ │ │ ├── README
243 | │ │ │ └── cmudict
244 | │ │ └── cmudict.zip
245 | │ └── taggers
246 | │ ├── averaged_perceptron_tagger
247 | │ │ └── averaged_perceptron_tagger.pickle
248 | │ └── averaged_perceptron_tagger.zip
249 | ```
250 |
251 |
252 |
253 | **(3) 报错 AttributeError: 'NoneType' object has no attribute 'seek'.**
254 |
255 | 原因:未找到模型文件
256 |
257 | 解决方法:大概率是你未下载模型文件或者存放路径不正确,查看自己下载的模型文件是否存在,即outputs文件夹存放路径和里面的模型文件是否正确,正确结构可参考 [第五步](#step5) 中的项目结构。
258 |
259 |
260 |
261 | **(4) 运行API服务出错 ImportError: cannot import name 'Doc' from 'typing_extensions'**
262 |
263 | 原因:typing_extensions 版本问题
264 |
265 | 解决方法:
266 |
267 | 尝试将`typing_extensions`升级至最新版本,如果已经是最新版本,则适当降低版本,以下版本在`fastapi V0.104.1`测试正常。
268 |
269 | ```
270 | pip install typing_extensions==4.8.0 --force-reinstall
271 | ```
272 |
273 |
274 |
275 | **(5) 请求文本转语音接口时报错 500 Internal Server Error ,FileNotFoundError: [Errno 2] No such file or directory: 'ffmpeg'**
276 |
277 | 原因:未安装ffmpeg
278 |
279 | 解决方法:
280 |
281 | 执行以下命令进行安装,如果是Ubuntu执行
282 |
283 | ```
284 | sudo apt update
285 | sudo apt install ffmpeg
286 | ```
287 |
288 | CentOS则执行
289 |
290 | ```
291 | sudo yum install epel-release
292 | sudo yum install ffmpeg
293 | ```
294 |
295 | 安装完成后,你可以在终端中运行以下命令来验证"ffmpeg"是否成功安装:
296 |
297 | ```
298 | ffmpeg -version
299 | ```
300 |
301 | 如果安装成功,你将看到"ffmpeg"的版本信息。
302 |
303 |
304 |
--------------------------------------------------------------------------------
/ROADMAP.md:
--------------------------------------------------------------------------------
1 | # EmotiVoice Roadmap
2 |
3 | This roadmap is for EmotiVoice (易魔声), a project driven by the community. We value your feedback and suggestions on our future direction.
4 |
5 | Please visit https://github.com/netease-youdao/EmotiVoice/issues on GitHub to submit your proposals.
6 | If you are interested, feel free to volunteer for any tasks, even if they are not listed.
7 |
8 | The plan is to finish 0.2 to 0.4 in Q4 2023.
9 |
10 | ## EmotiVoice 0.4
11 |
12 | - [ ] Updated model with potentially improved quality.
13 | - [ ] First version of desktop application.
14 | - [ ] Support longer text.
15 |
16 | ## EmotiVoice 0.3 (2023.12.13)
17 |
18 | - [x] Release [The EmotiVoice HTTP API](https://github.com/netease-youdao/EmotiVoice/wiki/HTTP-API) provided by [Zhiyun](https://mp.weixin.qq.com/s/_Fbj4TI4ifC6N7NFOUrqKQ).
19 | - [x] Release [Voice Cloning with your personal data](https://github.com/netease-youdao/EmotiVoice/wiki/Voice-Cloning-with-your-personal-data) along with [DataBaker Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/DataBaker) and [LJSpeech Recipe](https://github.com/netease-youdao/EmotiVoice/tree/main/data/LJspeech).
20 | - [x] Documentation: wiki page for hardware requirements. [#30](../../issues/30)
21 |
22 | ## EmotiVoice 0.2 (2023.11.17)
23 |
24 | - [x] Support mixed Chinese and English input text. [#28](../../issues/28)
25 | - [x] Resolve bugs related to certain modal particles, to make it more robust. [#18](../../issues/18)
26 | - [x] Documentation: voice list wiki page
27 | - [x] Documentation: this roadmap.
28 |
29 | ## EmotiVoice 0.1 (2023.11.10) first public version
30 |
31 | - [x] We offer a pretrained model with over 2000 voices, supporting both Chinese and English languages.
32 | - [x] You can perform inference using the command line interface. We also offer a user-friendly web demo for easy usage.
33 | - [x] For convenient deployment, we offer a Docker image.
34 |
35 |
--------------------------------------------------------------------------------
/assets/audio/emotivoice_intro_cn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/netease-youdao/EmotiVoice/bc2de8c9eb1121237958ef154cb171e7faefc769/assets/audio/emotivoice_intro_cn.wav
--------------------------------------------------------------------------------
/assets/audio/emotivoice_intro_en.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/netease-youdao/EmotiVoice/bc2de8c9eb1121237958ef154cb171e7faefc769/assets/audio/emotivoice_intro_en.wav
--------------------------------------------------------------------------------
/cn2an/conf.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/Ailln/cn2an.
3 | """
4 |
5 | NUMBER_CN2AN = {
6 | "零": 0,
7 | "〇": 0,
8 | "一": 1,
9 | "壹": 1,
10 | "幺": 1,
11 | "二": 2,
12 | "贰": 2,
13 | "两": 2,
14 | "三": 3,
15 | "叁": 3,
16 | "四": 4,
17 | "肆": 4,
18 | "五": 5,
19 | "伍": 5,
20 | "六": 6,
21 | "陆": 6,
22 | "七": 7,
23 | "柒": 7,
24 | "八": 8,
25 | "捌": 8,
26 | "九": 9,
27 | "玖": 9,
28 | }
29 | UNIT_CN2AN = {
30 | "十": 10,
31 | "拾": 10,
32 | "百": 100,
33 | "佰": 100,
34 | "千": 1000,
35 | "仟": 1000,
36 | "万": 10000,
37 | "亿": 100000000,
38 | }
39 | UNIT_LOW_AN2CN = {
40 | 10: "十",
41 | 100: "百",
42 | 1000: "千",
43 | 10000: "万",
44 | 100000000: "亿",
45 | }
46 | NUMBER_LOW_AN2CN = {
47 | 0: "零",
48 | 1: "一",
49 | 2: "二",
50 | 3: "三",
51 | 4: "四",
52 | 5: "五",
53 | 6: "六",
54 | 7: "七",
55 | 8: "八",
56 | 9: "九",
57 | }
58 | NUMBER_UP_AN2CN = {
59 | 0: "零",
60 | 1: "壹",
61 | 2: "贰",
62 | 3: "叁",
63 | 4: "肆",
64 | 5: "伍",
65 | 6: "陆",
66 | 7: "柒",
67 | 8: "捌",
68 | 9: "玖",
69 | }
70 | UNIT_LOW_ORDER_AN2CN = [
71 | "",
72 | "十",
73 | "百",
74 | "千",
75 | "万",
76 | "十",
77 | "百",
78 | "千",
79 | "亿",
80 | "十",
81 | "百",
82 | "千",
83 | "万",
84 | "十",
85 | "百",
86 | "千",
87 | ]
88 | UNIT_UP_ORDER_AN2CN = [
89 | "",
90 | "拾",
91 | "佰",
92 | "仟",
93 | "万",
94 | "拾",
95 | "佰",
96 | "仟",
97 | "亿",
98 | "拾",
99 | "佰",
100 | "仟",
101 | "万",
102 | "拾",
103 | "佰",
104 | "仟",
105 | ]
106 | STRICT_CN_NUMBER = {
107 | "零": "零",
108 | "一": "一壹",
109 | "二": "二贰",
110 | "三": "三叁",
111 | "四": "四肆",
112 | "五": "五伍",
113 | "六": "六陆",
114 | "七": "七柒",
115 | "八": "八捌",
116 | "九": "九玖",
117 | "十": "十拾",
118 | "百": "百佰",
119 | "千": "千仟",
120 | "万": "万",
121 | "亿": "亿",
122 | }
123 | NORMAL_CN_NUMBER = {
124 | "零": "零〇",
125 | "一": "一壹幺",
126 | "二": "二贰两",
127 | "三": "三叁仨",
128 | "四": "四肆",
129 | "五": "五伍",
130 | "六": "六陆",
131 | "七": "七柒",
132 | "八": "八捌",
133 | "九": "九玖",
134 | "十": "十拾",
135 | "百": "百佰",
136 | "千": "千仟",
137 | "万": "万",
138 | "亿": "亿",
139 | }
140 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 |
7 | # a list of ubuntu apt packages to install
8 | # system_packages:
9 | # - "libgl1-mesa-glx"
10 | # - "libglib2.0-0"
11 |
12 | python_version: "3.8"
13 | python_packages:
14 | - "torch==2.0.1"
15 | - "torchaudio==2.0.2"
16 | - "g2p-en==2.1.0"
17 | - "jieba==0.42.1"
18 | - "numba==0.58.1"
19 | - "numpy==1.24.4"
20 | - "pypinyin==0.49.0"
21 | - "scipy==1.10.1"
22 | - "soundfile==0.12.1"
23 | - "transformers==4.26.1"
24 | - "yacs==0.1.8"
25 |
26 | run:
27 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
28 |
29 | # predict.py defines how predictions are run on your model
30 | predict: "predict.py:Predictor"
31 |
--------------------------------------------------------------------------------
/config/joint/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | # with thanks to arjun-234 in https://github.com/netease-youdao/EmotiVoice/pull/38.
18 | def get_labels_length(file_path):
19 | """
20 | Return labels and their count in a file.
21 |
22 | Args:
23 | file_path (str): The path to the file containing the labels.
24 |
25 | Returns:
26 | list: labels; int: The number of labels in the file.
27 | """
28 | with open(file_path, encoding = "UTF-8") as f:
29 | tokens = [t.strip() for t in f.readlines()]
30 | return tokens, len(tokens)
31 |
32 | class Config:
33 | #### PATH ####
34 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__"))
35 | DATA_DIR = ROOT_DIR + "/data/youdao/"
36 | train_data_path = DATA_DIR + "train_am/datalist.jsonl"
37 | valid_data_path = DATA_DIR + "valid_am/datalist.jsonl"
38 | output_directory = ROOT_DIR + "/outputs"
39 | speaker2id_path = DATA_DIR + "text/speaker2"
40 | emotion2id_path = DATA_DIR + "text/emotion"
41 | pitch2id_path = DATA_DIR + "text/pitch"
42 | energy2id_path = DATA_DIR + "text/energy"
43 | speed2id_path = DATA_DIR + "text/speed"
44 | bert_path = 'WangZeJun/simbert-base-chinese'
45 | token_list_path = DATA_DIR + "text/tokenlist"
46 | style_encoder_ckpt = ROOT_DIR + "/outputs/style_encoder/ckpt/checkpoint_163431"
47 | tmp_dir = ROOT_DIR + "/tmp"
48 | model_config_path = ROOT_DIR + "/config/joint/config.yaml"
49 |
50 | #### Model ####
51 | bert_hidden_size = 768
52 | style_dim = 128
53 | downsample_ratio = 1 # Whole Model
54 |
55 | #### Text ####
56 | tokens, n_symbols = get_labels_length(token_list_path)
57 | sep = " "
58 |
59 | #### Speaker ####
60 | speakers, speaker_n_labels = get_labels_length(speaker2id_path)
61 |
62 | #### Emotion ####
63 | emotions, emotion_n_labels = get_labels_length(emotion2id_path)
64 |
65 | #### Speed ####
66 | speeds, speed_n_labels = get_labels_length(speed2id_path)
67 |
68 | #### Pitch ####
69 | pitchs, pitch_n_labels = get_labels_length(pitch2id_path)
70 |
71 | #### Energy ####
72 | energys, energy_n_labels = get_labels_length(energy2id_path)
73 |
74 | #### Train ####
75 | # epochs = 10
76 | lr = 1e-3
77 | lr_warmup_steps = 4000
78 | kl_warmup_steps = 60_000
79 | grad_clip_thresh = 1.0
80 | batch_size = 16
81 | train_steps = 10_000_000
82 | opt_level = "O1"
83 | seed = 1234
84 | iters_per_validation= 1000
85 | iters_per_checkpoint= 10000
86 |
87 |
88 | #### Audio ####
89 | sampling_rate = 16_000
90 | max_db = 1
91 | min_db = 0
92 | trim = True
93 |
94 | #### Stft ####
95 | filter_length = 1024
96 | hop_length = 256
97 | win_length = 1024
98 | window = "hann"
99 |
100 | #### Mel ####
101 | n_mel_channels = 80
102 | mel_fmin = 0
103 | mel_fmax = 8000
104 |
105 | #### Pitch ####
106 | pitch_min = 80
107 | pitch_max = 400
108 | pitch_stats = [225.089, 53.78]
109 |
110 | #### Energy ####
111 | energy_stats = [30.610, 21.78]
112 |
113 |
114 | #### Infernce ####
115 | gta = False
116 |
--------------------------------------------------------------------------------
/config/joint/config.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 |
5 | sr: 16000 # sr
6 | n_fft: 1024 # FFT size (samples).
7 | hop_length: 256 # Hop size (samples). 12.5ms
8 | win_length: 1024 # Window length (samples). 50ms
9 | # If set to null it will be the same as fft_size.
10 | window: "hann" # Window function.
11 |
12 | fmin: 0 # Minimum frequency of Mel basis.
13 | fmax: null # Maximum frequency of Mel basis.
14 | n_mels: 80 # The number of mel basis.
15 |
16 | pitch_min: 80 # Minimum f0 in linear domain for pitch extraction.
17 | pitch_max: 400 # Maximum f0 in linear domain for pitch extraction.
18 |
19 | segment_size: 32
20 |
21 |
22 | cut_sil: True
23 |
24 | shuffle: True
25 |
26 | pretrained_am: "" # absolute path
27 | pretrained_vocoder: "" # absolute path
28 | pretrained_discriminator: "" # absolute path
29 |
30 | max_db: 1
31 | min_db: 0
32 |
33 | ###########################################################
34 | # MODEL SETTING #
35 | ###########################################################
36 | model:
37 | speaker_embed_dim: 384
38 | bert_embedding: 768
39 | #### encoder ####
40 | lang_embed_dim: 0
41 | encoder_n_layers: 4
42 | encoder_n_heads: 8
43 | encoder_n_hidden: 384
44 | encoder_p_dropout: 0.2
45 | encoder_kernel_size_conv_mod: 3
46 | encoder_kernel_size_depthwise: 7
47 | #### decoder ####
48 | decoder_n_layers: 4
49 | decoder_n_heads: 8
50 | decoder_n_hidden: 384
51 | decoder_p_dropout: 0.2
52 | decoder_kernel_size_conv_mod: 3
53 | decoder_kernel_size_depthwise: 31
54 | #### prosodic ####
55 | bottleneck_size_p: 4
56 | bottleneck_size_u: 256
57 | ref_enc_filters: [32, 32, 64, 64, 128, 128]
58 | ref_enc_size: 3
59 | ref_enc_strides: [1, 2, 1, 2, 1]
60 | ref_enc_pad: [1, 1]
61 | ref_enc_gru_size: 32
62 | ref_attention_dropout: 0.2
63 | token_num: 32
64 | predictor_kernel_size: 5
65 | stop_prosodic_gradient: False
66 | ref_p_dropout: 0.1
67 | ref_n_heads: 4
68 | #### variance ####
69 | variance_n_hidden: 384
70 | variance_n_layers: 3
71 | variance_kernel_size: 3
72 | variance_p_dropout: 0.1
73 | variance_embed_kernel_size: 9
74 | variance_embde_p_dropout: 0.0
75 | stop_pitch_gradient: False
76 | stop_duration_gradient: False
77 | duration_p_dropout: 0.5
78 | duration_n_layers: 2
79 | duration_kernel_size: 3
80 | #### postnet ####
81 | postnet_layers: 0
82 | postnet_chans: 256
83 | postnet_filts: 5
84 | use_batch_norm: True
85 | postnet_dropout_rate: 0.5
86 | #### generator ####
87 | resblock: "1"
88 | upsample_rates: [8,8,2,2]
89 | upsample_kernel_sizes: [16,16,4,4]
90 | initial_channel: 80
91 | upsample_initial_channel: 512
92 | resblock_kernel_sizes: [3,7,11]
93 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
94 | r: 1
95 | ###########################################################
96 | # OPTIMIZER SETTING #
97 | ###########################################################
98 | optimizer:
99 | lr: 1.25e-5
100 | betas: [0.5, 0.9]
101 | eps: 1.0e-9
102 | weight_decay: 0.0
103 | scheduler:
104 | gamma: 0.999875
105 |
--------------------------------------------------------------------------------
/config/template.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | # with thanks to arjun-234 in https://github.com/netease-youdao/EmotiVoice/pull/38.
18 | def get_labels_length(file_path):
19 | """
20 | Return labels and their count in a file.
21 |
22 | Args:
23 | file_path (str): The path to the file containing the labels.
24 |
25 | Returns:
26 | list: labels; int: The number of labels in the file.
27 | """
28 | with open(file_path, encoding = "UTF-8") as f:
29 | tokens = [t.strip() for t in f.readlines()]
30 | return tokens, len(tokens)
31 |
32 | class Config:
33 | #### PATH ####
34 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__"))
35 | DATA_DIR = ROOT_DIR + "/"
36 | # Change datalist.jsonl to datalist_mfa.jsonl if you have run MFA
37 | train_data_path = DATA_DIR + "/train/datalist.jsonl"
38 | valid_data_path = DATA_DIR + "/valid/datalist.jsonl"
39 | output_directory = ROOT_DIR + "/"
40 | speaker2id_path = ROOT_DIR + "//speaker"
41 | emotion2id_path = ROOT_DIR + "//emotion"
42 | pitch2id_path = ROOT_DIR + "//pitch"
43 | energy2id_path = ROOT_DIR + "//energy"
44 | speed2id_path = ROOT_DIR + "//speed"
45 | bert_path = 'WangZeJun/simbert-base-chinese'
46 | token_list_path = ROOT_DIR + "//tokenlist"
47 | style_encoder_ckpt = ROOT_DIR + "/outputs/style_encoder/ckpt/checkpoint_163431"
48 | tmp_dir = output_directory + "/tmp"
49 | model_config_path = ROOT_DIR + "/config/joint/config.yaml"
50 |
51 | #### Model ####
52 | bert_hidden_size = 768
53 | style_dim = 128
54 | downsample_ratio = 1 # Whole Model
55 |
56 | #### Text ####
57 | tokens, n_symbols = get_labels_length(token_list_path)
58 | sep = " "
59 |
60 | #### Speaker ####
61 | speakers, speaker_n_labels = get_labels_length(speaker2id_path)
62 |
63 | #### Emotion ####
64 | emotions, emotion_n_labels = get_labels_length(emotion2id_path)
65 |
66 | #### Speed ####
67 | speeds, speed_n_labels = get_labels_length(speed2id_path)
68 |
69 | #### Pitch ####
70 | pitchs, pitch_n_labels = get_labels_length(pitch2id_path)
71 |
72 | #### Energy ####
73 | energys, energy_n_labels = get_labels_length(energy2id_path)
74 |
75 | #### Train ####
76 | # epochs = 10
77 | lr = 1e-3
78 | lr_warmup_steps = 4000
79 | kl_warmup_steps = 60_000
80 | grad_clip_thresh = 1.0
81 | batch_size = 8
82 | train_steps = 10_000_000
83 | opt_level = "O1"
84 | seed = 1234
85 | iters_per_validation= 1000
86 | iters_per_checkpoint= 5000
87 |
88 |
89 | #### Audio ####
90 | sampling_rate = 16_000
91 | max_db = 1
92 | min_db = 0
93 | trim = True
94 |
95 | #### Stft ####
96 | filter_length = 1024
97 | hop_length = 256
98 | win_length = 1024
99 | window = "hann"
100 |
101 | #### Mel ####
102 | n_mel_channels = 80
103 | mel_fmin = 0
104 | mel_fmax = 8000
105 |
106 | #### Pitch ####
107 | pitch_min = 80
108 | pitch_max = 400
109 | pitch_stats = [225.089, 53.78]
110 |
111 | #### Energy ####
112 | energy_stats = [30.610, 21.78]
113 |
114 |
115 | #### Infernce ####
116 | gta = False
117 |
--------------------------------------------------------------------------------
/data/DataBaker/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 😊 DataBaker Recipe
4 |
5 | This is the recipe of Chinese single female speaker TTS model with DataBaker corpus.
6 |
7 | ## Guide For Finetuning
8 | - [Environments Installation](#environments-installation)
9 | - [Step0 Download Data](#step0-download-data)
10 | - [Step1 Preprocess Data](#step1-preprocess-data)
11 | - [Step2 Run MFA (Optional)](#step2-run-mfa-optional-since-we-already-have-labeled-prosody)
12 | - [Step3 Prepare for training](#step3-prepare-for-training)
13 | - [Step4 Start training](#step4-finetune-your-model)
14 | - [Step5 Inference](#step5-inference)
15 |
16 | ### Environments Installation
17 |
18 | create conda enviroment
19 | ```bash
20 | conda create -n EmotiVoice python=3.8 -y
21 | conda activate EmotiVoice
22 | ```
23 | then run:
24 | ```bash
25 | pip install EmotiVoice[train]
26 | # or
27 | git clone https://github.com/netease-youdao/EmotiVoice
28 | pip install -e .[train]
29 | ```
30 | Additionally, it is important to prepare the pre-trained models as mentioned in the [pretrained models](https://github.com/netease-youdao/EmotiVoice/wiki/Pretrained-models).
31 |
32 | ### Step0 Download Data
33 |
34 | ```bash
35 | mkdir data/DataBaker/raw
36 |
37 | # download
38 | # please download the data from https://en.data-baker.com/datasets/freeDatasets/, and place the extracted BZNSYP folder under data/DataBaker/raw
39 | ```
40 |
41 | ### Step1 Preprocess Data
42 |
43 | For this recipe, since DataBaker has already provided phoneme labels, we will simply utilize that information.
44 |
45 | ```bash
46 | # format data
47 | python data/DataBaker/src/step1_clean_raw_data.py \
48 | --data_dir data/DataBaker
49 |
50 | # get phoneme
51 | python data/DataBaker/src/step2_get_phoneme.py \
52 | --data_dir data/DataBaker
53 | ```
54 |
55 | If you have prepared your own data with only text labels, you can obtain phonemes using the Text-to-Speech (TTS) frontend. For example, you can run the following command: `python data/DataBaker/src/step2_get_phoneme.py --data_dir data/DataBaker --generate_phoneme True`. However, please note that in this specific DataBaker's recipe, you should omit this command.
56 |
57 |
58 |
59 | ### Step2 Run MFA (Optional, since we already have labeled prosody)
60 |
61 | Please be aware that in this particular DataBaker's recipe, **you should skip this step**. Nonetheless, if you have already prepared your own data with only text labels, the following commands might assist you:
62 |
63 | ```bash
64 | # MFA environment install
65 | conda install -c conda-forge kaldi sox librosa biopython praatio tqdm requests colorama pyyaml pynini openfst baumwelch ngram postgresql -y
66 | pip install pgvector hdbscan montreal-forced-aligner
67 |
68 | # MFA Step1
69 | python mfa/step1_create_dataset.py \
70 | --data_dir data/DataBaker
71 |
72 | # MFA Step2
73 | python mfa/step2_prepare_data.py \
74 | --dataset_dir data/DataBaker/mfa \
75 | --wav data/DataBaker/mfa/wav.txt \
76 | --speaker data/DataBaker/mfa/speaker.txt \
77 | --text data/DataBaker/mfa/text.txt
78 |
79 | # MFA Step3
80 | python mfa/step3_prepare_special_tokens.py \
81 | --special_tokens data/DataBaker/mfa/special_token.txt
82 |
83 | # MFA Step4
84 | python mfa/step4_convert_text_to_phn.py \
85 | --text data/DataBaker/mfa/text.txt \
86 | --special_tokens data/DataBaker/mfa/special_token.txt \
87 | --output data/DataBaker/mfa/text.txt
88 |
89 | # MFA Step5
90 | python mfa/step5_prepare_alignment.py \
91 | --wav data/DataBaker/mfa/wav.txt \
92 | --speaker data/DataBaker/mfa/speaker.txt \
93 | --text data/DataBaker/mfa/text.txt \
94 | --special_tokens data/DataBaker/mfa/special_token.txt \
95 | --pronounciation_dict data/DataBaker/mfa/mfa_pronounciation_dict.txt \
96 | --output_dir data/DataBaker/mfa/lab
97 |
98 | # MFA Step6
99 | mfa validate \
100 | --overwrite \
101 | --clean \
102 | --single_speaker \
103 | data/DataBaker/mfa/lab \
104 | data/DataBaker/mfa/mfa_pronounciation_dict.txt
105 |
106 | mfa train \
107 | --overwrite \
108 | --clean \
109 | --single_speaker \
110 | data/DataBaker/mfa/lab \
111 | data/DataBaker/mfa/mfa_pronounciation_dict.txt \
112 | data/DataBaker/mfa/mfa/mfa_model.zip \
113 | data/DataBaker/mfa/TextGrid
114 |
115 | mfa align \
116 | --single_speaker \
117 | data/DataBaker/mfa/lab \
118 | data/DataBaker/mfa/mfa_pronounciation_dict.txt \
119 | data/DataBaker/mfa/mfa/mfa_model.zip \
120 | data/DataBaker/mfa/TextGrid
121 |
122 | # MFA Step7
123 | python mfa/step7_gen_alignment_from_textgrid.py \
124 | --wav data/DataBaker/mfa/wav.txt \
125 | --speaker data/DataBaker/mfa/speaker.txt \
126 | --text data/DataBaker/mfa/text.txt \
127 | --special_tokens data/DataBaker/mfa/special_token.txt \
128 | --text_grid data/DataBaker/mfa/TextGrid \
129 | --aligned_wav data/DataBaker/mfa/aligned_wav.txt \
130 | --aligned_speaker data/DataBaker/mfa/aligned_speaker.txt \
131 | --duration data/DataBaker/mfa/duration.txt \
132 | --aligned_text data/DataBaker/mfa/aligned_text.txt \
133 | --reassign_sp True
134 |
135 | # MFA Step8
136 | python mfa/step8_make_data_list.py \
137 | --wav data/DataBaker/mfa/aligned_wav.txt \
138 | --speaker data/DataBaker/mfa/aligned_speaker.txt \
139 | --text data/DataBaker/mfa/aligned_text.txt \
140 | --duration data/DataBaker/mfa/duration.txt \
141 | --datalist_path data/DataBaker/mfa/datalist.jsonl
142 |
143 | # MFA Step9
144 | python mfa/step9_datalist_from_mfa.py \
145 | --data_dir data/DataBaker
146 | ```
147 |
148 | ### Step3 Prepare for training
149 |
150 | ```bash
151 | python prepare_for_training.py --data_dir data/DataBaker --exp_dir exp/DataBaker
152 | ```
153 | __Please check and change the training and valid file path in the `exp/DataBaker/config/config.py`, especially:__
154 | - `model_config_path`: corresponing model config file
155 | - `DATA_DIR`: data dir
156 | - `train_data_path` and `valid_data_path`: training file and valid file. Change to `datalist_mfa.jsonl` if you run Step2
157 | - `batch_size`
158 |
159 | ### Step4 Finetune Your Model
160 |
161 | ```bash
162 | torchrun \
163 | --nproc_per_node=1 \
164 | --master_port 8008 \
165 | train_am_vocoder_joint.py \
166 | --config_folder exp/DataBaker/config \
167 | --load_pretrained_model True
168 | ```
169 |
170 | Training tips:
171 |
172 | - You can run tensorboad by
173 | ```
174 | tensorboard --logdir=exp/DataBaker
175 | ```
176 | - The model checkpoints are saved at `exp/DataBaker/ckpt`.
177 | - The bert features are extracted in the first epoch and saved in `exp/DataBaker/tmp/` folder, you can change the path in `exp/DataBaker/config/config.py`.
178 |
179 |
180 | ### Step5 Inference
181 |
182 |
183 | ```bash
184 | TEXT=data/inference/text
185 | python inference_am_vocoder_exp.py \
186 | --config_folder exp/DataBaker/config \
187 | --checkpoint g_00010000 \
188 | --test_file $TEXT
189 | ```
190 | __Please change the speaker names in the `data/inference/text`__
191 |
192 | the synthesized speech is under `exp/DataBaker/test_audio`.
--------------------------------------------------------------------------------
/data/DataBaker/src/step0_download.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | # please download the data from https://en.data-baker.com/datasets/freeDatasets/, and place the extracted BZNSYP folder under data/DataBaker/raw
4 |
5 |
--------------------------------------------------------------------------------
/data/DataBaker/src/step1_clean_raw_data.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/wenet-e2e/wetts.
3 | """
4 |
5 | import os
6 | import argparse
7 | import soundfile as sf
8 | import librosa
9 | import jsonlines
10 | from tqdm import tqdm
11 | import re
12 |
13 | def main(args):
14 |
15 | ROOT_DIR=os.path.abspath(args.data_dir)
16 | RAW_DIR=f"{ROOT_DIR}/raw"
17 | WAV_DIR=f"{ROOT_DIR}/wavs"
18 | TEXT_DIR=f"{ROOT_DIR}/text"
19 |
20 | os.makedirs(WAV_DIR, exist_ok=True)
21 | os.makedirs(TEXT_DIR, exist_ok=True)
22 |
23 |
24 | with open(f"{RAW_DIR}/BZNSYP/ProsodyLabeling/000001-010000.txt", encoding="utf-8") as f, \
25 | jsonlines.open(f"{TEXT_DIR}/data.jsonl", "w") as fout1:
26 |
27 | lines = f.readlines()
28 | for i in tqdm(range(0, len(lines), 2)):
29 | key = lines[i][:6]
30 |
31 | ### Text
32 | content_org = lines[i][7:].strip()
33 | content = re.sub("[。,、“”?:……!( )—;]", "", content_org)
34 | content_org = re.sub("#\d", "", content_org)
35 |
36 | chars = []
37 | prosody = {}
38 | j = 0
39 | while j < len(content):
40 | if content[j] == "#":
41 | prosody[len(chars) - 1] = content[j : j + 2]
42 | j += 2
43 | else:
44 | chars.append(content[j])
45 | j += 1
46 |
47 | if key == "005107":
48 | lines[i + 1] = lines[i + 1].replace(" ng1", " en1")
49 | if key == "002365":
50 | continue
51 |
52 | syllable = lines[i + 1].strip().split()
53 | s_index = 0
54 | phones = []
55 | phone = []
56 | for k, char in enumerate(chars):
57 | # 儿化音处理
58 | er_flag = False
59 | if char == "儿" and (s_index == len(syllable) or syllable[s_index][0:2] != "er"):
60 | er_flag = True
61 | else:
62 | phones.append(syllable[s_index])
63 | #phones.extend(lexicon[syllable[s_index]])
64 | s_index += 1
65 |
66 |
67 | if k in prosody:
68 | if er_flag:
69 | phones[-1] = prosody[k]
70 | else:
71 | phones.append(prosody[k])
72 | else:
73 | phones.append("#0")
74 |
75 | ### Wav
76 | path = f"{RAW_DIR}/BZNSYP/Wave/{key}.wav"
77 | wav_path = f"{WAV_DIR}/{key}.wav"
78 | y, sr = sf.read(path)
79 | y_16=librosa.resample(y, orig_sr=sr, target_sr=16_000)
80 | sf.write(wav_path, y_16, 16_000)
81 |
82 | fout1.write({
83 | "key":key,
84 | "wav_path":wav_path,
85 | "speaker":"BZNSYP",
86 | "text":[""] + phones[:-1] + [""],
87 | "original_text":content_org,
88 | })
89 |
90 |
91 | return
92 |
93 |
94 | if __name__ == "__main__":
95 | p = argparse.ArgumentParser()
96 | p.add_argument('--data_dir', type=str, required=True)
97 | args = p.parse_args()
98 |
99 | main(args)
100 |
--------------------------------------------------------------------------------
/data/DataBaker/src/step2_get_phoneme.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | import jsonlines
18 | import json
19 | from tqdm import tqdm
20 | from multiprocessing.pool import ThreadPool
21 | from functools import partial
22 |
23 | import re
24 | import sys
25 | DIR=os.path.dirname(os.path.abspath("__file__"))
26 | sys.path.append(DIR)
27 |
28 | from frontend_cn import split_py, tn_chinese
29 | from frontend_en import read_lexicon, G2p
30 | from frontend import contains_chinese, re_digits, g2p_cn
31 |
32 | # re_english_word = re.compile('([a-z\-\.\']+|\d+[\d\.]*)', re.I)
33 | re_english_word = re.compile('([^\u4e00-\u9fa5]+|[ \u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09\u4e00-\u9fa5]+)', re.I)
34 |
35 | def g2p_cn_en(text, g2p, lexicon):
36 | # Our policy dictates that if the text contains Chinese, digits are to be converted into Chinese.
37 | text=tn_chinese(text)
38 | parts = re_english_word.split(text)
39 | parts=list(filter(None, parts))
40 | tts_text = [""]
41 | chartype = ''
42 | text_contains_chinese = contains_chinese(text)
43 | for part in parts:
44 | if part == ' ' or part == '': continue
45 | if re_digits.match(part) and (text_contains_chinese or chartype == '') or contains_chinese(part):
46 | if chartype == 'en':
47 | tts_text.append('eng_cn_sp')
48 | phoneme = g2p_cn(part).split()[1:-1]
49 | chartype = 'cn'
50 | elif re_english_word.match(part):
51 | if chartype == 'cn':
52 | if "sp" in tts_text[-1]:
53 | ""
54 | else:
55 | tts_text.append('cn_eng_sp')
56 | phoneme = get_eng_phoneme(part, g2p, lexicon).split()
57 | if not phoneme :
58 | # tts_text.pop()
59 | continue
60 | else:
61 | chartype = 'en'
62 | else:
63 | continue
64 | tts_text.extend( phoneme )
65 |
66 | tts_text=" ".join(tts_text).split()
67 | if "sp" in tts_text[-1]:
68 | tts_text.pop()
69 | tts_text.append("")
70 |
71 | return " ".join(tts_text)
72 |
73 | def get_eng_phoneme(text, g2p, lexicon):
74 | """
75 | english g2p
76 | """
77 | filters = {",", " ", "'"}
78 | phones = []
79 | words = list(filter(lambda x: x not in {"", " "}, re.split(r"([,;.\-\?\!\s+])", text)))
80 |
81 | for w in words:
82 | if w.lower() in lexicon:
83 |
84 | for ph in lexicon[w.lower()]:
85 | if ph not in filters:
86 | phones += ["[" + ph + "]"]
87 |
88 | if "sp" not in phones[-1]:
89 | phones += ["engsp1"]
90 | else:
91 | phone=g2p(w)
92 | if not phone:
93 | continue
94 |
95 | if phone[0].isalnum():
96 |
97 | for ph in phone:
98 | if ph not in filters:
99 | phones += ["[" + ph + "]"]
100 | if ph == " " and "sp" not in phones[-1]:
101 | phones += ["engsp1"]
102 | elif phone == " ":
103 | continue
104 | elif phones:
105 | phones.pop() # pop engsp1
106 | phones.append("engsp4")
107 | if phones and "engsp" in phones[-1]:
108 | phones.pop()
109 |
110 |
111 | return " ".join(phones)
112 |
113 |
114 | def onetime(resource, sample):
115 |
116 | text=sample["text"]
117 | # del sample["original_text"]
118 |
119 | phoneme = get_phoneme(text, resource["g2p"]).split()
120 |
121 | sample["text"]=phoneme
122 | # sample["original_text"]=text
123 | sample["prompt"]=sample["original_text"]
124 |
125 | return sample
126 |
127 | def onetime2(resource, sample):
128 |
129 | text=sample["original_text"]
130 | del sample["original_text"]
131 | try:
132 | phoneme = g2p_cn_en(text, resource["g2p_en"], resource["lexicon"]).split()#g2p_cn_eng_mix(text, resource["g2p_en"], resource["lexicon"]).split()
133 | except:
134 | print("Warning!!! phoneme get error! " + \
135 | "Please check text")
136 | print("Text is: ", text)
137 | return ""
138 |
139 | if not phoneme:
140 | return ""
141 |
142 | sample["text"]=phoneme
143 | sample["original_text"]=text
144 | sample["prompt"]=sample["original_text"]
145 |
146 | return sample
147 |
148 | def get_phoneme(text, g2p):
149 | special_tokens = {"#0":"sp0", "#1":"sp1", "#2":"sp2", "#3":"sp3", "#4":"sp4", "":""}
150 | phones = []
151 |
152 | for ph in text:
153 | if ph not in special_tokens:
154 | phs = g2p(ph)
155 | phones.extend([ph for ph in phs if ph])
156 | else:
157 | phones.append(special_tokens[ph])
158 |
159 | return " ".join(phones)
160 |
161 |
162 |
163 | def main(args):
164 |
165 | ROOT_DIR=args.data_dir
166 | TRAIN_DIR=f"{ROOT_DIR}/train"
167 | VALID_DIR=f"{ROOT_DIR}/valid"
168 | TEXT_DIR=f"{ROOT_DIR}/text"
169 |
170 | os.makedirs(TRAIN_DIR, exist_ok=True)
171 | os.makedirs(VALID_DIR, exist_ok=True)
172 |
173 | lexicon = read_lexicon(f"{DIR}/lexicon/librispeech-lexicon.txt")
174 |
175 | g2p = G2p()
176 |
177 | resource={
178 | "g2p":split_py,
179 | "g2p_en":g2p,
180 | "lexicon":lexicon,
181 | }
182 |
183 | with jsonlines.open(f"{TEXT_DIR}/data.jsonl") as f:
184 | data = list(f)
185 |
186 | new_data=[]
187 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl", "w") as f:
188 | for sample in tqdm(data):
189 | if not args.generate_phoneme:
190 | sample = onetime(resource, sample)
191 | else:
192 | sample = onetime2(resource, sample)
193 | if not sample:
194 | continue
195 | f.write(sample)
196 | new_data.append(sample)
197 |
198 | with jsonlines.open(f"{TRAIN_DIR}/datalist.jsonl", "w") as f:
199 | for sample in tqdm(new_data[:-3]):
200 | f.write(sample)
201 |
202 | with jsonlines.open(f"{VALID_DIR}/datalist.jsonl", "w") as f:
203 | for sample in tqdm(data[-3:]):
204 | f.write(sample)
205 |
206 |
207 | return
208 |
209 | if __name__ == "__main__":
210 |
211 | p = argparse.ArgumentParser()
212 | p.add_argument('--data_dir', type=str, required=True)
213 | p.add_argument('--generate_phoneme', type=bool, default=False)
214 | args = p.parse_args()
215 |
216 | main(args)
--------------------------------------------------------------------------------
/data/LJspeech/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 😊 LJSpeech Recipe
4 |
5 | This is the recipe of English single female speaker TTS model with LJSpeech corpus.
6 |
7 | ## Guide For Finetuning
8 | - [Environments Installation](#environments-installation)
9 | - [Step0 Download Data](#step0-download-data)
10 | - [Step1 Preprocess Data](#step1-preprocess-data)
11 | - [Step2 Run MFA (Optional, but Recommended)](#step2-run-mfa-optional-but-recommended)
12 | - [Step3 Prepare for training](#step3-prepare-for-training)
13 | - [Step4 Start training](#step4-finetune-your-model)
14 | - [Step5 Inference](#step5-inference)
15 |
16 | Run EmotiVoice Finetuning on Google Colab Notebook! [](https://colab.research.google.com/drive/1dDAyjoYGcDGwYpHI3Oj2_OIV-7DIdx2L?usp=sharing)
17 |
18 | ### Environments Installation
19 |
20 | create conda enviroment
21 | ```bash
22 | conda create -n EmotiVoice python=3.8 -y
23 | conda activate EmotiVoice
24 | ```
25 | then run:
26 | ```bash
27 | pip install EmotiVoice[train]
28 | # or
29 | git clone https://github.com/netease-youdao/EmotiVoice
30 | pip install -e .[train]
31 | ```
32 | Additionally, it is important to prepare the pre-trained models as mentioned in the [pretrained models](https://github.com/netease-youdao/EmotiVoice/wiki/Pretrained-models).
33 |
34 | ### Step0 Download Data
35 |
36 | ```bash
37 | mkdir data/LJspeech/raw
38 |
39 | # download
40 | wget -P data/LJspeech/raw http://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
41 | # extract
42 | tar -xjf data/LJspeech/raw/LJSpeech-1.1.tar.bz2 -C data/LJspeech/raw
43 | ```
44 |
45 | ### Step1 Preprocess Data
46 |
47 | ```bash
48 | # format data
49 | python data/LJspeech/src/step1_clean_raw_data.py \
50 | --data_dir data/LJspeech
51 |
52 | # get phoneme
53 | python data/LJspeech/src/step2_get_phoneme.py \
54 | --data_dir data/LJspeech
55 | ```
56 |
57 | ### Step2 Run MFA (Optional, but Recommended!)
58 |
59 | ```bash
60 | # MFA environment install
61 | conda install -c conda-forge kaldi sox librosa biopython praatio tqdm requests colorama pyyaml pynini openfst baumwelch ngram postgresql -y
62 | pip install pgvector hdbscan montreal-forced-aligner
63 |
64 | # MFA Step1
65 | python mfa/step1_create_dataset.py \
66 | --data_dir data/LJspeech
67 |
68 | # MFA Step2
69 | python mfa/step2_prepare_data.py \
70 | --dataset_dir data/LJspeech/mfa \
71 | --wav data/LJspeech/mfa/wav.txt \
72 | --speaker data/LJspeech/mfa/speaker.txt \
73 | --text data/LJspeech/mfa/text.txt
74 |
75 | # MFA Step3
76 | python mfa/step3_prepare_special_tokens.py \
77 | --special_tokens data/LJspeech/mfa/special_token.txt
78 |
79 | # MFA Step4
80 | python mfa/step4_convert_text_to_phn.py \
81 | --text data/LJspeech/mfa/text.txt \
82 | --special_tokens data/LJspeech/mfa/special_token.txt \
83 | --output data/LJspeech/mfa/text.txt
84 |
85 | # MFA Step5
86 | python mfa/step5_prepare_alignment.py \
87 | --wav data/LJspeech/mfa/wav.txt \
88 | --speaker data/LJspeech/mfa/speaker.txt \
89 | --text data/LJspeech/mfa/text.txt \
90 | --special_tokens data/LJspeech/mfa/special_token.txt \
91 | --pronounciation_dict data/LJspeech/mfa/mfa_pronounciation_dict.txt \
92 | --output_dir data/LJspeech/mfa/lab
93 |
94 | # MFA Step6
95 | mfa validate \
96 | --overwrite \
97 | --clean \
98 | --single_speaker \
99 | data/LJspeech/mfa/lab \
100 | data/LJspeech/mfa/mfa_pronounciation_dict.txt
101 |
102 | mfa train \
103 | --overwrite \
104 | --clean \
105 | --single_speaker \
106 | data/LJspeech/mfa/lab \
107 | data/LJspeech/mfa/mfa_pronounciation_dict.txt \
108 | data/LJspeech/mfa/mfa/mfa_model.zip \
109 | data/LJspeech/mfa/TextGrid
110 |
111 | mfa align \
112 | --single_speaker \
113 | data/LJspeech/mfa/lab \
114 | data/LJspeech/mfa/mfa_pronounciation_dict.txt \
115 | data/LJspeech/mfa/mfa/mfa_model.zip \
116 | data/LJspeech/mfa/TextGrid
117 |
118 | # MFA Step7
119 | python mfa/step7_gen_alignment_from_textgrid.py \
120 | --wav data/LJspeech/mfa/wav.txt \
121 | --speaker data/LJspeech/mfa/speaker.txt \
122 | --text data/LJspeech/mfa/text.txt \
123 | --special_tokens data/LJspeech/mfa/special_token.txt \
124 | --text_grid data/LJspeech/mfa/TextGrid \
125 | --aligned_wav data/LJspeech/mfa/aligned_wav.txt \
126 | --aligned_speaker data/LJspeech/mfa/aligned_speaker.txt \
127 | --duration data/LJspeech/mfa/duration.txt \
128 | --aligned_text data/LJspeech/mfa/aligned_text.txt \
129 | --reassign_sp True
130 |
131 | # MFA Step8
132 | python mfa/step8_make_data_list.py \
133 | --wav data/LJspeech/mfa/aligned_wav.txt \
134 | --speaker data/LJspeech/mfa/aligned_speaker.txt \
135 | --text data/LJspeech/mfa/aligned_text.txt \
136 | --duration data/LJspeech/mfa/duration.txt \
137 | --datalist_path data/LJspeech/mfa/datalist.jsonl
138 |
139 | # MFA Step9
140 | python mfa/step9_datalist_from_mfa.py \
141 | --data_dir data/LJspeech
142 | ```
143 |
144 | ### Step3 Prepare for training
145 |
146 | ```bash
147 | python prepare_for_training.py --data_dir data/LJspeech --exp_dir exp/LJspeech
148 | ```
149 | __Please check and change the training and valid file path in the `exp/LJspeech/config/config.py`, especially:__
150 | - `model_config_path`: corresponing model config file
151 | - `DATA_DIR`: data dir
152 | - `train_data_path` and `valid_data_path`: training file and valid file. Change to `datalist_mfa.jsonl` if you run Step2
153 | - `batch_size`
154 |
155 | ### Step4 Finetune Your Model
156 |
157 | ```bash
158 | torchrun \
159 | --nproc_per_node=1 \
160 | --master_port 8008 \
161 | train_am_vocoder_joint.py \
162 | --config_folder exp/LJspeech/config \
163 | --load_pretrained_model True
164 | ```
165 |
166 | Training tips:
167 |
168 | - You can run tensorboad by
169 | ```
170 | tensorboard --logdir=exp/LJspeech
171 | ```
172 | - The model checkpoints are saved at `exp/LJspeech/ckpt`.
173 | - The bert features are extracted in the first epoch and saved in `exp/LJspeech/tmp/` folder, you can change the path in `exp/LJspeech/config/config.py`.
174 |
175 |
176 | ### Step5 Inference
177 |
178 |
179 | ```bash
180 | TEXT=data/inference/text
181 | python inference_am_vocoder_exp.py \
182 | --config_folder exp/LJspeech/config \
183 | --checkpoint g_00010000 \
184 | --test_file $TEXT
185 | ```
186 | __Please change the speaker name in the `data/inference/text`__
187 |
188 | the synthesized speech is under `exp/LJspeech/test_audio`.
--------------------------------------------------------------------------------
/data/LJspeech/src/step0_download.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | wget http://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
5 |
6 | tar -xjf LJSpeech-1.1.tar.bz2
7 |
--------------------------------------------------------------------------------
/data/LJspeech/src/step1_clean_raw_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import argparse
17 | import soundfile as sf
18 | import librosa
19 | import jsonlines
20 | from tqdm import tqdm
21 |
22 | def main(args):
23 |
24 | ROOT_DIR=os.path.abspath(args.data_dir)
25 | RAW_DIR=f"{ROOT_DIR}/raw"
26 | WAV_DIR=f"{ROOT_DIR}/wavs"
27 | TEXT_DIR=f"{ROOT_DIR}/text"
28 |
29 | os.makedirs(WAV_DIR, exist_ok=True)
30 | os.makedirs(TEXT_DIR, exist_ok=True)
31 |
32 | with open(f"{RAW_DIR}/LJSpeech-1.1/metadata.csv") as f, \
33 | jsonlines.open(f"{TEXT_DIR}/data.jsonl", "w") as fout1:
34 | # open(f"{TEXT_DIR}/text_raw", "w") as fout2:
35 | for line in tqdm(f):
36 | #### Text ####
37 | line = line.strip().split("|")
38 | name = line[0]
39 | text=line[1]
40 |
41 | #### Wav #####
42 | path = f"{RAW_DIR}/LJSpeech-1.1/wavs/{name}.wav"
43 | wav_path = f"{WAV_DIR}/{name}.wav"
44 | y, sr = sf.read(path)
45 | y_16=librosa.resample(y, orig_sr=sr, target_sr=16_000)
46 | sf.write(wav_path, y_16, 16_000)
47 |
48 | #### Write ####
49 | fout1.write({
50 | "key":name,
51 | "wav_path":wav_path,
52 | "speaker":"LJ",
53 | "original_text":text
54 | })
55 | # fout2.write(text+"\n")
56 |
57 |
58 |
59 |
60 | return
61 |
62 |
63 | if __name__ == "__main__":
64 | p = argparse.ArgumentParser()
65 | p.add_argument('--data_dir', type=str, required=True)
66 | args = p.parse_args()
67 |
68 | main(args)
69 |
--------------------------------------------------------------------------------
/data/LJspeech/src/step2_get_phoneme.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | import jsonlines
18 | import json
19 | from tqdm import tqdm
20 | from multiprocessing.pool import ThreadPool
21 | from functools import partial
22 | import re
23 | import sys
24 | DIR=os.path.dirname(os.path.abspath("__file__"))
25 | sys.path.append(DIR)
26 |
27 | from frontend_en import read_lexicon, G2p
28 |
29 |
30 | def onetime(resource, sample):
31 |
32 | text=sample["original_text"]
33 | del sample["original_text"]
34 |
35 | phoneme = get_phoneme(text, resource["g2p"], resource["lexicon"]).split()
36 |
37 | sample["text"]=phoneme
38 | sample["original_text"]=text
39 | sample["prompt"]=text
40 |
41 | return sample
42 |
43 | def get_phoneme(text, g2p, lexicon):
44 | filters = {",", " ", "'"}
45 | phones = []
46 | words = list(filter(lambda x: x not in {"", " "}, re.split(r"([,;.\-\?\!\s+])", text)))
47 |
48 | for w in words:
49 | if w.lower() in lexicon:
50 |
51 | for ph in lexicon[w.lower()]:
52 | if ph not in filters:
53 | phones += ["[" + ph + "]"]
54 |
55 | if "sp" not in phones[-1]:
56 | phones += ["engsp1"]
57 | else:
58 | phone=g2p(w)
59 | if not phone:
60 | continue
61 |
62 | if phone[0].isalnum():
63 |
64 | for ph in phone:
65 | if ph not in filters:
66 | phones += ["[" + ph + "]"]
67 | if ph == " " and "sp" not in phones[-1]:
68 | phones += ["engsp1"]
69 | elif phone == " ":
70 | continue
71 | elif phones:
72 | phones.pop() # pop engsp1
73 | phones.append("engsp4")
74 | if phones and "engsp" in phones[-1]:
75 | phones.pop()
76 |
77 | mark = "." if text[-1] != "?" else "?"
78 | phones = [""] + phones + [mark, ""]
79 | return " ".join(phones)
80 |
81 |
82 |
83 | def main(args):
84 |
85 | ROOT_DIR=args.data_dir
86 | TRAIN_DIR=f"{ROOT_DIR}/train"
87 | VALID_DIR=f"{ROOT_DIR}/valid"
88 | TEXT_DIR=f"{ROOT_DIR}/text"
89 |
90 | os.makedirs(TRAIN_DIR, exist_ok=True)
91 | os.makedirs(VALID_DIR, exist_ok=True)
92 |
93 | lexicon = read_lexicon(f"{DIR}/lexicon/librispeech-lexicon.txt")
94 |
95 | g2p = G2p()
96 |
97 | resource={
98 | "g2p":g2p,
99 | "lexicon":lexicon,
100 | }
101 |
102 | with jsonlines.open(f"{TEXT_DIR}/data.jsonl") as f:
103 | data = list(f)
104 |
105 | new_data=[]
106 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl", "w") as f:
107 | for sample in tqdm(data):
108 | sample = onetime(resource, sample)
109 | f.write(sample)
110 | new_data.append(sample)
111 |
112 | with jsonlines.open(f"{TRAIN_DIR}/datalist.jsonl", "w") as f:
113 | for sample in tqdm(new_data[:-3]):
114 | f.write(sample)
115 |
116 | with jsonlines.open(f"{VALID_DIR}/datalist.jsonl", "w") as f:
117 | for sample in tqdm(data[-3:]):
118 | f.write(sample)
119 |
120 |
121 | return
122 |
123 | if __name__ == "__main__":
124 |
125 | p = argparse.ArgumentParser()
126 | p.add_argument('--data_dir', type=str, required=True)
127 | args = p.parse_args()
128 |
129 | main(args)
--------------------------------------------------------------------------------
/data/inference/text:
--------------------------------------------------------------------------------
1 | 8051|Happy| [IH0] [M] [AA1] [T] engsp4 [V] [OY1] [S] engsp4 [AH0] engsp1 [M] [AH1] [L] [T] [IY0] engsp4 [V] [OY1] [S] engsp1 [AE1] [N] [D] engsp1 [P] [R] [AA1] [M] [P] [T] engsp4 [K] [AH0] [N] [T] [R] [OW1] [L] [D] engsp1 [T] [IY1] engsp4 [T] [IY1] engsp4 [EH1] [S] engsp1 [EH1] [N] [JH] [AH0] [N] . |Emoti-Voice - a Multi-Voice and Prompt-Controlled T-T-S Engine
2 | 8051|哭唧唧| uo3 sp1 l ai2 sp0 d ao4 sp1 b ei3 sp0 j ing1 sp3 q ing1 sp0 h ua2 sp0 d a4 sp0 x ve2 |我来到北京,清华大学
3 | 11614|第一章| d i4 sp0 i1 sp0 zh ang1 |第一章
4 | 9017|在昏暗狭小的房间内,我父亲躺在窗前的地板上,全身素白,显得身子特别长。| z ai4 sp1 h uen1 sp0 an4 sp1 x ia2 sp0 x iao3 sp0 d e5 sp1 f ang2 sp0 j ian1 sp0 n ei4 sp3 uo3 sp1 f u4 sp0 q in1 sp1 t ang3 sp0 z ai4 sp1 ch uang1 sp0 q ian2 sp0 d e5 sp1 d i4 sp0 b an3 sp0 sh ang4 sp3 q van2 sp0 sh en1 sp1 s u4 sp0 b ai2 sp3 x ian3 sp0 d e5 sp1 sh en1 sp0 z ii5 sp1 t e4 sp0 b ie2 sp0 ch ang2 |在昏暗狭小的房间内,我父亲躺在窗前的地板上,全身素白,显得身子特别长。
5 | 6097|他光着双脚,脚趾头怪模怪样地向外翻着,一双亲切的手平静地放在胸前,手指头也是弯曲的。| t a1 sp1 g uang1 sp0 zh e5 sp1 sh uang1 sp0 j iao3 sp3 j iao2 sp0 zh iii3 sp0 t ou5 sp1 g uai4 sp0 m u2 sp1 g uai4 sp0 iang4 sp0 d e5 sp1 x iang4 sp0 uai4 sp1 f an1 sp0 zh e5 sp3 i4 sp0 sh uang1 sp1 q in1 sp0 q ie4 sp0 d e5 sp0 sh ou3 sp2 p ing2 sp0 j ing4 sp0 d e5 sp1 f ang4 sp0 z ai4 sp1 x iong1 sp0 q ian2 sp3 sh ou2 sp0 zh iii3 sp0 t ou5 sp1 ie3 sp0 sh iii4 sp1 uan1 sp0 q v1 sp0 d e5 |他光着双脚,脚趾头怪模怪样地向外翻着,一双亲切的手平静地放在胸前,手指头也是弯曲的。
6 | 6671|他双目紧闭,可以看见铜钱在上面留下的黑色圆圈;和善的面孔乌青发黑,龇牙咧嘴,挺吓人的。| t a1 sp1 sh uang1 sp0 m u4 sp1 j in3 sp0 b i4 sp3 k e2 sp0 i3 sp1 k an4 sp0 j ian4 sp1 t ong2 sp0 q ian2 sp1 z ai4 sp0 sh ang4 sp0 m ian4 sp1 l iou2 sp0 x ia4 sp0 d e5 sp1 h ei1 sp0 s e4 sp1 van2 sp0 q van1 sp3 h e2 sp0 sh an4 sp0 d e5 sp1 m ian4 sp0 k ong3 sp2 u1 sp0 q ing1 sp1 f a1 sp0 h ei1 sp3 z ii1 sp0 ia2 sp0 l ie2 sp0 z uei3 sp3 t ing3 sp1 x ia4 sp0 r en2 sp0 d e5 |他双目紧闭,可以看见铜钱在上面留下的黑色圆圈;和善的面孔乌青发黑,龇牙咧嘴,挺吓人的。
7 | 6670|母亲半光着上身,穿一条红裙子,跪在地上,正在用那把我常用来锯西瓜皮的小黑梳子,将父亲那又长又软的头发从前额向脑后梳去。| m u3 sp0 q in1 sp2 b an4 sp0 g uang1 sp0 zh e5 sp1 sh ang4 sp0 sh en1 sp3 ch uan1 sp0 i4 sp0 t iao2 sp1 h ong2 sp0 q vn2 sp0 z ii5 sp3 g uei4 sp0 z ai4 sp1 d i4 sp0 sh ang5 sp3 zh eng4 sp0 z ai4 sp1 iong4 sp0 n a4 sp1 b a2 sp0 uo3 sp1 ch ang2 sp0 iong4 sp0 l ai2 sp1 j v4 sp1 x i1 sp0 g ua1 sp0 p i2 sp0 d e5 sp1 x iao3 sp0 h ei1 sp1 sh u1 sp0 z ii5 sp3 j iang1 sp1 f u4 sp0 q in1 sp1 n a4 sp1 iou4 sp0 ch ang2 sp1 iou4 sp0 r uan3 sp0 d e5 sp1 t ou2 sp0 f a4 sp3 c ong2 sp1 q ian2 sp0 e2 sp1 x iang4 sp1 n ao3 sp0 h ou4 sp1 sh u1 sp0 q v4 |母亲半光着上身,穿一条红裙子,跪在地上,正在用那把我常用来锯西瓜皮的小黑梳子,将父亲那又长又软的头发从前额向脑后梳去。
8 | 9136|母亲一直在诉说着什么,声音嘶哑而低沉,她那双浅灰色的眼睛已经浮肿,仿佛融化了似的,眼泪大滴大滴地直往下落。| m u3 sp0 q in1 sp1 i4 sp0 zh iii2 sp1 z ai4 sp1 s u4 sp0 sh uo1 sp0 zh e5 sp1 sh en2 sp0 m e5 sp3 sh eng1 sp0 in1 sp1 s ii1 sp0 ia3 sp1 er2 sp1 d i1 sp0 ch en2 sp3 t a1 sp1 n a4 sp0 sh uang1 sp1 q ian3 sp0 h uei1 sp0 s e4 sp0 d e5 sp1 ian3 sp0 j ing5 sp2 i3 sp0 j ing1 sp1 f u2 sp0 zh ong3 sp3 f ang3 sp0 f u2 sp1 r ong2 sp0 h ua4 sp0 l e5 sp1 sh iii4 sp0 d e5 sp3 ian3 sp0 l ei4 sp1 d a4 sp0 d i1 sp1 d a4 sp0 d i1 sp0 d e5 sp1 zh iii2 sp0 uang3 sp0 x ia4 sp0 l uo4 |母亲一直在诉说着什么,声音嘶哑而低沉,她那双浅灰色的眼睛已经浮肿,仿佛融化了似的,眼泪大滴大滴地直往下落。
9 | 11697|外婆拽着我的手;她长得圆滚滚的,大脑袋、大眼睛和一只滑稽可笑的松弛的鼻子。| uai4 sp0 p o2 sp1 zh uai4 sp0 zh e5 sp1 uo3 sp0 d e5 sp0 sh ou3 sp3 t a1 sp1 zh ang3 sp0 d e5 sp1 van2 sp0 g uen2 sp0 g uen3 sp0 d e5 sp3 d a4 sp0 n ao3 sp0 d ai5 sp3 d a4 sp0 ian3 sp0 j ing5 sp3 h e2 sp1 i4 sp0 zh iii1 sp1 h ua2 sp0 j i1 sp1 k e3 sp0 x iao4 sp0 d e5 sp1 s ong1 sp0 ch iii2 sp0 d e5 sp1 b i2 sp0 z ii5 |外婆拽着我的手;她长得圆滚滚的,大脑袋、大眼睛和一只滑稽可笑的松弛的鼻子。
10 | 92|她穿一身黑衣服,身上软乎乎的,特别好玩。她也在哭,但哭得有些特别,和母亲的哭声交相呼应。她全身都在颤抖,而且老是把我往父亲跟前推。我扭动身子,直往她身后躲;我感到害怕,浑身不自在。| t a1 sp0 ch uan1 sp1 i4 sp0 sh en1 sp1 h ei1 sp0 i1 sp0 f u2 sp3 sh en1 sp0 sh ang4 sp1 r uan3 sp0 h u1 sp0 h u1 sp0 d e5 sp3 t e4 sp0 b ie2 sp1 h ao3 sp0 uan2 sp3 t a1 sp0 ie3 sp1 z ai4 sp0 k u1 sp3 d an4 sp1 k u1 sp0 d e5 sp1 iou3 sp0 x ie1 sp1 t e4 sp0 b ie2 sp3 h e2 sp1 m u3 sp0 q in1 sp0 d e5 sp1 k u1 sp0 sh eng1 sp1 j iao1 sp0 x iang1 sp1 h u1 sp0 ing4 sp3 t a1 sp1 q van2 sp0 sh en1 sp1 d ou1 sp0 z ai4 sp1 ch an4 sp0 d ou3 sp3 er2 sp0 q ie3 sp1 l ao3 sp0 sh iii4 sp1 b a2 sp0 uo3 sp1 uang3 sp1 f u4 sp0 q in1 sp1 g en1 sp0 q ian5 sp0 t uei1 sp3 uo3 sp1 n iou3 sp0 d ong4 sp1 sh en1 sp0 z ii5 sp3 zh iii2 sp0 uang3 sp0 t a1 sp1 sh en1 sp0 h ou4 sp0 d uo3 sp3 uo3 sp1 g an3 sp0 d ao4 sp1 h ai4 sp0 p a4 sp3 h uen2 sp0 sh en1 sp1 b u2 sp0 z ii4 sp0 z ai5 |她穿一身黑衣服,身上软乎乎的,特别好玩。她也在哭,但哭得有些特别,和母亲的哭声交相呼应。她全身都在颤抖,而且老是把我往父亲跟前推。我扭动身子,直往她身后躲;我感到害怕,浑身不自在。
11 | 12787|我还从没有见过大人们哭,而且不明白外婆老说的那些话的意思:“跟你爹告个别吧,以后你再也看不到他啦,他死了,乖孩子,还不到年纪,不是时候啊……”| uo3 sp0 h ai2 sp1 c ong2 sp0 m ei2 sp0 iou3 sp1 j ian4 sp0 g uo4 sp1 d a4 sp0 r en2 sp0 m en5 sp1 k u1 sp3 er2 sp0 q ie3 sp1 b u4 sp0 m ing2 sp0 b ai2 sp1 uai4 sp0 p o2 sp1 l ao3 sp0 sh uo1 sp0 d e5 sp1 n a4 sp0 x ie1 sp0 h ua4 sp0 d e5 sp1 i4 sp0 s ii5 sp3 g en1 sp0 n i3 sp0 d ie1 sp1 g ao4 sp0 g e4 sp0 b ie2 sp0 b a5 sp3 i3 sp0 h ou4 sp3 n i3 sp1 z ai4 sp0 ie3 sp1 k an4 sp0 b u2 sp0 d ao4 sp1 t a1 sp0 l a5 sp3 t a1 sp0 s ii3 sp0 l e5 sp3 g uai1 sp0 h ai2 sp0 z ii5 sp3 h ai2 sp0 b u2 sp0 d ao4 sp1 n ian2 sp0 j i4 sp3 b u2 sp0 sh iii4 sp1 sh iii2 sp0 h ou5 sp0 a5 |我还从没有见过大人们哭,而且不明白外婆老说的那些话的意思:“跟你爹告个别吧,以后你再也看不到他啦,他死了,乖孩子,还不到年纪,不是时候啊……”
12 | 1006|我得过一场大病,这时刚刚能下地。生病期间一这一点我记得很清楚——父亲照看我时显得很高兴,后来他突然就不见了,换成了外婆这个怪里怪气的人。| uo3 sp1 d e2 sp0 g uo4 sp1 i4 sp0 ch ang3 sp1 d a4 sp0 b ing4 sp3 zh e4 sp0 sh iii2 sp1 g ang1 sp0 g ang1 sp1 n eng2 sp0 x ia4 sp0 d i4 sp3 sh eng1 sp0 b ing4 sp1 q i1 sp0 j ian1 sp3 i2 sp0 zh e4 sp0 i4 sp0 d ian3 sp1 uo3 sp1 j i4 sp0 d e5 sp1 h en3 sp0 q ing1 sp0 ch u5 sp3 f u4 sp0 q in1 sp1 zh ao4 sp0 k an4 sp1 uo3 sp0 sh iii2 sp2 x ian3 sp0 d e5 sp1 h en3 sp0 g ao1 sp0 x ing4 sp3 h ou4 sp0 l ai2 sp3 t a1 sp1 t u1 sp0 r an2 sp1 j iou4 sp1 b u2 sp0 j ian4 sp0 l e5 sp3 h uan4 sp0 ch eng2 sp0 l e5 sp1 uai4 sp0 p o2 sp1 zh e4 sp0 g e4 sp1 g uai4 sp0 l i3 sp1 g uai4 sp0 q i4 sp0 d e5 sp0 r en2 |我得过一场大病,这时刚刚能下地。生病期间一这一点我记得很清楚——父亲照看我时显得很高兴,后来他突然就不见了,换成了外婆这个怪里怪气的人。
13 |
--------------------------------------------------------------------------------
/data/youdao/text/emotion:
--------------------------------------------------------------------------------
1 | 普通
2 | 生气
3 | 开心
4 | 惊讶
5 | 悲伤
6 | 厌恶
7 | 恐惧
--------------------------------------------------------------------------------
/data/youdao/text/energy:
--------------------------------------------------------------------------------
1 | 音量普通
2 | 音量很高
3 | 音量很低
--------------------------------------------------------------------------------
/data/youdao/text/pitch:
--------------------------------------------------------------------------------
1 | 音调普通
2 | 音调很高
3 | 音调很低
--------------------------------------------------------------------------------
/data/youdao/text/speed:
--------------------------------------------------------------------------------
1 | 语速普通
2 | 语速很快
3 | 语速很慢
--------------------------------------------------------------------------------
/data/youdao/text/tokenlist:
--------------------------------------------------------------------------------
1 | _
2 |
3 | [AA0]
4 | [AA1]
5 | [AA2]
6 | [AE0]
7 | [AE1]
8 | [AE2]
9 | [AH0]
10 | [AH1]
11 | [AH2]
12 | [AO0]
13 | [AO1]
14 | [AO2]
15 | [AW0]
16 | [AW1]
17 | [AW2]
18 | [AY0]
19 | [AY1]
20 | [AY2]
21 | [B]
22 | [CH]
23 | [DH]
24 | [D]
25 | [EH0]
26 | [EH1]
27 | [EH2]
28 | [ER0]
29 | [ER1]
30 | [ER2]
31 | [EY0]
32 | [EY1]
33 | [EY2]
34 | [F]
35 | [G]
36 | [HH]
37 | [IH0]
38 | [IH1]
39 | [IH2]
40 | [IY0]
41 | [IY1]
42 | [IY2]
43 | [JH]
44 | [K]
45 | [L]
46 | [M]
47 | [NG]
48 | [N]
49 | [OW0]
50 | [OW1]
51 | [OW2]
52 | [OY0]
53 | [OY1]
54 | [OY2]
55 | [P]
56 | [R]
57 | [SH]
58 | [S]
59 | [TH]
60 | [T]
61 | [UH0]
62 | [UH1]
63 | [UH2]
64 | [UW0]
65 | [UW1]
66 | [UW2]
67 | [V]
68 | [W]
69 | [Y]
70 | [ZH]
71 | [Z]
72 | a1
73 | a2
74 | a3
75 | a4
76 | a5
77 | ai1
78 | ai2
79 | ai3
80 | ai4
81 | ai5
82 | air1
83 | air2
84 | air4
85 | air5
86 | an1
87 | an2
88 | an3
89 | an4
90 | an5
91 | ang1
92 | ang2
93 | ang3
94 | ang4
95 | ang5
96 | angr1
97 | angr2
98 | angr4
99 | anr1
100 | anr2
101 | anr3
102 | anr4
103 | ao1
104 | ao2
105 | ao3
106 | ao4
107 | ao5
108 | aor1
109 | aor2
110 | aor3
111 | aor4
112 | ar1
113 | ar2
114 | ar3
115 | ar4
116 | ar5
117 | arr4
118 | b
119 | c
120 | ch
121 | cn_eng_sp
122 | d
123 | e1
124 | e2
125 | e3
126 | e4
127 | e5
128 | ei1
129 | ei2
130 | ei3
131 | ei4
132 | ei5
133 | eir1
134 | eir4
135 | en1
136 | en2
137 | en3
138 | en4
139 | en5
140 | eng1
141 | eng2
142 | eng3
143 | eng4
144 | eng5
145 | eng_cn_sp
146 | engr1
147 | engr3
148 | engr4
149 | engsp1
150 | engsp2
151 | engsp4
152 | enr1
153 | enr2
154 | enr3
155 | enr4
156 | enr5
157 | er1
158 | er2
159 | er3
160 | er4
161 | er5
162 | f
163 | g
164 | h
165 | i1
166 | i2
167 | i3
168 | i4
169 | i5
170 | ia1
171 | ia2
172 | ia3
173 | ia4
174 | ia5
175 | ian1
176 | ian2
177 | ian3
178 | ian4
179 | ian5
180 | iang1
181 | iang2
182 | iang3
183 | iang4
184 | iang5
185 | iangr2
186 | iangr4
187 | ianr1
188 | ianr2
189 | ianr3
190 | ianr4
191 | ianr5
192 | iao1
193 | iao2
194 | iao3
195 | iao4
196 | iao5
197 | iaor2
198 | iaor3
199 | iaor4
200 | iar2
201 | iar3
202 | iar4
203 | ie1
204 | ie2
205 | ie3
206 | ie4
207 | ie5
208 | ier4
209 | ii1
210 | ii2
211 | ii3
212 | ii4
213 | ii5
214 | iii1
215 | iii2
216 | iii3
217 | iii4
218 | iii5
219 | iiir2
220 | iiir3
221 | iiir4
222 | iir2
223 | iir3
224 | iir4
225 | in1
226 | in2
227 | in3
228 | in4
229 | in5
230 | ing1
231 | ing2
232 | ing3
233 | ing4
234 | ing5
235 | ingr1
236 | ingr2
237 | ingr3
238 | ingr4
239 | inr1
240 | inr4
241 | iong1
242 | iong2
243 | iong3
244 | iong4
245 | iong5
246 | iou1
247 | iou2
248 | iou3
249 | iou4
250 | iou5
251 | iour2
252 | iour3
253 | iour4
254 | ir1
255 | ir2
256 | ir3
257 | ir4
258 | irr1
259 | j
260 | k
261 | l
262 | m
263 | n
264 | o1
265 | o2
266 | o3
267 | o4
268 | o5
269 | ong1
270 | ong2
271 | ong3
272 | ong4
273 | ong5
274 | ongr2
275 | ongr3
276 | ongr4
277 | or4
278 | ou1
279 | ou2
280 | ou3
281 | ou4
282 | ou5
283 | our1
284 | our2
285 | our3
286 | our4
287 | our5
288 | p
289 | q
290 | r
291 | s
292 | sh
293 | sp0
294 | sp1
295 | sp2
296 | sp3
297 | sp4
298 | t
299 | u1
300 | u2
301 | u3
302 | u4
303 | u5
304 | ua1
305 | ua2
306 | ua3
307 | ua4
308 | ua5
309 | uai1
310 | uai2
311 | uai3
312 | uai4
313 | uai5
314 | uair4
315 | uan1
316 | uan2
317 | uan3
318 | uan4
319 | uan5
320 | uang1
321 | uang2
322 | uang3
323 | uang4
324 | uang5
325 | uanr1
326 | uanr2
327 | uanr3
328 | uanr4
329 | uanr5
330 | uar1
331 | uar2
332 | uar3
333 | uar4
334 | uei1
335 | uei2
336 | uei3
337 | uei4
338 | uei5
339 | ueir1
340 | ueir2
341 | ueir3
342 | ueir4
343 | uen1
344 | uen2
345 | uen3
346 | uen4
347 | uen5
348 | ueng1
349 | ueng3
350 | ueng4
351 | uenr1
352 | uenr2
353 | uenr3
354 | uenr4
355 | uo1
356 | uo2
357 | uo3
358 | uo4
359 | uo5
360 | uor1
361 | uor2
362 | uor3
363 | uor4
364 | uor5
365 | ur1
366 | ur2
367 | ur3
368 | ur4
369 | ur5
370 | v1
371 | v2
372 | v3
373 | v4
374 | v5
375 | van1
376 | van2
377 | van3
378 | van4
379 | van5
380 | vanr1
381 | vanr2
382 | vanr3
383 | vanr4
384 | ve1
385 | ve2
386 | ve3
387 | ve4
388 | ve5
389 | ver2
390 | vn1
391 | vn2
392 | vn3
393 | vn4
394 | vn5
395 | vr2
396 | vr3
397 | vr4
398 | vr5
399 | x
400 | y
401 | z
402 | zh
403 | engsp0
404 | ?
405 | .
406 | spn
407 | ue2
408 | !
409 | err1
410 | [LAUGH]
411 | rr
412 | ier2
413 | or1
414 | ueng2
415 | ir5
416 | iar1
417 | iour1
418 | uncased15
419 | uncased16
420 | uncased17
421 | uncased18
422 | uncased19
423 | uncased20
424 | uncased21
425 | uncased22
426 | uncased23
427 | uncased24
428 | uncased25
429 | uncased26
430 | uncased27
431 | uncased28
432 | uncased29
433 | uncased30
434 | uncased31
435 | uncased32
436 | uncased33
437 | uncased34
438 | uncased35
439 | uncased36
440 | uncased37
441 | uncased38
442 | uncased39
443 | uncased40
444 | uncased41
445 | uncased42
446 | uncased43
447 | uncased44
448 | uncased45
449 | uncased46
450 | uncased47
451 | uncased48
452 | uncased49
453 | uncased50
454 | uncased51
455 | uncased52
456 | uncased53
457 | uncased54
458 | uncased55
459 | uncased56
460 | uncased57
461 | uncased58
462 | uncased59
463 | uncased60
464 | uncased61
465 | uncased62
466 | uncased63
467 | uncased64
468 | uncased65
469 | uncased66
470 | uncased67
471 | uncased68
472 | uncased69
473 | uncased70
474 | uncased71
475 | uncased72
476 | uncased73
477 | uncased74
478 | uncased75
479 | uncased76
480 | uncased77
481 | uncased78
482 | uncased79
483 | uncased80
484 | uncased81
485 | uncased82
486 | uncased83
487 | uncased84
488 | uncased85
489 | uncased86
490 | uncased87
491 | uncased88
492 | uncased89
493 | uncased90
494 | uncased91
495 | uncased92
496 | uncased93
497 | uncased94
498 | uncased95
499 | uncased96
500 | uncased97
501 | uncased98
502 | uncased99
503 |
--------------------------------------------------------------------------------
/demo_page.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import streamlit as st
16 | import os, glob
17 | import numpy as np
18 | from yacs import config as CONFIG
19 | import torch
20 | import re
21 |
22 | from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p
23 | from config.joint.config import Config
24 | from models.prompt_tts_modified.jets import JETSGenerator
25 | from models.prompt_tts_modified.simbert import StyleEncoder
26 | from transformers import AutoTokenizer
27 |
28 | import base64
29 | from pathlib import Path
30 |
31 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 | MAX_WAV_VALUE = 32768.0
33 |
34 | config = Config()
35 |
36 | def create_download_link():
37 | pdf_path = Path("EmotiVoice_UserAgreement_易魔声用户协议.pdf")
38 | base64_pdf = base64.b64encode(pdf_path.read_bytes()).decode("utf-8") # val looks like b'...'
39 | return f'EmotiVoice_UserAgreement_易魔声用户协议.pdf'
40 |
41 | html=create_download_link()
42 |
43 | st.set_page_config(
44 | page_title="demo page",
45 | page_icon="📕",
46 | )
47 | st.write("# Text-To-Speech")
48 | st.markdown(f"""
49 | ### How to use:
50 |
51 | - Simply select a **Speaker ID**, type in the **text** you want to convert and the emotion **Prompt**, like a single word or even a sentence. Then click on the **Synthesize** button below to start voice synthesis.
52 |
53 | - You can download the audio by clicking on the vertical three points next to the displayed audio widget.
54 |
55 | - For more information on **'Speaker ID'**, please consult the [EmotiVoice voice wiki page](https://github.com/netease-youdao/EmotiVoice/tree/main/data/youdao/text)
56 |
57 | - This interactive demo page is provided under the {html} file. The audio is synthesized by AI. 音频由AI合成,仅供参考。
58 |
59 | """, unsafe_allow_html=True)
60 |
61 | def scan_checkpoint(cp_dir, prefix, c=8):
62 | pattern = os.path.join(cp_dir, prefix + '?'*c)
63 | cp_list = glob.glob(pattern)
64 | if len(cp_list) == 0:
65 | return None
66 | return sorted(cp_list)[-1]
67 |
68 | @st.cache_resource
69 | def get_models():
70 |
71 | am_checkpoint_path = scan_checkpoint(f'{config.output_directory}/prompt_tts_open_source_joint/ckpt', 'g_')
72 |
73 | style_encoder_checkpoint_path = scan_checkpoint(f'{config.output_directory}/style_encoder/ckpt', 'checkpoint_', 6)#f'{config.output_directory}/style_encoder/ckpt/checkpoint_163431'
74 |
75 | with open(config.model_config_path, 'r') as fin:
76 | conf = CONFIG.load_cfg(fin)
77 |
78 | conf.n_vocab = config.n_symbols
79 | conf.n_speaker = config.speaker_n_labels
80 |
81 | style_encoder = StyleEncoder(config)
82 | model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu")
83 | model_ckpt = {}
84 | for key, value in model_CKPT['model'].items():
85 | new_key = key[7:]
86 | model_ckpt[new_key] = value
87 | style_encoder.load_state_dict(model_ckpt, strict=False)
88 | generator = JETSGenerator(conf).to(DEVICE)
89 |
90 | model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE)
91 | generator.load_state_dict(model_CKPT['generator'])
92 | generator.eval()
93 |
94 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path)
95 |
96 | with open(config.token_list_path, 'r') as f:
97 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())}
98 |
99 | with open(config.speaker2id_path, encoding='utf-8') as f:
100 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())}
101 |
102 |
103 | return (style_encoder, generator, tokenizer, token2id, speaker2id)
104 |
105 | def get_style_embedding(prompt, tokenizer, style_encoder):
106 | prompt = tokenizer([prompt], return_tensors="pt")
107 | input_ids = prompt["input_ids"]
108 | token_type_ids = prompt["token_type_ids"]
109 | attention_mask = prompt["attention_mask"]
110 | with torch.no_grad():
111 | output = style_encoder(
112 | input_ids=input_ids,
113 | token_type_ids=token_type_ids,
114 | attention_mask=attention_mask,
115 | )
116 | style_embedding = output["pooled_output"].cpu().squeeze().numpy()
117 | return style_embedding
118 |
119 | def tts(name, text, prompt, content, speaker, models):
120 | (style_encoder, generator, tokenizer, token2id, speaker2id)=models
121 |
122 |
123 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder)
124 | content_embedding = get_style_embedding(content, tokenizer, style_encoder)
125 |
126 | speaker = speaker2id[speaker]
127 |
128 | text_int = [token2id[ph] for ph in text.split()]
129 |
130 | sequence = torch.from_numpy(np.array(text_int)).to(DEVICE).long().unsqueeze(0)
131 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE)
132 | style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0)
133 | content_embedding = torch.from_numpy(content_embedding).to(DEVICE).unsqueeze(0)
134 | speaker = torch.from_numpy(np.array([speaker])).to(DEVICE)
135 |
136 | with torch.no_grad():
137 |
138 | infer_output = generator(
139 | inputs_ling=sequence,
140 | inputs_style_embedding=style_embedding,
141 | input_lengths=sequence_len,
142 | inputs_content_embedding=content_embedding,
143 | inputs_speaker=speaker,
144 | alpha=1.0
145 | )
146 |
147 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE
148 | audio = audio.cpu().numpy().astype('int16')
149 |
150 | return audio
151 |
152 | speakers = config.speakers
153 | models = get_models()
154 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt")
155 | g2p = G2p()
156 |
157 | def new_line(i):
158 | col1, col2, col3, col4 = st.columns([1.5, 1.5, 3.5, 1.3])
159 | with col1:
160 | speaker=st.selectbox("Speaker ID (说话人)", speakers, key=f"{i}_speaker")
161 | with col2:
162 | prompt=st.text_input("Prompt (开心/悲伤)", "", key=f"{i}_prompt")
163 | with col3:
164 | content=st.text_input("Text to be synthesized into speech (合成文本)", "合成文本", key=f"{i}_text")
165 | with col4:
166 | lang=st.selectbox("Language (语言)", ["zh_us"], key=f"{i}_lang")
167 |
168 | flag = st.button(f"Synthesize (合成)", key=f"{i}_button1")
169 | if flag:
170 | text = g2p_cn_en(content, g2p, lexicon)
171 | path = tts(i, text, prompt, content, speaker, models)
172 | st.audio(path, sample_rate=config.sampling_rate)
173 |
174 |
175 |
176 | new_line(0)
177 |
--------------------------------------------------------------------------------
/demo_page_databaker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import streamlit as st
16 | import os, glob
17 | import numpy as np
18 | from yacs import config as CONFIG
19 | import torch
20 | import re
21 |
22 | from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p
23 | from exp.DataBaker.config.config import Config
24 | from models.prompt_tts_modified.jets import JETSGenerator
25 | from models.prompt_tts_modified.simbert import StyleEncoder
26 | from transformers import AutoTokenizer
27 |
28 | import base64
29 | from pathlib import Path
30 |
31 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 | MAX_WAV_VALUE = 32768.0
33 |
34 | config = Config()
35 |
36 | def create_download_link():
37 | pdf_path = Path("EmotiVoice_UserAgreement_易魔声用户协议.pdf")
38 | base64_pdf = base64.b64encode(pdf_path.read_bytes()).decode("utf-8") # val looks like b'...'
39 | return f'EmotiVoice_UserAgreement_易魔声用户协议.pdf'
40 |
41 | html=create_download_link()
42 |
43 | st.set_page_config(
44 | page_title="demo page",
45 | page_icon="📕",
46 | )
47 | st.write("# Text-To-Speech")
48 | st.markdown(f"""
49 | ### How to use:
50 |
51 | - Simply select a **Speaker ID**, type in the **text** you want to convert and the emotion **Prompt**, like a single word or even a sentence. Then click on the **Synthesize** button below to start voice synthesis.
52 |
53 | - You can download the audio by clicking on the vertical three points next to the displayed audio widget.
54 |
55 | - For more information on **'Speaker ID'**, please consult the [EmotiVoice voice wiki page](https://github.com/netease-youdao/EmotiVoice/tree/main/data/youdao/text)
56 |
57 | - This interactive demo page is provided under the {html} file. The audio is synthesized by AI. 音频由AI合成,仅供参考。
58 |
59 | """, unsafe_allow_html=True)
60 |
61 | def scan_checkpoint(cp_dir, prefix, c=8):
62 | pattern = os.path.join(cp_dir, prefix + '?'*c)
63 | cp_list = glob.glob(pattern)
64 | if len(cp_list) == 0:
65 | return None
66 | return sorted(cp_list)[-1]
67 |
68 | @st.cache_resource
69 | def get_models():
70 |
71 | am_checkpoint_path = scan_checkpoint(f'{config.output_directory}/ckpt', 'g_')
72 |
73 | style_encoder_checkpoint_path = config.style_encoder_ckpt
74 |
75 | with open(config.model_config_path, 'r') as fin:
76 | conf = CONFIG.load_cfg(fin)
77 |
78 | conf.n_vocab = config.n_symbols
79 | conf.n_speaker = config.speaker_n_labels
80 |
81 | style_encoder = StyleEncoder(config)
82 | model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu")
83 | model_ckpt = {}
84 | for key, value in model_CKPT['model'].items():
85 | new_key = key[7:]
86 | model_ckpt[new_key] = value
87 | style_encoder.load_state_dict(model_ckpt, strict=False)
88 | generator = JETSGenerator(conf).to(DEVICE)
89 |
90 | model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE)
91 | generator.load_state_dict(model_CKPT['generator'])
92 | generator.eval()
93 |
94 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path)
95 |
96 | with open(config.token_list_path, 'r') as f:
97 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())}
98 |
99 | with open(config.speaker2id_path, encoding='utf-8') as f:
100 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())}
101 |
102 |
103 | return (style_encoder, generator, tokenizer, token2id, speaker2id)
104 |
105 | def get_style_embedding(prompt, tokenizer, style_encoder):
106 | prompt = tokenizer([prompt], return_tensors="pt")
107 | input_ids = prompt["input_ids"]
108 | token_type_ids = prompt["token_type_ids"]
109 | attention_mask = prompt["attention_mask"]
110 | with torch.no_grad():
111 | output = style_encoder(
112 | input_ids=input_ids,
113 | token_type_ids=token_type_ids,
114 | attention_mask=attention_mask,
115 | )
116 | style_embedding = output["pooled_output"].cpu().squeeze().numpy()
117 | return style_embedding
118 |
119 | def tts(name, text, prompt, content, speaker, models):
120 | (style_encoder, generator, tokenizer, token2id, speaker2id)=models
121 |
122 |
123 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder)
124 | content_embedding = get_style_embedding(content, tokenizer, style_encoder)
125 |
126 | speaker = speaker2id[speaker]
127 |
128 | text_int = [token2id[ph] for ph in text.split()]
129 |
130 | sequence = torch.from_numpy(np.array(text_int)).to(DEVICE).long().unsqueeze(0)
131 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE)
132 | style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0)
133 | content_embedding = torch.from_numpy(content_embedding).to(DEVICE).unsqueeze(0)
134 | speaker = torch.from_numpy(np.array([speaker])).to(DEVICE)
135 |
136 | with torch.no_grad():
137 |
138 | infer_output = generator(
139 | inputs_ling=sequence,
140 | inputs_style_embedding=style_embedding,
141 | input_lengths=sequence_len,
142 | inputs_content_embedding=content_embedding,
143 | inputs_speaker=speaker,
144 | alpha=1.0
145 | )
146 |
147 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE
148 | audio = audio.cpu().numpy().astype('int16')
149 |
150 | return audio
151 |
152 | speakers = config.speakers
153 | models = get_models()
154 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt")
155 | g2p = G2p()
156 |
157 | def new_line(i):
158 | col1, col2, col3, col4 = st.columns([1.5, 1.5, 3.5, 1.3])
159 | with col1:
160 | speaker=st.selectbox("Speaker ID (说话人)", speakers, key=f"{i}_speaker")
161 | with col2:
162 | prompt=st.text_input("Prompt (开心/悲伤)", "", key=f"{i}_prompt")
163 | with col3:
164 | content=st.text_input("Text to be synthesized into speech (合成文本)", "合成文本", key=f"{i}_text")
165 | with col4:
166 | lang=st.selectbox("Language (语言)", ["zh_us"], key=f"{i}_lang")
167 |
168 | flag = st.button(f"Synthesize (合成)", key=f"{i}_button1")
169 | if flag:
170 | text = g2p_cn_en(content, g2p, lexicon)
171 | path = tts(i, text, prompt, content, speaker, models)
172 | st.audio(path, sample_rate=config.sampling_rate)
173 |
174 |
175 |
176 | new_line(0)
177 |
--------------------------------------------------------------------------------
/frontend.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from frontend_cn import g2p_cn, re_digits, tn_chinese
17 | from frontend_en import ROOT_DIR, read_lexicon, G2p, get_eng_phoneme
18 |
19 | # Thanks to GuGCoCo and PatroxGaurab for identifying the issue:
20 | # the results differ between frontend.py and frontend_en.py. Here's a quick fix.
21 | #re_english_word = re.compile('([a-z\-\.\'\s,;\:\!\?]+|\d+[\d\.]*)', re.I)
22 | re_english_word = re.compile('([^\u4e00-\u9fa5]+|[ \u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09\u4e00-\u9fa5]+)', re.I)
23 | def g2p_cn_en(text, g2p, lexicon):
24 | # Our policy dictates that if the text contains Chinese, digits are to be converted into Chinese.
25 | text=tn_chinese(text)
26 | parts = re_english_word.split(text)
27 | parts=list(filter(None, parts))
28 | tts_text = [""]
29 | chartype = ''
30 | text_contains_chinese = contains_chinese(text)
31 | for part in parts:
32 | if part == ' ' or part == '': continue
33 | if re_digits.match(part) and (text_contains_chinese or chartype == '') or contains_chinese(part):
34 | if chartype == 'en':
35 | tts_text.append('eng_cn_sp')
36 | phoneme = g2p_cn(part).split()[1:-1]
37 | chartype = 'cn'
38 | elif re_english_word.match(part):
39 | if chartype == 'cn':
40 | if "sp" in tts_text[-1]:
41 | ""
42 | else:
43 | tts_text.append('cn_eng_sp')
44 | phoneme = get_eng_phoneme(part, g2p, lexicon, False).split()
45 | if not phoneme :
46 | # tts_text.pop()
47 | continue
48 | else:
49 | chartype = 'en'
50 | else:
51 | continue
52 | tts_text.extend( phoneme )
53 |
54 | tts_text=" ".join(tts_text).split()
55 | if "sp" in tts_text[-1]:
56 | tts_text.pop()
57 | tts_text.append("")
58 |
59 | return " ".join(tts_text)
60 |
61 | def contains_chinese(text):
62 | pattern = re.compile(r'[\u4e00-\u9fa5]')
63 | match = re.search(pattern, text)
64 | return match is not None
65 |
66 |
67 | if __name__ == "__main__":
68 | import sys
69 | from os.path import isfile
70 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt")
71 |
72 | g2p = G2p()
73 | if len(sys.argv) < 2:
74 | print("Usage: python %s " % sys.argv[0])
75 | exit()
76 | text_file = sys.argv[1]
77 | if isfile(text_file):
78 | fp = open(text_file, 'r')
79 | for line in fp:
80 | phoneme = g2p_cn_en(line.rstrip(), g2p, lexicon)
81 | print(phoneme)
82 | fp.close()
83 |
--------------------------------------------------------------------------------
/frontend_cn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from pypinyin import pinyin, lazy_pinyin, Style
17 | import jieba
18 | import string
19 | from cn2an.an2cn import An2Cn
20 | from pypinyin_dict.phrase_pinyin_data import cc_cedict
21 | cc_cedict.load()
22 | re_special_pinyin = re.compile(r'^(n|ng|m)$')
23 | def split_py(py):
24 | tone = py[-1]
25 | py = py[:-1]
26 | sm = ""
27 | ym = ""
28 | suf_r = ""
29 | if re_special_pinyin.match(py):
30 | py = 'e' + py
31 | if py[-1] == 'r':
32 | suf_r = 'r'
33 | py = py[:-1]
34 | if py == 'zi' or py == 'ci' or py == 'si' or py == 'ri':
35 | sm = py[:1]
36 | ym = "ii"
37 | elif py == 'zhi' or py == 'chi' or py == 'shi':
38 | sm = py[:2]
39 | ym = "iii"
40 | elif py == 'ya' or py == 'yan' or py == 'yang' or py == 'yao' or py == 'ye' or py == 'yong' or py == 'you':
41 | sm = ""
42 | ym = 'i' + py[1:]
43 | elif py == 'yi' or py == 'yin' or py == 'ying':
44 | sm = ""
45 | ym = py[1:]
46 | elif py == 'yu' or py == 'yv' or py == 'yuan' or py == 'yvan' or py == 'yue ' or py == 'yve' or py == 'yun' or py == 'yvn':
47 | sm = ""
48 | ym = 'v' + py[2:]
49 | elif py == 'wu':
50 | sm = ""
51 | ym = "u"
52 | elif py[0] == 'w':
53 | sm = ""
54 | ym = "u" + py[1:]
55 | elif len(py) >= 2 and (py[0] == 'j' or py[0] == 'q' or py[0] == 'x') and py[1] == 'u':
56 | sm = py[0]
57 | ym = 'v' + py[2:]
58 | else:
59 | seg_pos = re.search('a|e|i|o|u|v', py)
60 | sm = py[:seg_pos.start()]
61 | ym = py[seg_pos.start():]
62 | if ym == 'ui':
63 | ym = 'uei'
64 | elif ym == 'iu':
65 | ym = 'iou'
66 | elif ym == 'un':
67 | ym = 'uen'
68 | elif ym == 'ue':
69 | ym = 've'
70 | ym += suf_r + tone
71 | return sm, ym
72 |
73 |
74 | chinese_punctuation_pattern = r'[\u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09]'
75 |
76 |
77 | def has_chinese_punctuation(text):
78 | match = re.search(chinese_punctuation_pattern, text)
79 | return match is not None
80 | def has_english_punctuation(text):
81 | return text in string.punctuation
82 |
83 | # with thanks to KimigaiiWuyi in https://github.com/netease-youdao/EmotiVoice/pull/17.
84 | # Updated on November 20, 2023: EmotiVoice now incorporates cn2an (https://github.com/Ailln/cn2an) for number processing.
85 | re_digits = re.compile('(\d[\d\.]*)')
86 | def number_to_chinese(number):
87 | an2cn = An2Cn()
88 | result = an2cn.an2cn(number)
89 |
90 | return result
91 |
92 | def tn_chinese(text):
93 | parts = re_digits.split(text)
94 | words = []
95 | for part in parts:
96 | if re_digits.match(part):
97 | words.append(number_to_chinese(part))
98 | else:
99 | words.append(part)
100 | return ''.join(words)
101 |
102 | def g2p_cn(text):
103 | res_text=[""]
104 | seg_list = jieba.cut(text)
105 | for seg in seg_list:
106 | if seg == " ": continue
107 | seg_tn = tn_chinese(seg)
108 | py =[_py[0] for _py in pinyin(seg_tn, style=Style.TONE3,neutral_tone_with_five=True)]
109 |
110 | if any([has_chinese_punctuation(_py) for _py in py]) or any([has_english_punctuation(_py) for _py in py]):
111 | res_text.pop()
112 | res_text.append("sp3")
113 | else:
114 |
115 | py = [" ".join(split_py(_py)) for _py in py]
116 |
117 | res_text.append(" sp0 ".join(py))
118 | res_text.append("sp1")
119 | #res_text.pop()
120 | res_text.append("")
121 | return " ".join(res_text)
122 |
123 | if __name__ == "__main__":
124 | import sys
125 | from os.path import isfile
126 | if len(sys.argv) < 2:
127 | print("Usage: python %s " % sys.argv[0])
128 | exit()
129 | text_file = sys.argv[1]
130 | if isfile(text_file):
131 | fp = open(text_file, 'r')
132 | for line in fp:
133 | phoneme=g2p_cn(line.rstrip())
134 | print(phoneme)
135 | fp.close()
136 |
--------------------------------------------------------------------------------
/frontend_en.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | import argparse
17 | from string import punctuation
18 | import numpy as np
19 |
20 | from g2p_en import G2p
21 |
22 | import os
23 |
24 |
25 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__"))
26 |
27 | def read_lexicon(lex_path):
28 | lexicon = {}
29 | with open(lex_path) as f:
30 | for line in f:
31 | temp = re.split(r"\s+", line.strip("\n"))
32 | word = temp[0]
33 | phones = temp[1:]
34 | if word.lower() not in lexicon:
35 | lexicon[word.lower()] = phones
36 | return lexicon
37 |
38 | def get_eng_phoneme(text, g2p, lexicon, pad_sos_eos=True):
39 | """
40 | english g2p
41 | """
42 | filters = {",", " ", "'"}
43 | phones = []
44 | words = list(filter(lambda x: x not in {"", " "}, re.split(r"([,;.\-\?\!\s+])", text)))
45 |
46 | for w in words:
47 | if w.lower() in lexicon:
48 |
49 | for ph in lexicon[w.lower()]:
50 | if ph not in filters:
51 | phones += ["[" + ph + "]"]
52 |
53 | if "sp" not in phones[-1]:
54 | phones += ["engsp1"]
55 | else:
56 | phone=g2p(w)
57 | if not phone:
58 | continue
59 |
60 | if phone[0].isalnum():
61 |
62 | for ph in phone:
63 | if ph not in filters:
64 | phones += ["[" + ph + "]"]
65 | if ph == " " and "sp" not in phones[-1]:
66 | phones += ["engsp1"]
67 | elif phone == " ":
68 | continue
69 | elif phones:
70 | phones.pop() # pop engsp1
71 | phones.append("engsp4")
72 | if phones and "engsp" in phones[-1]:
73 | phones.pop()
74 |
75 | # mark = "." if text[-1] != "?" else "?"
76 | if pad_sos_eos:
77 | phones = [""] + phones + [""]
78 | return " ".join(phones)
79 |
80 |
81 | if __name__ == "__main__":
82 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt")
83 | g2p = G2p()
84 | phonemes= get_eng_phoneme("Happy New Year", g2p, lexicon)
85 | import sys
86 | from os.path import isfile
87 | if len(sys.argv) < 2:
88 | print("Usage: python %s " % sys.argv[0])
89 | exit()
90 | text_file = sys.argv[1]
91 | if isfile(text_file):
92 | fp = open(text_file, 'r')
93 | for line in fp:
94 | phoneme=get_eng_phoneme(line.rstrip(), g2p, lexicon)
95 | print(phoneme)
96 | fp.close()
--------------------------------------------------------------------------------
/inference_am_vocoder_exp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from models.prompt_tts_modified.jets import JETSGenerator
16 | from models.prompt_tts_modified.simbert import StyleEncoder
17 | from transformers import AutoTokenizer
18 | import os, sys, warnings, torch, glob, argparse
19 | import numpy as np
20 | from models.hifigan.get_vocoder import MAX_WAV_VALUE
21 | import soundfile as sf
22 | from yacs import config as CONFIG
23 | from tqdm import tqdm
24 |
25 | def get_style_embedding(prompt, tokenizer, style_encoder):
26 | prompt = tokenizer([prompt], return_tensors="pt")
27 | input_ids = prompt["input_ids"]
28 | token_type_ids = prompt["token_type_ids"]
29 | attention_mask = prompt["attention_mask"]
30 |
31 | with torch.no_grad():
32 | output = style_encoder(
33 | input_ids=input_ids,
34 | token_type_ids=token_type_ids,
35 | attention_mask=attention_mask,
36 | )
37 | style_embedding = output["pooled_output"].cpu().squeeze().numpy()
38 | return style_embedding
39 |
40 | def main(args, config):
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | root_path = os.path.join(config.output_directory)
43 | ckpt_path = os.path.join(root_path, "ckpt")
44 | files = os.listdir(ckpt_path)
45 |
46 | for file in files:
47 | if args.checkpoint:
48 | if file != args.checkpoint:
49 | continue
50 |
51 | checkpoint_path = os.path.join(ckpt_path, file)
52 |
53 | with open(config.model_config_path, 'r') as fin:
54 | conf = CONFIG.load_cfg(fin)
55 |
56 |
57 | conf.n_vocab = config.n_symbols
58 | conf.n_speaker = config.speaker_n_labels
59 |
60 | style_encoder = StyleEncoder(config)
61 | model_CKPT = torch.load(config.style_encoder_ckpt, map_location="cpu")
62 | model_ckpt = {}
63 | for key, value in model_CKPT['model'].items():
64 | new_key = key[7:]
65 | model_ckpt[new_key] = value
66 | style_encoder.load_state_dict(model_ckpt, strict=False)
67 |
68 |
69 |
70 | generator = JETSGenerator(conf).to(device)
71 |
72 | model_CKPT = torch.load(checkpoint_path, map_location=device)
73 | generator.load_state_dict(model_CKPT['generator'])
74 | generator.eval()
75 |
76 | with open(config.token_list_path, 'r') as f:
77 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())}
78 |
79 | with open(config.speaker2id_path, encoding='utf-8') as f:
80 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())}
81 |
82 |
83 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path)
84 |
85 | text_path = args.test_file
86 |
87 |
88 | if os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"):
89 | r = glob.glob(root_path + "/test_audio/audio/" +f"{file}/*")
90 | for j in r:
91 | os.remove(j)
92 | texts = []
93 | prompts = []
94 | speakers = []
95 | contents = []
96 | with open(text_path, "r") as f:
97 | for line in f:
98 | line = line.strip().split("|")
99 | speakers.append(line[0])
100 | prompts.append(line[1])
101 | texts.append(line[2].split())
102 | contents.append(line[3])
103 |
104 | for i, (speaker, prompt, text, content) in enumerate(tqdm(zip(speakers, prompts, texts, contents))):
105 |
106 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder)
107 | content_embedding = get_style_embedding(content, tokenizer, style_encoder)
108 |
109 | if speaker not in speaker2id:
110 | continue
111 | speaker = speaker2id[speaker]
112 |
113 | text_int = [token2id[ph] for ph in text]
114 |
115 | sequence = torch.from_numpy(np.array(text_int)).to(device).long().unsqueeze(0)
116 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(device)
117 | style_embedding = torch.from_numpy(style_embedding).to(device).unsqueeze(0)
118 | content_embedding = torch.from_numpy(content_embedding).to(device).unsqueeze(0)
119 | speaker = torch.from_numpy(np.array([speaker])).to(device)
120 | with torch.no_grad():
121 |
122 | infer_output = generator(
123 | inputs_ling=sequence,
124 | inputs_style_embedding=style_embedding,
125 | input_lengths=sequence_len,
126 | inputs_content_embedding=content_embedding,
127 | inputs_speaker=speaker,
128 | alpha=1.0
129 | )
130 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE
131 | audio = audio.cpu().numpy().astype('int16')
132 | if not os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"):
133 | os.makedirs(root_path + "/test_audio/audio/" +f"{file}/", exist_ok=True)
134 | sf.write(file=root_path + "/test_audio/audio/" +f"{file}/{i+1}.wav", data=audio, samplerate=config.sampling_rate) #h.sampling_rate
135 |
136 |
137 |
138 |
139 |
140 |
141 | if __name__ == '__main__':
142 | print("run!")
143 | p = argparse.ArgumentParser()
144 | p.add_argument("-c", "--config_folder", type=str, required=True)
145 | p.add_argument("--checkpoint", type=str, required=False, default='', help='inference specific checkpoint, e.g --checkpoint checkpoint_230000')
146 | p.add_argument('-t', '--test_file', type=str, required=True, help='the absolute path of test file that is going to inference')
147 |
148 | args = p.parse_args()
149 | ##################################################
150 | sys.path.append(os.path.dirname(os.path.abspath("__file__")) + "/" + args.config_folder)
151 |
152 | from config import Config
153 | config = Config()
154 | ##################################################
155 | main(args, config)
156 |
157 |
158 |
--------------------------------------------------------------------------------
/inference_am_vocoder_joint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from models.prompt_tts_modified.jets import JETSGenerator
16 | from models.prompt_tts_modified.simbert import StyleEncoder
17 | from transformers import AutoTokenizer
18 | import os, sys, warnings, torch, glob, argparse
19 | import numpy as np
20 | from models.hifigan.get_vocoder import MAX_WAV_VALUE
21 | import soundfile as sf
22 | from yacs import config as CONFIG
23 | from tqdm import tqdm
24 |
25 | def get_style_embedding(prompt, tokenizer, style_encoder):
26 | prompt = tokenizer([prompt], return_tensors="pt")
27 | input_ids = prompt["input_ids"]
28 | token_type_ids = prompt["token_type_ids"]
29 | attention_mask = prompt["attention_mask"]
30 |
31 | with torch.no_grad():
32 | output = style_encoder(
33 | input_ids=input_ids,
34 | token_type_ids=token_type_ids,
35 | attention_mask=attention_mask,
36 | )
37 | style_embedding = output["pooled_output"].cpu().squeeze().numpy()
38 | return style_embedding
39 |
40 | def main(args, config):
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | root_path = os.path.join(config.output_directory, args.logdir)
43 | ckpt_path = os.path.join(root_path, "ckpt")
44 | files = os.listdir(ckpt_path)
45 |
46 | for file in files:
47 | if args.checkpoint:
48 | if file != args.checkpoint:
49 | continue
50 |
51 | checkpoint_path = os.path.join(ckpt_path, file)
52 |
53 | with open(config.model_config_path, 'r') as fin:
54 | conf = CONFIG.load_cfg(fin)
55 |
56 |
57 | conf.n_vocab = config.n_symbols
58 | conf.n_speaker = config.speaker_n_labels
59 |
60 | style_encoder = StyleEncoder(config)
61 | model_CKPT = torch.load(config.style_encoder_ckpt, map_location="cpu")
62 | model_ckpt = {}
63 | for key, value in model_CKPT['model'].items():
64 | new_key = key[7:]
65 | model_ckpt[new_key] = value
66 | style_encoder.load_state_dict(model_ckpt, strict=False)
67 |
68 |
69 |
70 | generator = JETSGenerator(conf).to(device)
71 |
72 | model_CKPT = torch.load(checkpoint_path, map_location=device)
73 | generator.load_state_dict(model_CKPT['generator'])
74 | generator.eval()
75 |
76 | with open(config.token_list_path, 'r') as f:
77 | token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())}
78 |
79 | with open(config.speaker2id_path, encoding='utf-8') as f:
80 | speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())}
81 |
82 |
83 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path)
84 |
85 | text_path = args.test_file
86 |
87 |
88 | if os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"):
89 | r = glob.glob(root_path + "/test_audio/audio/" +f"{file}/*")
90 | for j in r:
91 | os.remove(j)
92 | texts = []
93 | prompts = []
94 | speakers = []
95 | contents = []
96 | with open(text_path, "r") as f:
97 | for line in f:
98 | line = line.strip().split("|")
99 | speakers.append(line[0])
100 | prompts.append(line[1])
101 | texts.append(line[2].split())
102 | contents.append(line[3])
103 |
104 | for i, (speaker, prompt, text, content) in enumerate(tqdm(zip(speakers, prompts, texts, contents))):
105 |
106 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder)
107 | content_embedding = get_style_embedding(content, tokenizer, style_encoder)
108 |
109 | if speaker not in speaker2id:
110 | continue
111 | speaker = speaker2id[speaker]
112 |
113 | text_int = [token2id[ph] for ph in text]
114 |
115 | sequence = torch.from_numpy(np.array(text_int)).to(device).long().unsqueeze(0)
116 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(device)
117 | style_embedding = torch.from_numpy(style_embedding).to(device).unsqueeze(0)
118 | content_embedding = torch.from_numpy(content_embedding).to(device).unsqueeze(0)
119 | speaker = torch.from_numpy(np.array([speaker])).to(device)
120 | with torch.no_grad():
121 |
122 | infer_output = generator(
123 | inputs_ling=sequence,
124 | inputs_style_embedding=style_embedding,
125 | input_lengths=sequence_len,
126 | inputs_content_embedding=content_embedding,
127 | inputs_speaker=speaker,
128 | alpha=1.0
129 | )
130 | audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE
131 | audio = audio.cpu().numpy().astype('int16')
132 | if not os.path.exists(root_path + "/test_audio/audio/" +f"{file}/"):
133 | os.makedirs(root_path + "/test_audio/audio/" +f"{file}/", exist_ok=True)
134 | sf.write(file=root_path + "/test_audio/audio/" +f"{file}/{i+1}.wav", data=audio, samplerate=config.sampling_rate) #h.sampling_rate
135 |
136 |
137 |
138 |
139 |
140 |
141 | if __name__ == '__main__':
142 | print("run!")
143 | p = argparse.ArgumentParser()
144 | p.add_argument('-d', '--logdir', type=str, required=True)
145 | p.add_argument("-c", "--config_folder", type=str, required=True)
146 | p.add_argument("--checkpoint", type=str, required=False, default='', help='inference specific checkpoint, e.g --checkpoint checkpoint_230000')
147 | p.add_argument('-t', '--test_file', type=str, required=True, help='the absolute path of test file that is going to inference')
148 |
149 | args = p.parse_args()
150 | ##################################################
151 | sys.path.append(os.path.dirname(os.path.abspath("__file__")) + "/" + args.config_folder)
152 |
153 | from config import Config
154 | config = Config()
155 | ##################################################
156 | main(args, config)
157 |
158 |
159 |
--------------------------------------------------------------------------------
/mel_process.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import numpy as np
9 | import librosa
10 | import librosa.util as librosa_util
11 | from librosa.util import normalize, pad_center, tiny
12 | from scipy.signal import get_window
13 | from scipy.io.wavfile import read
14 | from librosa.filters import mel as librosa_mel_fn
15 |
16 | MAX_WAV_VALUE = 32768.0
17 |
18 |
19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20 |
21 | return torch.log(torch.clamp(x, min=clip_val) * C)
22 |
23 |
24 | def dynamic_range_decompression_torch(x, C=1):
25 |
26 | return torch.exp(x) / C
27 |
28 |
29 | def spectral_normalize_torch(magnitudes):
30 | output = dynamic_range_compression_torch(magnitudes)
31 | return output
32 |
33 |
34 | def spectral_de_normalize_torch(magnitudes):
35 | output = dynamic_range_decompression_torch(magnitudes)
36 | return output
37 |
38 |
39 | mel_basis = {}
40 | hann_window = {}
41 |
42 |
43 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
44 | if torch.min(y) < -1.:
45 | print('min value is ', torch.min(y))
46 | if torch.max(y) > 1.:
47 | print('max value is ', torch.max(y))
48 |
49 | global hann_window
50 | dtype_device = str(y.dtype) + '_' + str(y.device)
51 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
52 | if wnsize_dtype_device not in hann_window:
53 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
54 |
55 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
56 | y = y.squeeze(1)
57 |
58 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
59 | center=center, pad_mode='reflect', normalized=False, onesided=True)
60 |
61 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
62 | return spec
63 |
64 |
65 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
66 | global mel_basis
67 | dtype_device = str(spec.dtype) + '_' + str(spec.device)
68 | fmax_dtype_device = str(fmax) + '_' + dtype_device
69 | if fmax_dtype_device not in mel_basis:
70 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
71 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
72 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
73 | spec = spectral_normalize_torch(spec)
74 | return spec
75 |
76 |
77 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
78 | if torch.min(y) < -1.:
79 | print('min value is ', torch.min(y))
80 | if torch.max(y) > 1.:
81 | print('max value is ', torch.max(y))
82 |
83 | global mel_basis, hann_window
84 | dtype_device = str(y.dtype) + '_' + str(y.device)
85 | fmax_dtype_device = str(fmax) + '_' + dtype_device
86 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
87 | if fmax_dtype_device not in mel_basis:
88 | mel = librosa_mel_fn(
89 | sr=sampling_rate,
90 | n_fft=n_fft,
91 | n_mels=num_mels,
92 | fmin=fmin,
93 | fmax=fmax)
94 |
95 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
96 | if wnsize_dtype_device not in hann_window:
97 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
98 |
99 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
100 | y = y.squeeze(1)
101 |
102 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
103 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
104 |
105 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
106 |
107 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
108 | spec = spectral_normalize_torch(spec)
109 |
110 | return spec
--------------------------------------------------------------------------------
/mfa/step1_create_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from tqdm import tqdm
3 | import jsonlines
4 | import re
5 | import argparse
6 | import os
7 |
8 | def main(args):
9 | ROOT_DIR=os.path.abspath(args.data_dir)
10 | TEXT_DIR=f"{ROOT_DIR}/text"
11 | MFA_DIR=f"{ROOT_DIR}/mfa"
12 |
13 | os.makedirs(MFA_DIR, exist_ok=True)
14 |
15 |
16 |
17 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl") as f1, \
18 | open(f"{MFA_DIR}/text_sp1-sp4", "w") as f2, \
19 | open(f"{MFA_DIR}/wav.scp", "w") as f3:
20 |
21 | data=list({f'{sample["key"]}_{sample["speaker"]}':sample for sample in list(f1)}.values())
22 |
23 | for sample in tqdm(data):
24 | text=[]
25 | for ph in sample["text"]:
26 | if ph[0] == '[':
27 | ph = ph[1:-1]
28 | elif ph == "cn_eng_sp":
29 | ph = "cnengsp"
30 | elif ph == "eng_cn_sp":
31 | ph = "engcnsp"
32 | text.append(ph)
33 | f2.write("{}|{} {}\n".format(re.sub(r" +", "", sample["speaker"]), sample["key"], " ".join(text)))
34 | f3.write("{} {}\n".format(sample["key"], sample["wav_path"]))
35 |
36 | if __name__ == "__main__":
37 | p = argparse.ArgumentParser()
38 | p.add_argument('--data_dir', type=str, required=True)
39 | args = p.parse_args()
40 |
41 | main(args)
42 |
--------------------------------------------------------------------------------
/mfa/step2_prepare_data.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import collections
4 | import pathlib
5 | import os
6 | from typing import Iterable
7 | from tqdm import tqdm
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--dataset_dir',
12 | type=str,
13 | help='Path to cath dataset')
14 | parser.add_argument('--wav', type=str, help='Path to export paths of wavs.')
15 | parser.add_argument('--speaker', type=str, help='Path to export speakers.')
16 | parser.add_argument('--text', type=str, help='Path to export text of wavs.')
17 | return parser.parse_args()
18 |
19 |
20 | def save_scp_files(wav_scp_path: os.PathLike, speaker_scp_path: os.PathLike,
21 | text_scp_path: os.PathLike, content: Iterable[str]):
22 | wav_scp_path = pathlib.Path(wav_scp_path)
23 | speaker_scp_path = pathlib.Path(speaker_scp_path)
24 | text_scp_path = pathlib.Path(text_scp_path)
25 |
26 | wav_scp_path.parent.mkdir(parents=True, exist_ok=True)
27 | speaker_scp_path.parent.mkdir(parents=True, exist_ok=True)
28 | text_scp_path.parent.mkdir(parents=True, exist_ok=True)
29 |
30 | with open(wav_scp_path, 'w') as wav_scp_file:
31 | wav_scp_file.writelines([str(line[0]) + '\n' for line in content])
32 | with open(speaker_scp_path, 'w') as speaker_scp_file:
33 | speaker_scp_file.writelines([line[1] + '\n' for line in content])
34 | with open(text_scp_path, 'w') as text_scp_file:
35 | text_scp_file.writelines([line[2] + '\n' for line in content])
36 |
37 |
38 | def main(args):
39 | dataset_dir = pathlib.Path(args.dataset_dir)
40 |
41 | with open(dataset_dir /
42 | 'text_sp1-sp4') as train_set_label_file:
43 | train_set_label = [
44 | x.strip() for x in train_set_label_file.readlines()
45 | ]
46 | train_set_path={}
47 | with open(dataset_dir /
48 | 'wav.scp') as train_set_path_file:
49 | for line in train_set_path_file:
50 | line = line.strip().split()
51 | train_set_path[line[0]] = line[1]
52 |
53 | samples = collections.defaultdict(list)
54 |
55 | for line in tqdm(train_set_label):
56 | line = line.split()
57 | # sample_name = "_".join(line[0].split("_")[1:])
58 | sample_name = line[0].split("|")[1]
59 | tokens = " ".join(line[1:])
60 | speaker = line[0].split("|")[0]
61 | wav_path = train_set_path[sample_name]
62 | if os.path.exists(wav_path):
63 | samples[speaker].append((wav_path, speaker, tokens))
64 | else:
65 | print(wav_path, "is not existed")
66 |
67 | sample_list = []
68 |
69 | for speaker in sorted(samples):
70 | sample_list.extend(samples[speaker])
71 |
72 | save_scp_files(args.wav, args.speaker, args.text, sample_list)
73 |
74 |
75 | if __name__ == "__main__":
76 | main(get_args())
77 |
--------------------------------------------------------------------------------
/mfa/step3_prepare_special_tokens.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--special_tokens',
7 | type=str,
8 | help='Path to special_token.txt')
9 | return parser.parse_args()
10 |
11 | def main(args):
12 | with open(args.special_tokens, "w") as f:
13 | for line in {"sp0", "sp1", "sp2", "sp3", "sp4","engsp1", "engsp2", "engsp3", "engsp4", "", "cn_eng_sp", "eng_cn_sp", "." , "?", "LAUGH"}:
14 | f.write(f"{line}\n")
15 |
16 | if __name__ == '__main__':
17 | main(get_args())
18 |
--------------------------------------------------------------------------------
/mfa/step4_convert_text_to_phn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Tsinghua University. (authors: Jie Chen)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Convert full label pingyin sequences into phoneme sequences according to
15 | lexicon.
16 | """
17 |
18 | import argparse
19 |
20 |
21 | def get_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--text', type=str, help='Path to text.txt.')
24 | parser.add_argument('--special_tokens',
25 | type=str,
26 | help='Path to special_token.txt')
27 | parser.add_argument('--output', type=str, help='Path to output file.')
28 | return parser.parse_args()
29 |
30 |
31 | def main(args):
32 | with open(args.special_tokens) as fin:
33 | special_tokens = set([x.strip() for x in fin.readlines()])
34 | samples = []
35 | with open(args.text) as fin:
36 | for line in fin:
37 | tokens = []
38 | word = []
39 | for ph in line.strip().split():
40 | if ph in special_tokens:
41 | word = "_".join(word)
42 |
43 | tokens.append(word)
44 | tokens.append(ph)
45 | word = []
46 | else:
47 | ph = ph #[A] -> A
48 | word.append(ph)
49 |
50 | samples.append(' '.join(tokens))
51 | with open(args.output, 'w') as fout:
52 | fout.writelines([x + '\n' for x in samples])
53 |
54 |
55 | if __name__ == '__main__':
56 | main(get_args())
57 |
--------------------------------------------------------------------------------
/mfa/step5_prepare_alignment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright 2022 Binbin Zhang(binbzha@qq.com), Jie Chen(unrea1sama@outlook.com)
3 | """Generate lab files from data list for alignment
4 | """
5 |
6 | import argparse
7 | import pathlib
8 | import random, os
9 | from tqdm import tqdm
10 | def get_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--wav", type=str, help='Path to wav.txt.')
13 | parser.add_argument("--speaker", type=str, help='Path to speaker.txt.')
14 | parser.add_argument(
15 | "--text",
16 | type=str,
17 | help=('Path to text.txt. ',
18 | 'It should only contain phonemes and special tokens.'))
19 | parser.add_argument('--special_tokens',
20 | type=str,
21 | help='Path to special_token.txt.')
22 | parser.add_argument(
23 | '--pronounciation_dict',
24 | type=str,
25 | help='Path to export pronounciation dictionary for MFA.')
26 | parser.add_argument('--output_dir',
27 | type=str,
28 | help='Path to directory for exporting .lab files.')
29 | return parser.parse_args()
30 |
31 |
32 | def main(args):
33 | output_dir = pathlib.Path(args.output_dir)
34 | pronounciation_dict = set()
35 | with open(args.special_tokens) as fin:
36 | special_tokens = set([x.strip() for x in fin.readlines()])
37 |
38 | num_speaker = 1
39 | with open(args.wav) as f:
40 | index = [i for i in range(len(f.readlines()))]
41 | _mfa_groups = [index[i::num_speaker] for i in range(num_speaker)]
42 | mfa_groups = []
43 | for i, group in enumerate(_mfa_groups):
44 | mfa_groups.extend([i for _ in range(len(group))])
45 |
46 | random.shuffle(mfa_groups)
47 | os.system(f"rm -rf {args.output_dir}/*")
48 | with open(args.wav) as fwav, open(args.speaker) as fspeaker, open(
49 | args.text) as ftext:
50 | for wav_path, speaker, text, i in tqdm(zip(fwav, fspeaker, ftext, mfa_groups)):
51 | i = speaker.strip()#str(i)
52 | wav_path, speaker, text = (pathlib.Path(wav_path.strip()),
53 | speaker.strip(), text.strip().split())
54 | lab_dir = output_dir / i
55 | lab_dir.mkdir(parents=True, exist_ok=True)
56 |
57 | name=wav_path.stem.strip()
58 |
59 | lab_file = output_dir / i / f'{i}_{name}.lab'
60 | wav_file = output_dir / i / f'{i}_{name}.wav'
61 | try:
62 | os.symlink(wav_path, wav_file)
63 | except:
64 | print("ERROR PATH",wav_path)
65 | continue
66 |
67 |
68 | with lab_file.open('w') as fout:
69 | text_no_special_tokens = [ph for ph in text if ph not in special_tokens]
70 | pronounciation_dict |= set(text_no_special_tokens)
71 | fout.writelines([' '.join(text_no_special_tokens)])
72 | with open(args.pronounciation_dict, 'w') as fout:
73 | fout.writelines([
74 | '{} {}\n'.format(symbol, " ".join(symbol.split("_"))) for symbol in pronounciation_dict
75 | ])
76 |
77 |
78 | if __name__ == '__main__':
79 | main(get_args())
80 |
--------------------------------------------------------------------------------
/mfa/step8_make_data_list.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Tsinghua University(Jie Chen)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import jsonlines
17 | import pathlib
18 |
19 | def read_lists(list_file):
20 | lists = []
21 | with open(list_file, 'r', encoding='utf8') as fin:
22 | for line in fin:
23 | lists.append(line.strip())
24 | return lists
25 |
26 | def get_args():
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--wav', type=str, help='Path to wav.txt.')
29 | parser.add_argument('--speaker', type=str, help='Path to speaker.txt.')
30 | parser.add_argument('--text', type=str, help='Path to text.txt.')
31 | parser.add_argument('--duration', type=str, help='Path to duration.txt.')
32 | parser.add_argument('--datalist_path',
33 | type=str,
34 | help='Path to export datalist.jsonl.')
35 | args = parser.parse_args()
36 | return args
37 |
38 |
39 | def main(args):
40 | wavs = read_lists(args.wav)
41 | speakers = read_lists(args.speaker)
42 | texts = read_lists(args.text)
43 | durations = read_lists(args.duration)
44 | with jsonlines.open(args.datalist_path, 'w') as fdatalist:
45 | for wav, speaker, text, duration in zip(wavs, speakers, texts,
46 | durations):
47 | key = pathlib.Path(wav).stem
48 | fdatalist.write({
49 | 'key': key,
50 | 'wav_path': wav,
51 | 'speaker': speaker,
52 | 'text': text.split(),
53 | 'duration': [float(x) for x in duration.split()]
54 | })
55 |
56 |
57 | if __name__ == '__main__':
58 | main(get_args())
59 |
--------------------------------------------------------------------------------
/mfa/step9_datalist_from_mfa.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import jsonlines
4 | import argparse
5 | import os
6 |
7 |
8 | def main(args):
9 | ROOT_DIR=os.path.abspath(args.data_dir)
10 | TEXT_DIR=f"{ROOT_DIR}/text"
11 | MFA_DIR=f"{ROOT_DIR}/mfa"
12 | TRAIN_DIR=f"{ROOT_DIR}/train"
13 | VALID_DIR=f"{ROOT_DIR}/valid"
14 |
15 | with jsonlines.open(f"{MFA_DIR}/datalist.jsonl") as f:
16 | data = list(f)
17 |
18 | with jsonlines.open(f"{TEXT_DIR}/datalist.jsonl") as f:
19 | data_ref = {sample["key"]:sample for sample in list(f)}
20 |
21 | new_data = []
22 | with jsonlines.open(f"{TEXT_DIR}/datalist_mfa.jsonl", "w") as f:
23 | for sample in data:
24 | if "duration" in sample:
25 | del sample["duration"]
26 |
27 |
28 |
29 | # if "emotion" not in sample:
30 | # sample["emotion"]="default"
31 |
32 |
33 | for i, ph in enumerate(sample["text"]):
34 | if ph.isupper():
35 | sample["text"][i] = "[" + ph + "]"
36 |
37 | if ph =="cnengsp":
38 | sample["text"][i] = "cn_eng_sp"
39 | if ph =="engcnsp":
40 | sample["text"][i] = "eng_cn_sp"
41 |
42 | sample_ref = data_ref[sample["key"]]
43 |
44 | sample["original_text"]=sample_ref["original_text"]
45 | sample["prompt"] = sample_ref["prompt"]
46 | new_data.append(sample)
47 | f.write(sample)
48 |
49 | with jsonlines.open(f"{TRAIN_DIR}/datalist_mfa.jsonl", "w") as f:
50 | for sample in new_data[:-3]:
51 | f.write(sample)
52 |
53 | with jsonlines.open(f"{VALID_DIR}/datalist_mfa.jsonl", "w") as f:
54 | for sample in data[-3:]:
55 | f.write(sample)
56 |
57 | if __name__ == "__main__":
58 | p = argparse.ArgumentParser()
59 | p.add_argument('--data_dir', type=str, required=True)
60 | args = p.parse_args()
61 |
62 | main(args)
--------------------------------------------------------------------------------
/models/hifigan/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import jsonlines
17 | from transformers import AutoTokenizer
18 | import os, sys
19 | import numpy as np
20 | from scipy.io.wavfile import read
21 | from torch.nn.utils.rnn import pad_sequence
22 | import copy
23 | from models.prompt_tts_modified.tacotron_stft import TacotronSTFT
24 |
25 |
26 | def get_mel(filename, stft, sampling_rate, trim=False):
27 |
28 | sr, wav = read(filename)
29 | if sr != sampling_rate:
30 | raise ValueError("{} SR doesn't match target {} SR".format(sr, sampling_rate))
31 |
32 | wav = wav / 32768.0
33 |
34 | wav = torch.FloatTensor(wav.astype(np.float32))
35 | ### trimming ###
36 | if trim:
37 | frac = 0.005
38 | start = torch.where(
39 | torch.abs(wav)>(torch.abs(wav).max()*frac)
40 | )[0][0]
41 | end = torch.where(torch.abs(wav)>(torch.abs(wav).max()*frac))[0][-1]
42 | ### 50ms silence padding ###
43 | wav = torch.nn.functional.pad(wav[start:end], (sampling_rate//20, sampling_rate//20))
44 | melspec = stft.mel_spectrogram(wav.unsqueeze(0))
45 |
46 | return melspec.squeeze(0), wav
47 |
48 | def pad_mel(data, downsample_ratio, max_len ):
49 | batch_size = len(data)
50 | num_mels = data[0].size(0)
51 | padded = torch.zeros((batch_size, num_mels, max_len))
52 | for i in range(batch_size):
53 | lens = data[i].size(1)
54 | if lens % downsample_ratio!=0:
55 | data[i] = data[i][:,:-(lens % downsample_ratio)]
56 | padded[i, :, :data[i].size(1)] = data[i]
57 |
58 | return padded
59 |
60 | class DatasetTTS(torch.utils.data.Dataset):
61 | def __init__(self, data_path, config):
62 | self.sampling_rate=config.sampling_rate
63 | self.datalist = self.load_files(data_path)
64 | self.stft = TacotronSTFT(
65 | filter_length=config.filter_length,
66 | hop_length=config.hop_length,
67 | win_length=config.win_length,
68 | n_mel_channels=config.n_mel_channels,
69 | sampling_rate=config.sampling_rate,
70 | mel_fmin=config.mel_fmin,
71 | mel_fmax=config.mel_fmax
72 | )
73 | self.trim = config.trim
74 | self.config=config
75 |
76 |
77 | def load_files(self, data_path):
78 | with jsonlines.open(data_path) as f:
79 | data = list(f)
80 | return data
81 |
82 |
83 | def __len__(self):
84 | return len(self.datalist)
85 |
86 | def __getitem__(self, index):
87 |
88 | uttid = self.datalist[index]["key"]
89 |
90 |
91 | mel, wav = get_mel(self.datalist[index]["wav_path"], self.stft, self.sampling_rate, trim=self.trim)
92 |
93 | return {
94 | "mel": mel,
95 | "uttid": uttid,
96 | "wav": wav,
97 | }
98 |
99 |
100 | def TextMelCollate(self, data):
101 |
102 | # Right zero-pad melspectrogram
103 | mel = [x['mel'] for x in data]
104 | max_target_len = max([x.shape[1] for x in mel])
105 |
106 | # wav
107 | wav = [x["wav"] for x in data]
108 |
109 | padded_wav = pad_sequence(wav,
110 | batch_first=True,
111 | padding_value=0.0)
112 | padded_mel = pad_mel(mel, self.config.downsample_ratio, max_target_len)
113 |
114 | mel_lens = torch.LongTensor([x.shape[1] for x in mel])
115 |
116 | res = {
117 | "mel" : padded_mel,
118 | "mel_lens" : mel_lens,
119 | "wav" : padded_wav,
120 | }
121 | return res
122 |
123 |
124 |
--------------------------------------------------------------------------------
/models/hifigan/env.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/jik876/hifi-gan
3 | """
4 |
5 | import os
6 | import shutil
7 |
8 |
9 | class AttrDict(dict):
10 | def __init__(self, *args, **kwargs):
11 | super(AttrDict, self).__init__(*args, **kwargs)
12 | self.__dict__ = self
13 |
14 |
15 | def build_env(config, config_name, path):
16 | t_path = os.path.join(path, config_name)
17 | if config != t_path:
18 | os.makedirs(path, exist_ok=True)
19 | shutil.copyfile(config, os.path.join(path, config_name))
20 |
--------------------------------------------------------------------------------
/models/hifigan/get_random_segments.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/espnet/espnet
3 | """
4 |
5 | import torch
6 |
7 |
8 | def get_random_segments( x: torch.Tensor, x_lengths: torch.Tensor, segment_size: int):
9 | b, d, t = x.size()
10 | max_start_idx = x_lengths - segment_size
11 | max_start_idx = torch.clamp(max_start_idx, min=0)
12 | start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
13 | dtype=torch.long,
14 | )
15 | segments = get_segments(x, start_idxs, segment_size)
16 | return segments, start_idxs, segment_size
17 |
18 |
19 | def get_segments( x: torch.Tensor, start_idxs: torch.Tensor, segment_size: int):
20 | b, c, t = x.size()
21 | segments = x.new_zeros(b, c, segment_size)
22 | if t < segment_size:
23 | x = torch.nn.functional.pad(x, (0, segment_size - t), 'constant')
24 | for i, start_idx in enumerate(start_idxs):
25 | segment = x[i, :, start_idx : start_idx + segment_size]
26 | segments[i,:,:segment.size(1)] = segment
27 | return segments
28 |
--------------------------------------------------------------------------------
/models/hifigan/get_vocoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os, json, torch
16 | from models.hifigan.env import AttrDict
17 | from models.hifigan.models import Generator
18 |
19 | MAX_WAV_VALUE = 32768.0
20 |
21 | def vocoder(hifi_gan_path, hifi_gan_name):
22 | device = torch.device('cpu')
23 | config_file = os.path.join(os.path.split(hifi_gan_path)[0], 'config.json')
24 | with open(config_file) as f:
25 | data = f.read()
26 | global h
27 | json_config = json.loads(data)
28 | h = AttrDict(json_config)
29 | torch.manual_seed(h.seed)
30 | generator = Generator(h).to(device)
31 |
32 | state_dict_g = torch.load(hifi_gan_path+hifi_gan_name, map_location=device)
33 |
34 | generator.load_state_dict(state_dict_g['generator'])
35 | generator.eval()
36 | generator.remove_weight_norm()
37 | return generator
38 |
39 | def vocoder2(config,hifi_gan_ckpt_path):
40 | device = torch.device('cpu')
41 | global h
42 | generator = Generator(config.model).to(device)
43 |
44 | state_dict_g = torch.load(hifi_gan_ckpt_path, map_location=device)
45 |
46 | generator.load_state_dict(state_dict_g['generator'])
47 | generator.eval()
48 | generator.remove_weight_norm()
49 | return generator
50 |
51 |
52 | def vocoder_inference(vocoder, melspec, max_db, min_db):
53 | with torch.no_grad():
54 | x = melspec*(max_db-min_db)+min_db
55 | device = torch.device('cpu')
56 | x = torch.FloatTensor(x).to(device)
57 | y_g_hat = vocoder(x)
58 | audio = y_g_hat.squeeze().numpy()
59 | return audio
--------------------------------------------------------------------------------
/models/hifigan/pretrained_discriminator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 | import torch
17 | from models.hifigan.models import MultiScaleDiscriminator, MultiPeriodDiscriminator
18 |
19 |
20 |
21 | class Discriminator(nn.Module):
22 | def __init__(self, config) -> None:
23 | super().__init__()
24 |
25 | self.msd = MultiScaleDiscriminator()
26 | self.mpd = MultiPeriodDiscriminator()
27 | if config.pretrained_discriminator:
28 | state_dict_do = torch.load(config.pretrained_discriminator,map_location="cpu")
29 |
30 | self.mpd.load_state_dict(state_dict_do['mpd'])
31 | self.msd.load_state_dict(state_dict_do['msd'])
32 | print("pretrained discriminator is loaded")
33 | def forward(self, y, y_hat):
34 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
35 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
36 |
37 | return y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g
--------------------------------------------------------------------------------
/models/prompt_tts_modified/audio_processing.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/espnet/espnet
3 | """
4 |
5 | import torch
6 | import numpy as np
7 | from scipy.signal import get_window
8 | import librosa.util as librosa_util
9 |
10 |
11 | def window_sumsquare(window,
12 | n_frames,
13 | hop_length=200,
14 | win_length=800,
15 | n_fft=800,
16 | dtype=np.float32,
17 | norm=None):
18 | if win_length is None:
19 | win_length = n_fft
20 |
21 | n = n_fft + hop_length * (n_frames - 1)
22 | x = np.zeros(n, dtype=dtype)
23 |
24 | # Compute the squared window at the desired length
25 | win_sq = get_window(window, win_length, fftbins=True)
26 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2
27 | win_sq = librosa_util.pad_center(win_sq, n_fft)
28 |
29 | # Fill the envelope
30 | for i in range(n_frames):
31 | sample = i * hop_length
32 | x[sample:min(n, sample+n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
33 | return x
34 |
35 |
36 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
37 |
38 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
39 | angles = angles.astype(np.float32)
40 | angles = torch.autograd.Variable(torch.from_numpy(angles))
41 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
42 |
43 | for i in range(n_iters):
44 | _, angles = stft_fn.transform(signal)
45 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
46 | return signal
47 |
48 |
49 |
50 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
51 | return torch.log(torch.clamp(x, min=clip_val) * C)
52 |
53 |
54 | def dynamic_range_decompression(x, C=1):
55 | return torch.exp(x) / C
56 |
--------------------------------------------------------------------------------
/models/prompt_tts_modified/jets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import numpy as np
18 | from typing import Optional
19 |
20 | from models.prompt_tts_modified.model_open_source import PromptTTS
21 | from models.hifigan.models import Generator as HiFiGANGenerator
22 |
23 | from models.hifigan.get_random_segments import get_random_segments, get_segments
24 |
25 |
26 | class JETSGenerator(nn.Module):
27 | def __init__(self, config) -> None:
28 |
29 | super().__init__()
30 |
31 | self.upsample_factor=int(np.prod(config.model.upsample_rates))
32 |
33 | self.segment_size = config.segment_size
34 |
35 | self.am = PromptTTS(config)
36 |
37 | self.generator = HiFiGANGenerator(config.model)
38 |
39 | # try:
40 | # model_CKPT = torch.load(config.pretrained_am, map_location="cpu")
41 | # self.am.load_state_dict(model_CKPT['model'])
42 | # state_dict_g = torch.load(config.pretrained_vocoder,map_location="cpu")
43 | # self.generator.load_state_dict(state_dict_g['generator'])
44 | # print("pretrained generator is loaded")
45 | # except:
46 | # print("pretrained generator is not loaded for training")
47 | self.config=config
48 |
49 |
50 | def forward(self, inputs_ling, input_lengths, inputs_speaker, inputs_style_embedding , inputs_content_embedding, mel_targets=None, output_lengths=None, pitch_targets=None, energy_targets=None, alpha=1.0, cut_flag=True):
51 |
52 | outputs = self.am(inputs_ling, input_lengths, inputs_speaker, inputs_style_embedding , inputs_content_embedding, mel_targets , output_lengths , pitch_targets , energy_targets , alpha)
53 |
54 |
55 | if mel_targets is not None and cut_flag:
56 | z_segments, z_start_idxs, segment_size = get_random_segments(
57 | outputs["dec_outputs"].transpose(1,2),
58 | output_lengths,
59 | self.segment_size,
60 | )
61 | else:
62 | z_segments = outputs["dec_outputs"].transpose(1,2)
63 | z_start_idxs=None
64 | segment_size=self.segment_size
65 |
66 | wav = self.generator(z_segments)
67 |
68 | outputs["wav_predictions"] = wav
69 | outputs["z_start_idxs"]= z_start_idxs
70 | outputs["segment_size"] = segment_size
71 | return outputs
72 |
--------------------------------------------------------------------------------
/models/prompt_tts_modified/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/alibaba-damo-academy/KAN-TTS.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 |
10 | def get_mask_from_lengths(lengths, max_len=None):
11 | batch_size = lengths.shape[0]
12 | if max_len is None:
13 | max_len = torch.max(lengths).item()
14 |
15 | ids = (
16 | torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
17 | )
18 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
19 |
20 | return mask
21 |
22 | class MelReconLoss(torch.nn.Module):
23 | def __init__(self, loss_type="mae"):
24 | super(MelReconLoss, self).__init__()
25 | self.loss_type = loss_type
26 | if loss_type == "mae":
27 | self.criterion = torch.nn.L1Loss(reduction="none")
28 | elif loss_type == "mse":
29 | self.criterion = torch.nn.MSELoss(reduction="none")
30 | else:
31 | raise ValueError("Unknown loss type: {}".format(loss_type))
32 |
33 | def forward(self, output_lengths, mel_targets, dec_outputs, postnet_outputs=None):
34 | """
35 | mel_targets: B, C, T
36 | """
37 | output_masks = get_mask_from_lengths(
38 | output_lengths, max_len=mel_targets.size(1)
39 | )
40 | output_masks = ~output_masks
41 | valid_outputs = output_masks.sum()
42 |
43 | mel_loss_ = torch.sum(
44 | self.criterion(mel_targets, dec_outputs) * output_masks.unsqueeze(-1)
45 | ) / (valid_outputs * mel_targets.size(-1))
46 |
47 | if postnet_outputs is not None:
48 | mel_loss = torch.sum(
49 | self.criterion(mel_targets, postnet_outputs)
50 | * output_masks.unsqueeze(-1)
51 | ) / (valid_outputs * mel_targets.size(-1))
52 | else:
53 | mel_loss = 0.0
54 |
55 | return mel_loss_, mel_loss
56 |
57 |
58 |
59 | class ForwardSumLoss(torch.nn.Module):
60 |
61 | def __init__(self):
62 | super().__init__()
63 |
64 | def forward(
65 | self,
66 | log_p_attn: torch.Tensor,
67 | ilens: torch.Tensor,
68 | olens: torch.Tensor,
69 | blank_prob: float = np.e**-1,
70 | ) -> torch.Tensor:
71 | B = log_p_attn.size(0)
72 |
73 | # a row must be added to the attention matrix to account for
74 | # blank token of CTC loss
75 | # (B,T_feats,T_text+1)
76 | log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob))
77 |
78 | loss = 0
79 | for bidx in range(B):
80 | # construct target sequnece.
81 | # Every text token is mapped to a unique sequnece number.
82 | target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0)
83 | cur_log_p_attn_pd = log_p_attn_pd[
84 | bidx, : olens[bidx], : ilens[bidx] + 1
85 | ].unsqueeze(
86 | 1
87 | ) # (T_feats,1,T_text+1)
88 | cur_log_p_attn_pd = F.log_softmax(cur_log_p_attn_pd, dim=-1)
89 | loss += F.ctc_loss(
90 | log_probs=cur_log_p_attn_pd,
91 | targets=target_seq,
92 | input_lengths=olens[bidx : bidx + 1],
93 | target_lengths=ilens[bidx : bidx + 1],
94 | zero_infinity=True,
95 | )
96 | loss = loss / B
97 | return loss
98 |
99 | class ProsodyReconLoss(torch.nn.Module):
100 | def __init__(self, loss_type="mae"):
101 | super(ProsodyReconLoss, self).__init__()
102 | self.loss_type = loss_type
103 | if loss_type == "mae":
104 | self.criterion = torch.nn.L1Loss(reduction="none")
105 | elif loss_type == "mse":
106 | self.criterion = torch.nn.MSELoss(reduction="none")
107 | else:
108 | raise ValueError("Unknown loss type: {}".format(loss_type))
109 |
110 | def forward(
111 | self,
112 | input_lengths,
113 | duration_targets,
114 | pitch_targets,
115 | energy_targets,
116 | log_duration_predictions,
117 | pitch_predictions,
118 | energy_predictions,
119 | ):
120 | input_masks = get_mask_from_lengths(
121 | input_lengths, max_len=duration_targets.size(1)
122 | )
123 | input_masks = ~input_masks
124 | valid_inputs = input_masks.sum()
125 |
126 | dur_loss = (
127 | torch.sum(
128 | self.criterion(
129 | torch.log(duration_targets.float() + 1), log_duration_predictions
130 | )
131 | * input_masks
132 | )
133 | / valid_inputs
134 | )
135 | pitch_loss = (
136 | torch.sum(self.criterion(pitch_targets, pitch_predictions) * input_masks)
137 | / valid_inputs
138 | )
139 | energy_loss = (
140 | torch.sum(self.criterion(energy_targets, energy_predictions) * input_masks)
141 | / valid_inputs
142 | )
143 |
144 | return dur_loss, pitch_loss, energy_loss
145 |
146 |
147 | class TTSLoss(torch.nn.Module):
148 | def __init__(self, loss_type="mae") -> None:
149 | super().__init__()
150 |
151 | self.Mel_Loss = MelReconLoss()
152 | self.Prosodu_Loss = ProsodyReconLoss(loss_type)
153 | self.ForwardSum_Loss = ForwardSumLoss()
154 |
155 | def forward(self, outputs):
156 |
157 | dec_outputs = outputs["dec_outputs"]
158 | postnet_outputs = outputs["postnet_outputs"]
159 | log_duration_predictions = outputs["log_duration_predictions"]
160 | pitch_predictions = outputs["pitch_predictions"]
161 | energy_predictions = outputs["energy_predictions"]
162 | duration_targets = outputs["duration_targets"]
163 | pitch_targets = outputs["pitch_targets"]
164 | energy_targets = outputs["energy_targets"]
165 | output_lengths = outputs["output_lengths"]
166 | input_lengths = outputs["input_lengths"]
167 | mel_targets = outputs["mel_targets"].transpose(1,2)
168 | log_p_attn = outputs["log_p_attn"]
169 | bin_loss = outputs["bin_loss"]
170 |
171 | dec_mel_loss, postnet_mel_loss = self.Mel_Loss(output_lengths, mel_targets, dec_outputs, postnet_outputs)
172 | dur_loss, pitch_loss, energy_loss = self.Prosodu_Loss(input_lengths, duration_targets, pitch_targets, energy_targets, log_duration_predictions, pitch_predictions, energy_predictions)
173 | forwardsum_loss = self.ForwardSum_Loss(log_p_attn, input_lengths, output_lengths)
174 |
175 | res = {
176 | "dec_mel_loss": dec_mel_loss,
177 | "postnet_mel_loss": postnet_mel_loss,
178 | "dur_loss": dur_loss,
179 | "pitch_loss": pitch_loss,
180 | "energy_loss": energy_loss,
181 | "forwardsum_loss": forwardsum_loss,
182 | "bin_loss": bin_loss,
183 | }
184 |
185 | return res
--------------------------------------------------------------------------------
/models/prompt_tts_modified/modules/alignment.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/espnet/espnet.
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from numba import jit
10 | from scipy.stats import betabinom
11 |
12 |
13 | class AlignmentModule(nn.Module):
14 |
15 | def __init__(self, adim, odim, cache_prior=True):
16 | super().__init__()
17 | self.cache_prior = cache_prior
18 | self._cache = {}
19 |
20 | self.t_conv1 = nn.Conv1d(adim, adim, kernel_size=3, padding=1)
21 | self.t_conv2 = nn.Conv1d(adim, adim, kernel_size=1, padding=0)
22 |
23 | self.f_conv1 = nn.Conv1d(odim, adim, kernel_size=3, padding=1)
24 | self.f_conv2 = nn.Conv1d(adim, adim, kernel_size=3, padding=1)
25 | self.f_conv3 = nn.Conv1d(adim, adim, kernel_size=1, padding=0)
26 |
27 | def forward(self, text, feats, text_lengths, feats_lengths, x_masks=None):
28 |
29 | text = text.transpose(1, 2)
30 | text = F.relu(self.t_conv1(text))
31 | text = self.t_conv2(text)
32 | text = text.transpose(1, 2)
33 |
34 | feats = feats.transpose(1, 2)
35 | feats = F.relu(self.f_conv1(feats))
36 | feats = F.relu(self.f_conv2(feats))
37 | feats = self.f_conv3(feats)
38 | feats = feats.transpose(1, 2)
39 |
40 | dist = feats.unsqueeze(2) - text.unsqueeze(1)
41 | dist = torch.norm(dist, p=2, dim=3)
42 | score = -dist
43 |
44 | if x_masks is not None:
45 | x_masks = x_masks.unsqueeze(-2)
46 | score = score.masked_fill(x_masks, -np.inf)
47 |
48 | log_p_attn = F.log_softmax(score, dim=-1)
49 | # add beta-binomial prior
50 | bb_prior = self._generate_prior(
51 | text_lengths,
52 | feats_lengths,
53 | ).to(dtype=log_p_attn.dtype, device=log_p_attn.device)
54 |
55 | log_p_attn = log_p_attn + bb_prior
56 |
57 | return log_p_attn
58 |
59 | def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor:
60 |
61 | B = len(text_lengths)
62 | T_text = text_lengths.max()
63 | T_feats = feats_lengths.max()
64 |
65 | bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf)
66 | for bidx in range(B):
67 | T = feats_lengths[bidx].item()
68 | N = text_lengths[bidx].item()
69 |
70 | key = str(T) + "," + str(N)
71 | if self.cache_prior and key in self._cache:
72 | prob = self._cache[key]
73 | else:
74 | alpha = w * np.arange(1, T + 1, dtype=float) # (T,)
75 | beta = w * np.array([T - t + 1 for t in alpha])
76 | k = np.arange(N)
77 | batched_k = k[..., None] # (N,1)
78 | prob = betabinom.logpmf(batched_k, N, alpha, beta) # (N,T)
79 |
80 | # store cache
81 | if self.cache_prior and key not in self._cache:
82 | self._cache[key] = prob
83 |
84 | prob = torch.from_numpy(prob).transpose(0, 1) # -> (T,N)
85 | bb_prior[bidx, :T, :N] = prob
86 |
87 | return bb_prior
88 |
89 |
90 |
91 |
92 | @jit(nopython=True)
93 | def _monotonic_alignment_search(log_p_attn):
94 |
95 | T_mel = log_p_attn.shape[0]
96 | T_inp = log_p_attn.shape[1]
97 | Q = np.full((T_inp, T_mel), fill_value=-np.inf)
98 |
99 | log_prob = log_p_attn.transpose(1, 0) # -> (T_inp,T_mel)
100 | # 1. Q <- init first row for all j
101 | for j in range(T_mel):
102 | Q[0, j] = log_prob[0, : j + 1].sum()
103 |
104 | # 2.
105 | for j in range(1, T_mel):
106 | for i in range(1, min(j + 1, T_inp)):
107 | Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j]
108 |
109 | # 3.
110 | A = np.full((T_mel,), fill_value=T_inp - 1)
111 | for j in range(T_mel - 2, -1, -1): # T_mel-2, ..., 0
112 | # 'i' in {A[j+1]-1, A[j+1]}
113 | i_a = A[j + 1] - 1
114 | i_b = A[j + 1]
115 | if i_b == 0:
116 | argmax_i = 0
117 | elif Q[i_a, j] >= Q[i_b, j]:
118 | argmax_i = i_a
119 | else:
120 | argmax_i = i_b
121 | A[j] = argmax_i
122 | return A
123 |
124 |
125 | def viterbi_decode(log_p_attn, text_lengths, feats_lengths):
126 |
127 | B = log_p_attn.size(0)
128 | T_text = log_p_attn.size(2)
129 | device = log_p_attn.device
130 |
131 | bin_loss = 0
132 | ds = torch.zeros((B, T_text), device=device)
133 | for b in range(B):
134 | cur_log_p_attn = log_p_attn[b, : feats_lengths[b], : text_lengths[b]]
135 | viterbi = _monotonic_alignment_search(cur_log_p_attn.detach().cpu().numpy())
136 | _ds = np.bincount(viterbi)
137 | ds[b, : len(_ds)] = torch.from_numpy(_ds).to(device)
138 |
139 | t_idx = torch.arange(feats_lengths[b])
140 | bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean()
141 | bin_loss = bin_loss / B
142 | return ds, bin_loss
143 |
144 |
145 | @jit(nopython=True)
146 | def _average_by_duration(ds, xs, text_lengths, feats_lengths):
147 | B = ds.shape[0]
148 | xs_avg = np.zeros_like(ds)
149 | ds = ds.astype(np.int32)
150 | for b in range(B):
151 | t_text = text_lengths[b]
152 | t_feats = feats_lengths[b]
153 | d = ds[b, :t_text]
154 | d_cumsum = d.cumsum()
155 | d_cumsum = [0] + list(d_cumsum)
156 | x = xs[b, :t_feats]
157 | for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])):
158 | if len(x[start:end]) != 0:
159 | xs_avg[b, n] = x[start:end].mean()
160 | else:
161 | xs_avg[b, n] = 0
162 | return xs_avg
163 |
164 |
165 | def average_by_duration(ds, xs, text_lengths, feats_lengths):
166 |
167 | device = ds.device
168 | args = [ds, xs, text_lengths, feats_lengths]
169 | args = [arg.detach().cpu().numpy() for arg in args]
170 | xs_avg = _average_by_duration(*args)
171 | xs_avg = torch.from_numpy(xs_avg).to(device)
172 | return xs_avg
173 |
174 |
175 | class GaussianUpsampling(torch.nn.Module):
176 |
177 | def __init__(self, delta=0.1):
178 | super().__init__()
179 | self.delta = delta
180 | def forward(self, hs, ds, h_masks=None, d_masks=None, alpha=1.0):
181 |
182 |
183 | ds = ds * alpha
184 |
185 | B = ds.size(0)
186 | device = ds.device
187 | if ds.sum() == 0:
188 | # NOTE(kan-bayashi): This case must not be happened in teacher forcing.
189 | # It will be happened in inference with a bad duration predictor.
190 | # So we do not need to care the padded sequence case here.
191 | ds[ds.sum(dim=1).eq(0)] = 1
192 |
193 | if h_masks is None:
194 | mel_lenghs = torch.sum(ds, dim=-1).int() # lengths = [5, 3, 2]
195 | T_feats = mel_lenghs.max().item() # T_feats = 5
196 | else:
197 | T_feats = h_masks.size(-1)
198 | t = torch.arange(0, T_feats).unsqueeze(0).repeat(B,1).to(device).float()
199 | if h_masks is not None:
200 | t = t * h_masks.float()
201 |
202 | c = ds.cumsum(dim=-1) - ds/2
203 |
204 | energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1)) ** 2
205 |
206 | if d_masks is not None:
207 | energy = energy.masked_fill(~(d_masks.unsqueeze(1).repeat(1,T_feats,1)), -float("inf"))
208 |
209 | p_attn = torch.softmax(energy, dim=2) # (B, T_feats, T_text)
210 | hs = torch.matmul(p_attn, hs)
211 | return hs
--------------------------------------------------------------------------------
/models/prompt_tts_modified/modules/initialize.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/espnet/espnet.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 |
10 |
11 | def initialize(model: torch.nn.Module, init: str):
12 | for p in model.parameters():
13 | if p.dim() > 1:
14 | if init == "xavier_uniform":
15 | torch.nn.init.xavier_uniform_(p.data)
16 | elif init == "xavier_normal":
17 | torch.nn.init.xavier_normal_(p.data)
18 | elif init == "kaiming_uniform":
19 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
20 | elif init == "kaiming_normal":
21 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
22 | else:
23 | raise ValueError("Unknown initialization: " + init)
24 | # bias init
25 | for p in model.parameters():
26 | if p.dim() == 1:
27 | p.data.zero_()
28 |
29 | # reset some modules with default init
30 | for m in model.modules():
31 | if isinstance(
32 | m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)
33 | ):
34 | m.reset_parameters()
35 | if hasattr(m, "espnet_initialization_fn"):
36 | m.espnet_initialization_fn()
37 |
38 | # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization
39 | if getattr(model, "encoder", None) and getattr(
40 | model.encoder, "reload_pretrained_parameters", None
41 | ):
42 | model.encoder.reload_pretrained_parameters()
43 | if getattr(model, "frontend", None) and getattr(
44 | model.frontend, "reload_pretrained_parameters", None
45 | ):
46 | model.frontend.reload_pretrained_parameters()
47 | if getattr(model, "postencoder", None) and getattr(
48 | model.postencoder, "reload_pretrained_parameters", None
49 | ):
50 | model.postencoder.reload_pretrained_parameters()
--------------------------------------------------------------------------------
/models/prompt_tts_modified/modules/variance.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/espnet/espnet.
3 | """
4 |
5 | import torch
6 |
7 | from models.prompt_tts_modified.modules.encoder import LayerNorm
8 |
9 | class DurationPredictor(torch.nn.Module):
10 |
11 | def __init__(
12 | self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0
13 | ):
14 |
15 | super(DurationPredictor, self).__init__()
16 | self.offset = offset
17 | self.conv = torch.nn.ModuleList()
18 | for idx in range(n_layers):
19 | in_chans = idim if idx == 0 else n_chans
20 | self.conv += [
21 | torch.nn.Sequential(
22 | torch.nn.Conv1d(
23 | in_chans,
24 | n_chans,
25 | kernel_size,
26 | stride=1,
27 | padding=(kernel_size - 1) // 2,
28 | ),
29 | torch.nn.ReLU(),
30 | LayerNorm(n_chans, dim=1),
31 | torch.nn.Dropout(dropout_rate),
32 | )
33 | ]
34 | self.linear = torch.nn.Linear(n_chans, 1)
35 |
36 | def _forward(self, xs, x_masks=None, is_inference=False):
37 |
38 | if x_masks is not None:
39 | xs = xs.masked_fill(x_masks, 0.0)
40 |
41 | xs = xs.transpose(1, -1) # (B, idim, Tmax)
42 | for f in self.conv:
43 | xs = f(xs) # (B, C, Tmax)
44 |
45 | # NOTE: calculate in log domain
46 | xs = self.linear(xs.transpose(1, -1)) # (B, Tmax)
47 | if is_inference:
48 | # NOTE: calculate in linear domain
49 | xs = torch.clamp(
50 | torch.round(xs.exp() - self.offset), min=0
51 | ).long() # avoid negative value
52 |
53 | if x_masks is not None:
54 | xs = xs.masked_fill(x_masks, 0.0)
55 |
56 | return xs.squeeze(-1)
57 |
58 | def forward(self, xs, x_masks=None):
59 |
60 | return self._forward(xs, x_masks, False)
61 |
62 | def inference(self, xs, x_masks=None):
63 |
64 | return self._forward(xs, x_masks, True)
65 |
66 |
67 |
68 | class VariancePredictor(torch.nn.Module):
69 |
70 |
71 | def __init__(
72 | self,
73 | idim: int,
74 | n_layers: int = 2,
75 | n_chans: int = 384,
76 | kernel_size: int = 3,
77 | bias: bool = True,
78 | dropout_rate: float = 0.5,
79 | ):
80 | super().__init__()
81 | self.conv = torch.nn.ModuleList()
82 | for idx in range(n_layers):
83 | in_chans = idim if idx == 0 else n_chans
84 | self.conv += [
85 | torch.nn.Sequential(
86 | torch.nn.Conv1d(
87 | in_chans,
88 | n_chans,
89 | kernel_size,
90 | stride=1,
91 | padding=(kernel_size - 1) // 2,
92 | bias=bias,
93 | ),
94 | torch.nn.ReLU(),
95 | LayerNorm(n_chans, dim=1),
96 | torch.nn.Dropout(dropout_rate),
97 | )
98 | ]
99 | self.linear = torch.nn.Linear(n_chans, 1)
100 |
101 | def forward(self, xs: torch.Tensor, x_masks: torch.Tensor = None) -> torch.Tensor:
102 | """Calculate forward propagation.
103 |
104 | Args:
105 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
106 | x_masks (ByteTensor): Batch of masks indicating padded part (B, Tmax).
107 |
108 | Returns:
109 | Tensor: Batch of predicted sequences (B, Tmax, 1).
110 |
111 | """
112 | if x_masks is not None:
113 | xs = xs.masked_fill(x_masks, 0.0)
114 |
115 | xs = xs.transpose(1, -1) # (B, idim, Tmax)
116 | for f in self.conv:
117 | xs = f(xs) # (B, C, Tmax)
118 |
119 | xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1)
120 |
121 | if x_masks is not None:
122 | xs = xs.masked_fill(x_masks, 0.0)
123 |
124 | return xs.squeeze(-1)
--------------------------------------------------------------------------------
/models/prompt_tts_modified/scheduler.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/alibaba-damo-academy/KAN-TTS.
3 | """
4 |
5 | from torch.optim.lr_scheduler import *
6 | from torch.optim.lr_scheduler import _LRScheduler
7 |
8 | class FindLR(_LRScheduler):
9 |
10 |
11 | def __init__(self, optimizer, max_steps, max_lr=10):
12 | self.max_steps = max_steps
13 | self.max_lr = max_lr
14 | super().__init__(optimizer)
15 |
16 | def get_lr(self):
17 | return [
18 | base_lr
19 | * ((self.max_lr / base_lr) ** (self.last_epoch / (self.max_steps - 1)))
20 | for base_lr in self.base_lrs
21 | ]
22 |
23 |
24 | class NoamLR(_LRScheduler):
25 | def __init__(self, optimizer, warmup_steps):
26 | self.warmup_steps = warmup_steps
27 | super().__init__(optimizer)
28 |
29 | def get_lr(self):
30 | last_epoch = max(1, self.last_epoch)
31 | scale = self.warmup_steps ** 0.5 * min(
32 | last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)
33 | )
34 | return [base_lr * scale for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/models/prompt_tts_modified/simbert.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 |
18 | from transformers import AutoModel
19 | import numpy as np
20 |
21 | class ClassificationHead(nn.Module):
22 | def __init__(self, hidden_size, num_labels, dropout_rate=0.1) -> None:
23 | super().__init__()
24 |
25 |
26 | self.dropout = nn.Dropout(dropout_rate)
27 | self.classifier = nn.Linear(hidden_size, num_labels)
28 |
29 | def forward(self, pooled_output):
30 |
31 | return self.classifier(self.dropout(pooled_output))
32 |
33 | class StyleEncoder(nn.Module):
34 | def __init__(self, config) -> None:
35 | super().__init__()
36 |
37 | self.bert = AutoModel.from_pretrained(config.bert_path)
38 |
39 | self.pitch_clf = ClassificationHead(config.bert_hidden_size, config.pitch_n_labels)
40 | self.speed_clf = ClassificationHead(config.bert_hidden_size, config.speed_n_labels)
41 | self.energy_clf = ClassificationHead(config.bert_hidden_size, config.energy_n_labels)
42 | self.emotion_clf = ClassificationHead(config.bert_hidden_size, config.emotion_n_labels)
43 | self.style_embed_proj = nn.Linear(config.bert_hidden_size, config.style_dim)
44 |
45 |
46 |
47 |
48 | def forward(self, input_ids, token_type_ids, attention_mask):
49 | outputs = self.bert(
50 | input_ids,
51 | attention_mask=attention_mask,
52 | token_type_ids=token_type_ids,
53 | ) # return a dict having ['last_hidden_state', 'pooler_output']
54 |
55 | pooled_output = outputs["pooler_output"]
56 |
57 | pitch_outputs = self.pitch_clf(pooled_output)
58 | speed_outputs = self.speed_clf(pooled_output)
59 | energy_outputs = self.energy_clf(pooled_output)
60 | emotion_outputs = self.emotion_clf(pooled_output)
61 | pred_style_embed = self.style_embed_proj(pooled_output)
62 |
63 | res = {
64 | "pooled_output":pooled_output,
65 | "pitch_outputs":pitch_outputs,
66 | "speed_outputs":speed_outputs,
67 | "energy_outputs":energy_outputs,
68 | "emotion_outputs":emotion_outputs,
69 | # "pred_style_embed":pred_style_embed,
70 | }
71 |
72 | return res
73 |
74 |
75 |
76 | class StylePretrainLoss(nn.Module):
77 | def __init__(self) -> None:
78 | super().__init__()
79 |
80 | self.loss = nn.CrossEntropyLoss()
81 |
82 | def forward(self, inputs, outputs):
83 |
84 | pitch_loss = self.loss(outputs["pitch_outputs"], inputs["pitch"])
85 | energy_loss = self.loss(outputs["energy_outputs"], inputs["energy"])
86 | speed_loss = self.loss(outputs["speed_outputs"], inputs["speed"])
87 | emotion_loss = self.loss(outputs["emotion_outputs"], inputs["emotion"])
88 |
89 | return {
90 | "pitch_loss" : pitch_loss,
91 | "energy_loss": energy_loss,
92 | "speed_loss" : speed_loss,
93 | "emotion_loss" : emotion_loss,
94 | }
95 |
96 |
97 | class StylePretrainLoss2(StylePretrainLoss):
98 | def __init__(self) -> None:
99 | super().__init__()
100 |
101 | self.loss = nn.CrossEntropyLoss()
102 |
103 | def forward(self, inputs, outputs):
104 | res = super().forward(inputs, outputs)
105 | speaker_loss = self.loss(outputs["speaker_outputs"], inputs["speaker"])
106 | res["speaker_loss"] = speaker_loss
107 | return res
108 |
109 | def flat_accuracy(preds, labels):
110 | """
111 | Function to calculate the accuracy of our predictions vs labels
112 | """
113 | pred_flat = np.argmax(preds, axis=1).flatten()
114 | labels_flat = labels.flatten()
115 | return np.sum(pred_flat == labels_flat) / len(labels_flat)
116 |
--------------------------------------------------------------------------------
/models/prompt_tts_modified/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/pseeth/pytorch-stft.
3 | """
4 |
5 | import torch
6 | import numpy as np
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 | from scipy.signal import get_window
10 | from librosa.util import pad_center, tiny
11 | from models.prompt_tts_modified.audio_processing import window_sumsquare
12 |
13 |
14 | class STFT(torch.nn.Module):
15 | def __init__(self, filter_length=800, hop_length=200, win_length=800,
16 | window='hann'):
17 | super(STFT, self).__init__()
18 | self.filter_length = filter_length
19 | self.hop_length = hop_length
20 | self.win_length = win_length
21 | self.window = window
22 | self.forward_transform = None
23 | scale = self.filter_length / self.hop_length
24 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
25 |
26 | cutoff = int((self.filter_length / 2 + 1))
27 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
28 | np.imag(fourier_basis[:cutoff, :])])
29 |
30 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
31 | inverse_basis = torch.FloatTensor(
32 | np.linalg.pinv(scale * fourier_basis).T[:, None, :])
33 |
34 | if window is not None:
35 | assert(filter_length >= win_length)
36 | # get window and zero center pad it to filter_length
37 | fft_window = get_window(window, win_length, fftbins=True)
38 | fft_window = pad_center(data=fft_window, size=filter_length)
39 | fft_window = torch.from_numpy(fft_window).float()
40 |
41 | # window the bases
42 | forward_basis *= fft_window
43 | inverse_basis *= fft_window
44 |
45 | self.register_buffer('forward_basis', forward_basis.float())
46 | self.register_buffer('inverse_basis', inverse_basis.float())
47 |
48 | def transform(self, input_data):
49 | num_batches = input_data.size(0)
50 | num_samples = input_data.size(1)
51 |
52 | self.num_samples = num_samples
53 |
54 | # similar to librosa, reflect-pad the input
55 | input_data = input_data.view(num_batches, 1, num_samples)
56 | input_data = F.pad(
57 | input_data.unsqueeze(1),
58 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
59 | mode='reflect')
60 | input_data = input_data.squeeze(1)
61 |
62 | forward_transform = F.conv1d(
63 | input_data,
64 | Variable(self.forward_basis, requires_grad=False),
65 | stride=self.hop_length,
66 | padding=0)
67 |
68 | cutoff = int((self.filter_length / 2) + 1)
69 | real_part = forward_transform[:, :cutoff, :]
70 | imag_part = forward_transform[:, cutoff:, :]
71 |
72 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
73 | phase = torch.autograd.Variable(
74 | torch.atan2(imag_part.data, real_part.data))
75 |
76 | return magnitude, phase
77 |
78 | def inverse(self, magnitude, phase):
79 | recombine_magnitude_phase = torch.cat(
80 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
81 |
82 | inverse_transform = F.conv_transpose1d(
83 | recombine_magnitude_phase,
84 | Variable(self.inverse_basis, requires_grad=False),
85 | stride=self.hop_length,
86 | padding=0)
87 |
88 | if self.window is not None:
89 | window_sum = window_sumsquare(
90 | self.window, magnitude.size(-1), hop_length=self.hop_length,
91 | win_length=self.win_length, n_fft=self.filter_length,
92 | dtype=np.float32)
93 | # remove modulation effects
94 | approx_nonzero_indices = torch.from_numpy(
95 | np.where(window_sum > tiny(window_sum))[0])
96 | window_sum = torch.autograd.Variable(
97 | torch.from_numpy(window_sum), requires_grad=False)
98 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
99 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
100 |
101 | # scale by hop ratio
102 | inverse_transform *= float(self.filter_length) / self.hop_length
103 |
104 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
105 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
106 |
107 | return inverse_transform
108 |
109 | def forward(self, input_data):
110 | self.magnitude, self.phase = self.transform(input_data)
111 | reconstruction = self.inverse(self.magnitude, self.phase)
112 | return reconstruction
113 |
--------------------------------------------------------------------------------
/models/prompt_tts_modified/style_encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from https://github.com/yl4579/StyleTTS.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.nn.utils import spectral_norm
8 |
9 | import math
10 |
11 | class LearnedDownSample(nn.Module):
12 | def __init__(self, layer_type, dim_in):
13 | super().__init__()
14 | self.layer_type = layer_type
15 |
16 | if self.layer_type == 'none':
17 | self.conv = nn.Identity()
18 | elif self.layer_type == 'timepreserve':
19 | self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
20 | elif self.layer_type == 'half':
21 | self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
22 | else:
23 | raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
24 |
25 | def forward(self, x):
26 | return self.conv(x)
27 |
28 | class LearnedUpSample(nn.Module):
29 | def __init__(self, layer_type, dim_in):
30 | super().__init__()
31 | self.layer_type = layer_type
32 |
33 | if self.layer_type == 'none':
34 | self.conv = nn.Identity()
35 | elif self.layer_type == 'timepreserve':
36 | self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
37 | elif self.layer_type == 'half':
38 | self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
39 | else:
40 | raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
41 |
42 |
43 | def forward(self, x):
44 | return self.conv(x)
45 |
46 | class DownSample(nn.Module):
47 | def __init__(self, layer_type):
48 | super().__init__()
49 | self.layer_type = layer_type
50 |
51 | def forward(self, x):
52 | if self.layer_type == 'none':
53 | return x
54 | elif self.layer_type == 'timepreserve':
55 | return F.avg_pool2d(x, (2, 1))
56 | elif self.layer_type == 'half':
57 | if x.shape[-1] % 2 != 0:
58 | x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
59 | return F.avg_pool2d(x, 2)
60 | else:
61 | raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
62 |
63 |
64 | class UpSample(nn.Module):
65 | def __init__(self, layer_type):
66 | super().__init__()
67 | self.layer_type = layer_type
68 |
69 | def forward(self, x):
70 | if self.layer_type == 'none':
71 | return x
72 | elif self.layer_type == 'timepreserve':
73 | return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
74 | elif self.layer_type == 'half':
75 | return F.interpolate(x, scale_factor=2, mode='nearest')
76 | else:
77 | raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
78 |
79 |
80 | class ResBlk(nn.Module):
81 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
82 | normalize=False, downsample='none'):
83 | super().__init__()
84 | self.actv = actv
85 | self.normalize = normalize
86 | self.downsample = DownSample(downsample)
87 | self.downsample_res = LearnedDownSample(downsample, dim_in)
88 | self.learned_sc = dim_in != dim_out
89 | self._build_weights(dim_in, dim_out)
90 |
91 | def _build_weights(self, dim_in, dim_out):
92 | self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
93 | self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
94 | if self.normalize:
95 | self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
96 | self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
97 | if self.learned_sc:
98 | self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
99 |
100 | def _shortcut(self, x):
101 | if self.learned_sc:
102 | x = self.conv1x1(x)
103 | if self.downsample:
104 | x = self.downsample(x)
105 | return x
106 |
107 | def _residual(self, x):
108 | if self.normalize:
109 | x = self.norm1(x)
110 | x = self.actv(x)
111 | x = self.conv1(x)
112 | x = self.downsample_res(x)
113 | if self.normalize:
114 | x = self.norm2(x)
115 | x = self.actv(x)
116 | x = self.conv2(x)
117 | return x
118 |
119 | def forward(self, x):
120 | x = self._shortcut(x) + self._residual(x)
121 | return x / math.sqrt(2) # unit variance
122 |
123 | class StyleEncoder(nn.Module):
124 | def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
125 | super().__init__()
126 | blocks = []
127 | blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
128 |
129 | repeat_num = 4
130 | for _ in range(repeat_num):
131 | dim_out = min(dim_in*2, max_conv_dim)
132 | blocks += [ResBlk(dim_in, dim_out, downsample='half')]
133 | dim_in = dim_out
134 |
135 | blocks += [nn.LeakyReLU(0.2)]
136 | blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
137 | blocks += [nn.AdaptiveAvgPool2d(1)]
138 | blocks += [nn.LeakyReLU(0.2)]
139 | self.shared = nn.Sequential(*blocks)
140 |
141 | self.unshared = nn.Linear(dim_out, style_dim)
142 |
143 | def forward(self, x):
144 | h = self.shared(x)
145 | h = h.view(h.size(0), -1)
146 | s = self.unshared(h)
147 |
148 | return s
149 |
150 |
151 | class CosineSimilarityLoss(nn.Module):
152 | def __init__(self) -> None:
153 | super().__init__()
154 |
155 | self.loss_fn = torch.nn.CosineEmbeddingLoss()
156 |
157 | def forward(self, output1, output2):
158 | B = output1.size(0)
159 | target = torch.ones(B, device=output1.device, requires_grad=False)
160 | loss = self.loss_fn(output1, output2, target)
161 | return loss
162 |
163 |
--------------------------------------------------------------------------------
/models/prompt_tts_modified/tacotron_stft.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/NVIDIA/tacotron2
3 | """
4 |
5 | import torch
6 | from librosa.filters import mel as librosa_mel_fn
7 | from models.prompt_tts_modified.audio_processing import dynamic_range_compression
8 | from models.prompt_tts_modified.audio_processing import dynamic_range_decompression
9 | from models.prompt_tts_modified.stft import STFT
10 |
11 |
12 | class LinearNorm(torch.nn.Module):
13 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
14 | super(LinearNorm, self).__init__()
15 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
16 |
17 | torch.nn.init.xavier_uniform_(
18 | self.linear_layer.weight,
19 | gain=torch.nn.init.calculate_gain(w_init_gain))
20 |
21 | def forward(self, x):
22 | return self.linear_layer(x)
23 |
24 |
25 | class ConvNorm(torch.nn.Module):
26 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
27 | padding=None, dilation=1, bias=True, w_init_gain='linear'):
28 | super(ConvNorm, self).__init__()
29 | if padding is None:
30 | assert(kernel_size % 2 == 1)
31 | padding = int(dilation * (kernel_size - 1) / 2)
32 |
33 | self.conv = torch.nn.Conv1d(in_channels, out_channels,
34 | kernel_size=kernel_size, stride=stride,
35 | padding=padding, dilation=dilation,
36 | bias=bias)
37 |
38 | torch.nn.init.xavier_uniform_(
39 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
40 |
41 | def forward(self, signal):
42 | conv_signal = self.conv(signal)
43 | return conv_signal
44 |
45 |
46 | class TacotronSTFT(torch.nn.Module):
47 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
48 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
49 | mel_fmax=8000.0):
50 | super(TacotronSTFT, self).__init__()
51 | self.n_mel_channels = n_mel_channels
52 | self.sampling_rate = sampling_rate
53 | self.stft_fn = STFT(filter_length, hop_length, win_length)
54 | mel_basis = librosa_mel_fn(
55 | sr=sampling_rate,
56 | n_fft=filter_length,
57 | n_mels=n_mel_channels,
58 | fmin=mel_fmin,
59 | fmax=mel_fmax)
60 | mel_basis = torch.from_numpy(mel_basis).float()
61 | self.register_buffer('mel_basis', mel_basis)
62 |
63 | def spectral_normalize(self, magnitudes):
64 | output = dynamic_range_compression(magnitudes)
65 | return output
66 |
67 | def spectral_de_normalize(self, magnitudes):
68 | output = dynamic_range_decompression(magnitudes)
69 | return output
70 |
71 | def mel_spectrogram(self, y):
72 |
73 | assert(torch.min(y.data) >= -1)
74 | assert(torch.max(y.data) <= 1)
75 |
76 | magnitudes, phases = self.stft_fn.transform(y)
77 | magnitudes = magnitudes.data
78 | mel_output = torch.matmul(self.mel_basis, magnitudes)
79 | mel_output = self.spectral_normalize(mel_output)
80 | return mel_output
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/openaiapi.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import io
4 | import torch
5 | import glob
6 |
7 | from fastapi import FastAPI, Response
8 | from pydantic import BaseModel
9 |
10 | from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p
11 | from models.prompt_tts_modified.jets import JETSGenerator
12 | from models.prompt_tts_modified.simbert import StyleEncoder
13 | from transformers import AutoTokenizer
14 | import numpy as np
15 | import soundfile as sf
16 | import pyrubberband as pyrb
17 | from pydub import AudioSegment
18 | from yacs import config as CONFIG
19 | from config.joint.config import Config
20 |
21 | LOGGER = logging.getLogger(__name__)
22 |
23 | DEFAULTS = {
24 | }
25 |
26 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 | print(DEVICE)
28 | config = Config()
29 | MAX_WAV_VALUE = 32768.0
30 |
31 |
32 | def get_env(key):
33 | return os.environ.get(key, DEFAULTS.get(key))
34 |
35 |
36 | def get_int_env(key):
37 | return int(get_env(key))
38 |
39 |
40 | def get_float_env(key):
41 | return float(get_env(key))
42 |
43 |
44 | def get_bool_env(key):
45 | return get_env(key).lower() == 'true'
46 |
47 |
48 | def scan_checkpoint(cp_dir, prefix, c=8):
49 | pattern = os.path.join(cp_dir, prefix + '?'*c)
50 | cp_list = glob.glob(pattern)
51 | if len(cp_list) == 0:
52 | return None
53 | return sorted(cp_list)[-1]
54 |
55 |
56 | def get_models():
57 |
58 | am_checkpoint_path = scan_checkpoint(
59 | f'{config.output_directory}/prompt_tts_open_source_joint/ckpt', 'g_')
60 |
61 | # f'{config.output_directory}/style_encoder/ckpt/checkpoint_163431'
62 | style_encoder_checkpoint_path = scan_checkpoint(
63 | f'{config.output_directory}/style_encoder/ckpt', 'checkpoint_', 6)
64 |
65 | with open(config.model_config_path, 'r') as fin:
66 | conf = CONFIG.load_cfg(fin)
67 |
68 | conf.n_vocab = config.n_symbols
69 | conf.n_speaker = config.speaker_n_labels
70 |
71 | style_encoder = StyleEncoder(config)
72 | model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu")
73 | model_ckpt = {}
74 | for key, value in model_CKPT['model'].items():
75 | new_key = key[7:]
76 | model_ckpt[new_key] = value
77 | style_encoder.load_state_dict(model_ckpt, strict=False)
78 | generator = JETSGenerator(conf).to(DEVICE)
79 |
80 | model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE)
81 | generator.load_state_dict(model_CKPT['generator'])
82 | generator.eval()
83 |
84 | tokenizer = AutoTokenizer.from_pretrained(config.bert_path)
85 |
86 | with open(config.token_list_path, 'r') as f:
87 | token2id = {t.strip(): idx for idx, t, in enumerate(f.readlines())}
88 |
89 | with open(config.speaker2id_path, encoding='utf-8') as f:
90 | speaker2id = {t.strip(): idx for idx, t in enumerate(f.readlines())}
91 |
92 | return (style_encoder, generator, tokenizer, token2id, speaker2id)
93 |
94 |
95 | def get_style_embedding(prompt, tokenizer, style_encoder):
96 | prompt = tokenizer([prompt], return_tensors="pt")
97 | input_ids = prompt["input_ids"]
98 | token_type_ids = prompt["token_type_ids"]
99 | attention_mask = prompt["attention_mask"]
100 | with torch.no_grad():
101 | output = style_encoder(
102 | input_ids=input_ids,
103 | token_type_ids=token_type_ids,
104 | attention_mask=attention_mask,
105 | )
106 | style_embedding = output["pooled_output"].cpu().squeeze().numpy()
107 | return style_embedding
108 |
109 |
110 | def emotivoice_tts(text, prompt, content, speaker, models):
111 | (style_encoder, generator, tokenizer, token2id, speaker2id) = models
112 |
113 | style_embedding = get_style_embedding(prompt, tokenizer, style_encoder)
114 | content_embedding = get_style_embedding(content, tokenizer, style_encoder)
115 |
116 | speaker = speaker2id[speaker]
117 |
118 | text_int = [token2id[ph] for ph in text.split()]
119 |
120 | sequence = torch.from_numpy(np.array(text_int)).to(
121 | DEVICE).long().unsqueeze(0)
122 | sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE)
123 | style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0)
124 | content_embedding = torch.from_numpy(
125 | content_embedding).to(DEVICE).unsqueeze(0)
126 | speaker = torch.from_numpy(np.array([speaker])).to(DEVICE)
127 |
128 | with torch.no_grad():
129 |
130 | infer_output = generator(
131 | inputs_ling=sequence,
132 | inputs_style_embedding=style_embedding,
133 | input_lengths=sequence_len,
134 | inputs_content_embedding=content_embedding,
135 | inputs_speaker=speaker,
136 | alpha=1.0
137 | )
138 |
139 | audio = infer_output["wav_predictions"].squeeze() * MAX_WAV_VALUE
140 | audio = audio.cpu().numpy().astype('int16')
141 |
142 | return audio
143 |
144 |
145 | speakers = config.speakers
146 | models = get_models()
147 | app = FastAPI()
148 | lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt")
149 | g2p = G2p()
150 |
151 | from typing import Optional
152 | class SpeechRequest(BaseModel):
153 | input: str
154 | voice: str = '8051'
155 | prompt: Optional[str] = ''
156 | language: Optional[str] = 'zh_us'
157 | model: Optional[str] = 'emoti-voice'
158 | response_format: Optional[str] = 'mp3'
159 | speed: Optional[float] = 1.0
160 |
161 |
162 | @app.post("/v1/audio/speech")
163 | def text_to_speech(speechRequest: SpeechRequest):
164 |
165 | text = g2p_cn_en(speechRequest.input, g2p, lexicon)
166 | np_audio = emotivoice_tts(text, speechRequest.prompt,
167 | speechRequest.input, speechRequest.voice,
168 | models)
169 | y_stretch = np_audio
170 | if speechRequest.speed != 1.0:
171 | y_stretch = pyrb.time_stretch(np_audio, config.sampling_rate, speechRequest.speed)
172 | wav_buffer = io.BytesIO()
173 | sf.write(file=wav_buffer, data=y_stretch,
174 | samplerate=config.sampling_rate, format='WAV')
175 | buffer = wav_buffer
176 | response_format = speechRequest.response_format
177 | if response_format != 'wav':
178 | wav_audio = AudioSegment.from_wav(wav_buffer)
179 | wav_audio.frame_rate=config.sampling_rate
180 | buffer = io.BytesIO()
181 | wav_audio.export(buffer, format=response_format)
182 |
183 | return Response(content=buffer.getvalue(),
184 | media_type=f"audio/{response_format}")
185 |
--------------------------------------------------------------------------------
/plot_image.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch.nn.functional as F
3 |
4 | import os
5 |
6 | def plot_image_sambert(target, melspec, mel_lengths=None, text_lengths=None, save_dir=None, global_step=None, name=None):
7 | # Draw mel_plots
8 | mel_plots, axes = plt.subplots(2,1,figsize=(20,15))
9 |
10 | T = mel_lengths[-1]
11 | L=100
12 |
13 |
14 | axes[0].imshow(target[-1].detach().cpu()[:,:T],
15 | origin='lower',
16 | aspect='auto')
17 |
18 | axes[1].imshow(melspec[-1].detach().cpu()[:,:T],
19 | origin='lower',
20 | aspect='auto')
21 | for i in range(2):
22 | tmp_dir = save_dir+'/att/'+name+'_'+str(global_step)
23 | if not os.path.exists(tmp_dir):
24 | os.makedirs(tmp_dir)
25 | plt.savefig(tmp_dir+'/'+name+'_'+str(global_step)+'_melspec_%s.png'%i)
26 |
27 | return mel_plots
--------------------------------------------------------------------------------
/prepare_for_training.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, YOUDAO
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import os
17 | import shutil
18 | import argparse
19 |
20 |
21 | def main(args):
22 | from os.path import join
23 | data_dir = args.data_dir
24 | exp_dir = args.exp_dir
25 | os.makedirs(exp_dir, exist_ok=True)
26 |
27 | info_dir = join(exp_dir, 'info')
28 | prepare_info(data_dir, info_dir)
29 |
30 | config_dir = join(exp_dir, 'config')
31 | prepare_config(data_dir, info_dir, exp_dir, config_dir)
32 |
33 | ckpt_dir = join(exp_dir, 'ckpt')
34 | prepare_ckpt(data_dir, info_dir, ckpt_dir)
35 |
36 |
37 | ROOT_DIR = os.path.dirname(os.path.abspath("__file__"))
38 | def prepare_info(data_dir, info_dir):
39 | import jsonlines
40 | print('prepare_info: %s' %info_dir)
41 | os.makedirs(info_dir, exist_ok=True)
42 |
43 | for name in ["emotion", "energy", "pitch", "speed", "tokenlist"]:
44 | shutil.copy(f"{ROOT_DIR}/data/youdao/text/{name}", f"{info_dir}/{name}")
45 |
46 | d_speaker = {} # get all the speakers from datalist
47 | with jsonlines.open(f"{data_dir}/train/datalist.jsonl") as reader:
48 | for obj in reader:
49 | speaker = obj["speaker"]
50 | if not speaker in d_speaker:
51 | d_speaker[speaker] = 1
52 | else:
53 | d_speaker[speaker] += 1
54 |
55 | with open(f"{ROOT_DIR}/data/youdao/text/speaker2") as f, \
56 | open(f"{info_dir}/speaker", "w") as fout:
57 |
58 | for line in f:
59 | speaker = line.strip()
60 | if speaker in d_speaker:
61 | print('warning: duplicate of speaker [%s] in [%s]' % (speaker, data_dir))
62 | continue
63 | fout.write(line.strip()+"\n")
64 |
65 | for speaker in sorted(d_speaker.keys()):
66 | fout.write(speaker + "\n")
67 |
68 |
69 | def prepare_config(data_dir, info_dir, exp_dir, config_dir):
70 | print('prepare_config: %s' %config_dir)
71 | os.makedirs(config_dir, exist_ok=True)
72 |
73 | with open(f"{ROOT_DIR}/config/template.py") as f, \
74 | open(f"{config_dir}/config.py", "w") as fout:
75 |
76 | for line in f:
77 | fout.write(line.replace('', data_dir).replace('', info_dir).replace('', exp_dir))
78 |
79 |
80 | def prepare_ckpt(data_dir, info_dir, ckpt_dir):
81 | print('prepare_ckpt: %s' %ckpt_dir)
82 | os.makedirs(ckpt_dir, exist_ok=True)
83 |
84 | with open(f"{info_dir}/speaker") as f:
85 | speaker_list=[line.strip() for line in f]
86 | assert len(speaker_list) >= 2014
87 |
88 | gen_ckpt_path = f"{ROOT_DIR}/outputs/prompt_tts_open_source_joint/ckpt/g_00140000"
89 | disc_ckpt_path = f"{ROOT_DIR}/outputs/prompt_tts_open_source_joint/ckpt/do_00140000"
90 |
91 | gen_ckpt = torch.load(gen_ckpt_path, map_location="cpu")
92 |
93 | speaker_embeddings = gen_ckpt["generator"]["am.spk_tokenizer.weight"].clone()
94 |
95 | new_embedding = torch.randn((len(speaker_list)-speaker_embeddings.size(0), speaker_embeddings.size(1)))
96 |
97 | gen_ckpt["generator"]["am.spk_tokenizer.weight"] = torch.cat([speaker_embeddings, new_embedding], dim=0)
98 |
99 |
100 | torch.save(gen_ckpt, f"{ckpt_dir}/pretrained_generator")
101 | shutil.copy(disc_ckpt_path, f"{ckpt_dir}/pretrained_discriminator")
102 |
103 |
104 |
105 | if __name__ == "__main__":
106 |
107 | p = argparse.ArgumentParser()
108 | p.add_argument('--data_dir', type=str, required=True)
109 | p.add_argument('--exp_dir', type=str, required=True)
110 | args = p.parse_args()
111 |
112 | main(args)
113 |
--------------------------------------------------------------------------------
/requirements.openaiapi.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | python-multipart
3 | uvicorn[standard]
4 | pydub
5 | pyrubberband
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchaudio
3 | numpy
4 | numba
5 | scipy
6 | transformers
7 | soundfile
8 | yacs
9 | g2p_en
10 | jieba
11 | pypinyin
12 | pypinyin_dict
13 | streamlit
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import find_packages, setup
4 |
5 | requirements={
6 | "infer": [
7 | "numpy>=1.24.3",
8 | "scipy>=1.10.1",
9 | "torch>=2.1",
10 | "torchaudio",
11 | "soundfile>=0.12.0",
12 | "librosa>=0.10.0",
13 | "scikit-learn",
14 | "numba==0.58.1",
15 | "inflect>=5.6.0",
16 | "tqdm>=4.64.1",
17 | "pyyaml>=6.0",
18 | "transformers==4.26.1",
19 | "yacs",
20 | "g2p_en",
21 | "jieba",
22 | "pypinyin",
23 | "streamlit",
24 | "pandas>=1.4,<2.0",
25 | ],
26 | "openai": [
27 | "fastapi",
28 | "python-multipart",
29 | "uvicorn[standard]",
30 | "pydub",
31 | ],
32 | "train": [
33 | "jsonlines",
34 | "praatio",
35 | "pyworld",
36 | "flake8",
37 | "flake8-bugbear",
38 | "flake8-comprehensions",
39 | "flake8-executable",
40 | "flake8-pyi",
41 | "mccabe",
42 | "pycodestyle",
43 | "pyflakes",
44 | "tensorboard",
45 | "einops",
46 | "matplotlib",
47 | ]
48 | }
49 |
50 | infer_requires = requirements["infer"]
51 | openai_requires = requirements["infer"] + requirements["openai"]
52 | train_requires = requirements["infer"] + requirements["train"]
53 |
54 | VERSION = '0.2.0'
55 |
56 | with open("README.md", "r", encoding="utf-8") as readme_file:
57 | README = readme_file.read()
58 |
59 |
60 | setup(
61 | name="EmotiVoice",
62 | version=VERSION,
63 | url="https://github.com/netease-youdao/EmotiVoice",
64 | author="Huaxuan Wang",
65 | author_email="wanghx04@rd.netease.com",
66 | description="EmotiVoice 😊: a Multi-Voice and Prompt-Controlled TTS Engine",
67 | long_description=README,
68 | long_description_content_type="text/markdown",
69 | license="Apache Software License",
70 | # package
71 | packages=find_packages(),
72 | project_urls={
73 | "Documentation": "https://github.com/netease-youdao/EmotiVoice/wiki",
74 | "Tracker": "https://github.com/netease-youdao/EmotiVoice/issues",
75 | "Repository": "https://github.com/netease-youdao/EmotiVoice",
76 | },
77 | install_requires=infer_requires,
78 | extras_require={
79 | "train": train_requires,
80 | "openai": openai_requires,
81 | },
82 | python_requires=">=3.8.0",
83 | classifiers=[
84 | "Programming Language :: Python",
85 | "Programming Language :: Python :: 3",
86 | "Programming Language :: Python :: 3.8",
87 | "Programming Language :: Python :: 3.9",
88 | "Programming Language :: Python :: 3.10",
89 | "Programming Language :: Python :: 3.11",
90 | "Development Status :: 3 - Alpha",
91 | "Intended Audience :: Science/Research",
92 | "Operating System :: POSIX :: Linux",
93 | "License :: OSI Approved :: Apache Software License",
94 | "Topic :: Software Development :: Libraries :: Python Modules",
95 | "Topic :: Multimedia :: Sound/Audio :: Speech",
96 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
97 | ],
98 | )
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/keithito/tacotron
3 | """
4 |
5 | import re
6 | from text import cleaners
7 | from text.symbols import symbols
8 |
9 |
10 | # Mappings from symbol to numeric ID and vice versa:
11 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
12 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
13 |
14 | # Regular expression matching text enclosed in curly braces:
15 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
16 |
17 |
18 | def text_to_sequence(text, cleaner_names):
19 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
20 |
21 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
22 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
23 |
24 | Args:
25 | text: string to convert to a sequence
26 | cleaner_names: names of the cleaner functions to run the text through
27 |
28 | Returns:
29 | List of integers corresponding to the symbols in the text
30 | """
31 | sequence = []
32 |
33 | # Check for curly braces and treat their contents as ARPAbet:
34 | while len(text):
35 | m = _curly_re.match(text)
36 |
37 | if not m:
38 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
39 | break
40 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
41 |
42 | sequence += _arpabet_to_sequence(m.group(2))
43 | text = m.group(3)
44 |
45 |
46 | return sequence
47 |
48 |
49 | def sequence_to_text(sequence):
50 | """Converts a sequence of IDs back to a string"""
51 | result = ""
52 | for symbol_id in sequence:
53 | if symbol_id in _id_to_symbol:
54 | s = _id_to_symbol[symbol_id]
55 | # Enclose ARPAbet back in curly braces:
56 | if len(s) > 1 and s[0] == "@":
57 | s = "{%s}" % s[1:]
58 | result += s
59 | return result.replace("}{", " ")
60 |
61 |
62 | def _clean_text(text, cleaner_names):
63 | for name in cleaner_names:
64 | cleaner = getattr(cleaners, name)
65 | if not cleaner:
66 | raise Exception("Unknown cleaner: %s" % name)
67 | text = cleaner(text)
68 | return text
69 |
70 |
71 | def _symbols_to_sequence(symbols):
72 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
73 |
74 |
75 | def _arpabet_to_sequence(text):
76 | return _symbols_to_sequence(["@" + s for s in text.split()])
77 |
78 |
79 | def _should_keep_symbol(s):
80 | return s in _symbol_to_id and s != "_" and s != "~"
81 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/keithito/tacotron
3 | """
4 |
5 | '''
6 | Cleaners are transformations that run over the input text at both training and eval time.
7 |
8 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
9 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
10 | 1. "english_cleaners" for English text
11 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
12 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
13 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
14 | the symbols in symbols.py to match your data).
15 | '''
16 |
17 |
18 | # Regular expression matching whitespace:
19 | import re
20 | from unidecode import unidecode
21 | from .numbers import normalize_numbers
22 | _whitespace_re = re.compile(r'\s+')
23 |
24 | # List of (regular expression, replacement) pairs for abbreviations:
25 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
26 | ('mrs', 'misess'),
27 | ('mr', 'mister'),
28 | ('dr', 'doctor'),
29 | ('st', 'saint'),
30 | ('co', 'company'),
31 | ('jr', 'junior'),
32 | ('maj', 'major'),
33 | ('gen', 'general'),
34 | ('drs', 'doctors'),
35 | ('rev', 'reverend'),
36 | ('lt', 'lieutenant'),
37 | ('hon', 'honorable'),
38 | ('sgt', 'sergeant'),
39 | ('capt', 'captain'),
40 | ('esq', 'esquire'),
41 | ('ltd', 'limited'),
42 | ('col', 'colonel'),
43 | ('ft', 'fort'),
44 | ]]
45 |
46 |
47 | def expand_abbreviations(text):
48 | for regex, replacement in _abbreviations:
49 | text = re.sub(regex, replacement, text)
50 | return text
51 |
52 |
53 | def expand_numbers(text):
54 | return normalize_numbers(text)
55 |
56 |
57 | def lowercase(text):
58 | return text.lower()
59 |
60 |
61 | def collapse_whitespace(text):
62 | return re.sub(_whitespace_re, ' ', text)
63 |
64 |
65 | def convert_to_ascii(text):
66 | return unidecode(text)
67 |
68 |
69 | def basic_cleaners(text):
70 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
71 | text = lowercase(text)
72 | text = collapse_whitespace(text)
73 | return text
74 |
75 |
76 | def transliteration_cleaners(text):
77 | '''Pipeline for non-English text that transliterates to ASCII.'''
78 | text = convert_to_ascii(text)
79 | text = lowercase(text)
80 | text = collapse_whitespace(text)
81 | return text
82 |
83 |
84 | def english_cleaners(text):
85 | '''Pipeline for English text, including number and abbreviation expansion.'''
86 | text = convert_to_ascii(text)
87 | text = lowercase(text)
88 | text = expand_numbers(text)
89 | text = expand_abbreviations(text)
90 | text = collapse_whitespace(text)
91 | return text
92 |
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/keithito/tacotron
3 | """
4 |
5 | import re
6 |
7 |
8 | valid_symbols = [
9 | "AA",
10 | "AA0",
11 | "AA1",
12 | "AA2",
13 | "AE",
14 | "AE0",
15 | "AE1",
16 | "AE2",
17 | "AH",
18 | "AH0",
19 | "AH1",
20 | "AH2",
21 | "AO",
22 | "AO0",
23 | "AO1",
24 | "AO2",
25 | "AW",
26 | "AW0",
27 | "AW1",
28 | "AW2",
29 | "AY",
30 | "AY0",
31 | "AY1",
32 | "AY2",
33 | "B",
34 | "CH",
35 | "D",
36 | "DH",
37 | "EH",
38 | "EH0",
39 | "EH1",
40 | "EH2",
41 | "ER",
42 | "ER0",
43 | "ER1",
44 | "ER2",
45 | "EY",
46 | "EY0",
47 | "EY1",
48 | "EY2",
49 | "F",
50 | "G",
51 | "HH",
52 | "IH",
53 | "IH0",
54 | "IH1",
55 | "IH2",
56 | "IY",
57 | "IY0",
58 | "IY1",
59 | "IY2",
60 | "JH",
61 | "K",
62 | "L",
63 | "M",
64 | "N",
65 | "NG",
66 | "OW",
67 | "OW0",
68 | "OW1",
69 | "OW2",
70 | "OY",
71 | "OY0",
72 | "OY1",
73 | "OY2",
74 | "P",
75 | "R",
76 | "S",
77 | "SH",
78 | "T",
79 | "TH",
80 | "UH",
81 | "UH0",
82 | "UH1",
83 | "UH2",
84 | "UW",
85 | "UW0",
86 | "UW1",
87 | "UW2",
88 | "V",
89 | "W",
90 | "Y",
91 | "Z",
92 | "ZH",
93 | ]
94 |
95 | _valid_symbol_set = set(valid_symbols)
96 |
97 |
98 | class CMUDict:
99 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
100 |
101 | def __init__(self, file_or_path, keep_ambiguous=True):
102 | if isinstance(file_or_path, str):
103 | with open(file_or_path, encoding="latin-1") as f:
104 | entries = _parse_cmudict(f)
105 | else:
106 | entries = _parse_cmudict(file_or_path)
107 | if not keep_ambiguous:
108 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
109 | self._entries = entries
110 |
111 | def __len__(self):
112 | return len(self._entries)
113 |
114 | def lookup(self, word):
115 | """Returns list of ARPAbet pronunciations of the given word."""
116 | return self._entries.get(word.upper())
117 |
118 |
119 | _alt_re = re.compile(r"\([0-9]+\)")
120 |
121 |
122 | def _parse_cmudict(file):
123 | cmudict = {}
124 | for line in file:
125 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
126 | parts = line.split(" ")
127 | word = re.sub(_alt_re, "", parts[0])
128 | pronunciation = _get_pronunciation(parts[1])
129 | if pronunciation:
130 | if word in cmudict:
131 | cmudict[word].append(pronunciation)
132 | else:
133 | cmudict[word] = [pronunciation]
134 | return cmudict
135 |
136 |
137 | def _get_pronunciation(s):
138 | parts = s.strip().split(" ")
139 | for part in parts:
140 | if part not in _valid_symbol_set:
141 | return None
142 | return " ".join(parts)
143 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/keithito/tacotron
3 | """
4 |
5 | import inflect
6 | import re
7 |
8 |
9 | _inflect = inflect.engine()
10 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
11 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
12 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
13 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
14 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
15 | _number_re = re.compile(r"[0-9]+")
16 |
17 |
18 | def _remove_commas(m):
19 | return m.group(1).replace(",", "")
20 |
21 |
22 | def _expand_decimal_point(m):
23 | return m.group(1).replace(".", " point ")
24 |
25 |
26 | def _expand_dollars(m):
27 | match = m.group(1)
28 | parts = match.split(".")
29 | if len(parts) > 2:
30 | return match + " dollars" # Unexpected format
31 | dollars = int(parts[0]) if parts[0] else 0
32 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
33 | if dollars and cents:
34 | dollar_unit = "dollar" if dollars == 1 else "dollars"
35 | cent_unit = "cent" if cents == 1 else "cents"
36 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
37 | elif dollars:
38 | dollar_unit = "dollar" if dollars == 1 else "dollars"
39 | return "%s %s" % (dollars, dollar_unit)
40 | elif cents:
41 | cent_unit = "cent" if cents == 1 else "cents"
42 | return "%s %s" % (cents, cent_unit)
43 | else:
44 | return "zero dollars"
45 |
46 |
47 | def _expand_ordinal(m):
48 | return _inflect.number_to_words(m.group(0))
49 |
50 |
51 | def _expand_number(m):
52 | num = int(m.group(0))
53 | if num > 1000 and num < 3000:
54 | if num == 2000:
55 | return "two thousand"
56 | elif num > 2000 and num < 2010:
57 | return "two thousand " + _inflect.number_to_words(num % 100)
58 | elif num % 100 == 0:
59 | return _inflect.number_to_words(num // 100) + " hundred"
60 | else:
61 | return _inflect.number_to_words(
62 | num, andword="", zero="oh", group=2
63 | ).replace(", ", " ")
64 | else:
65 | return _inflect.number_to_words(num, andword="")
66 |
67 |
68 | def normalize_numbers(text):
69 | text = re.sub(_comma_number_re, _remove_commas, text)
70 | text = re.sub(_pounds_re, r"\1 pounds", text)
71 | text = re.sub(_dollars_re, _expand_dollars, text)
72 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
73 | text = re.sub(_ordinal_re, _expand_ordinal, text)
74 | text = re.sub(_number_re, _expand_number, text)
75 | return text
76 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/keithito/tacotron
3 | """
4 |
5 | """
6 | Defines the set of symbols used in text input to the model.
7 |
8 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
9 |
10 | from text import cmudict
11 |
12 | _pad = "_"
13 | _punctuation = "!'(),.:;? "
14 | _special = "-"
15 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
16 | _silences = ["@sp", "@spn", "@sil"]
17 |
18 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
19 | _arpabet = ["@" + s for s in cmudict.valid_symbols]
20 |
21 |
22 | # Export all symbols:
23 | symbols = (
24 | [_pad]
25 | + list(_special)
26 | + list(_punctuation)
27 | + list(_letters)
28 | + _arpabet
29 | + _silences
30 | )
31 |
--------------------------------------------------------------------------------