├── AISHELL-3
├── ReadMe.txt
└── spk-info.txt
├── LICENSE
├── README.md
├── aishell_data_clean.ipynb
├── asr.py
├── asr_v2.py
├── asr_v3.py
├── asr_web.py
├── config.py
├── infer_server.py
├── requirements.txt
├── static
├── index.css
├── record.js
└── record.png
├── templates
└── index.html
├── train.py
├── train_v2.py
├── train_v3.py
├── train_v4.py
└── utils
├── binary.py
├── callback.py
├── data_utils.py
├── model_utils.py
├── reader.py
└── utils.py
/AISHELL-3/ReadMe.txt:
--------------------------------------------------------------------------------
1 |
2 | AISHELL-3
3 |
4 | 北京希尔贝壳科技有限公司
5 | Beijing Shell Shell Technology Co.,Ltd
6 | 10/18/2020
7 |
8 | 1. AISHELL-3 Speech Data
9 |
10 | - Sampling Rate : 44.1kHz
11 | - Sample Format : 16bit
12 | - Environment : Quiet indoor
13 | - Speech Data Type : PCM
14 | - Channel Number : 1
15 | - Recording Equipment : High fidelity microphone
16 | - Sentences : 88035 utterances
17 | - Speaker : 218 speakers (43 male and 175 female)
18 |
19 |
20 | 2. Data Structure
21 | │ README.txt (readme)
22 | │ ChangeLog (Change Information)
23 | │ phone_set.txt (phone Information)
24 | │ spk_info.txt (Speaker Information)
25 | └─ test (Test Data File)
26 | └─ train (Train Data File)
27 | │├─content.txt (Transcript Content)
28 | │├─prosody_label_train-set.txt (Prosody Lable)
29 | │├─wav (Audio Data File)
30 | │├─SSB005 (Speaker ID File)
31 | ││ ││ ID2166W0001.wav (Audio)
32 |
33 | 4. System
34 | AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus which could be used to train multi-speaker Text-to-Speech (TTS) systems.
35 | You can download data set from: http://www.aishelltech.com/aishell_3.
36 | The baseline system code and generated samples are available online form: https://sos1sos2sixteen.github.io/aishell3/.
37 |
--------------------------------------------------------------------------------
/AISHELL-3/spk-info.txt:
--------------------------------------------------------------------------------
1 | # voice-file name; age group; gender; accent
2 | # In years, A:< 14, B:14 - 25, C:26 - 40, D:> 41.
3 |
4 | SSB1837 B female north
5 | SSB0578 B female north
6 | SSB1216 B female north
7 | SSB1161 B female north
8 | SSB0016 B female north
9 | SSB1365 B male north
10 | SSB1759 B female north
11 | SSB0588 B female north
12 | SSB0534 C female north
13 | SSB0380 B female north
14 | SSB0273 B male north
15 | SSB1863 B male north
16 | SSB1125 B female south
17 | SSB1872 B female north
18 | SSB0993 C female north
19 | SSB0666 C female north
20 | SSB0668 B female north
21 | SSB0395 B female north
22 | SSB1831 B male south
23 | SSB1064 B female north
24 | SSB1219 B female north
25 | SSB1322 B female north
26 | SSB1002 B female north
27 | SSB1878 B female north
28 | SSB0415 B female north
29 | SSB1781 C female north
30 | SSB0631 B male south
31 | SSB0686 B female north
32 | SSB1328 B male north
33 | SSB0366 B female south
34 | SSB1100 C male south
35 | SSB0241 B male north
36 | SSB0966 B male north
37 | SSB1340 B female north
38 | SSB0762 B female north
39 | SSB0073 B male north
40 | SSB0632 B female south
41 | SSB0915 B female south
42 | SSB0748 B female north
43 | SSB1956 B female north
44 | SSB1056 B female south
45 | SSB0716 B female south
46 | SSB0629 B male north
47 | SSB1806 B female north
48 | SSB0599 C female north
49 | SSB0720 B female north
50 | SSB0385 B female south
51 | SSB0809 B female north
52 | SSB0342 B female others
53 | SSB0760 B female north
54 | SSB1253 B female south
55 | SSB1575 B female south
56 | SSB0863 B male north
57 | SSB1110 D female north
58 | SSB0200 B female north
59 | SSB1215 B female north
60 | SSB0375 B male south
61 | SSB1828 C female north
62 | SSB0737 D female north
63 | SSB0341 C female north
64 | SSB0009 B female south
65 | SSB0309 D female north
66 | SSB1055 C female north
67 | SSB1448 B male north
68 | SSB1176 B female north
69 | SSB1001 B female north
70 | SSB0193 B female south
71 | SSB0710 C male north
72 | SSB0427 B female north
73 | SSB0338 B female north
74 | SSB1131 B female south
75 | SSB1108 B female north
76 | SSB0149 B female south
77 | SSB0736 B male south
78 | SSB1555 B female south
79 | SSB0614 C female north
80 | SSB1072 B female south
81 | SSB1728 B female north
82 | SSB1382 B female north
83 | SSB0851 B female north
84 | SSB1585 B female south
85 | SSB1891 C female north
86 | SSB1393 B female north
87 | SSB1274 B female north
88 | SSB1204 B female north
89 | SSB1452 B female north
90 | SSB0570 B female north
91 | SSB0780 B female north
92 | SSB1593 B female south
93 | SSB0913 B female north
94 | SSB1302 B female north
95 | SSB0323 B female north
96 | SSB1135 B female north
97 | SSB0382 B female north
98 | SSB0887 B male north
99 | SSB1625 C female south
100 | SSB1366 C female north
101 | SSB0693 B female south
102 | SSB0594 B female north
103 | SSB1686 B female north
104 | SSB0012 B female north
105 | SSB0139 B male south
106 | SSB0751 B female north
107 | SSB0606 D female north
108 | SSB1341 B female south
109 | SSB0145 B female north
110 | SSB1136 B male south
111 | SSB0339 B female north
112 | SSB0482 B female north
113 | SSB0502 B female north
114 | SSB1650 B female north
115 | SSB0817 B male north
116 | SSB0261 C male north
117 | SSB0316 B male north
118 | SSB0033 B female south
119 | SSB0723 B female north
120 | SSB1008 B female north
121 | SSB0700 B female north
122 | SSB1457 B female north
123 | SSB0601 B female north
124 | SSB1809 B female north
125 | SSB1739 B female north
126 | SSB0407 C male north
127 | SSB0426 B female south
128 | SSB0470 B female north
129 | SSB0935 B female north
130 | SSB0822 B female north
131 | SSB0746 B female north
132 | SSB0758 B female north
133 | SSB1221 B female north
134 | SSB0038 B female north
135 | SSB1624 B male south
136 | SSB0133 B female south
137 | SSB0778 B female north
138 | SSB0702 B female south
139 | SSB1383 B male south
140 | SSB1563 B female south
141 | SSB1670 B female north
142 | SSB1096 B female south
143 | SSB0299 B female south
144 | SSB1711 B female north
145 | SSB1810 B female north
146 | SSB1115 B female south
147 | SSB1684 B male north
148 | SSB1402 C female north
149 | SSB1918 B female north
150 | SSB0246 B female south
151 | SSB1607 B female north
152 | SSB1437 C female north
153 | SSB0011 C female north
154 | SSB0288 C female north
155 | SSB0539 C female north
156 | SSB0394 B male north
157 | SSB0379 B female south
158 | SSB1187 B male north
159 | SSB0671 C female north
160 | SSB0544 B male north
161 | SSB0005 B female north
162 | SSB1846 B female south
163 | SSB0122 B female north
164 | SSB1630 C male north
165 | SSB1399 B female north
166 | SSB1567 B female others
167 | SSB0267 C female north
168 | SSB1392 B female north
169 | SSB0287 B female north
170 | SSB0717 B female south
171 | SSB0018 B female south
172 | SSB0315 B female south
173 | SSB1126 A female north
174 | SSB1320 D female north
175 | SSB0919 B female north
176 | SSB0623 B male south
177 | SSB0871 C female north
178 | SSB0786 B female north
179 | SSB1024 B female north
180 | SSB1935 B male north
181 | SSB0794 B female north
182 | SSB1832 B female north
183 | SSB1000 C female north
184 | SSB1431 C female north
185 | SSB0535 B male north
186 | SSB1782 C female north
187 | SSB0354 D female north
188 | SSB0393 A female north
189 | SSB0057 B female north
190 | SSB1385 B female north
191 | SSB1050 D female north
192 | SSB0435 B female north
193 | SSB1197 B female north
194 | SSB1939 B female north
195 | SSB1239 B male north
196 | SSB0434 D male north
197 | SSB1020 B male south
198 | SSB1091 B female north
199 | SSB1745 B male north
200 | SSB0565 B female north
201 | SSB0711 B female south
202 | SSB0603 B male north
203 | SSB0749 B female south
204 | SSB0997 B female north
205 | SSB0590 C male north
206 | SSB0784 C male north
207 | SSB1699 C female north
208 | SSB0607 C female north
209 | SSB1902 B female north
210 | SSB0043 B female north
211 | SSB1138 B female north
212 | SSB1408 B male north
213 | SSB1218 B female south
214 | SSB1377 B female north
215 | SSB0609 C male north
216 | SSB0987 B female north
217 | SSB0080 B female south
218 | SSB0307 B female north
219 | SSB0197 C female south
220 | SSB1203 C female north
221 | SSB0112 B female south
222 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Thirteen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # asr_AISHELL-3
2 | 使用AISHELL-3 数据集 训练语音识别模型
3 |
4 | ## 使用方法
5 | 创建虚拟环境
6 | ```
7 | conda create -n asr python=3.10
8 | ```
9 | 配置环境
10 | ```
11 | activate asr
12 | pip install -r requirements.txt
13 | ```
14 | 运行
15 | ```
16 | python train.py
17 | ```
18 | 如果已经运行且已经有音频特征文件 features.pkl
19 | 可直接运行 __trian_v2.py__
20 | ```
21 | python train_v2.py
22 | ```
23 | * 在train_v4中增加tensorboard
24 | 可查看训练日志
25 | ## 使用web进行语音识别
26 | 运行
27 | ```
28 | python asr_web.py
29 | ```
30 | * 可进行读取录音
31 | * 本地录制并上传进行识别
32 | * 预览
33 | 
34 |
35 | ## 查看训练日志
36 | 输入命令
37 | ```
38 | tensorboard --logdir= log_path
39 | ```
40 | ## loss曲线:
41 | __epochs=25__
42 | 
43 |
44 | ## librosa版本问题
45 | 这里使用的 *librosa==0.7.2*
46 | 可能会出现
47 | 
48 | **参考这份博客:**
49 |
50 | [解决不联网环境pip安装librosa、numba、llvmlite报错和版本兼容问题](https://blog.csdn.net/qq_39691492/article/details/130829401)
51 |
52 | *修改如下:*
53 |
54 | 
55 |
56 | ## api的部分参考
57 | [Whisper-Finetune](https://github.com/yeyupiaoling/Whisper-Finetune)
58 |
59 | 将原来微调的whisper模型换成这里训练的asr模型
60 |
--------------------------------------------------------------------------------
/aishell_data_clean.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 14,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "path = 'AISHELL-3/train/wav/'\n",
11 | "files = os.listdir(path)"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 48,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "# 音频文件\n",
21 | "wav_file=[]\n",
22 | "# 音频文件的相对路径\n",
23 | "wav_file_path=[]\n",
24 | "\n",
25 | "for i in files:\n",
26 | " temp_files = os.listdir(path+i)\n",
27 | " for j in temp_files:\n",
28 | " wav_file.append(j)\n",
29 | " wav_file_path.append(path+i+'/'+j)\n",
30 | "\n"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "* 两种导入方式"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 54,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "# with open(\"AISHELL-3/train/content.txt\", \"r\", encoding='utf-8') as f: # 打开文件\n",
47 | "# data = f.read() # 读取文件\n",
48 | " # print(data)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 60,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "f=open(\"AISHELL-3/train/content.txt\", \"r\", encoding='utf-8')\n",
58 | "txt=[]\n",
59 | "for line in f:\n",
60 | " txt.append(line.strip())\n"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 66,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "# 音频文件名字\n",
70 | "txt_filename = []\n",
71 | "# 对应文件内容\n",
72 | "txt_wav = []\n",
73 | "\n",
74 | "for i in txt:\n",
75 | " temp_txt_filename,temp_txt_wav = i.split('\\t')\n",
76 | " txt_filename.append(temp_txt_filename)\n",
77 | " txt_wav.append(temp_txt_wav)"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 76,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "# 字典 字母+数字\n",
87 | "en_num_all = []\n",
88 | "\n",
89 | "for letter in 'abcdefghijklmnopqrstuvwxyz':\n",
90 | " en_num_all.extend(letter)\n",
91 | "\n",
92 | "for number in range(10): \n",
93 | " en_num_all.extend(str(number))\n",
94 | " \n",
95 | "en_num_all.extend(' ')"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 109,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "# 音频文字\n",
105 | "texts = []\n",
106 | "\n",
107 | "for i in txt_wav:\n",
108 | " temp = ''\n",
109 | " for j in i:\n",
110 | " if j in en_num_all:\n",
111 | " continue\n",
112 | " else:\n",
113 | " # print(j)\n",
114 | " temp = temp+j\n",
115 | " texts.append(temp)\n",
116 | " "
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "#### texts 与 wav_file_path 对应"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 110,
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "data": {
133 | "text/plain": [
134 | "'一百三十三万三千五百六十五'"
135 | ]
136 | },
137 | "execution_count": 110,
138 | "metadata": {},
139 | "output_type": "execute_result"
140 | }
141 | ],
142 | "source": [
143 | "texts[-1]"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 107,
149 | "metadata": {},
150 | "outputs": [
151 | {
152 | "data": {
153 | "text/plain": [
154 | "'AISHELL-3/train/wav/SSB1956/SSB19560481.wav'"
155 | ]
156 | },
157 | "execution_count": 107,
158 | "metadata": {},
159 | "output_type": "execute_result"
160 | }
161 | ],
162 | "source": [
163 | "wav_file_path[-1]"
164 | ]
165 | }
166 | ],
167 | "metadata": {
168 | "kernelspec": {
169 | "display_name": "Python 3",
170 | "language": "python",
171 | "name": "python3"
172 | },
173 | "language_info": {
174 | "codemirror_mode": {
175 | "name": "ipython",
176 | "version": 3
177 | },
178 | "file_extension": ".py",
179 | "mimetype": "text/x-python",
180 | "name": "python",
181 | "nbconvert_exporter": "python",
182 | "pygments_lexer": "ipython3",
183 | "version": "3.10.2"
184 | }
185 | },
186 | "nbformat": 4,
187 | "nbformat_minor": 2
188 | }
189 |
--------------------------------------------------------------------------------
/asr.py:
--------------------------------------------------------------------------------
1 | from keras.models import load_model
2 | from keras import backend as K
3 | import numpy as np
4 | import librosa
5 | from python_speech_features import mfcc
6 | import speech_recognition as sr
7 | import pickle
8 | import glob
9 | import config
10 | import wave
11 | import os
12 |
13 | def save_as_wav(audio, output_file_path):
14 | with wave.open(output_file_path, 'wb') as wav_file:
15 | wav_file.setnchannels(1) # 单声道
16 | wav_file.setsampwidth(2) # 16位PCM编码
17 | wav_file.setframerate(44100) # 采样率为44.1kHz
18 | wav_file.writeframes(audio.frame_data)
19 |
20 | def input_audio():
21 | r = sr.Recognizer()
22 | with sr.Microphone() as source:
23 | print("请说...")
24 | r.pause_threshold = 1
25 | audio = r.listen(source)
26 | output_file_path = "temp_file.wav"
27 | save_as_wav(audio, output_file_path)
28 | wavs = glob.glob('temp_file.wav')
29 | return wavs
30 |
31 | def load_file():
32 | with open(config.pkl_path, 'rb') as fr:
33 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr)
34 | model = load_model(config.model_path)
35 | return char2id, id2char, mfcc_mean, mfcc_std, model
36 |
37 | def set_data(wavs, mfcc_mean, mfcc_std):
38 | mfcc_dim = config.mfcc_dim
39 | index = np.random.randint(len(wavs))
40 | audio, sr = librosa.load(wavs[index])
41 | energy = librosa.feature.rms(audio)
42 | frames = np.nonzero(energy >= np.max(energy) / 5)
43 | indices = librosa.core.frames_to_samples(frames)[1]
44 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
45 | X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
46 | X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14)
47 | print(X_data.shape)
48 | return X_data
49 |
50 |
51 | def wav_pred(model,X_data,id2char):
52 | pred = model.predict(np.expand_dims(X_data, axis=0))
53 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0])
54 | pred_ids = pred_ids.flatten().tolist()
55 | words=''
56 | judge=0
57 | for i in pred_ids:
58 | if i != -1:
59 | judge=1
60 | words=words+id2char[i]
61 | if judge==1:
62 | print(words)
63 | else:
64 | print("未检测到")
65 |
66 | def run():
67 | wavs = input_audio()
68 | char2id, id2char, mfcc_mean, mfcc_std, model = load_file()
69 | X_data = set_data(wavs, mfcc_mean, mfcc_std)
70 | wav_pred(model,X_data,id2char)
71 | os.remove("temp_file.wav")
72 |
73 |
74 | if __name__ == '__main__' :
75 | run()
76 |
77 |
--------------------------------------------------------------------------------
/asr_v2.py:
--------------------------------------------------------------------------------
1 | from keras.models import load_model
2 | from keras import backend as K
3 | import numpy as np
4 | import librosa
5 | from python_speech_features import mfcc
6 | import speech_recognition as sr
7 | import pickle
8 | import glob
9 | import config
10 | import wave
11 | import os
12 |
13 | def save_as_wav(audio, output_file_path):
14 | with wave.open(output_file_path, 'wb') as wav_file:
15 | wav_file.setnchannels(1) # 单声道
16 | wav_file.setsampwidth(2) # 16位PCM编码
17 | wav_file.setframerate(44100) # 采样率为44.1kHz
18 | wav_file.writeframes(audio.frame_data)
19 |
20 | def input_audio():
21 | r = sr.Recognizer()
22 | with sr.Microphone() as source:
23 | print("请说...")
24 | r.pause_threshold = 1
25 | audio = r.listen(source)
26 | output_file_path = "temp_file.wav"
27 | save_as_wav(audio, output_file_path)
28 | wavs = glob.glob('temp_file.wav')
29 | os.remove("temp_file.wav")
30 | return wavs
31 |
32 | def out_load_audio():
33 | path = config.audio_path
34 | wavs = glob.glob(path)
35 | return wavs
36 |
37 | def load_file():
38 | with open(config.pkl_path, 'rb') as fr:
39 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr)
40 | model = load_model(config.model_path)
41 | return char2id, id2char, mfcc_mean, mfcc_std, model
42 |
43 | def set_data(wavs, mfcc_mean, mfcc_std):
44 | mfcc_dim = config.mfcc_dim
45 | index = np.random.randint(len(wavs))
46 | audio, sr = librosa.load(wavs[index])
47 | energy = librosa.feature.rms(audio)
48 | frames = np.nonzero(energy >= np.max(energy) / 5)
49 | indices = librosa.core.frames_to_samples(frames)[1]
50 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
51 | X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
52 | X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14)
53 | # print(X_data.shape)
54 | return X_data
55 |
56 |
57 | def wav_pred(model,X_data,id2char):
58 | pred = model.predict(np.expand_dims(X_data, axis=0))
59 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0])
60 | pred_ids = pred_ids.flatten().tolist()
61 | words=''
62 | judge=0
63 | for i in pred_ids:
64 | if i != -1:
65 | judge=1
66 | words=words+id2char[i]
67 | if judge==1:
68 | print(words)
69 | else:
70 | print("未检测到")
71 |
72 | def run():
73 | # wavs = input_audio()
74 | wavs = out_load_audio()
75 | char2id, id2char, mfcc_mean, mfcc_std, model = load_file()
76 | X_data = set_data(wavs, mfcc_mean, mfcc_std)
77 | wav_pred(model,X_data,id2char)
78 |
79 |
80 |
81 | if __name__ == '__main__' :
82 | run()
83 |
84 |
--------------------------------------------------------------------------------
/asr_v3.py:
--------------------------------------------------------------------------------
1 | from keras.models import load_model
2 | from keras import backend as K
3 | import numpy as np
4 | import librosa
5 | from python_speech_features import mfcc
6 | import speech_recognition as sr
7 | import pickle
8 | import glob
9 | import config
10 | import wave
11 | import os
12 | import pyaudio
13 | from tqdm import tqdm
14 |
15 | class set_audio():
16 | CHUNK = 1024 # 每个缓冲区的帧数
17 | FORMAT = pyaudio.paInt16 # 采样位数
18 | CHANNELS = 1 # 单声道
19 | RATE = 44100 # 采样频率
20 |
21 | # 可设置录制时间
22 | def record_audio(record_second):
23 | """ 录音功能 """
24 | p = pyaudio.PyAudio() # 实例化对象
25 | stream = p.open(format=set_audio.FORMAT,
26 | channels=set_audio.CHANNELS,
27 | rate=set_audio.RATE,
28 | input=True,
29 | frames_per_buffer=set_audio.CHUNK) # 打开流,传入响应参数
30 |
31 | wf = wave.open('temp_file.wav', 'wb') # 打开 wav 文件。
32 | wf.setnchannels(set_audio.CHANNELS) # 声道设置
33 | wf.setsampwidth(p.get_sample_size(set_audio.FORMAT)) # 采样位数设置
34 | wf.setframerate(set_audio.RATE) # 采样频率设置
35 |
36 | for _ in tqdm(range(0, int(set_audio.RATE * record_second / set_audio.CHUNK))):
37 | data = stream.read(set_audio.CHUNK)
38 | wf.writeframes(data) # 写入数据
39 | stream.stop_stream() # 关闭流
40 | stream.close()
41 | p.terminate()
42 | wf.close()
43 |
44 | wavs = glob.glob('temp_file.wav')
45 |
46 | # os.remove("temp_file.wav")
47 |
48 | return wavs
49 |
50 |
51 |
52 | def save_as_wav(audio, output_file_path):
53 | with wave.open(output_file_path, 'wb') as wav_file:
54 | wav_file.setnchannels(1) # 单声道
55 | wav_file.setsampwidth(2) # 16位PCM编码
56 | wav_file.setframerate(44100) # 采样率为44.1kHz
57 | wav_file.writeframes(audio.frame_data)
58 |
59 | # 录音自动停止
60 | def input_audio():
61 | r = sr.Recognizer()
62 | with sr.Microphone() as source:
63 | print("请说...")
64 | r.pause_threshold = 1
65 | audio = r.listen(source)
66 | output_file_path = "temp_file.wav"
67 | save_as_wav(audio, output_file_path)
68 | wavs = glob.glob('temp_file.wav')
69 |
70 | return wavs
71 |
72 | def load_file():
73 | with open(config.pkl_path, 'rb') as fr:
74 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr)
75 | model = load_model(config.model_path)
76 | return char2id, id2char, mfcc_mean, mfcc_std, model
77 |
78 | def set_data(wavs, mfcc_mean, mfcc_std):
79 | mfcc_dim = config.mfcc_dim
80 | index = np.random.randint(len(wavs))
81 | audio, sr = librosa.load(wavs[index])
82 | energy = librosa.feature.rms(audio)
83 | frames = np.nonzero(energy >= np.max(energy) / 5)
84 | indices = librosa.core.frames_to_samples(frames)[1]
85 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
86 | X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
87 | X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14)
88 | # print(X_data.shape)
89 | return X_data
90 |
91 |
92 | def wav_pred(model,X_data,id2char):
93 | pred = model.predict(np.expand_dims(X_data, axis=0))
94 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0])
95 | pred_ids = pred_ids.flatten().tolist()
96 | words=''
97 | judge=0
98 | for i in pred_ids:
99 | if i != -1:
100 | judge=1
101 | words=words+id2char[i]
102 | if judge==1:
103 | print(words)
104 | else:
105 | print("未检测到")
106 |
107 | def run():
108 | # 自动停止录音
109 | # wavs = input_audio()
110 | # 设置录制时间
111 | wavs = record_audio(record_second=5)
112 | char2id, id2char, mfcc_mean, mfcc_std, model = load_file()
113 | X_data = set_data(wavs, mfcc_mean, mfcc_std)
114 | wav_pred(model,X_data,id2char)
115 | os.remove("temp_file.wav")
116 |
117 |
118 | if __name__ == '__main__' :
119 | run()
120 |
121 |
--------------------------------------------------------------------------------
/asr_web.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import os
4 |
5 | import torch
6 | import uvicorn
7 | from fastapi import FastAPI, File, Body, UploadFile, Request
8 | from starlette.staticfiles import StaticFiles
9 | from starlette.templating import Jinja2Templates
10 | from utils.utils import add_arguments, print_arguments
11 |
12 | from keras.models import load_model
13 | from keras import backend as K
14 | import numpy as np
15 | import librosa
16 | from python_speech_features import mfcc
17 | import speech_recognition as sr
18 | import pickle
19 | import config
20 | import wave
21 | import io
22 |
23 | from pydub import AudioSegment
24 | from io import BytesIO
25 | import librosa
26 | import numpy as np
27 | from python_speech_features import mfcc
28 |
29 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
30 | parser = argparse.ArgumentParser(description=__doc__)
31 | add_arg = functools.partial(add_arguments, argparser=parser)
32 | add_arg("host", type=str, default="0.0.0.0", help="监听主机的IP地址")
33 | add_arg("port", type=int, default=5000, help="服务所使用的端口号")
34 | # add_arg("model_path", type=str, default="models/whisper-tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称")
35 | # add_arg("model_path", type=str, default="models/tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称")
36 | add_arg("use_gpu", type=bool, default=True, help="是否使用gpu进行预测")
37 | add_arg("num_beams", type=int, default=1, help="解码搜索大小")
38 | add_arg("batch_size", type=int, default=16, help="预测batch_size大小")
39 | add_arg("use_compile", type=bool, default=False, help="是否使用Pytorch2.0的编译器")
40 | add_arg("assistant_model_path", type=str, default=None, help="助手模型,可以提高推理速度,例如openai/whisper-tiny")
41 | add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载")
42 | add_arg("use_flash_attention_2", type=bool, default=False, help="是否使用FlashAttention2加速")
43 | add_arg("use_bettertransformer", type=bool, default=False, help="是否使用BetterTransformer加速")
44 | args = parser.parse_args()
45 | print_arguments(args)
46 |
47 | # 设置设备
48 | device = "cuda:0" if torch.cuda.is_available() and args.use_gpu else "cpu"
49 | torch_dtype = torch.float16 if torch.cuda.is_available() and args.use_gpu else torch.float32
50 |
51 |
52 | model = load_model(config.model_path)
53 |
54 | with open(config.pkl_path, 'rb') as fr:
55 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr)
56 |
57 |
58 | app = FastAPI(title="thirteen语音识别")
59 | app.mount('/static', StaticFiles(directory='static'), name='static')
60 | templates = Jinja2Templates(directory="templates")
61 | model_semaphore = None
62 |
63 |
64 | def release_model_semaphore():
65 | model_semaphore.release()
66 |
67 |
68 | def recognition(file: File,mfcc_mean, mfcc_std):
69 |
70 | X_data = extract_mfcc_features(file, mfcc_mean, mfcc_std)
71 | pred = model.predict(np.expand_dims(X_data, axis=0))
72 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0])
73 | pred_ids = pred_ids.flatten().tolist()
74 | results = ''
75 | judge=0
76 | for i in pred_ids:
77 | if i != -1:
78 | judge=1
79 | results = results + id2char[i]
80 | if judge!=1:
81 | results = '未检测到'
82 |
83 | return results
84 |
85 |
86 | def extract_mfcc_features(audio_bytes, mfcc_mean, mfcc_std):
87 | # 使用pydub将bytes转换为WAV格式的AudioSegment(如果它不是WAV的话)
88 | # 注意:这里我们假设input_bytes是WAV或我们可以转换为WAV的格式
89 | # 如果input_bytes不是WAV且格式未知,你可能需要先检测它
90 | audio_segment = AudioSegment.from_file(BytesIO(audio_bytes), format="wav") # 如果已经是WAV,或者确定可以解析为WAV
91 |
92 | # 确保输出为WAV格式(如果之前不是的话,这一步其实是多余的,因为from_file已经处理了)
93 | # 但为了清晰起见,我们还是将其导出为WAV的bytes
94 | wav_bytes = BytesIO()
95 | audio_segment.export(wav_bytes, format="wav")
96 |
97 | # 重置BytesIO的指针到开头
98 | wav_bytes.seek(0)
99 |
100 | # 使用librosa加载WAV音频
101 | y, sr = librosa.load(wav_bytes)
102 |
103 | # 提取RMS能量
104 | energy = librosa.feature.rms(y=y)
105 |
106 | # 找到能量大于最大能量1/5的帧
107 | frames = np.nonzero(energy[0] >= np.max(energy[0]) / 5)
108 |
109 | # 将帧索引转换为样本索引
110 | if frames[0].size:
111 | indices = librosa.core.frames_to_samples(frames)[0]
112 | y = y[indices[0]:indices[-1]]
113 |
114 | # 提取MFCC特征
115 | mfcc_dim = 13 # 你可以根据需要修改MFCC的维度
116 | mfcc_features = mfcc(y, sr, numcep=mfcc_dim, nfft=551)
117 |
118 | # 这里假设你已经有了mfcc_mean和mfcc_std用于标准化(通常需要在训练阶段计算)
119 | # 如果没有,你可以跳过标准化步骤,或者计算它们
120 | mfcc_features = (mfcc_features - mfcc_mean) / (mfcc_std + 1e-14)
121 |
122 | return mfcc_features
123 |
124 |
125 | @app.post("/recognition")
126 | async def api_recognition(audio: UploadFile = File(..., description="音频文件")):
127 | # if language == "None": language = None
128 | data = await audio.read()
129 | with io.BytesIO(data) as bio:
130 | with wave.open(bio, 'rb') as wav_file:
131 | pass
132 | results = recognition(file= data, mfcc_mean= mfcc_mean, mfcc_std= mfcc_std)
133 | ret = {"results": results, "code": 0}
134 | return ret
135 |
136 |
137 | @app.get("/")
138 | async def index(request: Request):
139 | return templates.TemplateResponse("index.html", {"request": request, "id": id})
140 |
141 |
142 | if __name__ == '__main__':
143 | uvicorn.run(app, host=args.host, port=args.port)
144 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # 训练数据音频文件路径
2 | train_wav_data_path = 'AISHELL-3/train/wav/'
3 |
4 | # 训练数据内容文件路径
5 | train_texts_data_path = 'AISHELL-3/train/content.txt'
6 |
7 | # 测试数据音频文件路径
8 | test_wav_data_path = 'AISHELL-3/test/wav/'
9 |
10 | # 测试数据内容文件路径
11 | test_texts_data_path = 'AISHELL-3/test/content.txt'
12 |
13 | # 存放模型路径 /模型名字
14 | model_path = 'model/asr_AISHELL.h5'
15 |
16 | # 存放pkl路径 /pkl名字
17 | pkl_path = 'pkl_all/dictionary.pkl'
18 |
19 | # 存放labels路径
20 | labels_path = 'pkl_all/labels.pkl'
21 |
22 | # features.pkl路径
23 | features_path = 'pkl_all/features.pkl'
24 |
25 | # 外部导入音频路径
26 | audio_path = 'AISHELL-3/train/wav/SSB0005/SSB00050001.wav'
27 | # model_name
28 |
29 | batch_size = 16
30 |
31 | epochs = 25
32 |
33 | num_blocks = 3
34 |
35 | filters = 128
36 |
37 | mfcc_dim = 13
38 |
--------------------------------------------------------------------------------
/infer_server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import os
4 |
5 | import torch
6 | import uvicorn
7 | from fastapi import FastAPI, File, Body, UploadFile, Request
8 | from starlette.staticfiles import StaticFiles
9 | from starlette.templating import Jinja2Templates
10 | from utils.utils import add_arguments, print_arguments
11 |
12 | from keras.models import load_model
13 | from keras import backend as K
14 | import numpy as np
15 | import librosa
16 | from python_speech_features import mfcc
17 | import speech_recognition as sr
18 | import pickle
19 | import config
20 | import wave
21 | import io
22 |
23 | from pydub import AudioSegment
24 | from io import BytesIO
25 | import librosa
26 | import numpy as np
27 | from python_speech_features import mfcc
28 |
29 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
30 | parser = argparse.ArgumentParser(description=__doc__)
31 | add_arg = functools.partial(add_arguments, argparser=parser)
32 | add_arg("host", type=str, default="0.0.0.0", help="监听主机的IP地址")
33 | add_arg("port", type=int, default=5000, help="服务所使用的端口号")
34 | # add_arg("model_path", type=str, default="models/whisper-tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称")
35 | # add_arg("model_path", type=str, default="models/tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称")
36 | add_arg("use_gpu", type=bool, default=True, help="是否使用gpu进行预测")
37 | add_arg("num_beams", type=int, default=1, help="解码搜索大小")
38 | add_arg("batch_size", type=int, default=16, help="预测batch_size大小")
39 | add_arg("use_compile", type=bool, default=False, help="是否使用Pytorch2.0的编译器")
40 | add_arg("assistant_model_path", type=str, default=None, help="助手模型,可以提高推理速度,例如openai/whisper-tiny")
41 | add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载")
42 | add_arg("use_flash_attention_2", type=bool, default=False, help="是否使用FlashAttention2加速")
43 | add_arg("use_bettertransformer", type=bool, default=False, help="是否使用BetterTransformer加速")
44 | args = parser.parse_args()
45 | print_arguments(args)
46 |
47 | # 设置设备
48 | device = "cuda:0" if torch.cuda.is_available() and args.use_gpu else "cpu"
49 | torch_dtype = torch.float16 if torch.cuda.is_available() and args.use_gpu else torch.float32
50 |
51 |
52 | model = load_model(config.model_path)
53 |
54 | with open(config.pkl_path, 'rb') as fr:
55 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr)
56 |
57 |
58 | app = FastAPI(title="thirteen语音识别")
59 | app.mount('/static', StaticFiles(directory='static'), name='static')
60 | templates = Jinja2Templates(directory="templates")
61 | model_semaphore = None
62 |
63 |
64 | def release_model_semaphore():
65 | model_semaphore.release()
66 |
67 |
68 | def recognition(file: File,mfcc_mean, mfcc_std):
69 |
70 | X_data = extract_mfcc_features(file, mfcc_mean, mfcc_std)
71 | pred = model.predict(np.expand_dims(X_data, axis=0))
72 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0])
73 | pred_ids = pred_ids.flatten().tolist()
74 | results = ''
75 | judge=0
76 | for i in pred_ids:
77 | if i != -1:
78 | judge=1
79 | results = results + id2char[i]
80 | if judge!=1:
81 | results = '未检测到'
82 |
83 | return results
84 |
85 |
86 | def extract_mfcc_features(audio_bytes, mfcc_mean, mfcc_std):
87 | # 使用pydub将bytes转换为WAV格式的AudioSegment(如果它不是WAV的话)
88 | # 注意:这里我们假设input_bytes是WAV或我们可以转换为WAV的格式
89 | # 如果input_bytes不是WAV且格式未知,你可能需要先检测它
90 | audio_segment = AudioSegment.from_file(BytesIO(audio_bytes), format="wav") # 如果已经是WAV,或者确定可以解析为WAV
91 |
92 | # 确保输出为WAV格式(如果之前不是的话,这一步其实是多余的,因为from_file已经处理了)
93 | # 但为了清晰起见,我们还是将其导出为WAV的bytes
94 | wav_bytes = BytesIO()
95 | audio_segment.export(wav_bytes, format="wav")
96 |
97 | # 重置BytesIO的指针到开头
98 | wav_bytes.seek(0)
99 |
100 | # 使用librosa加载WAV音频
101 | y, sr = librosa.load(wav_bytes)
102 |
103 | # 提取RMS能量
104 | energy = librosa.feature.rms(y=y)
105 |
106 | # 找到能量大于最大能量1/5的帧
107 | frames = np.nonzero(energy[0] >= np.max(energy[0]) / 5)
108 |
109 | # 将帧索引转换为样本索引
110 | if frames[0].size:
111 | indices = librosa.core.frames_to_samples(frames)[0]
112 | y = y[indices[0]:indices[-1]]
113 |
114 | # 提取MFCC特征
115 | mfcc_dim = 13 # 你可以根据需要修改MFCC的维度
116 | mfcc_features = mfcc(y, sr, numcep=mfcc_dim, nfft=551)
117 |
118 | # 这里假设你已经有了mfcc_mean和mfcc_std用于标准化(通常需要在训练阶段计算)
119 | # 如果没有,你可以跳过标准化步骤,或者计算它们
120 | mfcc_features = (mfcc_features - mfcc_mean) / (mfcc_std + 1e-14)
121 |
122 | return mfcc_features
123 |
124 |
125 | @app.post("/recognition")
126 | async def api_recognition(audio: UploadFile = File(..., description="音频文件")):
127 | # if language == "None": language = None
128 | data = await audio.read()
129 | with io.BytesIO(data) as bio:
130 | with wave.open(bio, 'rb') as wav_file:
131 | pass
132 | results = recognition(file= data, mfcc_mean= mfcc_mean, mfcc_std= mfcc_std)
133 | ret = {"results": results, "code": 0}
134 | return ret
135 |
136 |
137 | @app.get("/")
138 | async def index(request: Request):
139 | return templates.TemplateResponse("index.html", {"request": request, "id": id})
140 |
141 |
142 | if __name__ == '__main__':
143 | uvicorn.run(app, host=args.host, port=args.port)
144 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipython==8.12.3
2 | keras==2.10.0
3 | librosa==0.7.2
4 | matplotlib==3.8.3
5 | numpy==1.24.4
6 | python_speech_features==0.6
7 | scipy==1.13.0
8 | SpeechRecognition==3.10.1
9 | tqdm==4.66.2
10 |
--------------------------------------------------------------------------------
/static/index.css:
--------------------------------------------------------------------------------
1 | * {
2 | box-sizing: border-box;
3 | }
4 |
5 | body {
6 | font-family: "Helvetica Neue", "Roboto", sans-serif;
7 | background-color: #f2f2f2;
8 | margin: 0;
9 | padding: 0;
10 | }
11 |
12 | #header {
13 | background-color: #fff;
14 | color: #333;
15 | display: flex;
16 | justify-content: center;
17 | align-items: center;
18 | height: 80px;
19 | }
20 |
21 | h1 {
22 | font-size: 36px;
23 | margin: 0;
24 | }
25 |
26 | #content {
27 | background-color: #fff;
28 | border-radius: 10px;
29 | box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
30 | margin: 50px auto;
31 | max-width: 800px;
32 | padding: 20px;
33 | }
34 |
35 | #content div {
36 | display: flex;
37 | flex-wrap: wrap;
38 | justify-content: space-between;
39 | margin-bottom: 20px;
40 | }
41 |
42 | #content a {
43 | background-color: #fff;
44 | border-radius: 5px;
45 | box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
46 | color: #333;
47 | padding: 10px;
48 | text-align: center;
49 | text-decoration: none;
50 | transition: background-color 0.2s;
51 | width: 20%;
52 | }
53 |
54 | #content a:hover {
55 | background-color: #f2f2f2;
56 | }
57 |
58 | #content img {
59 | cursor: pointer;
60 | height: 50px;
61 | transition: transform 0.2s;
62 | width: 50px;
63 | }
64 |
65 | #content img:hover {
66 | transform: scale(1.1);
67 | }
68 |
69 | #result {
70 | background-color: #fff;
71 | border-radius: 5px;
72 | box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
73 | padding: 10px;
74 | }
75 |
76 | #result textarea {
77 | border: none;
78 | border-radius: 5px;
79 | font-size: 16px;
80 | height: 300px;
81 | margin-top: 10px;
82 | padding: 10px;
83 | resize: none;
84 | width: 100%;
85 | }
86 |
87 | @media only screen and (max-width: 600px) {
88 | #content a {
89 | width: 100%;
90 | }
91 | }
--------------------------------------------------------------------------------
/static/record.js:
--------------------------------------------------------------------------------
1 | //兼容
2 | window.URL = window.URL || window.webkitURL;
3 | //获取计算机的设备:摄像头或者录音设备
4 | navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia;
5 |
6 | var HZRecorder = function (stream, config) {
7 | config = config || {};
8 | config.sampleBits = config.sampleBits || 16; //采样数位 8, 16
9 | config.sampleRate = config.sampleRate || 16000; //采样率 16000
10 |
11 | //创建一个音频环境对象
12 | var audioContext = window.AudioContext || window.webkitAudioContext;
13 | var context = new audioContext();
14 | var audioInput = context.createMediaStreamSource(stream);
15 | // 第二个和第三个参数指的是输入和输出都是单声道,2是双声道。
16 | var recorder = context.createScriptProcessor(4096, 2, 2);
17 |
18 | var audioData = {
19 | size: 0 //录音文件长度
20 | , buffer: [] //录音缓存
21 | , inputSampleRate: context.sampleRate //输入采样率
22 | , inputSampleBits: 16 //输入采样数位 8, 16
23 | , outputSampleRate: config.sampleRate //输出采样率
24 | , outputSampleBits: config.sampleBits //输出采样数位 8, 16
25 | , input: function (data) {
26 | this.buffer.push(new Float32Array(data));
27 | this.size += data.length;
28 | }
29 | , compress: function () { //合并压缩
30 | //合并
31 | var data = new Float32Array(this.size);
32 | var offset = 0;
33 | for (var i = 0; i < this.buffer.length; i++) {
34 | data.set(this.buffer[i], offset);
35 | offset += this.buffer[i].length;
36 | }
37 | //压缩
38 | var compression = parseInt(this.inputSampleRate / this.outputSampleRate);
39 | var length = data.length / compression;
40 | var result = new Float32Array(length);
41 | var index = 0, j = 0;
42 | while (index < length) {
43 | result[index] = data[j];
44 | j += compression;
45 | index++;
46 | }
47 | return result;
48 | }
49 | , encodeWAV: function () {
50 | var sampleRate = Math.min(this.inputSampleRate, this.outputSampleRate);
51 | var sampleBits = Math.min(this.inputSampleBits, this.outputSampleBits);
52 | var bytes = this.compress();
53 | var dataLength = bytes.length * (sampleBits / 8);
54 | var buffer = new ArrayBuffer(44 + dataLength);
55 | var data = new DataView(buffer);
56 |
57 | var channelCount = 1;//单声道
58 | var offset = 0;
59 |
60 | var writeString = function (str) {
61 | for (var i = 0; i < str.length; i++) {
62 | data.setUint8(offset + i, str.charCodeAt(i));
63 | }
64 | }
65 |
66 | // 资源交换文件标识符
67 | writeString('RIFF');
68 | offset += 4;
69 | // 下个地址开始到文件尾总字节数,即文件大小-8
70 | data.setUint32(offset, 36 + dataLength, true);
71 | offset += 4;
72 | // WAV文件标志
73 | writeString('WAVE');
74 | offset += 4;
75 | // 波形格式标志
76 | writeString('fmt ');
77 | offset += 4;
78 | // 过滤字节,一般为 0x10 = 16
79 | data.setUint32(offset, 16, true);
80 | offset += 4;
81 | // 格式类别 (PCM形式采样数据)
82 | data.setUint16(offset, 1, true);
83 | offset += 2;
84 | // 通道数
85 | data.setUint16(offset, channelCount, true);
86 | offset += 2;
87 | // 采样率,每秒样本数,表示每个通道的播放速度
88 | data.setUint32(offset, sampleRate, true);
89 | offset += 4;
90 | // 波形数据传输率 (每秒平均字节数) 单声道×每秒数据位数×每样本数据位/8
91 | data.setUint32(offset, channelCount * sampleRate * (sampleBits / 8), true);
92 | offset += 4;
93 | // 快数据调整数 采样一次占用字节数 单声道×每样本的数据位数/8
94 | data.setUint16(offset, channelCount * (sampleBits / 8), true);
95 | offset += 2;
96 | // 每样本数据位数
97 | data.setUint16(offset, sampleBits, true);
98 | offset += 2;
99 | // 数据标识符
100 | writeString('data');
101 | offset += 4;
102 | // 采样数据总数,即数据总大小-44
103 | data.setUint32(offset, dataLength, true);
104 | offset += 4;
105 | // 写入采样数据
106 | if (sampleBits === 8) {
107 | for (var i = 0; i < bytes.length; i++, offset++) {
108 | var s = Math.max(-1, Math.min(1, bytes[i]));
109 | var val = s < 0 ? s * 0x8000 : s * 0x7FFF;
110 | val = parseInt(255 / (65535 / (val + 32768)));
111 | data.setInt8(offset, val, true);
112 | }
113 | } else {
114 | for (var i = 0; i < bytes.length; i++, offset += 2) {
115 | var s = Math.max(-1, Math.min(1, bytes[i]));
116 | data.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
117 | }
118 | }
119 |
120 | return new Blob([data], {type: 'audio/wav'});
121 | }
122 | };
123 |
124 | //开始录音
125 | this.start = function () {
126 | audioInput.connect(recorder);
127 | recorder.connect(context.destination);
128 | }
129 |
130 | //停止
131 | this.stop = function () {
132 | recorder.disconnect();
133 | }
134 |
135 | //获取音频文件
136 | this.getBlob = function () {
137 | this.stop();
138 | return audioData.encodeWAV();
139 | }
140 |
141 | //回放
142 | this.play = function (audio) {
143 | audio.src = window.URL.createObjectURL(this.getBlob());
144 | }
145 | //清除
146 | this.clear = function () {
147 | audioData.buffer = [];
148 | audioData.size = 0;
149 | }
150 |
151 | //上传
152 | this.upload = function (url, callback) {
153 | var fd = new FormData();
154 | // 上传的文件名和数据
155 | fd.append("audio", this.getBlob());
156 | var xhr = new XMLHttpRequest();
157 | xhr.timeout = 60000
158 | if (callback) {
159 | xhr.upload.addEventListener("progress", function (e) {
160 | callback('uploading', e);
161 | }, false);
162 | xhr.addEventListener("load", function (e) {
163 | callback('ok', e);
164 | }, false);
165 | xhr.addEventListener("error", function (e) {
166 | callback('error', e);
167 | }, false);
168 | xhr.addEventListener("abort", function (e) {
169 | callback('cancel', e);
170 | }, false);
171 | }
172 | xhr.open("POST", url);
173 | xhr.send(fd);
174 | }
175 |
176 | //音频采集
177 | recorder.onaudioprocess = function (e) {
178 | audioData.input(e.inputBuffer.getChannelData(0));
179 | //record(e.inputBuffer.getChannelData(0));
180 | }
181 |
182 | };
183 | //抛出异常
184 | HZRecorder.throwError = function (message) {
185 | alert(message);
186 | throw new function () {
187 | this.toString = function () {
188 | return message;
189 | }
190 | }
191 | }
192 | //是否支持录音
193 | HZRecorder.canRecording = (navigator.getUserMedia != null);
194 | //获取录音机
195 | HZRecorder.get = function (callback, config) {
196 | if (callback) {
197 | if (navigator.getUserMedia) {
198 | navigator.getUserMedia(
199 | {audio: true} //只启用音频
200 | , function (stream) {
201 | var rec = new HZRecorder(stream, config);
202 | callback(rec);
203 | }
204 | , function (error) {
205 | switch (error.code || error.name) {
206 | case 'PERMISSION_DENIED':
207 | case 'PermissionDeniedError':
208 | HZRecorder.throwError('用户拒绝提供信息。');
209 | break;
210 | case 'NOT_SUPPORTED_ERROR':
211 | case 'NotSupportedError':
212 | HZRecorder.throwError('浏览器不支持硬件设备。');
213 | break;
214 | case 'MANDATORY_UNSATISFIED_ERROR':
215 | case 'MandatoryUnsatisfiedError':
216 | HZRecorder.throwError('无法发现指定的硬件设备。');
217 | break;
218 | default:
219 | HZRecorder.throwError('无法打开麦克风。异常信息:' + (error.code || error.name));
220 | break;
221 | }
222 | });
223 | } else {
224 | window.alert('不是HTTPS协议或者localhost地址,不能使用录音功能!')
225 | HZRecorder.throwErr('当前浏览器不支持录音功能。');
226 | return;
227 | }
228 | }
229 | };
--------------------------------------------------------------------------------
/static/record.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WThirteen/asr_AISHELL-3/3140e6f914dbb3ccac906b40dc6836cdd122e46d/static/record.png
--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Thirteen_temp
6 |
7 |
8 |
9 |
10 |
13 |
14 |
20 |
21 |
22 |
23 | 上传进度:
24 |
25 |
170 |
171 |
172 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #导入相关的库
2 | from keras.models import Model
3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization
4 | from keras.optimizers import Adam, SGD
5 | from keras import backend as K
6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | from mpl_toolkits.axes_grid1 import make_axes_locatable
10 | import random
11 | import pickle
12 | import glob
13 | from tqdm import tqdm
14 | import os
15 | from python_speech_features import mfcc
16 | import scipy.io.wavfile as wav
17 | import librosa
18 | from IPython.display import Audio
19 | import config
20 |
21 |
22 | def load_texts_data(path,en_num_all):
23 | f=open(path, "r", encoding='utf-8')
24 | txt=[]
25 | for line in f:
26 | txt.append(line.strip())
27 |
28 | # 音频文件名字
29 | txt_filename = []
30 | # 对应文件内容
31 | txt_wav = []
32 |
33 | for i in txt:
34 | temp_txt_filename,temp_txt_wav = i.split('\t')
35 | txt_filename.append(temp_txt_filename)
36 | txt_wav.append(temp_txt_wav)
37 |
38 | # 音频文字
39 | texts = []
40 |
41 | for i in txt_wav:
42 | temp = ''
43 | for j in i:
44 | if j in en_num_all:
45 | continue
46 | else:
47 | # print(j)
48 | temp = temp+j
49 | texts.append(temp)
50 |
51 | return texts
52 |
53 | def create_en_num():
54 | # 字典 字母+数字
55 | en_num_all = []
56 | # 字母
57 | for letter in 'abcdefghijklmnopqrstuvwxyz':
58 | en_num_all.extend(letter)
59 | # 数字
60 | for number in range(10):
61 | en_num_all.extend(str(number))
62 | # 空格
63 | en_num_all.extend(' ')
64 |
65 | return en_num_all
66 |
67 | def load_wav_data(path):
68 | files = os.listdir(path)
69 | # 音频文件
70 | wav_file = []
71 | # 音频文件的相对路径
72 | wav_file_path = []
73 |
74 | for i in files:
75 | temp_files = os.listdir(path+i)
76 | for j in temp_files:
77 | wav_file.append(j)
78 | wav_file_path.append(path+i+'/'+j)
79 |
80 | return wav_file_path
81 |
82 | #根据数据集标定的音素读入
83 | def load_and_trim(path):
84 | audio, sr = librosa.load(path)
85 | # energy = librosa.feature.rmse(audio)
86 | energy = librosa.feature.rms(audio)
87 | frames = np.nonzero(energy >= np.max(energy) / 5)
88 | indices = librosa.core.frames_to_samples(frames)[1]
89 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
90 | return audio, sr
91 |
92 | #可视化,显示语音文件的MFCC图
93 | def visualize(paths,texts,index,mfcc_dim):
94 | path = paths[index]
95 | text = texts[index]
96 | print('Audio Text:', text)
97 |
98 | audio, sr = load_and_trim(path)
99 | plt.figure(figsize=(12, 3))
100 | plt.plot(np.arange(len(audio)), audio)
101 | plt.title('Raw Audio Signal')
102 | plt.xlabel('Time')
103 | plt.ylabel('Audio Amplitude')
104 | plt.show()
105 |
106 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
107 | print('Shape of MFCC:', feature.shape)
108 |
109 | fig = plt.figure(figsize=(12, 5))
110 | ax = fig.add_subplot(111)
111 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto')
112 | plt.title('Normalized MFCC')
113 | plt.ylabel('Time')
114 | plt.xlabel('MFCC Coefficient')
115 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05))
116 | ax.set_xticks(np.arange(0, 13, 2), minor=False);
117 | plt.show()
118 |
119 | return path
120 |
121 | # Audio(visualize(0))
122 |
123 | def wav_features(paths,total):
124 | #提取音频特征并存储
125 | features = []
126 | for i in tqdm(range(total)):
127 | path = paths[i]
128 | audio, sr = load_and_trim(path)
129 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551))
130 | return features
131 |
132 | def save_features(features):
133 | with open(config.features_path, 'wb') as fw:
134 | pickle.dump(features,fw)
135 |
136 | def load_features():
137 | with open(config.features_path, 'rb') as f:
138 | features = pickle.load(f)
139 | return features
140 |
141 | def normalized_features(features):
142 | #随机选择100个数据集
143 | samples = random.sample(features, 100)
144 | samples = np.vstack(samples)
145 | #平均MFCC的值为了归一化处理
146 | mfcc_mean = np.mean(samples, axis=0)
147 | #计算标准差为了归一化
148 | mfcc_std = np.std(samples, axis=0)
149 | # print(mfcc_mean)
150 | # print(mfcc_std)
151 | #归一化特征
152 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features]
153 |
154 | return mfcc_mean,mfcc_std,features
155 |
156 | def save_labels(texts):
157 | #将数据集读入的标签和对应id存储列表
158 | chars = {}
159 | for text in texts:
160 | for c in text:
161 | chars[c] = chars.get(c, 0) + 1
162 |
163 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True)
164 | chars = [char[0] for char in chars]
165 | # print(len(chars), chars[:100])
166 |
167 | char2id = {c: i for i, c in enumerate(chars)}
168 | id2char = {i: c for i, c in enumerate(chars)}
169 |
170 | return char2id,id2char
171 |
172 | def data_set(total,features,texts):
173 | data_index = np.arange(total)
174 | np.random.shuffle(data_index)
175 | train_size = int(0.9 * total)
176 | test_size = total - train_size
177 | train_index = data_index[:train_size]
178 | test_index = data_index[train_size:]
179 | #神经网络输入和输出X,Y的读入数据集特征
180 | X_train = [features[i] for i in train_index]
181 | Y_train = [texts[i] for i in train_index]
182 | X_test = [features[i] for i in test_index]
183 | Y_test = [texts[i] for i in test_index]
184 |
185 | return X_train,Y_train,X_test,Y_test
186 |
187 |
188 | #定义训练批次的产生,一次训练16个
189 | def batch_generator(x, y,char2id):
190 | batch_size = config.batch_size
191 | offset = 0
192 | while True:
193 | offset += batch_size
194 |
195 | if offset == batch_size or offset >= len(x):
196 | data_index = np.arange(len(x))
197 | np.random.shuffle(data_index)
198 | x = [x[i] for i in data_index]
199 | y = [y[i] for i in data_index]
200 | offset = batch_size
201 |
202 | X_data = x[offset - batch_size: offset]
203 | Y_data = y[offset - batch_size: offset]
204 |
205 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)])
206 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)])
207 |
208 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim])
209 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id)
210 | X_length = np.zeros([batch_size, 1], dtype='int32')
211 | Y_length = np.zeros([batch_size, 1], dtype='int32')
212 |
213 | for i in range(batch_size):
214 | X_length[i, 0] = X_data[i].shape[0]
215 | X_batch[i, :X_length[i, 0], :] = X_data[i]
216 |
217 | Y_length[i, 0] = len(Y_data[i])
218 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]]
219 |
220 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length}
221 | outputs = {'ctc': np.zeros([batch_size])}
222 |
223 | yield (inputs, outputs)
224 |
225 | def input_layer():
226 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X')
227 | Y = Input(shape=(None,), dtype='float32', name='Y')
228 | X_length = Input(shape=(1,), dtype='int32', name='X_length')
229 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length')
230 |
231 | return X,Y,X_length,Y_length
232 |
233 |
234 | #卷积1层
235 | def conv1d(inputs, filters, kernel_size, dilation_rate):
236 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None,
237 | dilation_rate=dilation_rate)(inputs)
238 |
239 | #标准化函数
240 | def batchnorm(inputs):
241 | return BatchNormalization()(inputs)
242 |
243 | #激活层函数
244 | def activation(inputs, activation):
245 | return Activation(activation)(inputs)
246 |
247 | #全连接层函数
248 | def res_block(inputs, filters, kernel_size, dilation_rate):
249 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh')
250 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid')
251 | h0 = Multiply()([hf, hg])
252 |
253 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
254 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
255 |
256 | return Add()([ha, inputs]), hs
257 |
258 | #计算损失函数
259 | def calc_ctc_loss(args):
260 | y, yp, ypl, yl = args
261 | return K.ctc_batch_cost(y, yp, ypl, yl)
262 |
263 | def model_train(X,Y,X_length,Y_length,char2id):
264 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh')
265 | shortcut = []
266 | for i in range(config.num_blocks):
267 | for r in [1, 2, 4, 8, 16]:
268 | h0, s = res_block(h0, config.filters, 7, r)
269 | shortcut.append(s)
270 |
271 | h1 = activation(Add()(shortcut), 'relu')
272 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu')
273 | #softmax损失函数输出结果
274 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax')
275 | sub_model = Model(inputs=X, outputs=Y_pred)
276 |
277 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length])
278 | #加载模型训练
279 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss)
280 | #建立优化器
281 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5)
282 | #激活模型开始计算
283 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer)
284 |
285 | return sub_model,model
286 |
287 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std):
288 | #保存模型
289 | sub_model.save(config.model_path)
290 | #将字保存在pl=pkl中
291 | with open(config.pkl_path, 'wb') as fw:
292 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw)
293 |
294 |
295 | def draw_loss(history):
296 | train_loss = history.history['loss']
297 | valid_loss = history.history['val_loss']
298 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train')
299 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid')
300 | plt.legend(loc='upper right')
301 | plt.xlabel('Epoch')
302 | plt.ylabel('Loss')
303 | plt.show()
304 |
305 |
306 | def run():
307 | print("-----load data-----")
308 | path_train = load_wav_data(path=config.train_wav_data_path)
309 | path_test = load_wav_data(path=config.test_wav_data_path)
310 | paths = []
311 | paths.extend(path_train), paths.extend(path_test)
312 |
313 | privacy_dict = create_en_num()
314 | texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict)
315 | texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict)
316 | texts = []
317 | texts.extend(texts_train), texts.extend(texts_test)
318 |
319 | char2id,id2char = save_labels(texts)
320 |
321 | total = len(texts)
322 | print("-----Extract audio features-----")
323 | features = wav_features(paths,total)
324 |
325 | print("-----save features-----")
326 | save_features(features)
327 |
328 |
329 | mfcc_mean,mfcc_std,features = normalized_features(features)
330 |
331 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts)
332 |
333 | X,Y,X_length,Y_length = input_layer()
334 |
335 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id)
336 |
337 | # 回调
338 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0)
339 | # 监控 损失值(loss)作为指标
340 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6)
341 | #开始训练
342 | history = model.fit_generator(
343 | generator=batch_generator(X_train, Y_train, char2id),
344 | steps_per_epoch=len(X_train) // config.batch_size,
345 | epochs=config.epochs,
346 | validation_data=batch_generator(X_test, Y_test, char2id),
347 | validation_steps=len(X_test) // config.batch_size,
348 | callbacks=[checkpointer, lr_decay])
349 |
350 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std)
351 | draw_loss(history)
352 |
353 |
354 | if __name__ == '__main__' :
355 | run()
356 |
--------------------------------------------------------------------------------
/train_v2.py:
--------------------------------------------------------------------------------
1 | #导入相关的库
2 | from keras.models import Model
3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization
4 | from keras.optimizers import Adam, SGD
5 | from keras import backend as K
6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | from mpl_toolkits.axes_grid1 import make_axes_locatable
10 | import random
11 | import pickle
12 | import glob
13 | from tqdm import tqdm
14 | import os
15 | from python_speech_features import mfcc
16 | import scipy.io.wavfile as wav
17 | import librosa
18 | from IPython.display import Audio
19 | import config
20 |
21 |
22 | def load_texts_data(path,en_num_all):
23 | f=open(path, "r", encoding='utf-8')
24 | txt=[]
25 | for line in f:
26 | txt.append(line.strip())
27 |
28 | # 音频文件名字
29 | txt_filename = []
30 | # 对应文件内容
31 | txt_wav = []
32 |
33 | for i in txt:
34 | temp_txt_filename,temp_txt_wav = i.split('\t')
35 | txt_filename.append(temp_txt_filename)
36 | txt_wav.append(temp_txt_wav)
37 |
38 | # 音频文字
39 | texts = []
40 |
41 | for i in txt_wav:
42 | temp = ''
43 | for j in i:
44 | if j in en_num_all:
45 | continue
46 | else:
47 | # print(j)
48 | temp = temp+j
49 | texts.append(temp)
50 |
51 | return texts
52 |
53 | def create_en_num():
54 | # 字典 字母+数字
55 | en_num_all = []
56 | # 字母
57 | for letter in 'abcdefghijklmnopqrstuvwxyz':
58 | en_num_all.extend(letter)
59 | # 数字
60 | for number in range(10):
61 | en_num_all.extend(str(number))
62 | # 空格
63 | en_num_all.extend(' ')
64 |
65 | return en_num_all
66 |
67 | def load_wav_data(path):
68 | files = os.listdir(path)
69 | # 音频文件
70 | wav_file = []
71 | # 音频文件的相对路径
72 | wav_file_path = []
73 |
74 | for i in files:
75 | temp_files = os.listdir(path+i)
76 | for j in temp_files:
77 | wav_file.append(j)
78 | wav_file_path.append(path+i+'/'+j)
79 |
80 | return wav_file_path
81 |
82 | #根据数据集标定的音素读入
83 | def load_and_trim(path):
84 | audio, sr = librosa.load(path)
85 | # energy = librosa.feature.rmse(audio)
86 | energy = librosa.feature.rms(audio)
87 | frames = np.nonzero(energy >= np.max(energy) / 5)
88 | indices = librosa.core.frames_to_samples(frames)[1]
89 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
90 | return audio, sr
91 |
92 | #可视化,显示语音文件的MFCC图
93 | def visualize(paths,texts,index,mfcc_dim):
94 | path = paths[index]
95 | text = texts[index]
96 | print('Audio Text:', text)
97 |
98 | audio, sr = load_and_trim(path)
99 | plt.figure(figsize=(12, 3))
100 | plt.plot(np.arange(len(audio)), audio)
101 | plt.title('Raw Audio Signal')
102 | plt.xlabel('Time')
103 | plt.ylabel('Audio Amplitude')
104 | plt.show()
105 |
106 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
107 | print('Shape of MFCC:', feature.shape)
108 |
109 | fig = plt.figure(figsize=(12, 5))
110 | ax = fig.add_subplot(111)
111 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto')
112 | plt.title('Normalized MFCC')
113 | plt.ylabel('Time')
114 | plt.xlabel('MFCC Coefficient')
115 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05))
116 | ax.set_xticks(np.arange(0, 13, 2), minor=False);
117 | plt.show()
118 |
119 | return path
120 |
121 | # Audio(visualize(0))
122 |
123 | def wav_features(paths,total):
124 | #提取音频特征并存储
125 | features = []
126 | for i in tqdm(range(total)):
127 | path = paths[i]
128 | audio, sr = load_and_trim(path)
129 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551))
130 | return features
131 |
132 | def save_features(features):
133 | with open(config.features_path, 'wb') as fw:
134 | pickle.dump(features,fw)
135 |
136 | def load_features():
137 | with open(config.features_path, 'rb') as f:
138 | features = pickle.load(f)
139 | return features
140 |
141 | def normalized_features(features):
142 | #随机选择100个数据集
143 | samples = random.sample(features, 100)
144 | samples = np.vstack(samples)
145 | #平均MFCC的值为了归一化处理
146 | mfcc_mean = np.mean(samples, axis=0)
147 | #计算标准差为了归一化
148 | mfcc_std = np.std(samples, axis=0)
149 | # print(mfcc_mean)
150 | # print(mfcc_std)
151 | #归一化特征
152 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features]
153 |
154 | return mfcc_mean,mfcc_std,features
155 |
156 | def save_labels(texts):
157 | #将数据集读入的标签和对应id存储列表
158 | chars = {}
159 | for text in texts:
160 | for c in text:
161 | chars[c] = chars.get(c, 0) + 1
162 |
163 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True)
164 | chars = [char[0] for char in chars]
165 | # print(len(chars), chars[:100])
166 |
167 | char2id = {c: i for i, c in enumerate(chars)}
168 | id2char = {i: c for i, c in enumerate(chars)}
169 |
170 | return char2id,id2char
171 |
172 | def data_set(total,features,texts):
173 | data_index = np.arange(total)
174 | np.random.shuffle(data_index)
175 | train_size = int(0.9 * total)
176 | test_size = total - train_size
177 | train_index = data_index[:train_size]
178 | test_index = data_index[train_size:]
179 | #神经网络输入和输出X,Y的读入数据集特征
180 | X_train = [features[i] for i in train_index]
181 | Y_train = [texts[i] for i in train_index]
182 | X_test = [features[i] for i in test_index]
183 | Y_test = [texts[i] for i in test_index]
184 |
185 | return X_train,Y_train,X_test,Y_test
186 |
187 |
188 | #定义训练批次的产生,一次训练16个
189 | def batch_generator(x, y,char2id):
190 | batch_size = config.batch_size
191 | offset = 0
192 | while True:
193 | offset += batch_size
194 |
195 | if offset == batch_size or offset >= len(x):
196 | data_index = np.arange(len(x))
197 | np.random.shuffle(data_index)
198 | x = [x[i] for i in data_index]
199 | y = [y[i] for i in data_index]
200 | offset = batch_size
201 |
202 | X_data = x[offset - batch_size: offset]
203 | Y_data = y[offset - batch_size: offset]
204 |
205 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)])
206 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)])
207 |
208 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim])
209 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id)
210 | X_length = np.zeros([batch_size, 1], dtype='int32')
211 | Y_length = np.zeros([batch_size, 1], dtype='int32')
212 |
213 | for i in range(batch_size):
214 | X_length[i, 0] = X_data[i].shape[0]
215 | X_batch[i, :X_length[i, 0], :] = X_data[i]
216 |
217 | Y_length[i, 0] = len(Y_data[i])
218 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]]
219 |
220 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length}
221 | outputs = {'ctc': np.zeros([batch_size])}
222 |
223 | yield (inputs, outputs)
224 |
225 | def input_layer():
226 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X')
227 | Y = Input(shape=(None,), dtype='float32', name='Y')
228 | X_length = Input(shape=(1,), dtype='int32', name='X_length')
229 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length')
230 |
231 | return X,Y,X_length,Y_length
232 |
233 |
234 | #卷积1层
235 | def conv1d(inputs, filters, kernel_size, dilation_rate):
236 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None,
237 | dilation_rate=dilation_rate)(inputs)
238 |
239 | #标准化函数
240 | def batchnorm(inputs):
241 | return BatchNormalization()(inputs)
242 |
243 | #激活层函数
244 | def activation(inputs, activation):
245 | return Activation(activation)(inputs)
246 |
247 | #全连接层函数
248 | def res_block(inputs, filters, kernel_size, dilation_rate):
249 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh')
250 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid')
251 | h0 = Multiply()([hf, hg])
252 |
253 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
254 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
255 |
256 | return Add()([ha, inputs]), hs
257 |
258 | #计算损失函数
259 | def calc_ctc_loss(args):
260 | y, yp, ypl, yl = args
261 | return K.ctc_batch_cost(y, yp, ypl, yl)
262 |
263 | def model_train(X,Y,X_length,Y_length,char2id):
264 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh')
265 | shortcut = []
266 | for i in range(config.num_blocks):
267 | for r in [1, 2, 4, 8, 16]:
268 | h0, s = res_block(h0, config.filters, 7, r)
269 | shortcut.append(s)
270 |
271 | h1 = activation(Add()(shortcut), 'relu')
272 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu')
273 | #softmax损失函数输出结果
274 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax')
275 | sub_model = Model(inputs=X, outputs=Y_pred)
276 |
277 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length])
278 | #加载模型训练
279 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss)
280 | #建立优化器
281 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5)
282 | #激活模型开始计算
283 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer)
284 |
285 | return sub_model,model
286 |
287 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std):
288 | #保存模型
289 | sub_model.save(config.model_path)
290 | #将字保存在pl=pkl中
291 | with open(config.pkl_path, 'wb') as fw:
292 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw)
293 |
294 |
295 | def draw_loss(history):
296 | train_loss = history.history['loss']
297 | valid_loss = history.history['val_loss']
298 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train')
299 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid')
300 | plt.legend(loc='upper right')
301 | plt.xlabel('Epoch')
302 | plt.ylabel('Loss')
303 | plt.show()
304 |
305 |
306 | def run():
307 | print("-----load data-----")
308 | path_train = load_wav_data(path=config.train_wav_data_path)
309 | path_test = load_wav_data(path=config.test_wav_data_path)
310 | paths = []
311 | paths.extend(path_train), paths.extend(path_test)
312 |
313 | privacy_dict = create_en_num()
314 | texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict)
315 | texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict)
316 | texts = []
317 | texts.extend(texts_train), texts.extend(texts_test)
318 |
319 | char2id,id2char = save_labels(texts)
320 |
321 | total = len(texts)
322 |
323 | # print("-----Extract audio features-----")
324 | # features = wav_features(paths,total)
325 |
326 | # print("-----save features-----")
327 | # save_features(features)
328 |
329 | print("-----load features-----")
330 | features = load_features()
331 |
332 |
333 | mfcc_mean,mfcc_std,features = normalized_features(features)
334 |
335 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts)
336 |
337 | X,Y,X_length,Y_length = input_layer()
338 |
339 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id)
340 |
341 | # 回调
342 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0)
343 | # 监控 损失值(loss)作为指标
344 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6)
345 | #开始训练
346 | history = model.fit_generator(
347 | generator=batch_generator(X_train, Y_train, char2id),
348 | steps_per_epoch=len(X_train) // config.batch_size,
349 | epochs=config.epochs,
350 | validation_data=batch_generator(X_test, Y_test, char2id),
351 | validation_steps=len(X_test) // config.batch_size,
352 | callbacks=[checkpointer, lr_decay])
353 |
354 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std)
355 | draw_loss(history)
356 |
357 |
358 | if __name__ == '__main__' :
359 | run()
360 |
--------------------------------------------------------------------------------
/train_v3.py:
--------------------------------------------------------------------------------
1 | #导入相关的库
2 | from keras.models import Model
3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization
4 | from keras.optimizers import Adam, SGD
5 | from keras import backend as K
6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | from mpl_toolkits.axes_grid1 import make_axes_locatable
10 | import random
11 | import pickle
12 | import glob
13 | from tqdm import tqdm
14 | import os
15 | from python_speech_features import mfcc
16 | import scipy.io.wavfile as wav
17 | import librosa
18 | from IPython.display import Audio
19 | import config
20 |
21 |
22 | def load_texts_data(path,en_num_all):
23 | f=open(path, "r", encoding='utf-8')
24 | txt=[]
25 | for line in f:
26 | txt.append(line.strip())
27 |
28 | # 音频文件名字
29 | txt_filename = []
30 | # 对应文件内容
31 | txt_wav = []
32 |
33 | for i in txt:
34 | temp_txt_filename,temp_txt_wav = i.split('\t')
35 | txt_filename.append(temp_txt_filename)
36 | txt_wav.append(temp_txt_wav)
37 |
38 | # 音频文字
39 | texts = []
40 |
41 | for i in txt_wav:
42 | temp = ''
43 | for j in i:
44 | if j in en_num_all:
45 | continue
46 | else:
47 | # print(j)
48 | temp = temp+j
49 | texts.append(temp)
50 |
51 | return texts
52 |
53 | def create_en_num():
54 | # 字典 字母+数字
55 | en_num_all = []
56 | # 字母
57 | for letter in 'abcdefghijklmnopqrstuvwxyz':
58 | en_num_all.extend(letter)
59 | # 数字
60 | for number in range(10):
61 | en_num_all.extend(str(number))
62 | # 空格
63 | en_num_all.extend(' ')
64 |
65 | return en_num_all
66 |
67 | def load_wav_data(path):
68 | files = os.listdir(path)
69 | # 音频文件
70 | wav_file = []
71 | # 音频文件的相对路径
72 | wav_file_path = []
73 |
74 | for i in files:
75 | temp_files = os.listdir(path+i)
76 | for j in temp_files:
77 | wav_file.append(j)
78 | wav_file_path.append(path+i+'/'+j)
79 |
80 | return wav_file_path
81 |
82 | #根据数据集标定的音素读入
83 | def load_and_trim(path):
84 | audio, sr = librosa.load(path)
85 | # energy = librosa.feature.rmse(audio)
86 | energy = librosa.feature.rms(audio)
87 | frames = np.nonzero(energy >= np.max(energy) / 5)
88 | indices = librosa.core.frames_to_samples(frames)[1]
89 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
90 | return audio, sr
91 |
92 | #可视化,显示语音文件的MFCC图
93 | def visualize(paths,texts,index,mfcc_dim):
94 | path = paths[index]
95 | text = texts[index]
96 | print('Audio Text:', text)
97 |
98 | audio, sr = load_and_trim(path)
99 | plt.figure(figsize=(12, 3))
100 | plt.plot(np.arange(len(audio)), audio)
101 | plt.title('Raw Audio Signal')
102 | plt.xlabel('Time')
103 | plt.ylabel('Audio Amplitude')
104 | plt.show()
105 |
106 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
107 | print('Shape of MFCC:', feature.shape)
108 |
109 | fig = plt.figure(figsize=(12, 5))
110 | ax = fig.add_subplot(111)
111 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto')
112 | plt.title('Normalized MFCC')
113 | plt.ylabel('Time')
114 | plt.xlabel('MFCC Coefficient')
115 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05))
116 | ax.set_xticks(np.arange(0, 13, 2), minor=False);
117 | plt.show()
118 |
119 | return path
120 |
121 | # Audio(visualize(0))
122 |
123 | def wav_features(paths,total):
124 | #提取音频特征并存储
125 | features = []
126 | for i in tqdm(range(total)):
127 | path = paths[i]
128 | audio, sr = load_and_trim(path)
129 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551))
130 | return features
131 |
132 | def save_features(features):
133 | with open(config.features_path, 'wb') as fw:
134 | pickle.dump(features,fw)
135 |
136 | def load_features():
137 | with open(config.features_path, 'rb') as f:
138 | features = pickle.load(f)
139 | return features
140 |
141 | def normalized_features(features):
142 | #随机选择100个数据集
143 | samples = random.sample(features, 100)
144 | samples = np.vstack(samples)
145 | #平均MFCC的值为了归一化处理
146 | mfcc_mean = np.mean(samples, axis=0)
147 | #计算标准差为了归一化
148 | mfcc_std = np.std(samples, axis=0)
149 | # print(mfcc_mean)
150 | # print(mfcc_std)
151 | #归一化特征
152 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features]
153 |
154 | return mfcc_mean,mfcc_std,features
155 |
156 | def save_labels(texts):
157 | #将数据集读入的标签和对应id存储列表
158 | chars = {}
159 | for text in texts:
160 | for c in text:
161 | chars[c] = chars.get(c, 0) + 1
162 |
163 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True)
164 | chars = [char[0] for char in chars]
165 | # print(len(chars), chars[:100])
166 |
167 | char2id = {c: i for i, c in enumerate(chars)}
168 | id2char = {i: c for i, c in enumerate(chars)}
169 |
170 | with open(config.labels_path, 'wb') as fw:
171 | pickle.dump([texts,char2id, id2char], fw)
172 |
173 | return texts,char2id,id2char
174 |
175 |
176 | def load_labels():
177 | with open(config.labels_path, 'rb') as f:
178 | texts,char2id,id2char = pickle.load(f)
179 | return texts,char2id,id2char
180 |
181 |
182 | def data_set(total,features,texts):
183 | data_index = np.arange(total)
184 | np.random.shuffle(data_index)
185 | train_size = int(0.9 * total)
186 | test_size = total - train_size
187 | train_index = data_index[:train_size]
188 | test_index = data_index[train_size:]
189 | #神经网络输入和输出X,Y的读入数据集特征
190 | X_train = [features[i] for i in train_index]
191 | Y_train = [texts[i] for i in train_index]
192 | X_test = [features[i] for i in test_index]
193 | Y_test = [texts[i] for i in test_index]
194 |
195 | return X_train,Y_train,X_test,Y_test
196 |
197 |
198 | #定义训练批次的产生,一次训练16个
199 | def batch_generator(x, y,char2id):
200 | batch_size = config.batch_size
201 | offset = 0
202 | while True:
203 | offset += batch_size
204 |
205 | if offset == batch_size or offset >= len(x):
206 | data_index = np.arange(len(x))
207 | np.random.shuffle(data_index)
208 | x = [x[i] for i in data_index]
209 | y = [y[i] for i in data_index]
210 | offset = batch_size
211 |
212 | X_data = x[offset - batch_size: offset]
213 | Y_data = y[offset - batch_size: offset]
214 |
215 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)])
216 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)])
217 |
218 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim])
219 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id)
220 | X_length = np.zeros([batch_size, 1], dtype='int32')
221 | Y_length = np.zeros([batch_size, 1], dtype='int32')
222 |
223 | for i in range(batch_size):
224 | X_length[i, 0] = X_data[i].shape[0]
225 | X_batch[i, :X_length[i, 0], :] = X_data[i]
226 |
227 | Y_length[i, 0] = len(Y_data[i])
228 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]]
229 |
230 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length}
231 | outputs = {'ctc': np.zeros([batch_size])}
232 |
233 | yield (inputs, outputs)
234 |
235 | def input_layer():
236 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X')
237 | Y = Input(shape=(None,), dtype='float32', name='Y')
238 | X_length = Input(shape=(1,), dtype='int32', name='X_length')
239 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length')
240 |
241 | return X,Y,X_length,Y_length
242 |
243 |
244 | #卷积1层
245 | def conv1d(inputs, filters, kernel_size, dilation_rate):
246 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None,
247 | dilation_rate=dilation_rate)(inputs)
248 |
249 | #标准化函数
250 | def batchnorm(inputs):
251 | return BatchNormalization()(inputs)
252 |
253 | #激活层函数
254 | def activation(inputs, activation):
255 | return Activation(activation)(inputs)
256 |
257 | #全连接层函数
258 | def res_block(inputs, filters, kernel_size, dilation_rate):
259 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh')
260 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid')
261 | h0 = Multiply()([hf, hg])
262 |
263 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
264 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
265 |
266 | return Add()([ha, inputs]), hs
267 |
268 | #计算损失函数
269 | def calc_ctc_loss(args):
270 | y, yp, ypl, yl = args
271 | return K.ctc_batch_cost(y, yp, ypl, yl)
272 |
273 | def model_train(X,Y,X_length,Y_length,char2id):
274 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh')
275 | shortcut = []
276 | for i in range(config.num_blocks):
277 | for r in [1, 2, 4, 8, 16]:
278 | h0, s = res_block(h0, config.filters, 7, r)
279 | shortcut.append(s)
280 |
281 | h1 = activation(Add()(shortcut), 'relu')
282 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu')
283 | #softmax损失函数输出结果
284 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax')
285 | sub_model = Model(inputs=X, outputs=Y_pred)
286 |
287 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length])
288 | #加载模型训练
289 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss)
290 | #建立优化器
291 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5)
292 | #激活模型开始计算
293 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer)
294 |
295 | return sub_model,model
296 |
297 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std):
298 | #保存模型
299 | sub_model.save(config.model_path)
300 | #将字保存在pl=pkl中
301 | with open(config.pkl_path, 'wb') as fw:
302 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw)
303 |
304 |
305 | def draw_loss(history):
306 | train_loss = history.history['loss']
307 | valid_loss = history.history['val_loss']
308 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train')
309 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid')
310 | plt.legend(loc='upper right')
311 | plt.xlabel('Epoch')
312 | plt.ylabel('Loss')
313 | plt.show()
314 |
315 |
316 | def run():
317 | # print("-----load data-----")
318 | # path_train = load_wav_data(path=config.train_wav_data_path)
319 | # path_test = load_wav_data(path=config.test_wav_data_path)
320 | # paths = []
321 | # paths.extend(path_train), paths.extend(path_test)
322 |
323 | # privacy_dict = create_en_num()
324 | # texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict)
325 | # texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict)
326 | # texts = []
327 | # texts.extend(texts_train), texts.extend(texts_test)
328 |
329 | # texts,char2id,id2char = save_labels(texts)
330 |
331 | print("-----load labels-----")
332 | texts,char2id,id2char = load_labels()
333 |
334 | total = len(texts)
335 |
336 | # print("-----Extract audio features-----")
337 | # features = wav_features(paths,total)
338 |
339 | # print("-----save features-----")
340 | # save_features(features)
341 |
342 | print("-----load features-----")
343 | features = load_features()
344 |
345 |
346 | mfcc_mean,mfcc_std,features = normalized_features(features)
347 |
348 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts)
349 |
350 | X,Y,X_length,Y_length = input_layer()
351 |
352 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id)
353 |
354 | # 回调
355 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0)
356 | # 监控 损失值(loss)作为指标
357 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6)
358 | #开始训练
359 | history = model.fit_generator(
360 | generator=batch_generator(X_train, Y_train, char2id),
361 | steps_per_epoch=len(X_train) // config.batch_size,
362 | epochs=config.epochs,
363 | validation_data=batch_generator(X_test, Y_test, char2id),
364 | validation_steps=len(X_test) // config.batch_size,
365 | callbacks=[checkpointer, lr_decay])
366 |
367 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std)
368 | draw_loss(history)
369 |
370 |
371 | if __name__ == '__main__' :
372 | run()
373 |
--------------------------------------------------------------------------------
/train_v4.py:
--------------------------------------------------------------------------------
1 | #导入相关的库
2 | from keras.models import Model
3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization
4 | from keras.optimizers import Adam, SGD
5 | from keras import backend as K
6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
7 | import tensorflow as tf
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | from mpl_toolkits.axes_grid1 import make_axes_locatable
11 | import random
12 | import pickle
13 | import glob
14 | from tqdm import tqdm
15 | import os
16 | from python_speech_features import mfcc
17 | import scipy.io.wavfile as wav
18 | import librosa
19 | from IPython.display import Audio
20 | import config
21 |
22 |
23 | def load_texts_data(path,en_num_all):
24 | f=open(path, "r", encoding='utf-8')
25 | txt=[]
26 | for line in f:
27 | txt.append(line.strip())
28 |
29 | # 音频文件名字
30 | txt_filename = []
31 | # 对应文件内容
32 | txt_wav = []
33 |
34 | for i in txt:
35 | temp_txt_filename,temp_txt_wav = i.split('\t')
36 | txt_filename.append(temp_txt_filename)
37 | txt_wav.append(temp_txt_wav)
38 |
39 | # 音频文字
40 | texts = []
41 |
42 | for i in txt_wav:
43 | temp = ''
44 | for j in i:
45 | if j in en_num_all:
46 | continue
47 | else:
48 | # print(j)
49 | temp = temp+j
50 | texts.append(temp)
51 |
52 | return texts
53 |
54 | def create_en_num():
55 | # 字典 字母+数字
56 | en_num_all = []
57 | # 字母
58 | for letter in 'abcdefghijklmnopqrstuvwxyz':
59 | en_num_all.extend(letter)
60 | # 数字
61 | for number in range(10):
62 | en_num_all.extend(str(number))
63 | # 空格
64 | en_num_all.extend(' ')
65 |
66 | return en_num_all
67 |
68 | def load_wav_data(path):
69 | files = os.listdir(path)
70 | # 音频文件
71 | wav_file = []
72 | # 音频文件的相对路径
73 | wav_file_path = []
74 |
75 | for i in files:
76 | temp_files = os.listdir(path+i)
77 | for j in temp_files:
78 | wav_file.append(j)
79 | wav_file_path.append(path+i+'/'+j)
80 |
81 | return wav_file_path
82 |
83 | #根据数据集标定的音素读入
84 | def load_and_trim(path):
85 | audio, sr = librosa.load(path)
86 | # energy = librosa.feature.rmse(audio)
87 | energy = librosa.feature.rms(audio)
88 | frames = np.nonzero(energy >= np.max(energy) / 5)
89 | indices = librosa.core.frames_to_samples(frames)[1]
90 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
91 | return audio, sr
92 |
93 | #可视化,显示语音文件的MFCC图
94 | def visualize(paths,texts,index,mfcc_dim):
95 | path = paths[index]
96 | text = texts[index]
97 | print('Audio Text:', text)
98 |
99 | audio, sr = load_and_trim(path)
100 | plt.figure(figsize=(12, 3))
101 | plt.plot(np.arange(len(audio)), audio)
102 | plt.title('Raw Audio Signal')
103 | plt.xlabel('Time')
104 | plt.ylabel('Audio Amplitude')
105 | plt.show()
106 |
107 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
108 | print('Shape of MFCC:', feature.shape)
109 |
110 | fig = plt.figure(figsize=(12, 5))
111 | ax = fig.add_subplot(111)
112 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto')
113 | plt.title('Normalized MFCC')
114 | plt.ylabel('Time')
115 | plt.xlabel('MFCC Coefficient')
116 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05))
117 | ax.set_xticks(np.arange(0, 13, 2), minor=False);
118 | plt.show()
119 |
120 | return path
121 |
122 | # Audio(visualize(0))
123 |
124 | def wav_features(paths,total):
125 | #提取音频特征并存储
126 | features = []
127 | for i in tqdm(range(total)):
128 | path = paths[i]
129 | audio, sr = load_and_trim(path)
130 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551))
131 | return features
132 |
133 | def save_features(features):
134 | with open(config.features_path, 'wb') as fw:
135 | pickle.dump(features,fw)
136 |
137 | def load_features():
138 | with open(config.features_path, 'rb') as f:
139 | features = pickle.load(f)
140 | return features
141 |
142 | def normalized_features(features):
143 | #随机选择100个数据集
144 | samples = random.sample(features, 100)
145 | samples = np.vstack(samples)
146 | #平均MFCC的值为了归一化处理
147 | mfcc_mean = np.mean(samples, axis=0)
148 | #计算标准差为了归一化
149 | mfcc_std = np.std(samples, axis=0)
150 | # print(mfcc_mean)
151 | # print(mfcc_std)
152 | #归一化特征
153 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features]
154 |
155 | return mfcc_mean,mfcc_std,features
156 |
157 | def save_labels(texts):
158 | #将数据集读入的标签和对应id存储列表
159 | chars = {}
160 | for text in texts:
161 | for c in text:
162 | chars[c] = chars.get(c, 0) + 1
163 |
164 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True)
165 | chars = [char[0] for char in chars]
166 | # print(len(chars), chars[:100])
167 |
168 | char2id = {c: i for i, c in enumerate(chars)}
169 | id2char = {i: c for i, c in enumerate(chars)}
170 |
171 | with open(config.labels_path, 'wb') as fw:
172 | pickle.dump([texts,char2id, id2char], fw)
173 |
174 | return texts,char2id,id2char
175 |
176 |
177 | def load_labels():
178 | with open(config.labels_path, 'rb') as f:
179 | texts,char2id,id2char = pickle.load(f)
180 | return texts,char2id,id2char
181 |
182 |
183 | def data_set(total,features,texts):
184 | data_index = np.arange(total)
185 | np.random.shuffle(data_index)
186 | train_size = int(0.9 * total)
187 | test_size = total - train_size
188 | train_index = data_index[:train_size]
189 | test_index = data_index[train_size:]
190 | #神经网络输入和输出X,Y的读入数据集特征
191 | X_train = [features[i] for i in train_index]
192 | Y_train = [texts[i] for i in train_index]
193 | X_test = [features[i] for i in test_index]
194 | Y_test = [texts[i] for i in test_index]
195 |
196 | return X_train,Y_train,X_test,Y_test
197 |
198 |
199 | #定义训练批次的产生,一次训练16个
200 | def batch_generator(x, y,char2id):
201 | batch_size = config.batch_size
202 | offset = 0
203 | while True:
204 | offset += batch_size
205 |
206 | if offset == batch_size or offset >= len(x):
207 | data_index = np.arange(len(x))
208 | np.random.shuffle(data_index)
209 | x = [x[i] for i in data_index]
210 | y = [y[i] for i in data_index]
211 | offset = batch_size
212 |
213 | X_data = x[offset - batch_size: offset]
214 | Y_data = y[offset - batch_size: offset]
215 |
216 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)])
217 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)])
218 |
219 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim])
220 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id)
221 | X_length = np.zeros([batch_size, 1], dtype='int32')
222 | Y_length = np.zeros([batch_size, 1], dtype='int32')
223 |
224 | for i in range(batch_size):
225 | X_length[i, 0] = X_data[i].shape[0]
226 | X_batch[i, :X_length[i, 0], :] = X_data[i]
227 |
228 | Y_length[i, 0] = len(Y_data[i])
229 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]]
230 |
231 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length}
232 | outputs = {'ctc': np.zeros([batch_size])}
233 |
234 | yield (inputs, outputs)
235 |
236 | def input_layer():
237 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X')
238 | Y = Input(shape=(None,), dtype='float32', name='Y')
239 | X_length = Input(shape=(1,), dtype='int32', name='X_length')
240 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length')
241 |
242 | return X,Y,X_length,Y_length
243 |
244 |
245 | #卷积1层
246 | def conv1d(inputs, filters, kernel_size, dilation_rate):
247 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None,
248 | dilation_rate=dilation_rate)(inputs)
249 |
250 | #标准化函数
251 | def batchnorm(inputs):
252 | return BatchNormalization()(inputs)
253 |
254 | #激活层函数
255 | def activation(inputs, activation):
256 | return Activation(activation)(inputs)
257 |
258 | #全连接层函数
259 | def res_block(inputs, filters, kernel_size, dilation_rate):
260 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh')
261 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid')
262 | h0 = Multiply()([hf, hg])
263 |
264 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
265 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh')
266 |
267 | return Add()([ha, inputs]), hs
268 |
269 | #计算损失函数
270 | def calc_ctc_loss(args):
271 | y, yp, ypl, yl = args
272 | return K.ctc_batch_cost(y, yp, ypl, yl)
273 |
274 | def model_train(X,Y,X_length,Y_length,char2id):
275 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh')
276 | shortcut = []
277 | for i in range(config.num_blocks):
278 | for r in [1, 2, 4, 8, 16]:
279 | h0, s = res_block(h0, config.filters, 7, r)
280 | shortcut.append(s)
281 |
282 | h1 = activation(Add()(shortcut), 'relu')
283 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu')
284 | #softmax损失函数输出结果
285 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax')
286 | sub_model = Model(inputs=X, outputs=Y_pred)
287 |
288 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length])
289 | #加载模型训练
290 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss)
291 | #建立优化器
292 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5)
293 | #激活模型开始计算
294 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer)
295 |
296 | return sub_model,model
297 |
298 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std):
299 | #保存模型
300 | sub_model.save(config.model_path)
301 | #将字保存在pl=pkl中
302 | with open(config.pkl_path, 'wb') as fw:
303 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw)
304 |
305 |
306 | def draw_loss(history):
307 | train_loss = history.history['loss']
308 | valid_loss = history.history['val_loss']
309 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train')
310 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid')
311 | plt.legend(loc='upper right')
312 | plt.xlabel('Epoch')
313 | plt.ylabel('Loss')
314 | plt.show()
315 |
316 |
317 | def run():
318 | # print("-----load data-----")
319 | # path_train = load_wav_data(path=config.train_wav_data_path)
320 | # path_test = load_wav_data(path=config.test_wav_data_path)
321 | # paths = []
322 | # paths.extend(path_train), paths.extend(path_test)
323 |
324 | # privacy_dict = create_en_num()
325 | # texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict)
326 | # texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict)
327 | # texts = []
328 | # texts.extend(texts_train), texts.extend(texts_test)
329 |
330 | # texts,char2id,id2char = save_labels(texts)
331 |
332 | print("-----load labels-----")
333 | texts,char2id,id2char = load_labels()
334 |
335 | total = len(texts)
336 |
337 | # print("-----Extract audio features-----")
338 | # features = wav_features(paths,total)
339 |
340 | # print("-----save features-----")
341 | # save_features(features)
342 |
343 | print("-----load features-----")
344 | features = load_features()
345 |
346 |
347 | mfcc_mean,mfcc_std,features = normalized_features(features)
348 |
349 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts)
350 |
351 | X,Y,X_length,Y_length = input_layer()
352 |
353 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id)
354 |
355 | # 回调
356 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0)
357 | # 监控 损失值(loss)作为指标
358 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6)
359 |
360 | tf_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
361 | #开始训练
362 | history = model.fit_generator(
363 | generator=batch_generator(X_train, Y_train, char2id),
364 | steps_per_epoch=len(X_train) // config.batch_size,
365 | epochs=config.epochs,
366 | validation_data=batch_generator(X_test, Y_test, char2id),
367 | validation_steps=len(X_test) // config.batch_size,
368 | callbacks=[checkpointer, lr_decay,tf_callback])
369 |
370 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std)
371 | draw_loss(history)
372 |
373 |
374 | if __name__ == '__main__' :
375 | run()
376 |
--------------------------------------------------------------------------------
/utils/binary.py:
--------------------------------------------------------------------------------
1 | import json
2 | import mmap
3 |
4 | import struct
5 |
6 | from tqdm import tqdm
7 |
8 |
9 | class DatasetWriter(object):
10 | def __init__(self, prefix):
11 | # 创建对应的数据文件
12 | self.data_file = open(prefix + '.data', 'wb')
13 | self.header_file = open(prefix + '.header', 'wb')
14 | self.data_sum = 0
15 | self.offset = 0
16 | self.header = ''
17 |
18 | def add_data(self, data):
19 | key = str(self.data_sum)
20 | data = bytes(data, encoding="utf8")
21 | # 写入图像数据
22 | self.data_file.write(struct.pack('I', len(key)))
23 | self.data_file.write(key.encode('ascii'))
24 | self.data_file.write(struct.pack('I', len(data)))
25 | self.data_file.write(data)
26 | # 写入索引
27 | self.offset += 4 + len(key) + 4
28 | self.header = key + '\t' + str(self.offset) + '\t' + str(len(data)) + '\n'
29 | self.header_file.write(self.header.encode('ascii'))
30 | self.offset += len(data)
31 | self.data_sum += 1
32 |
33 | def close(self):
34 | self.data_file.close()
35 | self.header_file.close()
36 |
37 |
38 | class DatasetReader(object):
39 | def __init__(self, data_header_path, min_duration=0, max_duration=30):
40 | self.keys = []
41 | self.offset_dict = {}
42 | self.fp = open(data_header_path.replace('.header', '.data'), 'rb')
43 | self.m = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ)
44 | for line in tqdm(open(data_header_path, 'rb'), desc='读取数据列表'):
45 | key, val_pos, val_len = line.split('\t'.encode('ascii'))
46 | data = self.m[int(val_pos):int(val_pos) + int(val_len)]
47 | data = str(data, encoding="utf-8")
48 | data = json.loads(data)
49 | # 跳过超出长度限制的音频
50 | if data["duration"] < min_duration:
51 | continue
52 | if max_duration != -1 and data["duration"] > max_duration:
53 | continue
54 | self.keys.append(key)
55 | self.offset_dict[key] = (int(val_pos), int(val_len))
56 |
57 | # 获取一行列表数据
58 | def get_data(self, key):
59 | p = self.offset_dict.get(key, None)
60 | if p is None:
61 | return None
62 | val_pos, val_len = p
63 | data = self.m[val_pos:val_pos + val_len]
64 | data = str(data, encoding="utf-8")
65 | return json.loads(data)
66 |
67 | # 获取keys
68 | def get_keys(self):
69 | return self.keys
70 |
71 | def __len__(self):
72 | return len(self.keys)
73 |
--------------------------------------------------------------------------------
/utils/callback.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os
3 | import shutil
4 |
5 | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
6 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
7 |
8 |
9 | # 保存模型时的回调函数
10 | class SavePeftModelCallback(TrainerCallback):
11 | def on_save(self,
12 | args: TrainingArguments,
13 | state: TrainerState,
14 | control: TrainerControl,
15 | **kwargs, ):
16 | if args.local_rank == 0 or args.local_rank == -1:
17 | # 保存效果最好的模型
18 | best_checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-best")
19 | # 因为只保存最新5个检查点,所以要确保不是之前的检查点
20 | if os.path.exists(state.best_model_checkpoint):
21 | if os.path.exists(best_checkpoint_folder):
22 | shutil.rmtree(best_checkpoint_folder)
23 | shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder)
24 | print(f"效果最好的检查点为:{state.best_model_checkpoint},评估结果为:{state.best_metric}")
25 | return control
26 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass
3 | from typing import Any, List, Dict, Union
4 |
5 | import torch
6 | from zhconv import convert
7 |
8 |
9 | # 删除标点符号
10 | def remove_punctuation(text: str or List[str]):
11 | punctuation = '!,.;:?、!,。;:?'
12 | if isinstance(text, str):
13 | text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
14 | return text
15 | elif isinstance(text, list):
16 | result_text = []
17 | for t in text:
18 | t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
19 | result_text.append(t)
20 | return result_text
21 | else:
22 | raise Exception(f'不支持该类型{type(text)}')
23 |
24 |
25 | # 将繁体中文总成简体中文
26 | def to_simple(text: str or List[str]):
27 | if isinstance(text, str):
28 | text = convert(text, 'zh-cn')
29 | return text
30 | elif isinstance(text, list):
31 | result_text = []
32 | for t in text:
33 | t = convert(t, 'zh-cn')
34 | result_text.append(t)
35 | return result_text
36 | else:
37 | raise Exception(f'不支持该类型{type(text)}')
38 |
39 |
40 | @dataclass
41 | class DataCollatorSpeechSeq2SeqWithPadding:
42 | processor: Any
43 |
44 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
45 | # split inputs and labels since they have to be of different lengths and need different padding methods
46 | # first treat the audio inputs by simply returning torch tensors
47 | input_features = [{"input_features": feature["input_features"][0]} for feature in features]
48 | batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
49 |
50 | # get the tokenized label sequences
51 | label_features = [{"input_ids": feature["labels"]} for feature in features]
52 | # pad the labels to max length
53 | labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
54 |
55 | # replace padding with -100 to ignore loss correctly
56 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
57 |
58 | # if bos token is appended in previous tokenization step,
59 | # cut bos token here as it's append later anyways
60 | if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
61 | labels = labels[:, 1:]
62 |
63 | batch["labels"] = labels
64 |
65 | return batch
66 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import bitsandbytes as bnb
2 | import torch
3 | from transformers.trainer_pt_utils import LabelSmoother
4 |
5 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
6 |
7 |
8 | def find_all_linear_names(use_8bit, model):
9 | cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear
10 | lora_module_names = set()
11 | for name, module in model.named_modules():
12 | if isinstance(module, cls):
13 | names = name.split('.')
14 | lora_module_names.add(names[0] if len(names) == 1 else names[-1])
15 | target_modules = list(lora_module_names)
16 | return target_modules
17 |
18 |
19 | def load_from_checkpoint(resume_from_checkpoint, model=None):
20 | pass
21 |
--------------------------------------------------------------------------------
/utils/reader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import sys
5 | from typing import List
6 |
7 | import librosa
8 | import numpy as np
9 | import soundfile
10 | from torch.utils.data import Dataset
11 | from tqdm import tqdm
12 |
13 | from utils.binary import DatasetReader
14 |
15 |
16 | class CustomDataset(Dataset):
17 | def __init__(self,
18 | data_list_path,
19 | processor,
20 | mono=True,
21 | language=None,
22 | timestamps=False,
23 | sample_rate=16000,
24 | min_duration=0.5,
25 | max_duration=30,
26 | augment_config_path=None):
27 | """
28 | Args:
29 | data_list_path: 数据列表文件的路径,或者二进制列表的头文件路径
30 | processor: Whisper的预处理工具,WhisperProcessor.from_pretrained获取
31 | mono: 是否将音频转换成单通道,这个必须是True
32 | language: 微调数据的语言
33 | timestamps: 微调时是否使用时间戳
34 | sample_rate: 音频的采样率,默认是16000
35 | min_duration: 小于这个时间段的音频将被截断,单位秒,不能小于0.5,默认0.5s
36 | max_duration: 大于这个时间段的音频将被截断,单位秒,不能大于30,默认30s
37 | augment_config_path: 数据增强配置参数文件路径
38 | """
39 | super(CustomDataset, self).__init__()
40 | assert min_duration >= 0.5, f"min_duration不能小于0.5,当前为:{min_duration}"
41 | assert max_duration <= 30, f"max_duration不能大于30,当前为:{max_duration}"
42 | self.data_list_path = data_list_path
43 | self.processor = processor
44 | self.data_list_path = data_list_path
45 | self.sample_rate = sample_rate
46 | self.mono = mono
47 | self.language = language
48 | self.timestamps = timestamps
49 | self.min_duration = min_duration
50 | self.max_duration = max_duration
51 | self.vocab = self.processor.tokenizer.get_vocab()
52 | self.startoftranscript = self.vocab['<|startoftranscript|>']
53 | self.endoftext = self.vocab['<|endoftext|>']
54 | if '<|nospeech|>' in self.vocab.keys():
55 | self.nospeech = self.vocab['<|nospeech|>']
56 | self.timestamp_begin = None
57 | else:
58 | # 兼容旧模型
59 | self.nospeech = self.vocab['<|nocaptions|>']
60 | self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1
61 | self.data_list: List[dict] = []
62 | # 加载数据列表
63 | self._load_data_list()
64 | # 数据增强配置参数
65 | self.augment_configs = None
66 | self.noises_path = None
67 | self.speed_rates = None
68 | if augment_config_path:
69 | with open(augment_config_path, 'r', encoding='utf-8') as f:
70 | self.augment_configs = json.load(f)
71 |
72 | # 加载数据列表
73 | def _load_data_list(self):
74 | if self.data_list_path.endswith(".header"):
75 | # 获取二进制的数据列表
76 | self.dataset_reader = DatasetReader(data_header_path=self.data_list_path,
77 | min_duration=self.min_duration,
78 | max_duration=self.max_duration)
79 | self.data_list = self.dataset_reader.get_keys()
80 | else:
81 | # 获取数据列表
82 | with open(self.data_list_path, 'r', encoding='utf-8') as f:
83 | lines = f.readlines()
84 | self.data_list = []
85 | for line in tqdm(lines, desc='读取数据列表'):
86 | if isinstance(line, str):
87 | line = json.loads(line)
88 | if not isinstance(line, dict): continue
89 | # 跳过超出长度限制的音频
90 | if line["duration"] < self.min_duration:
91 | continue
92 | if self.max_duration != -1 and line["duration"] > self.max_duration:
93 | continue
94 | self.data_list.append(dict(line))
95 |
96 | # 从数据列表里面获取音频数据、采样率和文本
97 | def _get_list_data(self, idx):
98 | if self.data_list_path.endswith(".header"):
99 | data_list = self.dataset_reader.get_data(self.data_list[idx])
100 | else:
101 | data_list = self.data_list[idx]
102 | # 分割音频路径和标签
103 | audio_file = data_list["audio"]['path']
104 | transcript = data_list["sentences"] if self.timestamps else data_list["sentence"]
105 | language = data_list["language"] if 'language' in data_list.keys() else None
106 | if 'start_time' not in data_list["audio"].keys():
107 | sample, sample_rate = soundfile.read(audio_file, dtype='float32')
108 | else:
109 | start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"]
110 | # 分割读取音频
111 | sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time)
112 | sample = sample.T
113 | # 转成单通道
114 | if self.mono:
115 | sample = librosa.to_mono(sample)
116 | # 数据增强
117 | if self.augment_configs:
118 | sample, sample_rate = self.augment(sample, sample_rate)
119 | # 重采样
120 | if self.sample_rate != sample_rate:
121 | sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate)
122 | return sample, sample_rate, transcript, language
123 |
124 | def _load_timestamps_transcript(self, transcript: List[dict]):
125 | assert isinstance(transcript, list), f"transcript应该为list,当前为:{type(transcript)}"
126 | data = dict()
127 | labels = self.processor.tokenizer.prefix_tokens[:3]
128 | for t in transcript:
129 | # 将目标文本编码为标签ID
130 | start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01
131 | if self.timestamp_begin is None:
132 | start = self.vocab[f'<|{start:.2f}|>']
133 | else:
134 | start = self.timestamp_begin + round(start * 100) // 2
135 | end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01
136 | if self.timestamp_begin is None:
137 | end = self.vocab[f'<|{end:.2f}|>']
138 | else:
139 | end = self.timestamp_begin + round(end * 100) // 2
140 | label = self.processor(text=t['text']).input_ids[4:-1]
141 | labels.extend([start])
142 | labels.extend(label)
143 | labels.extend([end])
144 | data['labels'] = labels + [self.endoftext]
145 | return data
146 |
147 | def __getitem__(self, idx):
148 | try:
149 | # 从数据列表里面获取音频数据、采样率和文本
150 | sample, sample_rate, transcript, language = self._get_list_data(idx=idx)
151 | # 可以为单独数据设置语言
152 | self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language)
153 | if len(transcript) > 0:
154 | # 加载带有时间戳的文本
155 | if self.timestamps:
156 | data = self._load_timestamps_transcript(transcript=transcript)
157 | # 从输入音频数组中计算log-Mel输入特征
158 | data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features
159 | else:
160 | # 获取log-Mel特征和标签ID
161 | data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript)
162 | else:
163 | # 如果没有文本,则使用<|nospeech|>标记
164 | data = self.processor(audio=sample, sampling_rate=self.sample_rate)
165 | data['labels'] = [self.startoftranscript, self.nospeech, self.endoftext]
166 | return data
167 | except Exception as e:
168 | print(f'读取数据出错,序号:{idx},错误信息:{e}', file=sys.stderr)
169 | return self.__getitem__(random.randint(0, self.__len__() - 1))
170 |
171 | def __len__(self):
172 | return len(self.data_list)
173 |
174 | # 分割读取音频
175 | @staticmethod
176 | def slice_from_file(file, start, end):
177 | sndfile = soundfile.SoundFile(file)
178 | sample_rate = sndfile.samplerate
179 | duration = round(float(len(sndfile)) / sample_rate, 3)
180 | start = round(start, 3)
181 | end = round(end, 3)
182 | # 从末尾开始计
183 | if start < 0.0: start += duration
184 | if end < 0.0: end += duration
185 | # 保证数据不越界
186 | if start < 0.0: start = 0.0
187 | if end > duration: end = duration
188 | if end < 0.0:
189 | raise ValueError("切片结束位置(%f s)越界" % end)
190 | if start > end:
191 | raise ValueError("切片开始位置(%f s)晚于切片结束位置(%f s)" % (start, end))
192 | start_frame = int(start * sample_rate)
193 | end_frame = int(end * sample_rate)
194 | sndfile.seek(start_frame)
195 | sample = sndfile.read(frames=end_frame - start_frame, dtype='float32')
196 | return sample, sample_rate
197 |
198 | # 数据增强
199 | def augment(self, sample, sample_rate):
200 | for config in self.augment_configs:
201 | if config['type'] == 'speed' and random.random() < config['prob']:
202 | if self.speed_rates is None:
203 | min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \
204 | config['params']['max_speed_rate'], config['params']['num_rates']
205 | self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True)
206 | rate = random.choice(self.speed_rates)
207 | sample = self.change_speed(sample, speed_rate=rate)
208 | if config['type'] == 'shift' and random.random() < config['prob']:
209 | min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms']
210 | shift_ms = random.randint(min_shift_ms, max_shift_ms)
211 | sample = self.shift(sample, sample_rate, shift_ms=shift_ms)
212 | if config['type'] == 'volume' and random.random() < config['prob']:
213 | min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS']
214 | gain = random.randint(min_gain_dBFS, max_gain_dBFS)
215 | sample = self.volume(sample, gain=gain)
216 | if config['type'] == 'resample' and random.random() < config['prob']:
217 | new_sample_rates = config['params']['new_sample_rates']
218 | new_sample_rate = np.random.choice(new_sample_rates)
219 | sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate)
220 | sample_rate = new_sample_rate
221 | if config['type'] == 'noise' and random.random() < config['prob']:
222 | min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB']
223 | if self.noises_path is None:
224 | self.noises_path = []
225 | noise_dir = config['params']['noise_dir']
226 | if os.path.exists(noise_dir):
227 | for file in os.listdir(noise_dir):
228 | self.noises_path.append(os.path.join(noise_dir, file))
229 | noise_path = random.choice(self.noises_path)
230 | snr_dB = random.randint(min_snr_dB, max_snr_dB)
231 | sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB)
232 | return sample, sample_rate
233 |
234 | # 改变语速
235 | @staticmethod
236 | def change_speed(sample, speed_rate):
237 | if speed_rate == 1.0:
238 | return sample
239 | if speed_rate <= 0:
240 | raise ValueError("速度速率应大于零")
241 | old_length = sample.shape[0]
242 | new_length = int(old_length / speed_rate)
243 | old_indices = np.arange(old_length)
244 | new_indices = np.linspace(start=0, stop=old_length, num=new_length)
245 | sample = np.interp(new_indices, old_indices, sample).astype(np.float32)
246 | return sample
247 |
248 | # 音频偏移
249 | @staticmethod
250 | def shift(sample, sample_rate, shift_ms):
251 | duration = sample.shape[0] / sample_rate
252 | if abs(shift_ms) / 1000.0 > duration:
253 | raise ValueError("shift_ms的绝对值应该小于音频持续时间")
254 | shift_samples = int(shift_ms * sample_rate / 1000)
255 | if shift_samples > 0:
256 | sample[:-shift_samples] = sample[shift_samples:]
257 | sample[-shift_samples:] = 0
258 | elif shift_samples < 0:
259 | sample[-shift_samples:] = sample[:shift_samples]
260 | sample[:-shift_samples] = 0
261 | return sample
262 |
263 | # 改变音量
264 | @staticmethod
265 | def volume(sample, gain):
266 | sample *= 10.**(gain / 20.)
267 | return sample
268 |
269 | # 声音重采样
270 | @staticmethod
271 | def resample(sample, orig_sr, target_sr):
272 | sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr)
273 | return sample
274 |
275 | # 添加噪声
276 | def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0):
277 | noise_sample, sr = librosa.load(noise_path, sr=sample_rate)
278 | # 标准化音频音量,保证噪声不会太大
279 | target_db = -20
280 | gain = min(max_gain_db, target_db - self.rms_db(sample))
281 | sample *= 10. ** (gain / 20.)
282 | # 指定噪声音量
283 | sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample)
284 | noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db)
285 | noise_sample *= 10. ** (noise_gain_db / 20.)
286 | # 固定噪声长度
287 | if noise_sample.shape[0] < sample.shape[0]:
288 | diff_duration = sample.shape[0] - noise_sample.shape[0]
289 | noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap')
290 | elif noise_sample.shape[0] > sample.shape[0]:
291 | start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0])
292 | noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame]
293 | sample += noise_sample
294 | return sample
295 |
296 | @staticmethod
297 | def rms_db(sample):
298 | mean_square = np.mean(sample ** 2)
299 | return 10 * np.log10(mean_square)
300 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import tarfile
4 | import urllib.request
5 |
6 | from tqdm import tqdm
7 |
8 |
9 | def print_arguments(args):
10 | print("----------- Configuration Arguments -----------")
11 | for arg, value in vars(args).items():
12 | print("%s: %s" % (arg, value))
13 | print("------------------------------------------------")
14 |
15 |
16 | def strtobool(val):
17 | val = val.lower()
18 | if val in ('y', 'yes', 't', 'true', 'on', '1'):
19 | return True
20 | elif val in ('n', 'no', 'f', 'false', 'off', '0'):
21 | return False
22 | else:
23 | raise ValueError("invalid truth value %r" % (val,))
24 |
25 |
26 | def str_none(val):
27 | if val == 'None':
28 | return None
29 | else:
30 | return val
31 |
32 |
33 | def add_arguments(argname, type, default, help, argparser, **kwargs):
34 | type = strtobool if type == bool else type
35 | type = str_none if type == str else type
36 | argparser.add_argument("--" + argname,
37 | default=default,
38 | type=type,
39 | help=help + ' Default: %(default)s.',
40 | **kwargs)
41 |
42 |
43 | def md5file(fname):
44 | hash_md5 = hashlib.md5()
45 | f = open(fname, "rb")
46 | for chunk in iter(lambda: f.read(4096), b""):
47 | hash_md5.update(chunk)
48 | f.close()
49 | return hash_md5.hexdigest()
50 |
51 |
52 | def download(url, md5sum, target_dir):
53 | """Download file from url to target_dir, and check md5sum."""
54 | if not os.path.exists(target_dir): os.makedirs(target_dir)
55 | filepath = os.path.join(target_dir, url.split("/")[-1])
56 | if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
57 | print(f"Downloading {url} to {filepath} ...")
58 | with urllib.request.urlopen(url) as source, open(filepath, "wb") as output:
59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
60 | unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 | print(f"\nMD5 Chesksum {filepath} ...")
69 | if not md5file(filepath) == md5sum:
70 | raise RuntimeError("MD5 checksum failed.")
71 | else:
72 | print(f"File exists, skip downloading. ({filepath})")
73 | return filepath
74 |
75 |
76 | def unpack(filepath, target_dir, rm_tar=False):
77 | """Unpack the file to the target_dir."""
78 | print("Unpacking %s ..." % filepath)
79 | tar = tarfile.open(filepath)
80 | tar.extractall(target_dir)
81 | tar.close()
82 | if rm_tar:
83 | os.remove(filepath)
84 |
85 |
86 | def make_inputs_require_grad(module, input, output):
87 | output.requires_grad_(True)
88 |
--------------------------------------------------------------------------------