├── AISHELL-3 ├── ReadMe.txt └── spk-info.txt ├── LICENSE ├── README.md ├── aishell_data_clean.ipynb ├── asr.py ├── asr_v2.py ├── asr_v3.py ├── asr_web.py ├── config.py ├── infer_server.py ├── requirements.txt ├── static ├── index.css ├── record.js └── record.png ├── templates └── index.html ├── train.py ├── train_v2.py ├── train_v3.py ├── train_v4.py └── utils ├── binary.py ├── callback.py ├── data_utils.py ├── model_utils.py ├── reader.py └── utils.py /AISHELL-3/ReadMe.txt: -------------------------------------------------------------------------------- 1 | 2 | AISHELL-3 3 | 4 | 北京希尔贝壳科技有限公司 5 | Beijing Shell Shell Technology Co.,Ltd 6 | 10/18/2020 7 | 8 | 1. AISHELL-3 Speech Data 9 | 10 | - Sampling Rate : 44.1kHz 11 | - Sample Format : 16bit 12 | - Environment : Quiet indoor 13 | - Speech Data Type : PCM 14 | - Channel Number : 1 15 | - Recording Equipment : High fidelity microphone 16 | - Sentences : 88035 utterances 17 | - Speaker : 218 speakers (43 male and 175 female) 18 | 19 | 20 | 2. Data Structure 21 | │ README.txt (readme) 22 | │ ChangeLog (Change Information) 23 | │ phone_set.txt (phone Information) 24 | │ spk_info.txt (Speaker Information) 25 | └─ test (Test Data File) 26 | └─ train (Train Data File) 27 | │├─content.txt (Transcript Content) 28 | │├─prosody_label_train-set.txt (Prosody Lable) 29 | │├─wav (Audio Data File) 30 | │├─SSB005 (Speaker ID File) 31 | ││ ││ ID2166W0001.wav (Audio) 32 | 33 | 4. System 34 | AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus which could be used to train multi-speaker Text-to-Speech (TTS) systems. 35 | You can download data set from: http://www.aishelltech.com/aishell_3. 36 | The baseline system code and generated samples are available online form: https://sos1sos2sixteen.github.io/aishell3/. 37 | -------------------------------------------------------------------------------- /AISHELL-3/spk-info.txt: -------------------------------------------------------------------------------- 1 | # voice-file name; age group; gender; accent 2 | # In years, A:< 14, B:14 - 25, C:26 - 40, D:> 41. 3 | 4 | SSB1837 B female north 5 | SSB0578 B female north 6 | SSB1216 B female north 7 | SSB1161 B female north 8 | SSB0016 B female north 9 | SSB1365 B male north 10 | SSB1759 B female north 11 | SSB0588 B female north 12 | SSB0534 C female north 13 | SSB0380 B female north 14 | SSB0273 B male north 15 | SSB1863 B male north 16 | SSB1125 B female south 17 | SSB1872 B female north 18 | SSB0993 C female north 19 | SSB0666 C female north 20 | SSB0668 B female north 21 | SSB0395 B female north 22 | SSB1831 B male south 23 | SSB1064 B female north 24 | SSB1219 B female north 25 | SSB1322 B female north 26 | SSB1002 B female north 27 | SSB1878 B female north 28 | SSB0415 B female north 29 | SSB1781 C female north 30 | SSB0631 B male south 31 | SSB0686 B female north 32 | SSB1328 B male north 33 | SSB0366 B female south 34 | SSB1100 C male south 35 | SSB0241 B male north 36 | SSB0966 B male north 37 | SSB1340 B female north 38 | SSB0762 B female north 39 | SSB0073 B male north 40 | SSB0632 B female south 41 | SSB0915 B female south 42 | SSB0748 B female north 43 | SSB1956 B female north 44 | SSB1056 B female south 45 | SSB0716 B female south 46 | SSB0629 B male north 47 | SSB1806 B female north 48 | SSB0599 C female north 49 | SSB0720 B female north 50 | SSB0385 B female south 51 | SSB0809 B female north 52 | SSB0342 B female others 53 | SSB0760 B female north 54 | SSB1253 B female south 55 | SSB1575 B female south 56 | SSB0863 B male north 57 | SSB1110 D female north 58 | SSB0200 B female north 59 | SSB1215 B female north 60 | SSB0375 B male south 61 | SSB1828 C female north 62 | SSB0737 D female north 63 | SSB0341 C female north 64 | SSB0009 B female south 65 | SSB0309 D female north 66 | SSB1055 C female north 67 | SSB1448 B male north 68 | SSB1176 B female north 69 | SSB1001 B female north 70 | SSB0193 B female south 71 | SSB0710 C male north 72 | SSB0427 B female north 73 | SSB0338 B female north 74 | SSB1131 B female south 75 | SSB1108 B female north 76 | SSB0149 B female south 77 | SSB0736 B male south 78 | SSB1555 B female south 79 | SSB0614 C female north 80 | SSB1072 B female south 81 | SSB1728 B female north 82 | SSB1382 B female north 83 | SSB0851 B female north 84 | SSB1585 B female south 85 | SSB1891 C female north 86 | SSB1393 B female north 87 | SSB1274 B female north 88 | SSB1204 B female north 89 | SSB1452 B female north 90 | SSB0570 B female north 91 | SSB0780 B female north 92 | SSB1593 B female south 93 | SSB0913 B female north 94 | SSB1302 B female north 95 | SSB0323 B female north 96 | SSB1135 B female north 97 | SSB0382 B female north 98 | SSB0887 B male north 99 | SSB1625 C female south 100 | SSB1366 C female north 101 | SSB0693 B female south 102 | SSB0594 B female north 103 | SSB1686 B female north 104 | SSB0012 B female north 105 | SSB0139 B male south 106 | SSB0751 B female north 107 | SSB0606 D female north 108 | SSB1341 B female south 109 | SSB0145 B female north 110 | SSB1136 B male south 111 | SSB0339 B female north 112 | SSB0482 B female north 113 | SSB0502 B female north 114 | SSB1650 B female north 115 | SSB0817 B male north 116 | SSB0261 C male north 117 | SSB0316 B male north 118 | SSB0033 B female south 119 | SSB0723 B female north 120 | SSB1008 B female north 121 | SSB0700 B female north 122 | SSB1457 B female north 123 | SSB0601 B female north 124 | SSB1809 B female north 125 | SSB1739 B female north 126 | SSB0407 C male north 127 | SSB0426 B female south 128 | SSB0470 B female north 129 | SSB0935 B female north 130 | SSB0822 B female north 131 | SSB0746 B female north 132 | SSB0758 B female north 133 | SSB1221 B female north 134 | SSB0038 B female north 135 | SSB1624 B male south 136 | SSB0133 B female south 137 | SSB0778 B female north 138 | SSB0702 B female south 139 | SSB1383 B male south 140 | SSB1563 B female south 141 | SSB1670 B female north 142 | SSB1096 B female south 143 | SSB0299 B female south 144 | SSB1711 B female north 145 | SSB1810 B female north 146 | SSB1115 B female south 147 | SSB1684 B male north 148 | SSB1402 C female north 149 | SSB1918 B female north 150 | SSB0246 B female south 151 | SSB1607 B female north 152 | SSB1437 C female north 153 | SSB0011 C female north 154 | SSB0288 C female north 155 | SSB0539 C female north 156 | SSB0394 B male north 157 | SSB0379 B female south 158 | SSB1187 B male north 159 | SSB0671 C female north 160 | SSB0544 B male north 161 | SSB0005 B female north 162 | SSB1846 B female south 163 | SSB0122 B female north 164 | SSB1630 C male north 165 | SSB1399 B female north 166 | SSB1567 B female others 167 | SSB0267 C female north 168 | SSB1392 B female north 169 | SSB0287 B female north 170 | SSB0717 B female south 171 | SSB0018 B female south 172 | SSB0315 B female south 173 | SSB1126 A female north 174 | SSB1320 D female north 175 | SSB0919 B female north 176 | SSB0623 B male south 177 | SSB0871 C female north 178 | SSB0786 B female north 179 | SSB1024 B female north 180 | SSB1935 B male north 181 | SSB0794 B female north 182 | SSB1832 B female north 183 | SSB1000 C female north 184 | SSB1431 C female north 185 | SSB0535 B male north 186 | SSB1782 C female north 187 | SSB0354 D female north 188 | SSB0393 A female north 189 | SSB0057 B female north 190 | SSB1385 B female north 191 | SSB1050 D female north 192 | SSB0435 B female north 193 | SSB1197 B female north 194 | SSB1939 B female north 195 | SSB1239 B male north 196 | SSB0434 D male north 197 | SSB1020 B male south 198 | SSB1091 B female north 199 | SSB1745 B male north 200 | SSB0565 B female north 201 | SSB0711 B female south 202 | SSB0603 B male north 203 | SSB0749 B female south 204 | SSB0997 B female north 205 | SSB0590 C male north 206 | SSB0784 C male north 207 | SSB1699 C female north 208 | SSB0607 C female north 209 | SSB1902 B female north 210 | SSB0043 B female north 211 | SSB1138 B female north 212 | SSB1408 B male north 213 | SSB1218 B female south 214 | SSB1377 B female north 215 | SSB0609 C male north 216 | SSB0987 B female north 217 | SSB0080 B female south 218 | SSB0307 B female north 219 | SSB0197 C female south 220 | SSB1203 C female north 221 | SSB0112 B female south 222 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Thirteen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # asr_AISHELL-3 2 | 使用AISHELL-3 数据集 训练语音识别模型 3 | 4 | ## 使用方法 5 | 创建虚拟环境 6 | ``` 7 | conda create -n asr python=3.10 8 | ``` 9 | 配置环境 10 | ``` 11 | activate asr 12 | pip install -r requirements.txt 13 | ``` 14 | 运行 15 | ``` 16 | python train.py 17 | ``` 18 | 如果已经运行且已经有音频特征文件 features.pkl 19 | 可直接运行 __trian_v2.py__ 20 | ``` 21 | python train_v2.py 22 | ``` 23 | * 在train_v4中增加tensorboard 24 | 可查看训练日志 25 | ## 使用web进行语音识别 26 | 运行 27 | ``` 28 | python asr_web.py 29 | ``` 30 | * 可进行读取录音 31 | * 本地录制并上传进行识别 32 | * 预览 33 | ![image](https://github.com/WThirteen/asr_AISHELL-3/assets/100677199/59201975-12ea-46cf-9e4a-e490c02211c0) 34 | 35 | ## 查看训练日志 36 | 输入命令 37 | ``` 38 | tensorboard --logdir= log_path 39 | ``` 40 | ## loss曲线: 41 | __epochs=25__ 42 | ![epochs_25](https://github.com/WThirteen/asr_AISHELL-3/assets/100677199/c4ad5342-aee6-4950-833d-59c424b15f1e) 43 | 44 | ## librosa版本问题 45 | 这里使用的 *librosa==0.7.2* 46 | 可能会出现 47 | ![image](https://github.com/WThirteen/asr_thchs30/assets/100677199/6022f953-e40b-4b9e-9009-24a69d8a6e14) 48 | **参考这份博客:** 49 | 50 | [解决不联网环境pip安装librosa、numba、llvmlite报错和版本兼容问题](https://blog.csdn.net/qq_39691492/article/details/130829401) 51 | 52 | *修改如下:* 53 | 54 | ![image](https://github.com/WThirteen/asr_thchs30/assets/100677199/14ef3f58-7bb1-4f85-bc58-d49d761a86ae) 55 | 56 | ## api的部分参考 57 | [Whisper-Finetune](https://github.com/yeyupiaoling/Whisper-Finetune) 58 | 59 | 将原来微调的whisper模型换成这里训练的asr模型 60 | -------------------------------------------------------------------------------- /aishell_data_clean.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 14, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "path = 'AISHELL-3/train/wav/'\n", 11 | "files = os.listdir(path)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 48, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# 音频文件\n", 21 | "wav_file=[]\n", 22 | "# 音频文件的相对路径\n", 23 | "wav_file_path=[]\n", 24 | "\n", 25 | "for i in files:\n", 26 | " temp_files = os.listdir(path+i)\n", 27 | " for j in temp_files:\n", 28 | " wav_file.append(j)\n", 29 | " wav_file_path.append(path+i+'/'+j)\n", 30 | "\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "* 两种导入方式" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 54, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# with open(\"AISHELL-3/train/content.txt\", \"r\", encoding='utf-8') as f: # 打开文件\n", 47 | "# data = f.read() # 读取文件\n", 48 | " # print(data)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 60, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "f=open(\"AISHELL-3/train/content.txt\", \"r\", encoding='utf-8')\n", 58 | "txt=[]\n", 59 | "for line in f:\n", 60 | " txt.append(line.strip())\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 66, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# 音频文件名字\n", 70 | "txt_filename = []\n", 71 | "# 对应文件内容\n", 72 | "txt_wav = []\n", 73 | "\n", 74 | "for i in txt:\n", 75 | " temp_txt_filename,temp_txt_wav = i.split('\\t')\n", 76 | " txt_filename.append(temp_txt_filename)\n", 77 | " txt_wav.append(temp_txt_wav)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 76, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# 字典 字母+数字\n", 87 | "en_num_all = []\n", 88 | "\n", 89 | "for letter in 'abcdefghijklmnopqrstuvwxyz':\n", 90 | " en_num_all.extend(letter)\n", 91 | "\n", 92 | "for number in range(10): \n", 93 | " en_num_all.extend(str(number))\n", 94 | " \n", 95 | "en_num_all.extend(' ')" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 109, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# 音频文字\n", 105 | "texts = []\n", 106 | "\n", 107 | "for i in txt_wav:\n", 108 | " temp = ''\n", 109 | " for j in i:\n", 110 | " if j in en_num_all:\n", 111 | " continue\n", 112 | " else:\n", 113 | " # print(j)\n", 114 | " temp = temp+j\n", 115 | " texts.append(temp)\n", 116 | " " 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "#### texts 与 wav_file_path 对应" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 110, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "text/plain": [ 134 | "'一百三十三万三千五百六十五'" 135 | ] 136 | }, 137 | "execution_count": 110, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "texts[-1]" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 107, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "data": { 153 | "text/plain": [ 154 | "'AISHELL-3/train/wav/SSB1956/SSB19560481.wav'" 155 | ] 156 | }, 157 | "execution_count": 107, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "wav_file_path[-1]" 164 | ] 165 | } 166 | ], 167 | "metadata": { 168 | "kernelspec": { 169 | "display_name": "Python 3", 170 | "language": "python", 171 | "name": "python3" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 3 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython3", 183 | "version": "3.10.2" 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 2 188 | } 189 | -------------------------------------------------------------------------------- /asr.py: -------------------------------------------------------------------------------- 1 | from keras.models import load_model 2 | from keras import backend as K 3 | import numpy as np 4 | import librosa 5 | from python_speech_features import mfcc 6 | import speech_recognition as sr 7 | import pickle 8 | import glob 9 | import config 10 | import wave 11 | import os 12 | 13 | def save_as_wav(audio, output_file_path): 14 | with wave.open(output_file_path, 'wb') as wav_file: 15 | wav_file.setnchannels(1) # 单声道 16 | wav_file.setsampwidth(2) # 16位PCM编码 17 | wav_file.setframerate(44100) # 采样率为44.1kHz 18 | wav_file.writeframes(audio.frame_data) 19 | 20 | def input_audio(): 21 | r = sr.Recognizer() 22 | with sr.Microphone() as source: 23 | print("请说...") 24 | r.pause_threshold = 1 25 | audio = r.listen(source) 26 | output_file_path = "temp_file.wav" 27 | save_as_wav(audio, output_file_path) 28 | wavs = glob.glob('temp_file.wav') 29 | return wavs 30 | 31 | def load_file(): 32 | with open(config.pkl_path, 'rb') as fr: 33 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr) 34 | model = load_model(config.model_path) 35 | return char2id, id2char, mfcc_mean, mfcc_std, model 36 | 37 | def set_data(wavs, mfcc_mean, mfcc_std): 38 | mfcc_dim = config.mfcc_dim 39 | index = np.random.randint(len(wavs)) 40 | audio, sr = librosa.load(wavs[index]) 41 | energy = librosa.feature.rms(audio) 42 | frames = np.nonzero(energy >= np.max(energy) / 5) 43 | indices = librosa.core.frames_to_samples(frames)[1] 44 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 45 | X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551) 46 | X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14) 47 | print(X_data.shape) 48 | return X_data 49 | 50 | 51 | def wav_pred(model,X_data,id2char): 52 | pred = model.predict(np.expand_dims(X_data, axis=0)) 53 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0]) 54 | pred_ids = pred_ids.flatten().tolist() 55 | words='' 56 | judge=0 57 | for i in pred_ids: 58 | if i != -1: 59 | judge=1 60 | words=words+id2char[i] 61 | if judge==1: 62 | print(words) 63 | else: 64 | print("未检测到") 65 | 66 | def run(): 67 | wavs = input_audio() 68 | char2id, id2char, mfcc_mean, mfcc_std, model = load_file() 69 | X_data = set_data(wavs, mfcc_mean, mfcc_std) 70 | wav_pred(model,X_data,id2char) 71 | os.remove("temp_file.wav") 72 | 73 | 74 | if __name__ == '__main__' : 75 | run() 76 | 77 | -------------------------------------------------------------------------------- /asr_v2.py: -------------------------------------------------------------------------------- 1 | from keras.models import load_model 2 | from keras import backend as K 3 | import numpy as np 4 | import librosa 5 | from python_speech_features import mfcc 6 | import speech_recognition as sr 7 | import pickle 8 | import glob 9 | import config 10 | import wave 11 | import os 12 | 13 | def save_as_wav(audio, output_file_path): 14 | with wave.open(output_file_path, 'wb') as wav_file: 15 | wav_file.setnchannels(1) # 单声道 16 | wav_file.setsampwidth(2) # 16位PCM编码 17 | wav_file.setframerate(44100) # 采样率为44.1kHz 18 | wav_file.writeframes(audio.frame_data) 19 | 20 | def input_audio(): 21 | r = sr.Recognizer() 22 | with sr.Microphone() as source: 23 | print("请说...") 24 | r.pause_threshold = 1 25 | audio = r.listen(source) 26 | output_file_path = "temp_file.wav" 27 | save_as_wav(audio, output_file_path) 28 | wavs = glob.glob('temp_file.wav') 29 | os.remove("temp_file.wav") 30 | return wavs 31 | 32 | def out_load_audio(): 33 | path = config.audio_path 34 | wavs = glob.glob(path) 35 | return wavs 36 | 37 | def load_file(): 38 | with open(config.pkl_path, 'rb') as fr: 39 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr) 40 | model = load_model(config.model_path) 41 | return char2id, id2char, mfcc_mean, mfcc_std, model 42 | 43 | def set_data(wavs, mfcc_mean, mfcc_std): 44 | mfcc_dim = config.mfcc_dim 45 | index = np.random.randint(len(wavs)) 46 | audio, sr = librosa.load(wavs[index]) 47 | energy = librosa.feature.rms(audio) 48 | frames = np.nonzero(energy >= np.max(energy) / 5) 49 | indices = librosa.core.frames_to_samples(frames)[1] 50 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 51 | X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551) 52 | X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14) 53 | # print(X_data.shape) 54 | return X_data 55 | 56 | 57 | def wav_pred(model,X_data,id2char): 58 | pred = model.predict(np.expand_dims(X_data, axis=0)) 59 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0]) 60 | pred_ids = pred_ids.flatten().tolist() 61 | words='' 62 | judge=0 63 | for i in pred_ids: 64 | if i != -1: 65 | judge=1 66 | words=words+id2char[i] 67 | if judge==1: 68 | print(words) 69 | else: 70 | print("未检测到") 71 | 72 | def run(): 73 | # wavs = input_audio() 74 | wavs = out_load_audio() 75 | char2id, id2char, mfcc_mean, mfcc_std, model = load_file() 76 | X_data = set_data(wavs, mfcc_mean, mfcc_std) 77 | wav_pred(model,X_data,id2char) 78 | 79 | 80 | 81 | if __name__ == '__main__' : 82 | run() 83 | 84 | -------------------------------------------------------------------------------- /asr_v3.py: -------------------------------------------------------------------------------- 1 | from keras.models import load_model 2 | from keras import backend as K 3 | import numpy as np 4 | import librosa 5 | from python_speech_features import mfcc 6 | import speech_recognition as sr 7 | import pickle 8 | import glob 9 | import config 10 | import wave 11 | import os 12 | import pyaudio 13 | from tqdm import tqdm 14 | 15 | class set_audio(): 16 | CHUNK = 1024 # 每个缓冲区的帧数 17 | FORMAT = pyaudio.paInt16 # 采样位数 18 | CHANNELS = 1 # 单声道 19 | RATE = 44100 # 采样频率 20 | 21 | # 可设置录制时间 22 | def record_audio(record_second): 23 | """ 录音功能 """ 24 | p = pyaudio.PyAudio() # 实例化对象 25 | stream = p.open(format=set_audio.FORMAT, 26 | channels=set_audio.CHANNELS, 27 | rate=set_audio.RATE, 28 | input=True, 29 | frames_per_buffer=set_audio.CHUNK) # 打开流,传入响应参数 30 | 31 | wf = wave.open('temp_file.wav', 'wb') # 打开 wav 文件。 32 | wf.setnchannels(set_audio.CHANNELS) # 声道设置 33 | wf.setsampwidth(p.get_sample_size(set_audio.FORMAT)) # 采样位数设置 34 | wf.setframerate(set_audio.RATE) # 采样频率设置 35 | 36 | for _ in tqdm(range(0, int(set_audio.RATE * record_second / set_audio.CHUNK))): 37 | data = stream.read(set_audio.CHUNK) 38 | wf.writeframes(data) # 写入数据 39 | stream.stop_stream() # 关闭流 40 | stream.close() 41 | p.terminate() 42 | wf.close() 43 | 44 | wavs = glob.glob('temp_file.wav') 45 | 46 | # os.remove("temp_file.wav") 47 | 48 | return wavs 49 | 50 | 51 | 52 | def save_as_wav(audio, output_file_path): 53 | with wave.open(output_file_path, 'wb') as wav_file: 54 | wav_file.setnchannels(1) # 单声道 55 | wav_file.setsampwidth(2) # 16位PCM编码 56 | wav_file.setframerate(44100) # 采样率为44.1kHz 57 | wav_file.writeframes(audio.frame_data) 58 | 59 | # 录音自动停止 60 | def input_audio(): 61 | r = sr.Recognizer() 62 | with sr.Microphone() as source: 63 | print("请说...") 64 | r.pause_threshold = 1 65 | audio = r.listen(source) 66 | output_file_path = "temp_file.wav" 67 | save_as_wav(audio, output_file_path) 68 | wavs = glob.glob('temp_file.wav') 69 | 70 | return wavs 71 | 72 | def load_file(): 73 | with open(config.pkl_path, 'rb') as fr: 74 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr) 75 | model = load_model(config.model_path) 76 | return char2id, id2char, mfcc_mean, mfcc_std, model 77 | 78 | def set_data(wavs, mfcc_mean, mfcc_std): 79 | mfcc_dim = config.mfcc_dim 80 | index = np.random.randint(len(wavs)) 81 | audio, sr = librosa.load(wavs[index]) 82 | energy = librosa.feature.rms(audio) 83 | frames = np.nonzero(energy >= np.max(energy) / 5) 84 | indices = librosa.core.frames_to_samples(frames)[1] 85 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 86 | X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551) 87 | X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14) 88 | # print(X_data.shape) 89 | return X_data 90 | 91 | 92 | def wav_pred(model,X_data,id2char): 93 | pred = model.predict(np.expand_dims(X_data, axis=0)) 94 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0]) 95 | pred_ids = pred_ids.flatten().tolist() 96 | words='' 97 | judge=0 98 | for i in pred_ids: 99 | if i != -1: 100 | judge=1 101 | words=words+id2char[i] 102 | if judge==1: 103 | print(words) 104 | else: 105 | print("未检测到") 106 | 107 | def run(): 108 | # 自动停止录音 109 | # wavs = input_audio() 110 | # 设置录制时间 111 | wavs = record_audio(record_second=5) 112 | char2id, id2char, mfcc_mean, mfcc_std, model = load_file() 113 | X_data = set_data(wavs, mfcc_mean, mfcc_std) 114 | wav_pred(model,X_data,id2char) 115 | os.remove("temp_file.wav") 116 | 117 | 118 | if __name__ == '__main__' : 119 | run() 120 | 121 | -------------------------------------------------------------------------------- /asr_web.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import os 4 | 5 | import torch 6 | import uvicorn 7 | from fastapi import FastAPI, File, Body, UploadFile, Request 8 | from starlette.staticfiles import StaticFiles 9 | from starlette.templating import Jinja2Templates 10 | from utils.utils import add_arguments, print_arguments 11 | 12 | from keras.models import load_model 13 | from keras import backend as K 14 | import numpy as np 15 | import librosa 16 | from python_speech_features import mfcc 17 | import speech_recognition as sr 18 | import pickle 19 | import config 20 | import wave 21 | import io 22 | 23 | from pydub import AudioSegment 24 | from io import BytesIO 25 | import librosa 26 | import numpy as np 27 | from python_speech_features import mfcc 28 | 29 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 30 | parser = argparse.ArgumentParser(description=__doc__) 31 | add_arg = functools.partial(add_arguments, argparser=parser) 32 | add_arg("host", type=str, default="0.0.0.0", help="监听主机的IP地址") 33 | add_arg("port", type=int, default=5000, help="服务所使用的端口号") 34 | # add_arg("model_path", type=str, default="models/whisper-tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称") 35 | # add_arg("model_path", type=str, default="models/tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称") 36 | add_arg("use_gpu", type=bool, default=True, help="是否使用gpu进行预测") 37 | add_arg("num_beams", type=int, default=1, help="解码搜索大小") 38 | add_arg("batch_size", type=int, default=16, help="预测batch_size大小") 39 | add_arg("use_compile", type=bool, default=False, help="是否使用Pytorch2.0的编译器") 40 | add_arg("assistant_model_path", type=str, default=None, help="助手模型,可以提高推理速度,例如openai/whisper-tiny") 41 | add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载") 42 | add_arg("use_flash_attention_2", type=bool, default=False, help="是否使用FlashAttention2加速") 43 | add_arg("use_bettertransformer", type=bool, default=False, help="是否使用BetterTransformer加速") 44 | args = parser.parse_args() 45 | print_arguments(args) 46 | 47 | # 设置设备 48 | device = "cuda:0" if torch.cuda.is_available() and args.use_gpu else "cpu" 49 | torch_dtype = torch.float16 if torch.cuda.is_available() and args.use_gpu else torch.float32 50 | 51 | 52 | model = load_model(config.model_path) 53 | 54 | with open(config.pkl_path, 'rb') as fr: 55 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr) 56 | 57 | 58 | app = FastAPI(title="thirteen语音识别") 59 | app.mount('/static', StaticFiles(directory='static'), name='static') 60 | templates = Jinja2Templates(directory="templates") 61 | model_semaphore = None 62 | 63 | 64 | def release_model_semaphore(): 65 | model_semaphore.release() 66 | 67 | 68 | def recognition(file: File,mfcc_mean, mfcc_std): 69 | 70 | X_data = extract_mfcc_features(file, mfcc_mean, mfcc_std) 71 | pred = model.predict(np.expand_dims(X_data, axis=0)) 72 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0]) 73 | pred_ids = pred_ids.flatten().tolist() 74 | results = '' 75 | judge=0 76 | for i in pred_ids: 77 | if i != -1: 78 | judge=1 79 | results = results + id2char[i] 80 | if judge!=1: 81 | results = '未检测到' 82 | 83 | return results 84 | 85 | 86 | def extract_mfcc_features(audio_bytes, mfcc_mean, mfcc_std): 87 | # 使用pydub将bytes转换为WAV格式的AudioSegment(如果它不是WAV的话) 88 | # 注意:这里我们假设input_bytes是WAV或我们可以转换为WAV的格式 89 | # 如果input_bytes不是WAV且格式未知,你可能需要先检测它 90 | audio_segment = AudioSegment.from_file(BytesIO(audio_bytes), format="wav") # 如果已经是WAV,或者确定可以解析为WAV 91 | 92 | # 确保输出为WAV格式(如果之前不是的话,这一步其实是多余的,因为from_file已经处理了) 93 | # 但为了清晰起见,我们还是将其导出为WAV的bytes 94 | wav_bytes = BytesIO() 95 | audio_segment.export(wav_bytes, format="wav") 96 | 97 | # 重置BytesIO的指针到开头 98 | wav_bytes.seek(0) 99 | 100 | # 使用librosa加载WAV音频 101 | y, sr = librosa.load(wav_bytes) 102 | 103 | # 提取RMS能量 104 | energy = librosa.feature.rms(y=y) 105 | 106 | # 找到能量大于最大能量1/5的帧 107 | frames = np.nonzero(energy[0] >= np.max(energy[0]) / 5) 108 | 109 | # 将帧索引转换为样本索引 110 | if frames[0].size: 111 | indices = librosa.core.frames_to_samples(frames)[0] 112 | y = y[indices[0]:indices[-1]] 113 | 114 | # 提取MFCC特征 115 | mfcc_dim = 13 # 你可以根据需要修改MFCC的维度 116 | mfcc_features = mfcc(y, sr, numcep=mfcc_dim, nfft=551) 117 | 118 | # 这里假设你已经有了mfcc_mean和mfcc_std用于标准化(通常需要在训练阶段计算) 119 | # 如果没有,你可以跳过标准化步骤,或者计算它们 120 | mfcc_features = (mfcc_features - mfcc_mean) / (mfcc_std + 1e-14) 121 | 122 | return mfcc_features 123 | 124 | 125 | @app.post("/recognition") 126 | async def api_recognition(audio: UploadFile = File(..., description="音频文件")): 127 | # if language == "None": language = None 128 | data = await audio.read() 129 | with io.BytesIO(data) as bio: 130 | with wave.open(bio, 'rb') as wav_file: 131 | pass 132 | results = recognition(file= data, mfcc_mean= mfcc_mean, mfcc_std= mfcc_std) 133 | ret = {"results": results, "code": 0} 134 | return ret 135 | 136 | 137 | @app.get("/") 138 | async def index(request: Request): 139 | return templates.TemplateResponse("index.html", {"request": request, "id": id}) 140 | 141 | 142 | if __name__ == '__main__': 143 | uvicorn.run(app, host=args.host, port=args.port) 144 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # 训练数据音频文件路径 2 | train_wav_data_path = 'AISHELL-3/train/wav/' 3 | 4 | # 训练数据内容文件路径 5 | train_texts_data_path = 'AISHELL-3/train/content.txt' 6 | 7 | # 测试数据音频文件路径 8 | test_wav_data_path = 'AISHELL-3/test/wav/' 9 | 10 | # 测试数据内容文件路径 11 | test_texts_data_path = 'AISHELL-3/test/content.txt' 12 | 13 | # 存放模型路径 /模型名字 14 | model_path = 'model/asr_AISHELL.h5' 15 | 16 | # 存放pkl路径 /pkl名字 17 | pkl_path = 'pkl_all/dictionary.pkl' 18 | 19 | # 存放labels路径 20 | labels_path = 'pkl_all/labels.pkl' 21 | 22 | # features.pkl路径 23 | features_path = 'pkl_all/features.pkl' 24 | 25 | # 外部导入音频路径 26 | audio_path = 'AISHELL-3/train/wav/SSB0005/SSB00050001.wav' 27 | # model_name 28 | 29 | batch_size = 16 30 | 31 | epochs = 25 32 | 33 | num_blocks = 3 34 | 35 | filters = 128 36 | 37 | mfcc_dim = 13 38 | -------------------------------------------------------------------------------- /infer_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import os 4 | 5 | import torch 6 | import uvicorn 7 | from fastapi import FastAPI, File, Body, UploadFile, Request 8 | from starlette.staticfiles import StaticFiles 9 | from starlette.templating import Jinja2Templates 10 | from utils.utils import add_arguments, print_arguments 11 | 12 | from keras.models import load_model 13 | from keras import backend as K 14 | import numpy as np 15 | import librosa 16 | from python_speech_features import mfcc 17 | import speech_recognition as sr 18 | import pickle 19 | import config 20 | import wave 21 | import io 22 | 23 | from pydub import AudioSegment 24 | from io import BytesIO 25 | import librosa 26 | import numpy as np 27 | from python_speech_features import mfcc 28 | 29 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 30 | parser = argparse.ArgumentParser(description=__doc__) 31 | add_arg = functools.partial(add_arguments, argparser=parser) 32 | add_arg("host", type=str, default="0.0.0.0", help="监听主机的IP地址") 33 | add_arg("port", type=int, default=5000, help="服务所使用的端口号") 34 | # add_arg("model_path", type=str, default="models/whisper-tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称") 35 | # add_arg("model_path", type=str, default="models/tiny-finetune/", help="合并模型的路径,或者是huggingface上模型的名称") 36 | add_arg("use_gpu", type=bool, default=True, help="是否使用gpu进行预测") 37 | add_arg("num_beams", type=int, default=1, help="解码搜索大小") 38 | add_arg("batch_size", type=int, default=16, help="预测batch_size大小") 39 | add_arg("use_compile", type=bool, default=False, help="是否使用Pytorch2.0的编译器") 40 | add_arg("assistant_model_path", type=str, default=None, help="助手模型,可以提高推理速度,例如openai/whisper-tiny") 41 | add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载") 42 | add_arg("use_flash_attention_2", type=bool, default=False, help="是否使用FlashAttention2加速") 43 | add_arg("use_bettertransformer", type=bool, default=False, help="是否使用BetterTransformer加速") 44 | args = parser.parse_args() 45 | print_arguments(args) 46 | 47 | # 设置设备 48 | device = "cuda:0" if torch.cuda.is_available() and args.use_gpu else "cpu" 49 | torch_dtype = torch.float16 if torch.cuda.is_available() and args.use_gpu else torch.float32 50 | 51 | 52 | model = load_model(config.model_path) 53 | 54 | with open(config.pkl_path, 'rb') as fr: 55 | [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr) 56 | 57 | 58 | app = FastAPI(title="thirteen语音识别") 59 | app.mount('/static', StaticFiles(directory='static'), name='static') 60 | templates = Jinja2Templates(directory="templates") 61 | model_semaphore = None 62 | 63 | 64 | def release_model_semaphore(): 65 | model_semaphore.release() 66 | 67 | 68 | def recognition(file: File,mfcc_mean, mfcc_std): 69 | 70 | X_data = extract_mfcc_features(file, mfcc_mean, mfcc_std) 71 | pred = model.predict(np.expand_dims(X_data, axis=0)) 72 | pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0]) 73 | pred_ids = pred_ids.flatten().tolist() 74 | results = '' 75 | judge=0 76 | for i in pred_ids: 77 | if i != -1: 78 | judge=1 79 | results = results + id2char[i] 80 | if judge!=1: 81 | results = '未检测到' 82 | 83 | return results 84 | 85 | 86 | def extract_mfcc_features(audio_bytes, mfcc_mean, mfcc_std): 87 | # 使用pydub将bytes转换为WAV格式的AudioSegment(如果它不是WAV的话) 88 | # 注意:这里我们假设input_bytes是WAV或我们可以转换为WAV的格式 89 | # 如果input_bytes不是WAV且格式未知,你可能需要先检测它 90 | audio_segment = AudioSegment.from_file(BytesIO(audio_bytes), format="wav") # 如果已经是WAV,或者确定可以解析为WAV 91 | 92 | # 确保输出为WAV格式(如果之前不是的话,这一步其实是多余的,因为from_file已经处理了) 93 | # 但为了清晰起见,我们还是将其导出为WAV的bytes 94 | wav_bytes = BytesIO() 95 | audio_segment.export(wav_bytes, format="wav") 96 | 97 | # 重置BytesIO的指针到开头 98 | wav_bytes.seek(0) 99 | 100 | # 使用librosa加载WAV音频 101 | y, sr = librosa.load(wav_bytes) 102 | 103 | # 提取RMS能量 104 | energy = librosa.feature.rms(y=y) 105 | 106 | # 找到能量大于最大能量1/5的帧 107 | frames = np.nonzero(energy[0] >= np.max(energy[0]) / 5) 108 | 109 | # 将帧索引转换为样本索引 110 | if frames[0].size: 111 | indices = librosa.core.frames_to_samples(frames)[0] 112 | y = y[indices[0]:indices[-1]] 113 | 114 | # 提取MFCC特征 115 | mfcc_dim = 13 # 你可以根据需要修改MFCC的维度 116 | mfcc_features = mfcc(y, sr, numcep=mfcc_dim, nfft=551) 117 | 118 | # 这里假设你已经有了mfcc_mean和mfcc_std用于标准化(通常需要在训练阶段计算) 119 | # 如果没有,你可以跳过标准化步骤,或者计算它们 120 | mfcc_features = (mfcc_features - mfcc_mean) / (mfcc_std + 1e-14) 121 | 122 | return mfcc_features 123 | 124 | 125 | @app.post("/recognition") 126 | async def api_recognition(audio: UploadFile = File(..., description="音频文件")): 127 | # if language == "None": language = None 128 | data = await audio.read() 129 | with io.BytesIO(data) as bio: 130 | with wave.open(bio, 'rb') as wav_file: 131 | pass 132 | results = recognition(file= data, mfcc_mean= mfcc_mean, mfcc_std= mfcc_std) 133 | ret = {"results": results, "code": 0} 134 | return ret 135 | 136 | 137 | @app.get("/") 138 | async def index(request: Request): 139 | return templates.TemplateResponse("index.html", {"request": request, "id": id}) 140 | 141 | 142 | if __name__ == '__main__': 143 | uvicorn.run(app, host=args.host, port=args.port) 144 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipython==8.12.3 2 | keras==2.10.0 3 | librosa==0.7.2 4 | matplotlib==3.8.3 5 | numpy==1.24.4 6 | python_speech_features==0.6 7 | scipy==1.13.0 8 | SpeechRecognition==3.10.1 9 | tqdm==4.66.2 10 | -------------------------------------------------------------------------------- /static/index.css: -------------------------------------------------------------------------------- 1 | * { 2 | box-sizing: border-box; 3 | } 4 | 5 | body { 6 | font-family: "Helvetica Neue", "Roboto", sans-serif; 7 | background-color: #f2f2f2; 8 | margin: 0; 9 | padding: 0; 10 | } 11 | 12 | #header { 13 | background-color: #fff; 14 | color: #333; 15 | display: flex; 16 | justify-content: center; 17 | align-items: center; 18 | height: 80px; 19 | } 20 | 21 | h1 { 22 | font-size: 36px; 23 | margin: 0; 24 | } 25 | 26 | #content { 27 | background-color: #fff; 28 | border-radius: 10px; 29 | box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2); 30 | margin: 50px auto; 31 | max-width: 800px; 32 | padding: 20px; 33 | } 34 | 35 | #content div { 36 | display: flex; 37 | flex-wrap: wrap; 38 | justify-content: space-between; 39 | margin-bottom: 20px; 40 | } 41 | 42 | #content a { 43 | background-color: #fff; 44 | border-radius: 5px; 45 | box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2); 46 | color: #333; 47 | padding: 10px; 48 | text-align: center; 49 | text-decoration: none; 50 | transition: background-color 0.2s; 51 | width: 20%; 52 | } 53 | 54 | #content a:hover { 55 | background-color: #f2f2f2; 56 | } 57 | 58 | #content img { 59 | cursor: pointer; 60 | height: 50px; 61 | transition: transform 0.2s; 62 | width: 50px; 63 | } 64 | 65 | #content img:hover { 66 | transform: scale(1.1); 67 | } 68 | 69 | #result { 70 | background-color: #fff; 71 | border-radius: 5px; 72 | box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2); 73 | padding: 10px; 74 | } 75 | 76 | #result textarea { 77 | border: none; 78 | border-radius: 5px; 79 | font-size: 16px; 80 | height: 300px; 81 | margin-top: 10px; 82 | padding: 10px; 83 | resize: none; 84 | width: 100%; 85 | } 86 | 87 | @media only screen and (max-width: 600px) { 88 | #content a { 89 | width: 100%; 90 | } 91 | } -------------------------------------------------------------------------------- /static/record.js: -------------------------------------------------------------------------------- 1 | //兼容 2 | window.URL = window.URL || window.webkitURL; 3 | //获取计算机的设备:摄像头或者录音设备 4 | navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia; 5 | 6 | var HZRecorder = function (stream, config) { 7 | config = config || {}; 8 | config.sampleBits = config.sampleBits || 16; //采样数位 8, 16 9 | config.sampleRate = config.sampleRate || 16000; //采样率 16000 10 | 11 | //创建一个音频环境对象 12 | var audioContext = window.AudioContext || window.webkitAudioContext; 13 | var context = new audioContext(); 14 | var audioInput = context.createMediaStreamSource(stream); 15 | // 第二个和第三个参数指的是输入和输出都是单声道,2是双声道。 16 | var recorder = context.createScriptProcessor(4096, 2, 2); 17 | 18 | var audioData = { 19 | size: 0 //录音文件长度 20 | , buffer: [] //录音缓存 21 | , inputSampleRate: context.sampleRate //输入采样率 22 | , inputSampleBits: 16 //输入采样数位 8, 16 23 | , outputSampleRate: config.sampleRate //输出采样率 24 | , outputSampleBits: config.sampleBits //输出采样数位 8, 16 25 | , input: function (data) { 26 | this.buffer.push(new Float32Array(data)); 27 | this.size += data.length; 28 | } 29 | , compress: function () { //合并压缩 30 | //合并 31 | var data = new Float32Array(this.size); 32 | var offset = 0; 33 | for (var i = 0; i < this.buffer.length; i++) { 34 | data.set(this.buffer[i], offset); 35 | offset += this.buffer[i].length; 36 | } 37 | //压缩 38 | var compression = parseInt(this.inputSampleRate / this.outputSampleRate); 39 | var length = data.length / compression; 40 | var result = new Float32Array(length); 41 | var index = 0, j = 0; 42 | while (index < length) { 43 | result[index] = data[j]; 44 | j += compression; 45 | index++; 46 | } 47 | return result; 48 | } 49 | , encodeWAV: function () { 50 | var sampleRate = Math.min(this.inputSampleRate, this.outputSampleRate); 51 | var sampleBits = Math.min(this.inputSampleBits, this.outputSampleBits); 52 | var bytes = this.compress(); 53 | var dataLength = bytes.length * (sampleBits / 8); 54 | var buffer = new ArrayBuffer(44 + dataLength); 55 | var data = new DataView(buffer); 56 | 57 | var channelCount = 1;//单声道 58 | var offset = 0; 59 | 60 | var writeString = function (str) { 61 | for (var i = 0; i < str.length; i++) { 62 | data.setUint8(offset + i, str.charCodeAt(i)); 63 | } 64 | } 65 | 66 | // 资源交换文件标识符 67 | writeString('RIFF'); 68 | offset += 4; 69 | // 下个地址开始到文件尾总字节数,即文件大小-8 70 | data.setUint32(offset, 36 + dataLength, true); 71 | offset += 4; 72 | // WAV文件标志 73 | writeString('WAVE'); 74 | offset += 4; 75 | // 波形格式标志 76 | writeString('fmt '); 77 | offset += 4; 78 | // 过滤字节,一般为 0x10 = 16 79 | data.setUint32(offset, 16, true); 80 | offset += 4; 81 | // 格式类别 (PCM形式采样数据) 82 | data.setUint16(offset, 1, true); 83 | offset += 2; 84 | // 通道数 85 | data.setUint16(offset, channelCount, true); 86 | offset += 2; 87 | // 采样率,每秒样本数,表示每个通道的播放速度 88 | data.setUint32(offset, sampleRate, true); 89 | offset += 4; 90 | // 波形数据传输率 (每秒平均字节数) 单声道×每秒数据位数×每样本数据位/8 91 | data.setUint32(offset, channelCount * sampleRate * (sampleBits / 8), true); 92 | offset += 4; 93 | // 快数据调整数 采样一次占用字节数 单声道×每样本的数据位数/8 94 | data.setUint16(offset, channelCount * (sampleBits / 8), true); 95 | offset += 2; 96 | // 每样本数据位数 97 | data.setUint16(offset, sampleBits, true); 98 | offset += 2; 99 | // 数据标识符 100 | writeString('data'); 101 | offset += 4; 102 | // 采样数据总数,即数据总大小-44 103 | data.setUint32(offset, dataLength, true); 104 | offset += 4; 105 | // 写入采样数据 106 | if (sampleBits === 8) { 107 | for (var i = 0; i < bytes.length; i++, offset++) { 108 | var s = Math.max(-1, Math.min(1, bytes[i])); 109 | var val = s < 0 ? s * 0x8000 : s * 0x7FFF; 110 | val = parseInt(255 / (65535 / (val + 32768))); 111 | data.setInt8(offset, val, true); 112 | } 113 | } else { 114 | for (var i = 0; i < bytes.length; i++, offset += 2) { 115 | var s = Math.max(-1, Math.min(1, bytes[i])); 116 | data.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true); 117 | } 118 | } 119 | 120 | return new Blob([data], {type: 'audio/wav'}); 121 | } 122 | }; 123 | 124 | //开始录音 125 | this.start = function () { 126 | audioInput.connect(recorder); 127 | recorder.connect(context.destination); 128 | } 129 | 130 | //停止 131 | this.stop = function () { 132 | recorder.disconnect(); 133 | } 134 | 135 | //获取音频文件 136 | this.getBlob = function () { 137 | this.stop(); 138 | return audioData.encodeWAV(); 139 | } 140 | 141 | //回放 142 | this.play = function (audio) { 143 | audio.src = window.URL.createObjectURL(this.getBlob()); 144 | } 145 | //清除 146 | this.clear = function () { 147 | audioData.buffer = []; 148 | audioData.size = 0; 149 | } 150 | 151 | //上传 152 | this.upload = function (url, callback) { 153 | var fd = new FormData(); 154 | // 上传的文件名和数据 155 | fd.append("audio", this.getBlob()); 156 | var xhr = new XMLHttpRequest(); 157 | xhr.timeout = 60000 158 | if (callback) { 159 | xhr.upload.addEventListener("progress", function (e) { 160 | callback('uploading', e); 161 | }, false); 162 | xhr.addEventListener("load", function (e) { 163 | callback('ok', e); 164 | }, false); 165 | xhr.addEventListener("error", function (e) { 166 | callback('error', e); 167 | }, false); 168 | xhr.addEventListener("abort", function (e) { 169 | callback('cancel', e); 170 | }, false); 171 | } 172 | xhr.open("POST", url); 173 | xhr.send(fd); 174 | } 175 | 176 | //音频采集 177 | recorder.onaudioprocess = function (e) { 178 | audioData.input(e.inputBuffer.getChannelData(0)); 179 | //record(e.inputBuffer.getChannelData(0)); 180 | } 181 | 182 | }; 183 | //抛出异常 184 | HZRecorder.throwError = function (message) { 185 | alert(message); 186 | throw new function () { 187 | this.toString = function () { 188 | return message; 189 | } 190 | } 191 | } 192 | //是否支持录音 193 | HZRecorder.canRecording = (navigator.getUserMedia != null); 194 | //获取录音机 195 | HZRecorder.get = function (callback, config) { 196 | if (callback) { 197 | if (navigator.getUserMedia) { 198 | navigator.getUserMedia( 199 | {audio: true} //只启用音频 200 | , function (stream) { 201 | var rec = new HZRecorder(stream, config); 202 | callback(rec); 203 | } 204 | , function (error) { 205 | switch (error.code || error.name) { 206 | case 'PERMISSION_DENIED': 207 | case 'PermissionDeniedError': 208 | HZRecorder.throwError('用户拒绝提供信息。'); 209 | break; 210 | case 'NOT_SUPPORTED_ERROR': 211 | case 'NotSupportedError': 212 | HZRecorder.throwError('浏览器不支持硬件设备。'); 213 | break; 214 | case 'MANDATORY_UNSATISFIED_ERROR': 215 | case 'MandatoryUnsatisfiedError': 216 | HZRecorder.throwError('无法发现指定的硬件设备。'); 217 | break; 218 | default: 219 | HZRecorder.throwError('无法打开麦克风。异常信息:' + (error.code || error.name)); 220 | break; 221 | } 222 | }); 223 | } else { 224 | window.alert('不是HTTPS协议或者localhost地址,不能使用录音功能!') 225 | HZRecorder.throwErr('当前浏览器不支持录音功能。'); 226 | return; 227 | } 228 | } 229 | }; -------------------------------------------------------------------------------- /static/record.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WThirteen/asr_AISHELL-3/3140e6f914dbb3ccac906b40dc6836cdd122e46d/static/record.png -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Thirteen_temp 6 | 7 | 8 | 9 | 10 | 13 |
14 |
15 | 选择音频文件 16 | 上传录音 17 | 18 | 录音 19 |
20 |
21 | 22 |
23 | 上传进度:  24 |
25 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #导入相关的库 2 | from keras.models import Model 3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization 4 | from keras.optimizers import Adam, SGD 5 | from keras import backend as K 6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from mpl_toolkits.axes_grid1 import make_axes_locatable 10 | import random 11 | import pickle 12 | import glob 13 | from tqdm import tqdm 14 | import os 15 | from python_speech_features import mfcc 16 | import scipy.io.wavfile as wav 17 | import librosa 18 | from IPython.display import Audio 19 | import config 20 | 21 | 22 | def load_texts_data(path,en_num_all): 23 | f=open(path, "r", encoding='utf-8') 24 | txt=[] 25 | for line in f: 26 | txt.append(line.strip()) 27 | 28 | # 音频文件名字 29 | txt_filename = [] 30 | # 对应文件内容 31 | txt_wav = [] 32 | 33 | for i in txt: 34 | temp_txt_filename,temp_txt_wav = i.split('\t') 35 | txt_filename.append(temp_txt_filename) 36 | txt_wav.append(temp_txt_wav) 37 | 38 | # 音频文字 39 | texts = [] 40 | 41 | for i in txt_wav: 42 | temp = '' 43 | for j in i: 44 | if j in en_num_all: 45 | continue 46 | else: 47 | # print(j) 48 | temp = temp+j 49 | texts.append(temp) 50 | 51 | return texts 52 | 53 | def create_en_num(): 54 | # 字典 字母+数字 55 | en_num_all = [] 56 | # 字母 57 | for letter in 'abcdefghijklmnopqrstuvwxyz': 58 | en_num_all.extend(letter) 59 | # 数字 60 | for number in range(10): 61 | en_num_all.extend(str(number)) 62 | # 空格 63 | en_num_all.extend(' ') 64 | 65 | return en_num_all 66 | 67 | def load_wav_data(path): 68 | files = os.listdir(path) 69 | # 音频文件 70 | wav_file = [] 71 | # 音频文件的相对路径 72 | wav_file_path = [] 73 | 74 | for i in files: 75 | temp_files = os.listdir(path+i) 76 | for j in temp_files: 77 | wav_file.append(j) 78 | wav_file_path.append(path+i+'/'+j) 79 | 80 | return wav_file_path 81 | 82 | #根据数据集标定的音素读入 83 | def load_and_trim(path): 84 | audio, sr = librosa.load(path) 85 | # energy = librosa.feature.rmse(audio) 86 | energy = librosa.feature.rms(audio) 87 | frames = np.nonzero(energy >= np.max(energy) / 5) 88 | indices = librosa.core.frames_to_samples(frames)[1] 89 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 90 | return audio, sr 91 | 92 | #可视化,显示语音文件的MFCC图 93 | def visualize(paths,texts,index,mfcc_dim): 94 | path = paths[index] 95 | text = texts[index] 96 | print('Audio Text:', text) 97 | 98 | audio, sr = load_and_trim(path) 99 | plt.figure(figsize=(12, 3)) 100 | plt.plot(np.arange(len(audio)), audio) 101 | plt.title('Raw Audio Signal') 102 | plt.xlabel('Time') 103 | plt.ylabel('Audio Amplitude') 104 | plt.show() 105 | 106 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551) 107 | print('Shape of MFCC:', feature.shape) 108 | 109 | fig = plt.figure(figsize=(12, 5)) 110 | ax = fig.add_subplot(111) 111 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto') 112 | plt.title('Normalized MFCC') 113 | plt.ylabel('Time') 114 | plt.xlabel('MFCC Coefficient') 115 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)) 116 | ax.set_xticks(np.arange(0, 13, 2), minor=False); 117 | plt.show() 118 | 119 | return path 120 | 121 | # Audio(visualize(0)) 122 | 123 | def wav_features(paths,total): 124 | #提取音频特征并存储 125 | features = [] 126 | for i in tqdm(range(total)): 127 | path = paths[i] 128 | audio, sr = load_and_trim(path) 129 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551)) 130 | return features 131 | 132 | def save_features(features): 133 | with open(config.features_path, 'wb') as fw: 134 | pickle.dump(features,fw) 135 | 136 | def load_features(): 137 | with open(config.features_path, 'rb') as f: 138 | features = pickle.load(f) 139 | return features 140 | 141 | def normalized_features(features): 142 | #随机选择100个数据集 143 | samples = random.sample(features, 100) 144 | samples = np.vstack(samples) 145 | #平均MFCC的值为了归一化处理 146 | mfcc_mean = np.mean(samples, axis=0) 147 | #计算标准差为了归一化 148 | mfcc_std = np.std(samples, axis=0) 149 | # print(mfcc_mean) 150 | # print(mfcc_std) 151 | #归一化特征 152 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features] 153 | 154 | return mfcc_mean,mfcc_std,features 155 | 156 | def save_labels(texts): 157 | #将数据集读入的标签和对应id存储列表 158 | chars = {} 159 | for text in texts: 160 | for c in text: 161 | chars[c] = chars.get(c, 0) + 1 162 | 163 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True) 164 | chars = [char[0] for char in chars] 165 | # print(len(chars), chars[:100]) 166 | 167 | char2id = {c: i for i, c in enumerate(chars)} 168 | id2char = {i: c for i, c in enumerate(chars)} 169 | 170 | return char2id,id2char 171 | 172 | def data_set(total,features,texts): 173 | data_index = np.arange(total) 174 | np.random.shuffle(data_index) 175 | train_size = int(0.9 * total) 176 | test_size = total - train_size 177 | train_index = data_index[:train_size] 178 | test_index = data_index[train_size:] 179 | #神经网络输入和输出X,Y的读入数据集特征 180 | X_train = [features[i] for i in train_index] 181 | Y_train = [texts[i] for i in train_index] 182 | X_test = [features[i] for i in test_index] 183 | Y_test = [texts[i] for i in test_index] 184 | 185 | return X_train,Y_train,X_test,Y_test 186 | 187 | 188 | #定义训练批次的产生,一次训练16个 189 | def batch_generator(x, y,char2id): 190 | batch_size = config.batch_size 191 | offset = 0 192 | while True: 193 | offset += batch_size 194 | 195 | if offset == batch_size or offset >= len(x): 196 | data_index = np.arange(len(x)) 197 | np.random.shuffle(data_index) 198 | x = [x[i] for i in data_index] 199 | y = [y[i] for i in data_index] 200 | offset = batch_size 201 | 202 | X_data = x[offset - batch_size: offset] 203 | Y_data = y[offset - batch_size: offset] 204 | 205 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)]) 206 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)]) 207 | 208 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim]) 209 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id) 210 | X_length = np.zeros([batch_size, 1], dtype='int32') 211 | Y_length = np.zeros([batch_size, 1], dtype='int32') 212 | 213 | for i in range(batch_size): 214 | X_length[i, 0] = X_data[i].shape[0] 215 | X_batch[i, :X_length[i, 0], :] = X_data[i] 216 | 217 | Y_length[i, 0] = len(Y_data[i]) 218 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]] 219 | 220 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length} 221 | outputs = {'ctc': np.zeros([batch_size])} 222 | 223 | yield (inputs, outputs) 224 | 225 | def input_layer(): 226 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X') 227 | Y = Input(shape=(None,), dtype='float32', name='Y') 228 | X_length = Input(shape=(1,), dtype='int32', name='X_length') 229 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length') 230 | 231 | return X,Y,X_length,Y_length 232 | 233 | 234 | #卷积1层 235 | def conv1d(inputs, filters, kernel_size, dilation_rate): 236 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None, 237 | dilation_rate=dilation_rate)(inputs) 238 | 239 | #标准化函数 240 | def batchnorm(inputs): 241 | return BatchNormalization()(inputs) 242 | 243 | #激活层函数 244 | def activation(inputs, activation): 245 | return Activation(activation)(inputs) 246 | 247 | #全连接层函数 248 | def res_block(inputs, filters, kernel_size, dilation_rate): 249 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh') 250 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid') 251 | h0 = Multiply()([hf, hg]) 252 | 253 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 254 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 255 | 256 | return Add()([ha, inputs]), hs 257 | 258 | #计算损失函数 259 | def calc_ctc_loss(args): 260 | y, yp, ypl, yl = args 261 | return K.ctc_batch_cost(y, yp, ypl, yl) 262 | 263 | def model_train(X,Y,X_length,Y_length,char2id): 264 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh') 265 | shortcut = [] 266 | for i in range(config.num_blocks): 267 | for r in [1, 2, 4, 8, 16]: 268 | h0, s = res_block(h0, config.filters, 7, r) 269 | shortcut.append(s) 270 | 271 | h1 = activation(Add()(shortcut), 'relu') 272 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu') 273 | #softmax损失函数输出结果 274 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax') 275 | sub_model = Model(inputs=X, outputs=Y_pred) 276 | 277 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length]) 278 | #加载模型训练 279 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss) 280 | #建立优化器 281 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5) 282 | #激活模型开始计算 283 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer) 284 | 285 | return sub_model,model 286 | 287 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std): 288 | #保存模型 289 | sub_model.save(config.model_path) 290 | #将字保存在pl=pkl中 291 | with open(config.pkl_path, 'wb') as fw: 292 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw) 293 | 294 | 295 | def draw_loss(history): 296 | train_loss = history.history['loss'] 297 | valid_loss = history.history['val_loss'] 298 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train') 299 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid') 300 | plt.legend(loc='upper right') 301 | plt.xlabel('Epoch') 302 | plt.ylabel('Loss') 303 | plt.show() 304 | 305 | 306 | def run(): 307 | print("-----load data-----") 308 | path_train = load_wav_data(path=config.train_wav_data_path) 309 | path_test = load_wav_data(path=config.test_wav_data_path) 310 | paths = [] 311 | paths.extend(path_train), paths.extend(path_test) 312 | 313 | privacy_dict = create_en_num() 314 | texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict) 315 | texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict) 316 | texts = [] 317 | texts.extend(texts_train), texts.extend(texts_test) 318 | 319 | char2id,id2char = save_labels(texts) 320 | 321 | total = len(texts) 322 | print("-----Extract audio features-----") 323 | features = wav_features(paths,total) 324 | 325 | print("-----save features-----") 326 | save_features(features) 327 | 328 | 329 | mfcc_mean,mfcc_std,features = normalized_features(features) 330 | 331 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts) 332 | 333 | X,Y,X_length,Y_length = input_layer() 334 | 335 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id) 336 | 337 | # 回调 338 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0) 339 | # 监控 损失值(loss)作为指标 340 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6) 341 | #开始训练 342 | history = model.fit_generator( 343 | generator=batch_generator(X_train, Y_train, char2id), 344 | steps_per_epoch=len(X_train) // config.batch_size, 345 | epochs=config.epochs, 346 | validation_data=batch_generator(X_test, Y_test, char2id), 347 | validation_steps=len(X_test) // config.batch_size, 348 | callbacks=[checkpointer, lr_decay]) 349 | 350 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std) 351 | draw_loss(history) 352 | 353 | 354 | if __name__ == '__main__' : 355 | run() 356 | -------------------------------------------------------------------------------- /train_v2.py: -------------------------------------------------------------------------------- 1 | #导入相关的库 2 | from keras.models import Model 3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization 4 | from keras.optimizers import Adam, SGD 5 | from keras import backend as K 6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from mpl_toolkits.axes_grid1 import make_axes_locatable 10 | import random 11 | import pickle 12 | import glob 13 | from tqdm import tqdm 14 | import os 15 | from python_speech_features import mfcc 16 | import scipy.io.wavfile as wav 17 | import librosa 18 | from IPython.display import Audio 19 | import config 20 | 21 | 22 | def load_texts_data(path,en_num_all): 23 | f=open(path, "r", encoding='utf-8') 24 | txt=[] 25 | for line in f: 26 | txt.append(line.strip()) 27 | 28 | # 音频文件名字 29 | txt_filename = [] 30 | # 对应文件内容 31 | txt_wav = [] 32 | 33 | for i in txt: 34 | temp_txt_filename,temp_txt_wav = i.split('\t') 35 | txt_filename.append(temp_txt_filename) 36 | txt_wav.append(temp_txt_wav) 37 | 38 | # 音频文字 39 | texts = [] 40 | 41 | for i in txt_wav: 42 | temp = '' 43 | for j in i: 44 | if j in en_num_all: 45 | continue 46 | else: 47 | # print(j) 48 | temp = temp+j 49 | texts.append(temp) 50 | 51 | return texts 52 | 53 | def create_en_num(): 54 | # 字典 字母+数字 55 | en_num_all = [] 56 | # 字母 57 | for letter in 'abcdefghijklmnopqrstuvwxyz': 58 | en_num_all.extend(letter) 59 | # 数字 60 | for number in range(10): 61 | en_num_all.extend(str(number)) 62 | # 空格 63 | en_num_all.extend(' ') 64 | 65 | return en_num_all 66 | 67 | def load_wav_data(path): 68 | files = os.listdir(path) 69 | # 音频文件 70 | wav_file = [] 71 | # 音频文件的相对路径 72 | wav_file_path = [] 73 | 74 | for i in files: 75 | temp_files = os.listdir(path+i) 76 | for j in temp_files: 77 | wav_file.append(j) 78 | wav_file_path.append(path+i+'/'+j) 79 | 80 | return wav_file_path 81 | 82 | #根据数据集标定的音素读入 83 | def load_and_trim(path): 84 | audio, sr = librosa.load(path) 85 | # energy = librosa.feature.rmse(audio) 86 | energy = librosa.feature.rms(audio) 87 | frames = np.nonzero(energy >= np.max(energy) / 5) 88 | indices = librosa.core.frames_to_samples(frames)[1] 89 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 90 | return audio, sr 91 | 92 | #可视化,显示语音文件的MFCC图 93 | def visualize(paths,texts,index,mfcc_dim): 94 | path = paths[index] 95 | text = texts[index] 96 | print('Audio Text:', text) 97 | 98 | audio, sr = load_and_trim(path) 99 | plt.figure(figsize=(12, 3)) 100 | plt.plot(np.arange(len(audio)), audio) 101 | plt.title('Raw Audio Signal') 102 | plt.xlabel('Time') 103 | plt.ylabel('Audio Amplitude') 104 | plt.show() 105 | 106 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551) 107 | print('Shape of MFCC:', feature.shape) 108 | 109 | fig = plt.figure(figsize=(12, 5)) 110 | ax = fig.add_subplot(111) 111 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto') 112 | plt.title('Normalized MFCC') 113 | plt.ylabel('Time') 114 | plt.xlabel('MFCC Coefficient') 115 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)) 116 | ax.set_xticks(np.arange(0, 13, 2), minor=False); 117 | plt.show() 118 | 119 | return path 120 | 121 | # Audio(visualize(0)) 122 | 123 | def wav_features(paths,total): 124 | #提取音频特征并存储 125 | features = [] 126 | for i in tqdm(range(total)): 127 | path = paths[i] 128 | audio, sr = load_and_trim(path) 129 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551)) 130 | return features 131 | 132 | def save_features(features): 133 | with open(config.features_path, 'wb') as fw: 134 | pickle.dump(features,fw) 135 | 136 | def load_features(): 137 | with open(config.features_path, 'rb') as f: 138 | features = pickle.load(f) 139 | return features 140 | 141 | def normalized_features(features): 142 | #随机选择100个数据集 143 | samples = random.sample(features, 100) 144 | samples = np.vstack(samples) 145 | #平均MFCC的值为了归一化处理 146 | mfcc_mean = np.mean(samples, axis=0) 147 | #计算标准差为了归一化 148 | mfcc_std = np.std(samples, axis=0) 149 | # print(mfcc_mean) 150 | # print(mfcc_std) 151 | #归一化特征 152 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features] 153 | 154 | return mfcc_mean,mfcc_std,features 155 | 156 | def save_labels(texts): 157 | #将数据集读入的标签和对应id存储列表 158 | chars = {} 159 | for text in texts: 160 | for c in text: 161 | chars[c] = chars.get(c, 0) + 1 162 | 163 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True) 164 | chars = [char[0] for char in chars] 165 | # print(len(chars), chars[:100]) 166 | 167 | char2id = {c: i for i, c in enumerate(chars)} 168 | id2char = {i: c for i, c in enumerate(chars)} 169 | 170 | return char2id,id2char 171 | 172 | def data_set(total,features,texts): 173 | data_index = np.arange(total) 174 | np.random.shuffle(data_index) 175 | train_size = int(0.9 * total) 176 | test_size = total - train_size 177 | train_index = data_index[:train_size] 178 | test_index = data_index[train_size:] 179 | #神经网络输入和输出X,Y的读入数据集特征 180 | X_train = [features[i] for i in train_index] 181 | Y_train = [texts[i] for i in train_index] 182 | X_test = [features[i] for i in test_index] 183 | Y_test = [texts[i] for i in test_index] 184 | 185 | return X_train,Y_train,X_test,Y_test 186 | 187 | 188 | #定义训练批次的产生,一次训练16个 189 | def batch_generator(x, y,char2id): 190 | batch_size = config.batch_size 191 | offset = 0 192 | while True: 193 | offset += batch_size 194 | 195 | if offset == batch_size or offset >= len(x): 196 | data_index = np.arange(len(x)) 197 | np.random.shuffle(data_index) 198 | x = [x[i] for i in data_index] 199 | y = [y[i] for i in data_index] 200 | offset = batch_size 201 | 202 | X_data = x[offset - batch_size: offset] 203 | Y_data = y[offset - batch_size: offset] 204 | 205 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)]) 206 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)]) 207 | 208 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim]) 209 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id) 210 | X_length = np.zeros([batch_size, 1], dtype='int32') 211 | Y_length = np.zeros([batch_size, 1], dtype='int32') 212 | 213 | for i in range(batch_size): 214 | X_length[i, 0] = X_data[i].shape[0] 215 | X_batch[i, :X_length[i, 0], :] = X_data[i] 216 | 217 | Y_length[i, 0] = len(Y_data[i]) 218 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]] 219 | 220 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length} 221 | outputs = {'ctc': np.zeros([batch_size])} 222 | 223 | yield (inputs, outputs) 224 | 225 | def input_layer(): 226 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X') 227 | Y = Input(shape=(None,), dtype='float32', name='Y') 228 | X_length = Input(shape=(1,), dtype='int32', name='X_length') 229 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length') 230 | 231 | return X,Y,X_length,Y_length 232 | 233 | 234 | #卷积1层 235 | def conv1d(inputs, filters, kernel_size, dilation_rate): 236 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None, 237 | dilation_rate=dilation_rate)(inputs) 238 | 239 | #标准化函数 240 | def batchnorm(inputs): 241 | return BatchNormalization()(inputs) 242 | 243 | #激活层函数 244 | def activation(inputs, activation): 245 | return Activation(activation)(inputs) 246 | 247 | #全连接层函数 248 | def res_block(inputs, filters, kernel_size, dilation_rate): 249 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh') 250 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid') 251 | h0 = Multiply()([hf, hg]) 252 | 253 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 254 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 255 | 256 | return Add()([ha, inputs]), hs 257 | 258 | #计算损失函数 259 | def calc_ctc_loss(args): 260 | y, yp, ypl, yl = args 261 | return K.ctc_batch_cost(y, yp, ypl, yl) 262 | 263 | def model_train(X,Y,X_length,Y_length,char2id): 264 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh') 265 | shortcut = [] 266 | for i in range(config.num_blocks): 267 | for r in [1, 2, 4, 8, 16]: 268 | h0, s = res_block(h0, config.filters, 7, r) 269 | shortcut.append(s) 270 | 271 | h1 = activation(Add()(shortcut), 'relu') 272 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu') 273 | #softmax损失函数输出结果 274 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax') 275 | sub_model = Model(inputs=X, outputs=Y_pred) 276 | 277 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length]) 278 | #加载模型训练 279 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss) 280 | #建立优化器 281 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5) 282 | #激活模型开始计算 283 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer) 284 | 285 | return sub_model,model 286 | 287 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std): 288 | #保存模型 289 | sub_model.save(config.model_path) 290 | #将字保存在pl=pkl中 291 | with open(config.pkl_path, 'wb') as fw: 292 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw) 293 | 294 | 295 | def draw_loss(history): 296 | train_loss = history.history['loss'] 297 | valid_loss = history.history['val_loss'] 298 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train') 299 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid') 300 | plt.legend(loc='upper right') 301 | plt.xlabel('Epoch') 302 | plt.ylabel('Loss') 303 | plt.show() 304 | 305 | 306 | def run(): 307 | print("-----load data-----") 308 | path_train = load_wav_data(path=config.train_wav_data_path) 309 | path_test = load_wav_data(path=config.test_wav_data_path) 310 | paths = [] 311 | paths.extend(path_train), paths.extend(path_test) 312 | 313 | privacy_dict = create_en_num() 314 | texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict) 315 | texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict) 316 | texts = [] 317 | texts.extend(texts_train), texts.extend(texts_test) 318 | 319 | char2id,id2char = save_labels(texts) 320 | 321 | total = len(texts) 322 | 323 | # print("-----Extract audio features-----") 324 | # features = wav_features(paths,total) 325 | 326 | # print("-----save features-----") 327 | # save_features(features) 328 | 329 | print("-----load features-----") 330 | features = load_features() 331 | 332 | 333 | mfcc_mean,mfcc_std,features = normalized_features(features) 334 | 335 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts) 336 | 337 | X,Y,X_length,Y_length = input_layer() 338 | 339 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id) 340 | 341 | # 回调 342 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0) 343 | # 监控 损失值(loss)作为指标 344 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6) 345 | #开始训练 346 | history = model.fit_generator( 347 | generator=batch_generator(X_train, Y_train, char2id), 348 | steps_per_epoch=len(X_train) // config.batch_size, 349 | epochs=config.epochs, 350 | validation_data=batch_generator(X_test, Y_test, char2id), 351 | validation_steps=len(X_test) // config.batch_size, 352 | callbacks=[checkpointer, lr_decay]) 353 | 354 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std) 355 | draw_loss(history) 356 | 357 | 358 | if __name__ == '__main__' : 359 | run() 360 | -------------------------------------------------------------------------------- /train_v3.py: -------------------------------------------------------------------------------- 1 | #导入相关的库 2 | from keras.models import Model 3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization 4 | from keras.optimizers import Adam, SGD 5 | from keras import backend as K 6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from mpl_toolkits.axes_grid1 import make_axes_locatable 10 | import random 11 | import pickle 12 | import glob 13 | from tqdm import tqdm 14 | import os 15 | from python_speech_features import mfcc 16 | import scipy.io.wavfile as wav 17 | import librosa 18 | from IPython.display import Audio 19 | import config 20 | 21 | 22 | def load_texts_data(path,en_num_all): 23 | f=open(path, "r", encoding='utf-8') 24 | txt=[] 25 | for line in f: 26 | txt.append(line.strip()) 27 | 28 | # 音频文件名字 29 | txt_filename = [] 30 | # 对应文件内容 31 | txt_wav = [] 32 | 33 | for i in txt: 34 | temp_txt_filename,temp_txt_wav = i.split('\t') 35 | txt_filename.append(temp_txt_filename) 36 | txt_wav.append(temp_txt_wav) 37 | 38 | # 音频文字 39 | texts = [] 40 | 41 | for i in txt_wav: 42 | temp = '' 43 | for j in i: 44 | if j in en_num_all: 45 | continue 46 | else: 47 | # print(j) 48 | temp = temp+j 49 | texts.append(temp) 50 | 51 | return texts 52 | 53 | def create_en_num(): 54 | # 字典 字母+数字 55 | en_num_all = [] 56 | # 字母 57 | for letter in 'abcdefghijklmnopqrstuvwxyz': 58 | en_num_all.extend(letter) 59 | # 数字 60 | for number in range(10): 61 | en_num_all.extend(str(number)) 62 | # 空格 63 | en_num_all.extend(' ') 64 | 65 | return en_num_all 66 | 67 | def load_wav_data(path): 68 | files = os.listdir(path) 69 | # 音频文件 70 | wav_file = [] 71 | # 音频文件的相对路径 72 | wav_file_path = [] 73 | 74 | for i in files: 75 | temp_files = os.listdir(path+i) 76 | for j in temp_files: 77 | wav_file.append(j) 78 | wav_file_path.append(path+i+'/'+j) 79 | 80 | return wav_file_path 81 | 82 | #根据数据集标定的音素读入 83 | def load_and_trim(path): 84 | audio, sr = librosa.load(path) 85 | # energy = librosa.feature.rmse(audio) 86 | energy = librosa.feature.rms(audio) 87 | frames = np.nonzero(energy >= np.max(energy) / 5) 88 | indices = librosa.core.frames_to_samples(frames)[1] 89 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 90 | return audio, sr 91 | 92 | #可视化,显示语音文件的MFCC图 93 | def visualize(paths,texts,index,mfcc_dim): 94 | path = paths[index] 95 | text = texts[index] 96 | print('Audio Text:', text) 97 | 98 | audio, sr = load_and_trim(path) 99 | plt.figure(figsize=(12, 3)) 100 | plt.plot(np.arange(len(audio)), audio) 101 | plt.title('Raw Audio Signal') 102 | plt.xlabel('Time') 103 | plt.ylabel('Audio Amplitude') 104 | plt.show() 105 | 106 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551) 107 | print('Shape of MFCC:', feature.shape) 108 | 109 | fig = plt.figure(figsize=(12, 5)) 110 | ax = fig.add_subplot(111) 111 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto') 112 | plt.title('Normalized MFCC') 113 | plt.ylabel('Time') 114 | plt.xlabel('MFCC Coefficient') 115 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)) 116 | ax.set_xticks(np.arange(0, 13, 2), minor=False); 117 | plt.show() 118 | 119 | return path 120 | 121 | # Audio(visualize(0)) 122 | 123 | def wav_features(paths,total): 124 | #提取音频特征并存储 125 | features = [] 126 | for i in tqdm(range(total)): 127 | path = paths[i] 128 | audio, sr = load_and_trim(path) 129 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551)) 130 | return features 131 | 132 | def save_features(features): 133 | with open(config.features_path, 'wb') as fw: 134 | pickle.dump(features,fw) 135 | 136 | def load_features(): 137 | with open(config.features_path, 'rb') as f: 138 | features = pickle.load(f) 139 | return features 140 | 141 | def normalized_features(features): 142 | #随机选择100个数据集 143 | samples = random.sample(features, 100) 144 | samples = np.vstack(samples) 145 | #平均MFCC的值为了归一化处理 146 | mfcc_mean = np.mean(samples, axis=0) 147 | #计算标准差为了归一化 148 | mfcc_std = np.std(samples, axis=0) 149 | # print(mfcc_mean) 150 | # print(mfcc_std) 151 | #归一化特征 152 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features] 153 | 154 | return mfcc_mean,mfcc_std,features 155 | 156 | def save_labels(texts): 157 | #将数据集读入的标签和对应id存储列表 158 | chars = {} 159 | for text in texts: 160 | for c in text: 161 | chars[c] = chars.get(c, 0) + 1 162 | 163 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True) 164 | chars = [char[0] for char in chars] 165 | # print(len(chars), chars[:100]) 166 | 167 | char2id = {c: i for i, c in enumerate(chars)} 168 | id2char = {i: c for i, c in enumerate(chars)} 169 | 170 | with open(config.labels_path, 'wb') as fw: 171 | pickle.dump([texts,char2id, id2char], fw) 172 | 173 | return texts,char2id,id2char 174 | 175 | 176 | def load_labels(): 177 | with open(config.labels_path, 'rb') as f: 178 | texts,char2id,id2char = pickle.load(f) 179 | return texts,char2id,id2char 180 | 181 | 182 | def data_set(total,features,texts): 183 | data_index = np.arange(total) 184 | np.random.shuffle(data_index) 185 | train_size = int(0.9 * total) 186 | test_size = total - train_size 187 | train_index = data_index[:train_size] 188 | test_index = data_index[train_size:] 189 | #神经网络输入和输出X,Y的读入数据集特征 190 | X_train = [features[i] for i in train_index] 191 | Y_train = [texts[i] for i in train_index] 192 | X_test = [features[i] for i in test_index] 193 | Y_test = [texts[i] for i in test_index] 194 | 195 | return X_train,Y_train,X_test,Y_test 196 | 197 | 198 | #定义训练批次的产生,一次训练16个 199 | def batch_generator(x, y,char2id): 200 | batch_size = config.batch_size 201 | offset = 0 202 | while True: 203 | offset += batch_size 204 | 205 | if offset == batch_size or offset >= len(x): 206 | data_index = np.arange(len(x)) 207 | np.random.shuffle(data_index) 208 | x = [x[i] for i in data_index] 209 | y = [y[i] for i in data_index] 210 | offset = batch_size 211 | 212 | X_data = x[offset - batch_size: offset] 213 | Y_data = y[offset - batch_size: offset] 214 | 215 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)]) 216 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)]) 217 | 218 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim]) 219 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id) 220 | X_length = np.zeros([batch_size, 1], dtype='int32') 221 | Y_length = np.zeros([batch_size, 1], dtype='int32') 222 | 223 | for i in range(batch_size): 224 | X_length[i, 0] = X_data[i].shape[0] 225 | X_batch[i, :X_length[i, 0], :] = X_data[i] 226 | 227 | Y_length[i, 0] = len(Y_data[i]) 228 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]] 229 | 230 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length} 231 | outputs = {'ctc': np.zeros([batch_size])} 232 | 233 | yield (inputs, outputs) 234 | 235 | def input_layer(): 236 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X') 237 | Y = Input(shape=(None,), dtype='float32', name='Y') 238 | X_length = Input(shape=(1,), dtype='int32', name='X_length') 239 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length') 240 | 241 | return X,Y,X_length,Y_length 242 | 243 | 244 | #卷积1层 245 | def conv1d(inputs, filters, kernel_size, dilation_rate): 246 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None, 247 | dilation_rate=dilation_rate)(inputs) 248 | 249 | #标准化函数 250 | def batchnorm(inputs): 251 | return BatchNormalization()(inputs) 252 | 253 | #激活层函数 254 | def activation(inputs, activation): 255 | return Activation(activation)(inputs) 256 | 257 | #全连接层函数 258 | def res_block(inputs, filters, kernel_size, dilation_rate): 259 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh') 260 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid') 261 | h0 = Multiply()([hf, hg]) 262 | 263 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 264 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 265 | 266 | return Add()([ha, inputs]), hs 267 | 268 | #计算损失函数 269 | def calc_ctc_loss(args): 270 | y, yp, ypl, yl = args 271 | return K.ctc_batch_cost(y, yp, ypl, yl) 272 | 273 | def model_train(X,Y,X_length,Y_length,char2id): 274 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh') 275 | shortcut = [] 276 | for i in range(config.num_blocks): 277 | for r in [1, 2, 4, 8, 16]: 278 | h0, s = res_block(h0, config.filters, 7, r) 279 | shortcut.append(s) 280 | 281 | h1 = activation(Add()(shortcut), 'relu') 282 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu') 283 | #softmax损失函数输出结果 284 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax') 285 | sub_model = Model(inputs=X, outputs=Y_pred) 286 | 287 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length]) 288 | #加载模型训练 289 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss) 290 | #建立优化器 291 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5) 292 | #激活模型开始计算 293 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer) 294 | 295 | return sub_model,model 296 | 297 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std): 298 | #保存模型 299 | sub_model.save(config.model_path) 300 | #将字保存在pl=pkl中 301 | with open(config.pkl_path, 'wb') as fw: 302 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw) 303 | 304 | 305 | def draw_loss(history): 306 | train_loss = history.history['loss'] 307 | valid_loss = history.history['val_loss'] 308 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train') 309 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid') 310 | plt.legend(loc='upper right') 311 | plt.xlabel('Epoch') 312 | plt.ylabel('Loss') 313 | plt.show() 314 | 315 | 316 | def run(): 317 | # print("-----load data-----") 318 | # path_train = load_wav_data(path=config.train_wav_data_path) 319 | # path_test = load_wav_data(path=config.test_wav_data_path) 320 | # paths = [] 321 | # paths.extend(path_train), paths.extend(path_test) 322 | 323 | # privacy_dict = create_en_num() 324 | # texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict) 325 | # texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict) 326 | # texts = [] 327 | # texts.extend(texts_train), texts.extend(texts_test) 328 | 329 | # texts,char2id,id2char = save_labels(texts) 330 | 331 | print("-----load labels-----") 332 | texts,char2id,id2char = load_labels() 333 | 334 | total = len(texts) 335 | 336 | # print("-----Extract audio features-----") 337 | # features = wav_features(paths,total) 338 | 339 | # print("-----save features-----") 340 | # save_features(features) 341 | 342 | print("-----load features-----") 343 | features = load_features() 344 | 345 | 346 | mfcc_mean,mfcc_std,features = normalized_features(features) 347 | 348 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts) 349 | 350 | X,Y,X_length,Y_length = input_layer() 351 | 352 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id) 353 | 354 | # 回调 355 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0) 356 | # 监控 损失值(loss)作为指标 357 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6) 358 | #开始训练 359 | history = model.fit_generator( 360 | generator=batch_generator(X_train, Y_train, char2id), 361 | steps_per_epoch=len(X_train) // config.batch_size, 362 | epochs=config.epochs, 363 | validation_data=batch_generator(X_test, Y_test, char2id), 364 | validation_steps=len(X_test) // config.batch_size, 365 | callbacks=[checkpointer, lr_decay]) 366 | 367 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std) 368 | draw_loss(history) 369 | 370 | 371 | if __name__ == '__main__' : 372 | run() 373 | -------------------------------------------------------------------------------- /train_v4.py: -------------------------------------------------------------------------------- 1 | #导入相关的库 2 | from keras.models import Model 3 | from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization 4 | from keras.optimizers import Adam, SGD 5 | from keras import backend as K 6 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 7 | import tensorflow as tf 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable 11 | import random 12 | import pickle 13 | import glob 14 | from tqdm import tqdm 15 | import os 16 | from python_speech_features import mfcc 17 | import scipy.io.wavfile as wav 18 | import librosa 19 | from IPython.display import Audio 20 | import config 21 | 22 | 23 | def load_texts_data(path,en_num_all): 24 | f=open(path, "r", encoding='utf-8') 25 | txt=[] 26 | for line in f: 27 | txt.append(line.strip()) 28 | 29 | # 音频文件名字 30 | txt_filename = [] 31 | # 对应文件内容 32 | txt_wav = [] 33 | 34 | for i in txt: 35 | temp_txt_filename,temp_txt_wav = i.split('\t') 36 | txt_filename.append(temp_txt_filename) 37 | txt_wav.append(temp_txt_wav) 38 | 39 | # 音频文字 40 | texts = [] 41 | 42 | for i in txt_wav: 43 | temp = '' 44 | for j in i: 45 | if j in en_num_all: 46 | continue 47 | else: 48 | # print(j) 49 | temp = temp+j 50 | texts.append(temp) 51 | 52 | return texts 53 | 54 | def create_en_num(): 55 | # 字典 字母+数字 56 | en_num_all = [] 57 | # 字母 58 | for letter in 'abcdefghijklmnopqrstuvwxyz': 59 | en_num_all.extend(letter) 60 | # 数字 61 | for number in range(10): 62 | en_num_all.extend(str(number)) 63 | # 空格 64 | en_num_all.extend(' ') 65 | 66 | return en_num_all 67 | 68 | def load_wav_data(path): 69 | files = os.listdir(path) 70 | # 音频文件 71 | wav_file = [] 72 | # 音频文件的相对路径 73 | wav_file_path = [] 74 | 75 | for i in files: 76 | temp_files = os.listdir(path+i) 77 | for j in temp_files: 78 | wav_file.append(j) 79 | wav_file_path.append(path+i+'/'+j) 80 | 81 | return wav_file_path 82 | 83 | #根据数据集标定的音素读入 84 | def load_and_trim(path): 85 | audio, sr = librosa.load(path) 86 | # energy = librosa.feature.rmse(audio) 87 | energy = librosa.feature.rms(audio) 88 | frames = np.nonzero(energy >= np.max(energy) / 5) 89 | indices = librosa.core.frames_to_samples(frames)[1] 90 | audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 91 | return audio, sr 92 | 93 | #可视化,显示语音文件的MFCC图 94 | def visualize(paths,texts,index,mfcc_dim): 95 | path = paths[index] 96 | text = texts[index] 97 | print('Audio Text:', text) 98 | 99 | audio, sr = load_and_trim(path) 100 | plt.figure(figsize=(12, 3)) 101 | plt.plot(np.arange(len(audio)), audio) 102 | plt.title('Raw Audio Signal') 103 | plt.xlabel('Time') 104 | plt.ylabel('Audio Amplitude') 105 | plt.show() 106 | 107 | feature = mfcc(audio, sr, numcep=mfcc_dim, nfft=551) 108 | print('Shape of MFCC:', feature.shape) 109 | 110 | fig = plt.figure(figsize=(12, 5)) 111 | ax = fig.add_subplot(111) 112 | im = ax.imshow(feature, cmap=plt.cm.jet, aspect='auto') 113 | plt.title('Normalized MFCC') 114 | plt.ylabel('Time') 115 | plt.xlabel('MFCC Coefficient') 116 | plt.colorbar(im, cax=make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)) 117 | ax.set_xticks(np.arange(0, 13, 2), minor=False); 118 | plt.show() 119 | 120 | return path 121 | 122 | # Audio(visualize(0)) 123 | 124 | def wav_features(paths,total): 125 | #提取音频特征并存储 126 | features = [] 127 | for i in tqdm(range(total)): 128 | path = paths[i] 129 | audio, sr = load_and_trim(path) 130 | features.append(mfcc(audio, sr, numcep=config.mfcc_dim, nfft=551)) 131 | return features 132 | 133 | def save_features(features): 134 | with open(config.features_path, 'wb') as fw: 135 | pickle.dump(features,fw) 136 | 137 | def load_features(): 138 | with open(config.features_path, 'rb') as f: 139 | features = pickle.load(f) 140 | return features 141 | 142 | def normalized_features(features): 143 | #随机选择100个数据集 144 | samples = random.sample(features, 100) 145 | samples = np.vstack(samples) 146 | #平均MFCC的值为了归一化处理 147 | mfcc_mean = np.mean(samples, axis=0) 148 | #计算标准差为了归一化 149 | mfcc_std = np.std(samples, axis=0) 150 | # print(mfcc_mean) 151 | # print(mfcc_std) 152 | #归一化特征 153 | features = [(feature - mfcc_mean) / (mfcc_std + 1e-14) for feature in features] 154 | 155 | return mfcc_mean,mfcc_std,features 156 | 157 | def save_labels(texts): 158 | #将数据集读入的标签和对应id存储列表 159 | chars = {} 160 | for text in texts: 161 | for c in text: 162 | chars[c] = chars.get(c, 0) + 1 163 | 164 | chars = sorted(chars.items(), key=lambda x: x[1], reverse=True) 165 | chars = [char[0] for char in chars] 166 | # print(len(chars), chars[:100]) 167 | 168 | char2id = {c: i for i, c in enumerate(chars)} 169 | id2char = {i: c for i, c in enumerate(chars)} 170 | 171 | with open(config.labels_path, 'wb') as fw: 172 | pickle.dump([texts,char2id, id2char], fw) 173 | 174 | return texts,char2id,id2char 175 | 176 | 177 | def load_labels(): 178 | with open(config.labels_path, 'rb') as f: 179 | texts,char2id,id2char = pickle.load(f) 180 | return texts,char2id,id2char 181 | 182 | 183 | def data_set(total,features,texts): 184 | data_index = np.arange(total) 185 | np.random.shuffle(data_index) 186 | train_size = int(0.9 * total) 187 | test_size = total - train_size 188 | train_index = data_index[:train_size] 189 | test_index = data_index[train_size:] 190 | #神经网络输入和输出X,Y的读入数据集特征 191 | X_train = [features[i] for i in train_index] 192 | Y_train = [texts[i] for i in train_index] 193 | X_test = [features[i] for i in test_index] 194 | Y_test = [texts[i] for i in test_index] 195 | 196 | return X_train,Y_train,X_test,Y_test 197 | 198 | 199 | #定义训练批次的产生,一次训练16个 200 | def batch_generator(x, y,char2id): 201 | batch_size = config.batch_size 202 | offset = 0 203 | while True: 204 | offset += batch_size 205 | 206 | if offset == batch_size or offset >= len(x): 207 | data_index = np.arange(len(x)) 208 | np.random.shuffle(data_index) 209 | x = [x[i] for i in data_index] 210 | y = [y[i] for i in data_index] 211 | offset = batch_size 212 | 213 | X_data = x[offset - batch_size: offset] 214 | Y_data = y[offset - batch_size: offset] 215 | 216 | X_maxlen = max([X_data[i].shape[0] for i in range(batch_size)]) 217 | Y_maxlen = max([len(Y_data[i]) for i in range(batch_size)]) 218 | 219 | X_batch = np.zeros([batch_size, X_maxlen, config.mfcc_dim]) 220 | Y_batch = np.ones([batch_size, Y_maxlen]) * len(char2id) 221 | X_length = np.zeros([batch_size, 1], dtype='int32') 222 | Y_length = np.zeros([batch_size, 1], dtype='int32') 223 | 224 | for i in range(batch_size): 225 | X_length[i, 0] = X_data[i].shape[0] 226 | X_batch[i, :X_length[i, 0], :] = X_data[i] 227 | 228 | Y_length[i, 0] = len(Y_data[i]) 229 | Y_batch[i, :Y_length[i, 0]] = [char2id[c] for c in Y_data[i]] 230 | 231 | inputs = {'X': X_batch, 'Y': Y_batch, 'X_length': X_length, 'Y_length': Y_length} 232 | outputs = {'ctc': np.zeros([batch_size])} 233 | 234 | yield (inputs, outputs) 235 | 236 | def input_layer(): 237 | X = Input(shape=(None, config.mfcc_dim,), dtype='float32', name='X') 238 | Y = Input(shape=(None,), dtype='float32', name='Y') 239 | X_length = Input(shape=(1,), dtype='int32', name='X_length') 240 | Y_length = Input(shape=(1,), dtype='int32', name='Y_length') 241 | 242 | return X,Y,X_length,Y_length 243 | 244 | 245 | #卷积1层 246 | def conv1d(inputs, filters, kernel_size, dilation_rate): 247 | return Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='causal', activation=None, 248 | dilation_rate=dilation_rate)(inputs) 249 | 250 | #标准化函数 251 | def batchnorm(inputs): 252 | return BatchNormalization()(inputs) 253 | 254 | #激活层函数 255 | def activation(inputs, activation): 256 | return Activation(activation)(inputs) 257 | 258 | #全连接层函数 259 | def res_block(inputs, filters, kernel_size, dilation_rate): 260 | hf = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'tanh') 261 | hg = activation(batchnorm(conv1d(inputs, filters, kernel_size, dilation_rate)), 'sigmoid') 262 | h0 = Multiply()([hf, hg]) 263 | 264 | ha = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 265 | hs = activation(batchnorm(conv1d(h0, filters, 1, 1)), 'tanh') 266 | 267 | return Add()([ha, inputs]), hs 268 | 269 | #计算损失函数 270 | def calc_ctc_loss(args): 271 | y, yp, ypl, yl = args 272 | return K.ctc_batch_cost(y, yp, ypl, yl) 273 | 274 | def model_train(X,Y,X_length,Y_length,char2id): 275 | h0 = activation(batchnorm(conv1d(X, config.filters, 1, 1)), 'tanh') 276 | shortcut = [] 277 | for i in range(config.num_blocks): 278 | for r in [1, 2, 4, 8, 16]: 279 | h0, s = res_block(h0, config.filters, 7, r) 280 | shortcut.append(s) 281 | 282 | h1 = activation(Add()(shortcut), 'relu') 283 | h1 = activation(batchnorm(conv1d(h1, config.filters, 1, 1)), 'relu') 284 | #softmax损失函数输出结果 285 | Y_pred = activation(batchnorm(conv1d(h1, len(char2id) + 1, 1, 1)), 'softmax') 286 | sub_model = Model(inputs=X, outputs=Y_pred) 287 | 288 | ctc_loss = Lambda(calc_ctc_loss, output_shape=(1,), name='ctc')([Y, Y_pred, X_length, Y_length]) 289 | #加载模型训练 290 | model = Model(inputs=[X, Y, X_length, Y_length], outputs=ctc_loss) 291 | #建立优化器 292 | optimizer = SGD(lr=0.02, momentum=0.9, nesterov=True, clipnorm=5) 293 | #激活模型开始计算 294 | model.compile(loss={'ctc': lambda ctc_true, ctc_pred: ctc_pred}, optimizer=optimizer) 295 | 296 | return sub_model,model 297 | 298 | def save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std): 299 | #保存模型 300 | sub_model.save(config.model_path) 301 | #将字保存在pl=pkl中 302 | with open(config.pkl_path, 'wb') as fw: 303 | pickle.dump([char2id, id2char, mfcc_mean, mfcc_std], fw) 304 | 305 | 306 | def draw_loss(history): 307 | train_loss = history.history['loss'] 308 | valid_loss = history.history['val_loss'] 309 | plt.plot(np.linspace(1, config.epochs, config.epochs), train_loss, label='train') 310 | plt.plot(np.linspace(1, config.epochs, config.epochs), valid_loss, label='valid') 311 | plt.legend(loc='upper right') 312 | plt.xlabel('Epoch') 313 | plt.ylabel('Loss') 314 | plt.show() 315 | 316 | 317 | def run(): 318 | # print("-----load data-----") 319 | # path_train = load_wav_data(path=config.train_wav_data_path) 320 | # path_test = load_wav_data(path=config.test_wav_data_path) 321 | # paths = [] 322 | # paths.extend(path_train), paths.extend(path_test) 323 | 324 | # privacy_dict = create_en_num() 325 | # texts_train = load_texts_data(path=config.train_texts_data_path, en_num_all=privacy_dict) 326 | # texts_test = load_texts_data(path=config.test_texts_data_path, en_num_all=privacy_dict) 327 | # texts = [] 328 | # texts.extend(texts_train), texts.extend(texts_test) 329 | 330 | # texts,char2id,id2char = save_labels(texts) 331 | 332 | print("-----load labels-----") 333 | texts,char2id,id2char = load_labels() 334 | 335 | total = len(texts) 336 | 337 | # print("-----Extract audio features-----") 338 | # features = wav_features(paths,total) 339 | 340 | # print("-----save features-----") 341 | # save_features(features) 342 | 343 | print("-----load features-----") 344 | features = load_features() 345 | 346 | 347 | mfcc_mean,mfcc_std,features = normalized_features(features) 348 | 349 | X_train,Y_train,X_test,Y_test = data_set(total,features,texts) 350 | 351 | X,Y,X_length,Y_length = input_layer() 352 | 353 | sub_model,model = model_train(X,Y,X_length,Y_length,char2id) 354 | 355 | # 回调 356 | checkpointer = ModelCheckpoint(filepath=config.model_path, verbose=0) 357 | # 监控 损失值(loss)作为指标 358 | lr_decay = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=1, min_lr=1e-6) 359 | 360 | tf_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") 361 | #开始训练 362 | history = model.fit_generator( 363 | generator=batch_generator(X_train, Y_train, char2id), 364 | steps_per_epoch=len(X_train) // config.batch_size, 365 | epochs=config.epochs, 366 | validation_data=batch_generator(X_test, Y_test, char2id), 367 | validation_steps=len(X_test) // config.batch_size, 368 | callbacks=[checkpointer, lr_decay,tf_callback]) 369 | 370 | save_model_pkl(sub_model,char2id,id2char,mfcc_mean,mfcc_std) 371 | draw_loss(history) 372 | 373 | 374 | if __name__ == '__main__' : 375 | run() 376 | -------------------------------------------------------------------------------- /utils/binary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import mmap 3 | 4 | import struct 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | class DatasetWriter(object): 10 | def __init__(self, prefix): 11 | # 创建对应的数据文件 12 | self.data_file = open(prefix + '.data', 'wb') 13 | self.header_file = open(prefix + '.header', 'wb') 14 | self.data_sum = 0 15 | self.offset = 0 16 | self.header = '' 17 | 18 | def add_data(self, data): 19 | key = str(self.data_sum) 20 | data = bytes(data, encoding="utf8") 21 | # 写入图像数据 22 | self.data_file.write(struct.pack('I', len(key))) 23 | self.data_file.write(key.encode('ascii')) 24 | self.data_file.write(struct.pack('I', len(data))) 25 | self.data_file.write(data) 26 | # 写入索引 27 | self.offset += 4 + len(key) + 4 28 | self.header = key + '\t' + str(self.offset) + '\t' + str(len(data)) + '\n' 29 | self.header_file.write(self.header.encode('ascii')) 30 | self.offset += len(data) 31 | self.data_sum += 1 32 | 33 | def close(self): 34 | self.data_file.close() 35 | self.header_file.close() 36 | 37 | 38 | class DatasetReader(object): 39 | def __init__(self, data_header_path, min_duration=0, max_duration=30): 40 | self.keys = [] 41 | self.offset_dict = {} 42 | self.fp = open(data_header_path.replace('.header', '.data'), 'rb') 43 | self.m = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ) 44 | for line in tqdm(open(data_header_path, 'rb'), desc='读取数据列表'): 45 | key, val_pos, val_len = line.split('\t'.encode('ascii')) 46 | data = self.m[int(val_pos):int(val_pos) + int(val_len)] 47 | data = str(data, encoding="utf-8") 48 | data = json.loads(data) 49 | # 跳过超出长度限制的音频 50 | if data["duration"] < min_duration: 51 | continue 52 | if max_duration != -1 and data["duration"] > max_duration: 53 | continue 54 | self.keys.append(key) 55 | self.offset_dict[key] = (int(val_pos), int(val_len)) 56 | 57 | # 获取一行列表数据 58 | def get_data(self, key): 59 | p = self.offset_dict.get(key, None) 60 | if p is None: 61 | return None 62 | val_pos, val_len = p 63 | data = self.m[val_pos:val_pos + val_len] 64 | data = str(data, encoding="utf-8") 65 | return json.loads(data) 66 | 67 | # 获取keys 68 | def get_keys(self): 69 | return self.keys 70 | 71 | def __len__(self): 72 | return len(self.keys) 73 | -------------------------------------------------------------------------------- /utils/callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os 3 | import shutil 4 | 5 | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl 6 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 7 | 8 | 9 | # 保存模型时的回调函数 10 | class SavePeftModelCallback(TrainerCallback): 11 | def on_save(self, 12 | args: TrainingArguments, 13 | state: TrainerState, 14 | control: TrainerControl, 15 | **kwargs, ): 16 | if args.local_rank == 0 or args.local_rank == -1: 17 | # 保存效果最好的模型 18 | best_checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-best") 19 | # 因为只保存最新5个检查点,所以要确保不是之前的检查点 20 | if os.path.exists(state.best_model_checkpoint): 21 | if os.path.exists(best_checkpoint_folder): 22 | shutil.rmtree(best_checkpoint_folder) 23 | shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder) 24 | print(f"效果最好的检查点为:{state.best_model_checkpoint},评估结果为:{state.best_metric}") 25 | return control 26 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | from typing import Any, List, Dict, Union 4 | 5 | import torch 6 | from zhconv import convert 7 | 8 | 9 | # 删除标点符号 10 | def remove_punctuation(text: str or List[str]): 11 | punctuation = '!,.;:?、!,。;:?' 12 | if isinstance(text, str): 13 | text = re.sub(r'[{}]+'.format(punctuation), '', text).strip() 14 | return text 15 | elif isinstance(text, list): 16 | result_text = [] 17 | for t in text: 18 | t = re.sub(r'[{}]+'.format(punctuation), '', t).strip() 19 | result_text.append(t) 20 | return result_text 21 | else: 22 | raise Exception(f'不支持该类型{type(text)}') 23 | 24 | 25 | # 将繁体中文总成简体中文 26 | def to_simple(text: str or List[str]): 27 | if isinstance(text, str): 28 | text = convert(text, 'zh-cn') 29 | return text 30 | elif isinstance(text, list): 31 | result_text = [] 32 | for t in text: 33 | t = convert(t, 'zh-cn') 34 | result_text.append(t) 35 | return result_text 36 | else: 37 | raise Exception(f'不支持该类型{type(text)}') 38 | 39 | 40 | @dataclass 41 | class DataCollatorSpeechSeq2SeqWithPadding: 42 | processor: Any 43 | 44 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 45 | # split inputs and labels since they have to be of different lengths and need different padding methods 46 | # first treat the audio inputs by simply returning torch tensors 47 | input_features = [{"input_features": feature["input_features"][0]} for feature in features] 48 | batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") 49 | 50 | # get the tokenized label sequences 51 | label_features = [{"input_ids": feature["labels"]} for feature in features] 52 | # pad the labels to max length 53 | labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") 54 | 55 | # replace padding with -100 to ignore loss correctly 56 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 57 | 58 | # if bos token is appended in previous tokenization step, 59 | # cut bos token here as it's append later anyways 60 | if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): 61 | labels = labels[:, 1:] 62 | 63 | batch["labels"] = labels 64 | 65 | return batch 66 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import bitsandbytes as bnb 2 | import torch 3 | from transformers.trainer_pt_utils import LabelSmoother 4 | 5 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 6 | 7 | 8 | def find_all_linear_names(use_8bit, model): 9 | cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear 10 | lora_module_names = set() 11 | for name, module in model.named_modules(): 12 | if isinstance(module, cls): 13 | names = name.split('.') 14 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 15 | target_modules = list(lora_module_names) 16 | return target_modules 17 | 18 | 19 | def load_from_checkpoint(resume_from_checkpoint, model=None): 20 | pass 21 | -------------------------------------------------------------------------------- /utils/reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import sys 5 | from typing import List 6 | 7 | import librosa 8 | import numpy as np 9 | import soundfile 10 | from torch.utils.data import Dataset 11 | from tqdm import tqdm 12 | 13 | from utils.binary import DatasetReader 14 | 15 | 16 | class CustomDataset(Dataset): 17 | def __init__(self, 18 | data_list_path, 19 | processor, 20 | mono=True, 21 | language=None, 22 | timestamps=False, 23 | sample_rate=16000, 24 | min_duration=0.5, 25 | max_duration=30, 26 | augment_config_path=None): 27 | """ 28 | Args: 29 | data_list_path: 数据列表文件的路径,或者二进制列表的头文件路径 30 | processor: Whisper的预处理工具,WhisperProcessor.from_pretrained获取 31 | mono: 是否将音频转换成单通道,这个必须是True 32 | language: 微调数据的语言 33 | timestamps: 微调时是否使用时间戳 34 | sample_rate: 音频的采样率,默认是16000 35 | min_duration: 小于这个时间段的音频将被截断,单位秒,不能小于0.5,默认0.5s 36 | max_duration: 大于这个时间段的音频将被截断,单位秒,不能大于30,默认30s 37 | augment_config_path: 数据增强配置参数文件路径 38 | """ 39 | super(CustomDataset, self).__init__() 40 | assert min_duration >= 0.5, f"min_duration不能小于0.5,当前为:{min_duration}" 41 | assert max_duration <= 30, f"max_duration不能大于30,当前为:{max_duration}" 42 | self.data_list_path = data_list_path 43 | self.processor = processor 44 | self.data_list_path = data_list_path 45 | self.sample_rate = sample_rate 46 | self.mono = mono 47 | self.language = language 48 | self.timestamps = timestamps 49 | self.min_duration = min_duration 50 | self.max_duration = max_duration 51 | self.vocab = self.processor.tokenizer.get_vocab() 52 | self.startoftranscript = self.vocab['<|startoftranscript|>'] 53 | self.endoftext = self.vocab['<|endoftext|>'] 54 | if '<|nospeech|>' in self.vocab.keys(): 55 | self.nospeech = self.vocab['<|nospeech|>'] 56 | self.timestamp_begin = None 57 | else: 58 | # 兼容旧模型 59 | self.nospeech = self.vocab['<|nocaptions|>'] 60 | self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1 61 | self.data_list: List[dict] = [] 62 | # 加载数据列表 63 | self._load_data_list() 64 | # 数据增强配置参数 65 | self.augment_configs = None 66 | self.noises_path = None 67 | self.speed_rates = None 68 | if augment_config_path: 69 | with open(augment_config_path, 'r', encoding='utf-8') as f: 70 | self.augment_configs = json.load(f) 71 | 72 | # 加载数据列表 73 | def _load_data_list(self): 74 | if self.data_list_path.endswith(".header"): 75 | # 获取二进制的数据列表 76 | self.dataset_reader = DatasetReader(data_header_path=self.data_list_path, 77 | min_duration=self.min_duration, 78 | max_duration=self.max_duration) 79 | self.data_list = self.dataset_reader.get_keys() 80 | else: 81 | # 获取数据列表 82 | with open(self.data_list_path, 'r', encoding='utf-8') as f: 83 | lines = f.readlines() 84 | self.data_list = [] 85 | for line in tqdm(lines, desc='读取数据列表'): 86 | if isinstance(line, str): 87 | line = json.loads(line) 88 | if not isinstance(line, dict): continue 89 | # 跳过超出长度限制的音频 90 | if line["duration"] < self.min_duration: 91 | continue 92 | if self.max_duration != -1 and line["duration"] > self.max_duration: 93 | continue 94 | self.data_list.append(dict(line)) 95 | 96 | # 从数据列表里面获取音频数据、采样率和文本 97 | def _get_list_data(self, idx): 98 | if self.data_list_path.endswith(".header"): 99 | data_list = self.dataset_reader.get_data(self.data_list[idx]) 100 | else: 101 | data_list = self.data_list[idx] 102 | # 分割音频路径和标签 103 | audio_file = data_list["audio"]['path'] 104 | transcript = data_list["sentences"] if self.timestamps else data_list["sentence"] 105 | language = data_list["language"] if 'language' in data_list.keys() else None 106 | if 'start_time' not in data_list["audio"].keys(): 107 | sample, sample_rate = soundfile.read(audio_file, dtype='float32') 108 | else: 109 | start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"] 110 | # 分割读取音频 111 | sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time) 112 | sample = sample.T 113 | # 转成单通道 114 | if self.mono: 115 | sample = librosa.to_mono(sample) 116 | # 数据增强 117 | if self.augment_configs: 118 | sample, sample_rate = self.augment(sample, sample_rate) 119 | # 重采样 120 | if self.sample_rate != sample_rate: 121 | sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate) 122 | return sample, sample_rate, transcript, language 123 | 124 | def _load_timestamps_transcript(self, transcript: List[dict]): 125 | assert isinstance(transcript, list), f"transcript应该为list,当前为:{type(transcript)}" 126 | data = dict() 127 | labels = self.processor.tokenizer.prefix_tokens[:3] 128 | for t in transcript: 129 | # 将目标文本编码为标签ID 130 | start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01 131 | if self.timestamp_begin is None: 132 | start = self.vocab[f'<|{start:.2f}|>'] 133 | else: 134 | start = self.timestamp_begin + round(start * 100) // 2 135 | end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01 136 | if self.timestamp_begin is None: 137 | end = self.vocab[f'<|{end:.2f}|>'] 138 | else: 139 | end = self.timestamp_begin + round(end * 100) // 2 140 | label = self.processor(text=t['text']).input_ids[4:-1] 141 | labels.extend([start]) 142 | labels.extend(label) 143 | labels.extend([end]) 144 | data['labels'] = labels + [self.endoftext] 145 | return data 146 | 147 | def __getitem__(self, idx): 148 | try: 149 | # 从数据列表里面获取音频数据、采样率和文本 150 | sample, sample_rate, transcript, language = self._get_list_data(idx=idx) 151 | # 可以为单独数据设置语言 152 | self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language) 153 | if len(transcript) > 0: 154 | # 加载带有时间戳的文本 155 | if self.timestamps: 156 | data = self._load_timestamps_transcript(transcript=transcript) 157 | # 从输入音频数组中计算log-Mel输入特征 158 | data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features 159 | else: 160 | # 获取log-Mel特征和标签ID 161 | data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript) 162 | else: 163 | # 如果没有文本,则使用<|nospeech|>标记 164 | data = self.processor(audio=sample, sampling_rate=self.sample_rate) 165 | data['labels'] = [self.startoftranscript, self.nospeech, self.endoftext] 166 | return data 167 | except Exception as e: 168 | print(f'读取数据出错,序号:{idx},错误信息:{e}', file=sys.stderr) 169 | return self.__getitem__(random.randint(0, self.__len__() - 1)) 170 | 171 | def __len__(self): 172 | return len(self.data_list) 173 | 174 | # 分割读取音频 175 | @staticmethod 176 | def slice_from_file(file, start, end): 177 | sndfile = soundfile.SoundFile(file) 178 | sample_rate = sndfile.samplerate 179 | duration = round(float(len(sndfile)) / sample_rate, 3) 180 | start = round(start, 3) 181 | end = round(end, 3) 182 | # 从末尾开始计 183 | if start < 0.0: start += duration 184 | if end < 0.0: end += duration 185 | # 保证数据不越界 186 | if start < 0.0: start = 0.0 187 | if end > duration: end = duration 188 | if end < 0.0: 189 | raise ValueError("切片结束位置(%f s)越界" % end) 190 | if start > end: 191 | raise ValueError("切片开始位置(%f s)晚于切片结束位置(%f s)" % (start, end)) 192 | start_frame = int(start * sample_rate) 193 | end_frame = int(end * sample_rate) 194 | sndfile.seek(start_frame) 195 | sample = sndfile.read(frames=end_frame - start_frame, dtype='float32') 196 | return sample, sample_rate 197 | 198 | # 数据增强 199 | def augment(self, sample, sample_rate): 200 | for config in self.augment_configs: 201 | if config['type'] == 'speed' and random.random() < config['prob']: 202 | if self.speed_rates is None: 203 | min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \ 204 | config['params']['max_speed_rate'], config['params']['num_rates'] 205 | self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True) 206 | rate = random.choice(self.speed_rates) 207 | sample = self.change_speed(sample, speed_rate=rate) 208 | if config['type'] == 'shift' and random.random() < config['prob']: 209 | min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms'] 210 | shift_ms = random.randint(min_shift_ms, max_shift_ms) 211 | sample = self.shift(sample, sample_rate, shift_ms=shift_ms) 212 | if config['type'] == 'volume' and random.random() < config['prob']: 213 | min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS'] 214 | gain = random.randint(min_gain_dBFS, max_gain_dBFS) 215 | sample = self.volume(sample, gain=gain) 216 | if config['type'] == 'resample' and random.random() < config['prob']: 217 | new_sample_rates = config['params']['new_sample_rates'] 218 | new_sample_rate = np.random.choice(new_sample_rates) 219 | sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate) 220 | sample_rate = new_sample_rate 221 | if config['type'] == 'noise' and random.random() < config['prob']: 222 | min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB'] 223 | if self.noises_path is None: 224 | self.noises_path = [] 225 | noise_dir = config['params']['noise_dir'] 226 | if os.path.exists(noise_dir): 227 | for file in os.listdir(noise_dir): 228 | self.noises_path.append(os.path.join(noise_dir, file)) 229 | noise_path = random.choice(self.noises_path) 230 | snr_dB = random.randint(min_snr_dB, max_snr_dB) 231 | sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB) 232 | return sample, sample_rate 233 | 234 | # 改变语速 235 | @staticmethod 236 | def change_speed(sample, speed_rate): 237 | if speed_rate == 1.0: 238 | return sample 239 | if speed_rate <= 0: 240 | raise ValueError("速度速率应大于零") 241 | old_length = sample.shape[0] 242 | new_length = int(old_length / speed_rate) 243 | old_indices = np.arange(old_length) 244 | new_indices = np.linspace(start=0, stop=old_length, num=new_length) 245 | sample = np.interp(new_indices, old_indices, sample).astype(np.float32) 246 | return sample 247 | 248 | # 音频偏移 249 | @staticmethod 250 | def shift(sample, sample_rate, shift_ms): 251 | duration = sample.shape[0] / sample_rate 252 | if abs(shift_ms) / 1000.0 > duration: 253 | raise ValueError("shift_ms的绝对值应该小于音频持续时间") 254 | shift_samples = int(shift_ms * sample_rate / 1000) 255 | if shift_samples > 0: 256 | sample[:-shift_samples] = sample[shift_samples:] 257 | sample[-shift_samples:] = 0 258 | elif shift_samples < 0: 259 | sample[-shift_samples:] = sample[:shift_samples] 260 | sample[:-shift_samples] = 0 261 | return sample 262 | 263 | # 改变音量 264 | @staticmethod 265 | def volume(sample, gain): 266 | sample *= 10.**(gain / 20.) 267 | return sample 268 | 269 | # 声音重采样 270 | @staticmethod 271 | def resample(sample, orig_sr, target_sr): 272 | sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr) 273 | return sample 274 | 275 | # 添加噪声 276 | def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0): 277 | noise_sample, sr = librosa.load(noise_path, sr=sample_rate) 278 | # 标准化音频音量,保证噪声不会太大 279 | target_db = -20 280 | gain = min(max_gain_db, target_db - self.rms_db(sample)) 281 | sample *= 10. ** (gain / 20.) 282 | # 指定噪声音量 283 | sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample) 284 | noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db) 285 | noise_sample *= 10. ** (noise_gain_db / 20.) 286 | # 固定噪声长度 287 | if noise_sample.shape[0] < sample.shape[0]: 288 | diff_duration = sample.shape[0] - noise_sample.shape[0] 289 | noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap') 290 | elif noise_sample.shape[0] > sample.shape[0]: 291 | start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0]) 292 | noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame] 293 | sample += noise_sample 294 | return sample 295 | 296 | @staticmethod 297 | def rms_db(sample): 298 | mean_square = np.mean(sample ** 2) 299 | return 10 * np.log10(mean_square) 300 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import tarfile 4 | import urllib.request 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def print_arguments(args): 10 | print("----------- Configuration Arguments -----------") 11 | for arg, value in vars(args).items(): 12 | print("%s: %s" % (arg, value)) 13 | print("------------------------------------------------") 14 | 15 | 16 | def strtobool(val): 17 | val = val.lower() 18 | if val in ('y', 'yes', 't', 'true', 'on', '1'): 19 | return True 20 | elif val in ('n', 'no', 'f', 'false', 'off', '0'): 21 | return False 22 | else: 23 | raise ValueError("invalid truth value %r" % (val,)) 24 | 25 | 26 | def str_none(val): 27 | if val == 'None': 28 | return None 29 | else: 30 | return val 31 | 32 | 33 | def add_arguments(argname, type, default, help, argparser, **kwargs): 34 | type = strtobool if type == bool else type 35 | type = str_none if type == str else type 36 | argparser.add_argument("--" + argname, 37 | default=default, 38 | type=type, 39 | help=help + ' Default: %(default)s.', 40 | **kwargs) 41 | 42 | 43 | def md5file(fname): 44 | hash_md5 = hashlib.md5() 45 | f = open(fname, "rb") 46 | for chunk in iter(lambda: f.read(4096), b""): 47 | hash_md5.update(chunk) 48 | f.close() 49 | return hash_md5.hexdigest() 50 | 51 | 52 | def download(url, md5sum, target_dir): 53 | """Download file from url to target_dir, and check md5sum.""" 54 | if not os.path.exists(target_dir): os.makedirs(target_dir) 55 | filepath = os.path.join(target_dir, url.split("/")[-1]) 56 | if not (os.path.exists(filepath) and md5file(filepath) == md5sum): 57 | print(f"Downloading {url} to {filepath} ...") 58 | with urllib.request.urlopen(url) as source, open(filepath, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, 60 | unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | print(f"\nMD5 Chesksum {filepath} ...") 69 | if not md5file(filepath) == md5sum: 70 | raise RuntimeError("MD5 checksum failed.") 71 | else: 72 | print(f"File exists, skip downloading. ({filepath})") 73 | return filepath 74 | 75 | 76 | def unpack(filepath, target_dir, rm_tar=False): 77 | """Unpack the file to the target_dir.""" 78 | print("Unpacking %s ..." % filepath) 79 | tar = tarfile.open(filepath) 80 | tar.extractall(target_dir) 81 | tar.close() 82 | if rm_tar: 83 | os.remove(filepath) 84 | 85 | 86 | def make_inputs_require_grad(module, input, output): 87 | output.requires_grad_(True) 88 | --------------------------------------------------------------------------------