├── Final_GUI.py ├── GMM_UBM.py ├── GUI.py ├── MFCC_DTW.py ├── README.md ├── UI ├── GMM_UBM_GUI.py ├── GMM_UBM_GUI.ui ├── __init__.py ├── final.py ├── final.ui └── tmp.py ├── VAD.py ├── d_vector.py ├── demo.mp4 ├── report ├── First.pdf ├── GMM_UBM.pdf ├── final.pdf └── mfcc报告.pdf ├── requirements.txt ├── utils ├── __init__.py ├── processing.py └── tools.py └── 展示.pptx /Final_GUI.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/30 10:12 4 | # @Author : chuyu zhang 5 | # @File : Final_GUI.py 6 | # @Software: PyCharm 7 | 8 | import sys 9 | from PyQt5 import QtWidgets 10 | from UI.tmp import Ui_MainWindow 11 | 12 | class MyPyQT_Form(QtWidgets.QMainWindow,Ui_MainWindow): 13 | def __init__(self): 14 | super(MyPyQT_Form,self).__init__() 15 | self.setupUi(self) 16 | 17 | 18 | if __name__ == '__main__': 19 | app = QtWidgets.QApplication(sys.argv) 20 | my_pyqt_form = MyPyQT_Form() 21 | my_pyqt_form.show() 22 | sys.exit(app.exec_()) 23 | -------------------------------------------------------------------------------- /GMM_UBM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/4/12 22:24 4 | # @Author : chuyu zhang 5 | # @File : GMM_UBM.py 6 | # @Software: PyCharm 7 | 8 | import os 9 | from utils.tools import read, get_time 10 | from tqdm import tqdm 11 | 12 | # from utils.processing import MFCC 13 | import python_speech_features as psf 14 | import numpy as np 15 | import pickle as pkl 16 | from sklearn.mixture import GaussianMixture 17 | from sklearn.model_selection import train_test_split 18 | from sklearn import preprocessing 19 | 20 | from sidekit.frontend.features import plp,mfcc 21 | label_encoder = {} 22 | 23 | 24 | def load_data(path='dataset/ASR_GMM'): 25 | """ 26 | load audio file. 27 | :param path: the dir to audio file 28 | :return: x type:list,each element is an audio, y type:list,it is the label of x 29 | """ 30 | start_time = get_time() 31 | print("Loading data...") 32 | speaker_list = os.listdir(path) 33 | y = [] 34 | x = [] 35 | num = 0 36 | for speaker in tqdm(speaker_list): 37 | # encoder the speaker to num 38 | label_encoder[speaker] = num 39 | path1 = os.path.join(path, speaker) 40 | for _dir in os.listdir(path1): 41 | path2 = os.path.join(path1, _dir) 42 | for _wav in os.listdir(path2): 43 | samplerate, audio = read(os.path.join(path2, _wav)) 44 | y.append(num) 45 | # sample rate is 16000, you can down sample it to 8000, but the result will be bad. 46 | x.append(audio) 47 | 48 | num += 1 49 | print("Complete! Spend {:.2f}s".format(get_time(start_time))) 50 | return x,y 51 | 52 | 53 | def delta(feat, N=2): 54 | """Compute delta features from a feature vector sequence. 55 | :param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector. 56 | :param N: For each frame, calculate delta features based on preceding and following N frames 57 | :returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector. 58 | """ 59 | if N < 1: 60 | raise ValueError('N must be an integer >= 1') 61 | NUMFRAMES = len(feat) 62 | denominator = 2 * sum([i**2 for i in range(1, N+1)]) 63 | delta_feat = np.empty_like(feat) 64 | # padded version of feat 65 | padded = np.pad(feat, ((N, N), (0, 0)), mode='edge') 66 | for t in range(NUMFRAMES): 67 | # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1] 68 | delta_feat[t] = np.dot(np.arange(-N, N+1), padded[t : t+2*N+1]) / denominator 69 | return delta_feat 70 | 71 | 72 | def extract_feature(x, y, is_train=False, feature_type='MFCC'): 73 | """ 74 | extract feature from x 75 | :param x: type list, each element is audio 76 | :param y: type list, each element is label of audio in x 77 | :param filepath: the path to save feature 78 | :param is_train: if true, generate train_data(type dict, key is lable, value is feature), 79 | if false, just extract feature from x 80 | :return: 81 | """ 82 | start_time = get_time() 83 | print("Extract {} feature...".format(feature_type)) 84 | feature = [] 85 | train_data = {} 86 | for i in tqdm(range(len(x))): 87 | # extract mfcc feature based on psf, you can look more detail on psf's website. 88 | if feature_type=='MFCC': 89 | _feature = mfcc(x[i]) 90 | mfcc_delta = delta(_feature) 91 | _feature = np.hstack((_feature, mfcc_delta)) 92 | 93 | _feature = preprocessing.scale(_feature) 94 | elif feature_type=='PLP': 95 | _feature = plp(x[i]) 96 | mfcc_delta = delta(_feature) 97 | _feature = np.hstack((_feature, mfcc_delta)) 98 | 99 | _feature = preprocessing.scale(_feature) 100 | else: 101 | raise NameError 102 | 103 | # append _feature to feature 104 | feature.append(_feature) 105 | 106 | if is_train: 107 | if y[i] in train_data: 108 | train_data[y[i]] = np.vstack((train_data[y[i]], _feature)) 109 | else: 110 | train_data[y[i]] = _feature 111 | 112 | print("Complete! Spend {:.2f}s".format(get_time(start_time))) 113 | 114 | 115 | if is_train: 116 | return train_data, feature, y 117 | else: 118 | return feature, y 119 | 120 | 121 | def load_extract(test_size=0.3): 122 | # combination load_data and extract_feature 123 | x, y = load_data() 124 | # train test split 125 | x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0) 126 | # extract feature from train 127 | train_data, x_train, y_train = extract_feature(x=x_train, y=y_train, is_train=True) 128 | # extract feature from test 129 | x_test, y_test = extract_feature(x=x_test,y=y_test) 130 | 131 | return train_data, x_train, y_train, x_test, y_test 132 | 133 | 134 | def GMM(train, x_train, y_train, x_test, y_test, n_components=16, model=False): 135 | print("Train GMM-UBM model !") 136 | # split x,y to train and test 137 | start_time = get_time() 138 | # if model is True, it will load model from Model/GMM_MFCC_model.pkl. 139 | # if False,it will train and save model 140 | 141 | if model: 142 | print("load model from file...") 143 | with open("Model/GMM_MFCC_model.pkl", 'rb') as f: 144 | GMM = pkl.load(f) 145 | with open("Model/UBM_MFCC_model.pkl", 'rb') as f: 146 | UBM = pkl.load(f) 147 | else: 148 | # speaker_list = os.listdir('dataset/ASR_GMM') 149 | GMM = [] 150 | ubm_train = None 151 | flag = False 152 | # UBM = [] 153 | print("Train GMM!") 154 | for speaker in tqdm(label_encoder.values()): 155 | # print(type(speaker)) 156 | # speaker = label_encoder[speaker] 157 | # GMM based on speaker 158 | gmm = GaussianMixture(n_components = n_components, covariance_type='diag') 159 | gmm.fit(train[speaker]) 160 | GMM.append(gmm) 161 | if flag: 162 | ubm_train = np.vstack((ubm_train, train[speaker])) 163 | else: 164 | ubm_train = train[speaker] 165 | flag = True 166 | 167 | # UBM based on background 168 | print("Train UBM!") 169 | UBM = GaussianMixture(n_components = n_components, covariance_type='diag') 170 | UBM.fit(ubm_train) 171 | # UBM.append(gmm) 172 | 173 | if not os.path.exists('Model'): 174 | os.mkdir("Model") 175 | 176 | with open("Model/GMM_MFCC_model.pkl", 'wb') as f: 177 | pkl.dump(GMM, f) 178 | with open("Model/UBM_MFCC_model.pkl", 'wb') as f: 179 | pkl.dump(UBM, f) 180 | 181 | # train accuracy 182 | valid = np.zeros((len(x_train), len(GMM))) 183 | for i in range(len(GMM)): 184 | for j in range(len(x_train)): 185 | valid[j, i] = GMM[i].score(x_train[j]) - UBM.score(x_train[j]) 186 | 187 | valid = valid.argmax(axis=1) 188 | acc_train = (valid==np.array(y_train)).sum()/len(x_train) 189 | 190 | # test accuracy 191 | pred = np.zeros((len(x_test), len(GMM))) 192 | for i in range(len(GMM)): 193 | for j in range(len(x_test)): 194 | pred[j, i] = GMM[i].score(x_test[j]) - UBM.score(x_test[j]) 195 | 196 | pred = pred.argmax(axis=1) 197 | acc = (pred==np.array(y_test)).sum()/len(x_test) 198 | 199 | print("spend {:.2f}s, train acc {:.2%}, test acc {:.2%}".format(get_time(start_time), acc_train,acc)) 200 | 201 | 202 | def main(): 203 | train_data, mfcc_x_train, mfcc_y_train, mfcc_x_test, mfcc_y_test = load_extract() 204 | GMM(train_data, mfcc_x_train, mfcc_y_train, mfcc_x_test, mfcc_y_test, model=False) 205 | 206 | 207 | if __name__=='__main__': 208 | main() -------------------------------------------------------------------------------- /GUI.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/4/20 22:49 4 | # @Author : chuyu zhang 5 | # @File : GUI.py 6 | # @Software: PyCharm 7 | 8 | # 继承至界面文件的主窗口类 9 | import sys 10 | from PyQt5 import QtWidgets 11 | from UI.GMM_UBM_GUI import Ui_Form 12 | 13 | class MyPyQT_Form(QtWidgets.QWidget,Ui_Form): 14 | def __init__(self): 15 | super(MyPyQT_Form,self).__init__() 16 | self.setupUi(self) 17 | 18 | 19 | if __name__ == '__main__': 20 | app = QtWidgets.QApplication(sys.argv) 21 | my_pyqt_form = MyPyQT_Form() 22 | my_pyqt_form.show() 23 | sys.exit(app.exec_()) 24 | -------------------------------------------------------------------------------- /MFCC_DTW.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/3/22 22:08 4 | # @Author : chuyu zhang 5 | # @File : MFCC_DTW.py 6 | # @Software: PyCharm 7 | 8 | import os 9 | import random 10 | from utils.tools import read, get_time,plot_confusion_matrix 11 | from utils.processing import enframe, MFCC 12 | import numpy as np 13 | from scipy.fftpack import fft 14 | 15 | import librosa 16 | # dtw is accurate than fastdtw, but it is slower, I will test the speed and acc later 17 | from scipy.spatial.distance import euclidean 18 | from dtw import dtw,accelerated_dtw 19 | from fastdtw import fastdtw 20 | 21 | import matplotlib.pyplot as plt 22 | # import seaborn as sns 23 | from tqdm import tqdm 24 | 25 | 26 | eps = 1e-8 27 | 28 | def MFCC_lib(raw_signal, n_mfcc=13): 29 | feature = librosa.feature.mfcc(raw_signal.astype('float32'), n_mfcc=n_mfcc, sr=8000) 30 | # print(feature.T.shape) 31 | return feature.T.flatten() 32 | 33 | def _MFCC(raw_signal): 34 | """ 35 | extract mfcc feature 36 | :param raw_signal: the original audio signal 37 | :param fs: sample frequency 38 | :param frameSize:the size of each frame 39 | :param step: 40 | :return: a series of mfcc feature of each frame and flatten to (num, ) 41 | """ 42 | # Signal normalization 43 | 44 | """ 45 | raw_signal = np.double(raw_signal) 46 | 47 | raw_signal = raw_signal / (2.0 ** 15) 48 | DC = raw_signal.mean() 49 | MAX = (np.abs(raw_signal)).max() 50 | raw_signal = (raw_signal - DC) / (MAX + eps) 51 | """ 52 | feature = MFCC(raw_signal, fs=8000, frameSize=512, step=256) 53 | # print(feature.shape) 54 | return feature.flatten() 55 | 56 | 57 | def distance_dtw(sample_x, sample_y, show=False, dtw_method=1, dist=euclidean): 58 | """ 59 | calculate the distance between sample_x and sample_y using dtw 60 | :param sample_x: ndarray, mfcc feature for each frame 61 | :param sample_y: the same as sample_x 62 | :param show: bool, if true, show the path 63 | :param dtw_method: 1:accelerated_dtw, 2:fastdtw 64 | :return: the euclidean distance 65 | """ 66 | # euclidean_norm = lambda x, y: np.abs(x - y)euclidean 67 | # 68 | # 69 | if dtw_method==2: 70 | d, path = fastdtw(sample_x, sample_y, dist=dist) 71 | elif dtw_method==1: 72 | d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(sample_x, sample_y, dist='euclidean') 73 | if show: 74 | plt.imshow(acc_cost_matrix.T, origin='lower', cmap='gray', interpolation='nearest') 75 | plt.plot(path[0], path[1], 'w') 76 | plt.show() 77 | 78 | return d 79 | 80 | 81 | def distance_train(data): 82 | """ 83 | calculate the distance of all data 84 | :param data: input data, list, mfcc feature of all audio 85 | :return: the distance matrix 86 | """ 87 | start_time = get_time() 88 | distance = np.zeros((len(data), len(data))) 89 | for index, sample_x in enumerate(data): 90 | col = index + 1 91 | for sample_y in data[col:]: 92 | distance[index, col] = distance_dtw(sample_x, sample_y) 93 | distance[col, index] = distance[index, col] 94 | col += 1 95 | print('cost {}s'.format(get_time(start_time))) 96 | return distance 97 | 98 | def distance_test(x_test, x_train, show=False): 99 | """ 100 | calculate the distance between x_test(one sample) and x_train(many sample) 101 | :param x_test: a sample 102 | :param x_train: the whole train dataset 103 | :return: distance based on dtw 104 | """ 105 | distance = np.zeros((1, len(x_train))) 106 | for index in range(len(x_train)): 107 | distance[0, index] = distance_dtw(x_train[index], x_test, show=show) 108 | return distance 109 | 110 | 111 | def sample(x, y, sample_num=2, whole_num=8): 112 | index = random.sample(range(whole_num), sample_num) 113 | sample_x = [] 114 | sample_y = [] 115 | for i in range(4): 116 | for _index in index: 117 | sample_x.append(x[_index + whole_num*i]) 118 | sample_y.append(y[_index + whole_num*i]) 119 | return sample_x, sample_y 120 | 121 | 122 | def load_train(path='dataset/ASR/train', mfcc_extract=_MFCC): 123 | """ 124 | load data from dataset/ASR/train and generate template 125 | :param path: the path of dataset 126 | :return: x is train data, y_label is the label of x 127 | """ 128 | start_time = get_time() 129 | # wav_dir is a list, which include four directory in train. 130 | wav_dir = os.listdir(path) 131 | y_label = [] 132 | x = [] 133 | print("Generate template according to train set.") 134 | for _dir in tqdm(wav_dir): 135 | _x = [] 136 | for _path in os.listdir(os.path.join(path, _dir)): 137 | _, data = read(os.path.join(path, _dir, _path)) 138 | # Some audio has two channel, but some audio has one channel. 139 | # so, I add "try except" to deal with such problem. 140 | # downsample the data to 8k 141 | try: 142 | _x.append(mfcc_extract(data[range(0, data.shape[0], 2), 0])) 143 | except: 144 | _x.append(mfcc_extract(data[range(0, data.shape[0], 2)])) 145 | del data 146 | # print(_x[-1].shape) 147 | # generate a template of different speaker. 148 | x.append(generate_template(_x)) 149 | y_label.append(_dir) 150 | 151 | print('Loading train data, extract mfcc feature and generate template spend {}s'.format(get_time(start_time))) 152 | return x,y_label 153 | 154 | 155 | def load_test(path='dataset/ASR/test', mfcc_extract=MFCC, template=False): 156 | """ 157 | load data from dataset/ASR/test 158 | :param path: the path of dataset 159 | :return: x is train data, y_label is the label of x 160 | """ 161 | start_time = get_time() 162 | if template: 163 | # load template directly. 164 | pass 165 | # wav_dir is a list, which include four directory in train. 166 | wav_dir = os.listdir(path) 167 | y_label = [] 168 | x = [] 169 | # enc = OrdinalEncoder() 170 | for _dir in wav_dir: 171 | for _path in os.listdir(os.path.join(path, _dir)): 172 | _, data = read(os.path.join(path, _dir, _path)) 173 | # Some audio has two channel, but some audio has one channel. 174 | # so, I add "try except" to deal with such problem. 175 | # downsample the data to 8k 176 | try: 177 | x.append(mfcc_extract(data[range(0, data.shape[0], 2), 0])) 178 | except: 179 | x.append(mfcc_extract(data[range(0, data.shape[0], 2)])) 180 | del data 181 | y_label.append(_dir) 182 | 183 | print('Loading test data and extract mfcc feature spend {}s'.format(get_time(start_time))) 184 | return x,y_label 185 | 186 | 187 | def generate_template(x): 188 | # max_length is the max length of audio in x. 189 | max_length = -1 190 | 191 | # max_length_index is the index of max length audio. 192 | max_length_index = 0 193 | template = None 194 | for index, _x in enumerate(x): 195 | if _x.shape[0] > max_length: 196 | max_length = _x.shape[0] 197 | max_length_index = index 198 | 199 | template = x[max_length_index] 200 | for index, _x in enumerate(x): 201 | if index != max_length_index: 202 | d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(_x, template, dist='euclidean') 203 | template = (_x[path[0]] + template[path[1]])/2 204 | # the dimension of template will arise after previous step, 205 | # so I will decrease the dimension of template, to keep it to be the same as initial. 206 | pre_road = -1 207 | ind = [] 208 | for current_road in path[1]: 209 | if current_road!=pre_road: 210 | ind.append(True) 211 | else: 212 | ind.append(False) 213 | pre_road = current_road 214 | 215 | template = template[ind] 216 | 217 | return template 218 | 219 | 220 | def vote(label): 221 | label = np.array(label) 222 | _dict = {} 223 | for l in label: 224 | if l not in _dict: 225 | _dict[l] = 1 226 | else: 227 | _dict[l] += 1 228 | 229 | return sorted(_dict.items(), key=lambda x: x[1], reverse=True)[0][0] 230 | 231 | 232 | def test(threshold=100): 233 | x_train,y_train = load_train(path='dataset/ASR/train') 234 | # x_train,y_train = sample(x_train, y_train) 235 | x_test,y_test = load_test(path='dataset/ASR/test') 236 | # x_test, y_test = x_train,y_train 237 | y_pred = [] 238 | # print(len(x_train)) 239 | distances = np.zeros((len(x_test), len(x_train))) 240 | index = 0 241 | for x in tqdm(x_test): 242 | distance = distance_test(x, x_train) 243 | distances[index, :] = distance 244 | # top = np.argsort(distance) 245 | # print(top) 246 | y_pred.append(y_train[np.argmin(distance)]) 247 | index += 1 248 | # when I set threshold to 100, the results is very bad, many sample are classified to other, 249 | # so, I decide to give up threshold, 250 | """ 251 | if np.min(distance) < threshold: 252 | y_pred.append(y_train[np.argmin(distance)]) 253 | else: 254 | y_pred.append('other') 255 | """ 256 | y_pred = np.array(y_pred) 257 | y_test = np.array(y_test) 258 | acc = (y_pred==y_test).sum()/y_test.shape[0] 259 | print("accuracy is {:.2%}".format(acc)) 260 | # distances = np.concatenate([y_test.reshape(-1,1), distances], axis=1) 261 | # print(y_train) 262 | # np.savetxt('distance_template.csv', X=distances, delimiter=',') 263 | np.savetxt('res.csv', X=(y_pred==y_test), delimiter=',') 264 | 265 | plot_confusion_matrix(y_test, y_pred, classes=y_train) 266 | plt.show() 267 | 268 | 269 | def plot(filename): 270 | _, audio = read(filename) 271 | # 语音图 272 | plt.figure() 273 | plt.plot(audio) 274 | # 频谱图 275 | # mel频谱图 276 | # DTW路径图 277 | pass 278 | 279 | 280 | if __name__=='__main__': 281 | test() 282 | """ 283 | x_train,y_train = load_wav(path='dataset/ASR/train') 284 | # distance = distance_dtw(x_train[0], x_train[1]) 285 | print(x_train[0].shape) 286 | d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(x_train[0], x_train[1], dist='euclidean') 287 | print(cost_matrix.shape) 288 | print(acc_cost_matrix.shape) 289 | print('*'*50) 290 | print(path[0].shape) 291 | print('*'*50) 292 | print(path[1].shape) 293 | plt.imshow(acc_cost_matrix.T, origin='lower', cmap='gray', interpolation='nearest') 294 | plt.plot(path[0], path[1], 'w') 295 | plt.show() 296 | """ 297 | 298 | # np.savetxt('dis.csv', X=distance, delimiter=',') 299 | # print(distance_dtw(x[0], x[1], show=True)) 300 | # print(distance_dtw(x[0], x[5], show=True)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # speech_signal_processing 2 | 3 | Any question, you can pull a issue or email me. 4 | 5 | ## Description 6 | 7 | * VAD.py is the first project. 8 | 9 | * MFCC_DTW.py is the second project. 10 | 11 | * GMM_UBM.py is the third project, and GUI.py is the GUI of this project. 12 | 13 | * d_vector.py is final project, and Final_GUI.py is the GUI of this project. 14 | 15 | * feature dir saved model and feature file. you can download it from 16 | [here](https://pan.baidu.com/s/1P65s2OAqqsnnfFOn6XCbSA), code is iwmf. 17 | 18 | * All the reports are in report dir. 19 | 20 | ## Requirement 21 | 22 | python 3.x,windows 23 | 24 | Any other package, run code below 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | or 29 | ``` 30 | pip install dtw librosa fastdtw tqdm sidekit tensorflow keras numpy scipy pyqt sklearn 31 | ``` 32 | 33 | *NOTE*:you can use mirror to speed up,refer [blog](https://www.cnblogs.com/microman/p/6107879.html) 34 | 35 | ## dataset 36 | download dataset for d-vector from [voxceleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/). 37 | 38 | ## Experiment Log 39 | 40 | #### MFCC+DTW 41 | | DTW |Time(s)| Acc(%) | 42 | |:---:|:---:|:---:| 43 | |accelerated_dtw|92|83.72| 44 | |accelerated_dtw+pre-emphasis|105|74.42| 45 | |fastdtw|71|60.47| 46 | |fastdtw+pre-emphasis|79|65.12| 47 | 48 | **Summury**:The results of fastdtw is bad than accelerated_dtw, so I suggest you to use accelerated 49 | rather than fastdtw if you prefer more on accuracy. 50 | 51 | ## MFCC 52 | 53 | [blog](https://kleinzcy.github.io/blog/speech%20signal%20processing/%E6%A2%85%E5%B0%94%E5%80%92%E8%B0%B1%E7%B3%BB%E6%95%B0) 54 | 55 | ## GMM 56 | 57 | [blog](https://appliedmachinelearning.blog/2017/11/14/spoken-speaker-identification-based-on-gaussian-mixture-models-python-implementation/) 58 | 59 | [paper](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.117.338&rep=rep1&type=pdf) 60 | 61 | [scikit-learn](https://scikit-learn.org/stable/modules/mixture.html#gmm) 62 | 63 | [SIDEKIT](https://pypi.org/project/SIDEKIT/) 64 | 65 | ## MFCC+GMM 66 | 67 | Please read report for more details. 68 | 69 | ## d-vector 70 | 71 | we train our model on voxceleb dataset, more details, please read report. 72 | 73 | |model|time(s)|train_acc|valid_acc|epoch|test_acc|test_time| 74 | |:---:|:---:|:---:|:----:|:---:|:---:|:---:| 75 | |nn|56s|0.5321|0.4672|50|0.3682|11.92| 76 | |lstm|2906|0.788|0.5472|100|0.4371|49.53| 77 | |gru|2977|0.9385|0.7484|30|0.3766|70.05| 78 | 79 | **inference**:[paper](https://ieeexplore.ieee.org/document/6854363) 80 | 81 | **inference_gru**:[paper](https://arxiv.org/pdf/1705.02304.pdf) 82 | 83 | **inference_lstm**:[paper](https://arxiv.org/abs/1509.08062) 84 | 85 | ## Reference 86 | 87 | 1. [audio-mnist-with-person-detection](https://github.com/yogeshjadhav7/audio-mnist-with-person-detection) 88 | 89 | 2. [dVectorSpeakerRecognition](https://github.com/wangleiai/dVectorSpeakerRecognition) 90 | 91 | 3. [speaker-verification](https://github.com/rajathkmp/speaker-verification) 92 | 93 | 4. [voxceleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/) 94 | 95 | 96 | -------------------------------------------------------------------------------- /UI/GMM_UBM_GUI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'GMM_UBM_GUI.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.11.3 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | import pickle as pkl 11 | import os 12 | from playsound import playsound 13 | from GMM_UBM import delta 14 | from utils.tools import record,read 15 | import numpy as np 16 | from sidekit.frontend.features import plp,mfcc 17 | from sklearn import preprocessing 18 | from aip import AipSpeech 19 | 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | class Ui_Form(object): 24 | def setupUi(self, Form): 25 | Form.setObjectName("Form") 26 | Form.resize(300, 343) 27 | self.horizontalLayoutWidget = QtWidgets.QWidget(Form) 28 | self.horizontalLayoutWidget.setGeometry(QtCore.QRect(70, 20, 160, 80)) 29 | self.horizontalLayoutWidget.setObjectName("horizontalLayoutWidget") 30 | self.horizontalLayout = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget) 31 | self.horizontalLayout.setContentsMargins(0, 0, 0, 0) 32 | self.horizontalLayout.setObjectName("horizontalLayout") 33 | self.radioButton_2 = QtWidgets.QRadioButton(self.horizontalLayoutWidget) 34 | self.radioButton_2.setObjectName("radioButton_2") 35 | self.horizontalLayout.addWidget(self.radioButton_2) 36 | self.radioButton = QtWidgets.QRadioButton(self.horizontalLayoutWidget) 37 | self.radioButton.setObjectName("radioButton") 38 | self.horizontalLayout.addWidget(self.radioButton) 39 | self.pushButton = QtWidgets.QPushButton(Form) 40 | self.pushButton.setGeometry(QtCore.QRect(100, 120, 93, 28)) 41 | self.pushButton.setObjectName("pushButton") 42 | self.pushButton_2 = QtWidgets.QPushButton(Form) 43 | self.pushButton_2.setGeometry(QtCore.QRect(100, 180, 93, 28)) 44 | self.pushButton_2.setObjectName("pushButton_2") 45 | self.textBrowser = QtWidgets.QTextBrowser(Form) 46 | self.textBrowser.setGeometry(QtCore.QRect(20, 220, 256, 111)) 47 | self.textBrowser.setObjectName("textBrowser") 48 | 49 | self.retranslateUi(Form) 50 | self.radioButton_2.clicked.connect(Form.load_plp) 51 | self.radioButton.clicked.connect(Form.load_mfcc) 52 | self.pushButton.clicked.connect(Form.record) 53 | self.pushButton_2.clicked.connect(Form.test) 54 | QtCore.QMetaObject.connectSlotsByName(Form) 55 | 56 | def retranslateUi(self, Form): 57 | _translate = QtCore.QCoreApplication.translate 58 | Form.setWindowTitle(_translate("Form", "Form")) 59 | self.radioButton_2.setText(_translate("Form", "PLP")) 60 | self.radioButton.setText(_translate("Form", "MFCC")) 61 | self.pushButton.setText(_translate("Form", "Record")) 62 | self.pushButton_2.setText(_translate("Form", "Test")) 63 | 64 | def load_plp(self): 65 | self.textBrowser.clear() 66 | self.textBrowser.append('Loading PLP feature model') 67 | self.feature_type = 'PLP' 68 | self.load() 69 | 70 | def load_mfcc(self): 71 | self.textBrowser.clear() 72 | self.textBrowser.append('Loading MFCC feature model') 73 | self.feature_type = 'MFCC' 74 | self.load() 75 | 76 | def load(self): 77 | with open("Model/GMM_{}_model.pkl".format(self.feature_type), 'rb') as f: 78 | self.GMM = pkl.load(f) 79 | with open("Model/UBM_{}_model.pkl".format(self.feature_type), 'rb') as f: 80 | self.UBM = pkl.load(f) 81 | 82 | self.textBrowser.append('complete') 83 | 84 | 85 | def record(self): 86 | self.textBrowser.append('Start the recording !') 87 | record(seconds=3) 88 | self.textBrowser.append('3 seconds record has completed.') 89 | _, audio = read(filename='test.wav') 90 | if self.feature_type=='MFCC': 91 | feature = mfcc(audio)[0] 92 | else: 93 | feature = plp(audio)[0] 94 | 95 | _delta = delta(feature) 96 | feature = np.hstack((feature, _delta)) 97 | 98 | feature = preprocessing.scale(feature) 99 | self.feature = feature 100 | os.remove('test.wav') 101 | 102 | def test(self): 103 | prob = np.zeros((1, len(self.GMM))) 104 | for i in range(len(self.GMM)): 105 | prob[0,i] = self.GMM[i].score(self.feature) - self.UBM.score(self.feature) 106 | 107 | res = prob.argmax(axis=1) 108 | 109 | num2name = ['班富景', '郭佳怡', '黄心羿', '居慧敏', '廖楚楚', '刘山', '任蕴菡', '阮煜文', '苏林林', '万之颖', 110 | '陈斌', '陈泓宇', '陈军栋', '蔡晓明', '邓刚刚', '董俊虎', '代旭辉', '高威', '龚兵庆', '姜宇伦', 111 | '靳子涵', '李恩', '罗远哲', '罗伟宇', '李想', '李晓波', '李彦能', '刘乙灼', '刘志航', '李忠亚'] 112 | 113 | prob = np.exp(prob)/np.exp(prob).sum(axis=1) 114 | # print(prob) 115 | output = str(num2name[res[0]]) + ':' + str(prob[0,res[0]]) 116 | self.textBrowser.append(output) 117 | APP_ID = '11719204' 118 | API_KEY = 'g7SpqGrkSKgTEBti3pfDsprD' 119 | SECRET_KEY = 'Tn5CS7EE26rDH34H8z7GV3p0DYYpsksZ' 120 | 121 | client = AipSpeech(APP_ID, API_KEY, SECRET_KEY) 122 | result = client.synthesis('{:.2%}的可能是{}'.format(prob[0,res[0]],num2name[res[0]]), 'zh', 0, { 123 | 'vol': 5, 124 | }) 125 | if not isinstance(result, dict): 126 | with open('result.mp3', 'wb') as f: 127 | f.write(result) 128 | 129 | playsound("result.mp3") 130 | os.remove('result.mp3') 131 | -------------------------------------------------------------------------------- /UI/GMM_UBM_GUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Form 4 | 5 | 6 | 7 | 0 8 | 0 9 | 300 10 | 343 11 | 12 | 13 | 14 | Form 15 | 16 | 17 | 18 | 19 | 70 20 | 20 21 | 160 22 | 80 23 | 24 | 25 | 26 | 27 | 28 | 29 | PLP 30 | 31 | 32 | 33 | 34 | 35 | 36 | MFCC 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 100 46 | 120 47 | 93 48 | 28 49 | 50 | 51 | 52 | Record 53 | 54 | 55 | 56 | 57 | 58 | 100 59 | 180 60 | 93 61 | 28 62 | 63 | 64 | 65 | Test 66 | 67 | 68 | 69 | 70 | 71 | 20 72 | 220 73 | 256 74 | 111 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | radioButton_2 83 | clicked() 84 | Form 85 | load_plp() 86 | 87 | 88 | 73 89 | 57 90 | 91 | 92 | 45 93 | 63 94 | 95 | 96 | 97 | 98 | radioButton 99 | clicked() 100 | Form 101 | load_mfcc() 102 | 103 | 104 | 211 105 | 65 106 | 107 | 108 | 260 109 | 62 110 | 111 | 112 | 113 | 114 | pushButton 115 | clicked() 116 | Form 117 | record() 118 | 119 | 120 | 190 121 | 142 122 | 123 | 124 | 229 125 | 134 126 | 127 | 128 | 129 | 130 | pushButton_2 131 | clicked() 132 | Form 133 | test() 134 | 135 | 136 | 173 137 | 190 138 | 139 | 140 | 234 141 | 189 142 | 143 | 144 | 145 | 146 | 147 | load_plp() 148 | load_mfcc() 149 | record() 150 | test() 151 | 152 | 153 | -------------------------------------------------------------------------------- /UI/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/4/20 22:56 4 | # @Author : chuyu zhang 5 | # @File : __init__.py.py 6 | # @Software: PyCharm -------------------------------------------------------------------------------- /UI/final.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'final.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.11.3 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | 11 | class Ui_MainWindow(object): 12 | def setupUi(self, MainWindow): 13 | MainWindow.setObjectName("MainWindow") 14 | MainWindow.resize(692, 561) 15 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 16 | sizePolicy.setHorizontalStretch(0) 17 | sizePolicy.setVerticalStretch(0) 18 | sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth()) 19 | MainWindow.setSizePolicy(sizePolicy) 20 | MainWindow.setFocusPolicy(QtCore.Qt.NoFocus) 21 | MainWindow.setAnimated(True) 22 | self.centralwidget = QtWidgets.QWidget(MainWindow) 23 | self.centralwidget.setObjectName("centralwidget") 24 | self.gridLayout = QtWidgets.QGridLayout(self.centralwidget) 25 | self.gridLayout.setObjectName("gridLayout") 26 | self.verticalLayout_4 = QtWidgets.QVBoxLayout() 27 | self.verticalLayout_4.setContentsMargins(-1, 0, -1, -1) 28 | self.verticalLayout_4.setObjectName("verticalLayout_4") 29 | self.label = QtWidgets.QLabel(self.centralwidget) 30 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 31 | sizePolicy.setHorizontalStretch(0) 32 | sizePolicy.setVerticalStretch(0) 33 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth()) 34 | self.label.setSizePolicy(sizePolicy) 35 | self.label.setMaximumSize(QtCore.QSize(480, 320)) 36 | self.label.setSizeIncrement(QtCore.QSize(1, 1)) 37 | self.label.setText("") 38 | self.label.setPixmap(QtGui.QPixmap(":/newPrefix/1.jpg")) 39 | self.label.setScaledContents(True) 40 | self.label.setAlignment(QtCore.Qt.AlignCenter) 41 | self.label.setObjectName("label") 42 | self.verticalLayout_4.addWidget(self.label) 43 | self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget) 44 | font = QtGui.QFont() 45 | font.setFamily("Adobe 宋体 Std L") 46 | font.setPointSize(15) 47 | self.textBrowser_2.setFont(font) 48 | self.textBrowser_2.setReadOnly(False) 49 | self.textBrowser_2.setCursorWidth(4) 50 | self.textBrowser_2.setObjectName("textBrowser_2") 51 | self.verticalLayout_4.addWidget(self.textBrowser_2) 52 | self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget) 53 | self.textBrowser_3.setObjectName("textBrowser_3") 54 | self.verticalLayout_4.addWidget(self.textBrowser_3) 55 | self.gridLayout.addLayout(self.verticalLayout_4, 0, 1, 1, 1) 56 | self.verticalLayout_1 = QtWidgets.QVBoxLayout() 57 | self.verticalLayout_1.setSpacing(0) 58 | self.verticalLayout_1.setObjectName("verticalLayout_1") 59 | self.horizontalLayout = QtWidgets.QHBoxLayout() 60 | self.horizontalLayout.setContentsMargins(-1, -1, 0, -1) 61 | self.horizontalLayout.setSpacing(10) 62 | self.horizontalLayout.setObjectName("horizontalLayout") 63 | self.groupBox = QtWidgets.QGroupBox(self.centralwidget) 64 | self.groupBox.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China)) 65 | self.groupBox.setObjectName("groupBox") 66 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.groupBox) 67 | self.verticalLayout_3.setObjectName("verticalLayout_3") 68 | self.radioButton = QtWidgets.QRadioButton(self.groupBox) 69 | self.radioButton.setObjectName("radioButton") 70 | self.verticalLayout_3.addWidget(self.radioButton) 71 | self.radioButton_2 = QtWidgets.QRadioButton(self.groupBox) 72 | self.radioButton_2.setObjectName("radioButton_2") 73 | self.verticalLayout_3.addWidget(self.radioButton_2) 74 | self.radioButton_3 = QtWidgets.QRadioButton(self.groupBox) 75 | self.radioButton_3.setObjectName("radioButton_3") 76 | self.verticalLayout_3.addWidget(self.radioButton_3) 77 | self.horizontalLayout.addWidget(self.groupBox) 78 | self.groupBox_2 = QtWidgets.QGroupBox(self.centralwidget) 79 | self.groupBox_2.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China)) 80 | self.groupBox_2.setObjectName("groupBox_2") 81 | self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.groupBox_2) 82 | self.verticalLayout_2.setObjectName("verticalLayout_2") 83 | self.radioButton_4 = QtWidgets.QRadioButton(self.groupBox_2) 84 | self.radioButton_4.setObjectName("radioButton_4") 85 | self.verticalLayout_2.addWidget(self.radioButton_4) 86 | self.radioButton_6 = QtWidgets.QRadioButton(self.groupBox_2) 87 | self.radioButton_6.setObjectName("radioButton_6") 88 | self.verticalLayout_2.addWidget(self.radioButton_6) 89 | self.radioButton_5 = QtWidgets.QRadioButton(self.groupBox_2) 90 | self.radioButton_5.setObjectName("radioButton_5") 91 | self.verticalLayout_2.addWidget(self.radioButton_5) 92 | self.horizontalLayout.addWidget(self.groupBox_2) 93 | self.verticalLayout_1.addLayout(self.horizontalLayout) 94 | self.verticalLayout = QtWidgets.QVBoxLayout() 95 | self.verticalLayout.setObjectName("verticalLayout") 96 | self.groupBox_3 = QtWidgets.QGroupBox(self.centralwidget) 97 | self.groupBox_3.setObjectName("groupBox_3") 98 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.groupBox_3) 99 | self.horizontalLayout_2.setContentsMargins(5, 5, 5, 5) 100 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 101 | self.radioButton_8 = QtWidgets.QRadioButton(self.groupBox_3) 102 | self.radioButton_8.setObjectName("radioButton_8") 103 | self.horizontalLayout_2.addWidget(self.radioButton_8) 104 | self.radioButton_7 = QtWidgets.QRadioButton(self.groupBox_3) 105 | self.radioButton_7.setObjectName("radioButton_7") 106 | self.horizontalLayout_2.addWidget(self.radioButton_7) 107 | self.verticalLayout.addWidget(self.groupBox_3) 108 | self.pushButton = QtWidgets.QPushButton(self.centralwidget) 109 | self.pushButton.setObjectName("pushButton") 110 | self.verticalLayout.addWidget(self.pushButton) 111 | self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget) 112 | self.pushButton_2.setObjectName("pushButton_2") 113 | self.verticalLayout.addWidget(self.pushButton_2) 114 | self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget) 115 | self.pushButton_3.setObjectName("pushButton_3") 116 | self.verticalLayout.addWidget(self.pushButton_3) 117 | self.verticalLayout_1.addLayout(self.verticalLayout) 118 | self.verticalLayout_1.setStretch(0, 1) 119 | self.verticalLayout_1.setStretch(1, 3) 120 | self.gridLayout.addLayout(self.verticalLayout_1, 0, 0, 1, 1) 121 | MainWindow.setCentralWidget(self.centralwidget) 122 | self.menubar = QtWidgets.QMenuBar(MainWindow) 123 | self.menubar.setGeometry(QtCore.QRect(0, 0, 692, 26)) 124 | self.menubar.setObjectName("menubar") 125 | MainWindow.setMenuBar(self.menubar) 126 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 127 | self.statusbar.setObjectName("statusbar") 128 | MainWindow.setStatusBar(self.statusbar) 129 | 130 | self.retranslateUi(MainWindow) 131 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 132 | 133 | def retranslateUi(self, MainWindow): 134 | _translate = QtCore.QCoreApplication.translate 135 | MainWindow.setWindowTitle(_translate("MainWindow", "说护者识别")) 136 | self.textBrowser_2.setHtml(_translate("MainWindow", "\n" 137 | "\n" 140 | "


")) 141 | self.textBrowser_3.setHtml(_translate("MainWindow", "\n" 142 | "\n" 145 | "


")) 146 | self.groupBox.setTitle(_translate("MainWindow", "说话人识别")) 147 | self.radioButton.setText(_translate("MainWindow", "NN")) 148 | self.radioButton_2.setText(_translate("MainWindow", "GRU")) 149 | self.radioButton_3.setText(_translate("MainWindow", "LSTM")) 150 | self.groupBox_2.setTitle(_translate("MainWindow", "语种识别")) 151 | self.radioButton_4.setText(_translate("MainWindow", "MFCC")) 152 | self.radioButton_6.setText(_translate("MainWindow", "PLP")) 153 | self.radioButton_5.setText(_translate("MainWindow", "MFCC+PLP")) 154 | self.groupBox_3.setTitle(_translate("MainWindow", "mode")) 155 | self.radioButton_8.setText(_translate("MainWindow", "说话人识别")) 156 | self.radioButton_7.setText(_translate("MainWindow", "语种识别")) 157 | self.pushButton.setText(_translate("MainWindow", "Enroll")) 158 | self.pushButton_2.setText(_translate("MainWindow", "Test")) 159 | self.pushButton_3.setText(_translate("MainWindow", "Set")) 160 | 161 | import source_rc 162 | -------------------------------------------------------------------------------- /UI/final.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 692 10 | 561 11 | 12 | 13 | 14 | 15 | 0 16 | 0 17 | 18 | 19 | 20 | Qt::NoFocus 21 | 22 | 23 | 说护者识别 24 | 25 | 26 | true 27 | 28 | 29 | 30 | 31 | 32 | 33 | 0 34 | 35 | 36 | 37 | 38 | 39 | 0 40 | 0 41 | 42 | 43 | 44 | 45 | 480 46 | 320 47 | 48 | 49 | 50 | 51 | 1 52 | 1 53 | 54 | 55 | 56 | 57 | 58 | 59 | :/newPrefix/1.jpg 60 | 61 | 62 | true 63 | 64 | 65 | Qt::AlignCenter 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | Adobe 宋体 Std L 74 | 15 75 | 76 | 77 | 78 | false 79 | 80 | 81 | <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> 82 | <html><head><meta name="qrichtext" content="1" /><style type="text/css"> 83 | p, li { white-space: pre-wrap; } 84 | </style></head><body style=" font-family:'Adobe 宋体 Std L'; font-size:15pt; font-weight:400; font-style:normal;"> 85 | <p style="-qt-paragraph-type:empty; margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px; font-family:'SimSun'; font-size:9pt;"><br /></p></body></html> 86 | 87 | 88 | 4 89 | 90 | 91 | 92 | 93 | 94 | 95 | <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> 96 | <html><head><meta name="qrichtext" content="1" /><style type="text/css"> 97 | p, li { white-space: pre-wrap; } 98 | </style></head><body style=" font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;"> 99 | <p style="-qt-paragraph-type:empty; margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><br /></p></body></html> 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 0 109 | 110 | 111 | 112 | 113 | 10 114 | 115 | 116 | 0 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 说话人识别 125 | 126 | 127 | 128 | 129 | 130 | NN 131 | 132 | 133 | 134 | 135 | 136 | 137 | GRU 138 | 139 | 140 | 141 | 142 | 143 | 144 | LSTM 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 语种识别 158 | 159 | 160 | 161 | 162 | 163 | MFCC 164 | 165 | 166 | 167 | 168 | 169 | 170 | PLP 171 | 172 | 173 | 174 | 175 | 176 | 177 | MFCC+PLP 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | mode 192 | 193 | 194 | 195 | 5 196 | 197 | 198 | 5 199 | 200 | 201 | 5 202 | 203 | 204 | 5 205 | 206 | 207 | 208 | 209 | 说话人识别 210 | 211 | 212 | 213 | 214 | 215 | 216 | 语种识别 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | Enroll 227 | 228 | 229 | 230 | 231 | 232 | 233 | Test 234 | 235 | 236 | 237 | 238 | 239 | 240 | Set 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 0 254 | 0 255 | 692 256 | 26 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | -------------------------------------------------------------------------------- /UI/tmp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'final.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.11.3 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | from PyQt5.QtMultimedia import QMediaContent, QMediaPlayer 11 | from PyQt5.QtMultimediaWidgets import QVideoWidget 12 | from d_vector import nn_model 13 | from keras.models import load_model 14 | import pickle as pkl 15 | import numpy as np 16 | from scipy.spatial.distance import cosine 17 | from sidekit.frontend.features import plp,mfcc 18 | from sklearn import preprocessing 19 | from utils.tools import record,read 20 | from tqdm import tqdm 21 | import os 22 | import python_speech_features as psf 23 | #TODO 测试录音数据 24 | #TODO 实验报告 25 | class Ui_MainWindow(object): 26 | def setupUi(self, MainWindow): 27 | #TODO 播放条 28 | #TODO 上下拉升问题 29 | MainWindow.setObjectName("MainWindow") 30 | MainWindow.resize(739, 583) 31 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 32 | sizePolicy.setHorizontalStretch(0) 33 | sizePolicy.setVerticalStretch(0) 34 | sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth()) 35 | MainWindow.setSizePolicy(sizePolicy) 36 | MainWindow.setFocusPolicy(QtCore.Qt.NoFocus) 37 | MainWindow.setAnimated(True) 38 | 39 | self.centralwidget = QtWidgets.QWidget(MainWindow) 40 | self.centralwidget.setObjectName("centralwidget") 41 | self.gridLayout = QtWidgets.QGridLayout(self.centralwidget) 42 | self.gridLayout.setObjectName("gridLayout") 43 | 44 | # verticalLayout_1这是groupbox和pushbutton垂直排列 45 | self.verticalLayout_1 = QtWidgets.QVBoxLayout() 46 | self.verticalLayout_1.setSpacing(40) 47 | self.verticalLayout_1.setObjectName("verticalLayout_1") 48 | 49 | # verticalLayout_5是三个groupbox横向排列 50 | self.verticalLayout_5 = QtWidgets.QVBoxLayout() 51 | self.verticalLayout_5.setObjectName("verticalLayout_5") 52 | 53 | # groupbox_3 54 | self.groupBox_3 = QtWidgets.QGroupBox(self.centralwidget) 55 | self.groupBox_3.setObjectName("groupBox_3") 56 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.groupBox_3) 57 | self.horizontalLayout_2.setContentsMargins(5, 5, 5, 5) 58 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 59 | self.radioButton_7 = QtWidgets.QRadioButton(self.groupBox_3) 60 | self.radioButton_7.setObjectName("radioButton_7") 61 | self.horizontalLayout_2.addWidget(self.radioButton_7) 62 | self.radioButton_8 = QtWidgets.QRadioButton(self.groupBox_3) 63 | self.radioButton_8.setObjectName("radioButton_8") 64 | self.horizontalLayout_2.addWidget(self.radioButton_8) 65 | 66 | self.verticalLayout_5.addWidget(self.groupBox_3) 67 | 68 | # horizontalLayout这是两个groupbox横向排列 69 | self.horizontalLayout = QtWidgets.QHBoxLayout() 70 | self.horizontalLayout.setContentsMargins(-1, -1, 0, -1) 71 | self.horizontalLayout.setSpacing(10) 72 | self.horizontalLayout.setObjectName("horizontalLayout") 73 | 74 | # groupbox 75 | self.groupBox = QtWidgets.QGroupBox(self.centralwidget) 76 | self.groupBox.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China)) 77 | self.groupBox.setObjectName("groupBox") 78 | 79 | # verticalLayout_3这是groupbox内垂直排列 80 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.groupBox) 81 | self.verticalLayout_3.setObjectName("verticalLayout_3") 82 | self.verticalLayout_3.setSpacing(10) 83 | 84 | self.radioButton = QtWidgets.QRadioButton(self.groupBox) 85 | self.radioButton.setObjectName("radioButton") 86 | self.verticalLayout_3.addWidget(self.radioButton) 87 | self.radioButton_2 = QtWidgets.QRadioButton(self.groupBox) 88 | self.radioButton_2.setObjectName("radioButton_2") 89 | self.verticalLayout_3.addWidget(self.radioButton_2) 90 | self.radioButton_3 = QtWidgets.QRadioButton(self.groupBox) 91 | self.radioButton_3.setObjectName("radioButton_3") 92 | self.verticalLayout_3.addWidget(self.radioButton_3) 93 | self.horizontalLayout.addWidget(self.groupBox) 94 | 95 | # groupbox_2 96 | self.groupBox_2 = QtWidgets.QGroupBox(self.centralwidget) 97 | self.groupBox_2.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China)) 98 | self.groupBox_2.setObjectName("groupBox_2") 99 | 100 | # verticalLayout_2这是groupbox_2内垂直排列 101 | self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.groupBox_2) 102 | self.verticalLayout_2.setObjectName("verticalLayout_2") 103 | self.verticalLayout_2.setSpacing(10) 104 | 105 | self.radioButton_4 = QtWidgets.QRadioButton(self.groupBox_2) 106 | self.radioButton_4.setObjectName("radioButton_4") 107 | self.verticalLayout_2.addWidget(self.radioButton_4) 108 | self.radioButton_5 = QtWidgets.QRadioButton(self.groupBox_2) 109 | self.radioButton_5.setObjectName("radioButton_5") 110 | self.verticalLayout_2.addWidget(self.radioButton_5) 111 | self.radioButton_6 = QtWidgets.QRadioButton(self.groupBox_2) 112 | self.radioButton_6.setObjectName("radioButton_6") 113 | self.verticalLayout_2.addWidget(self.radioButton_6) 114 | self.horizontalLayout.addWidget(self.groupBox_2) 115 | 116 | self.verticalLayout_5.addLayout(self.horizontalLayout) 117 | 118 | self.verticalLayout_1.addLayout(self.verticalLayout_5) 119 | 120 | # verticalLayout这是三个pushbutton垂直排列 121 | self.verticalLayout = QtWidgets.QVBoxLayout() 122 | self.verticalLayout.setObjectName("verticalLayout") 123 | self.verticalLayout.setSpacing(50) 124 | 125 | self.pushButton = QtWidgets.QPushButton(self.centralwidget) 126 | self.pushButton.setObjectName("pushButton") 127 | self.verticalLayout.addWidget(self.pushButton) 128 | self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget) 129 | self.pushButton_2.setObjectName("pushButton_2") 130 | self.verticalLayout.addWidget(self.pushButton_2) 131 | self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget) 132 | self.pushButton_3.setObjectName("pushButton_3") 133 | self.verticalLayout.addWidget(self.pushButton_3) 134 | 135 | self.verticalLayout_1.addLayout(self.verticalLayout) 136 | 137 | self.verticalLayout_1.setStretch(1, 4) 138 | 139 | self.gridLayout.addLayout(self.verticalLayout_1, 0, 0, 1, 1) 140 | 141 | # verticalLayout_4这是三个文本框的垂直排列 142 | self.verticalLayout_4 = QtWidgets.QVBoxLayout() 143 | self.verticalLayout_4.setContentsMargins(-1, 0, -1, -1) 144 | self.verticalLayout_4.setObjectName("verticalLayout_4") 145 | 146 | self.label = QtWidgets.QLabel(self.centralwidget) 147 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 148 | sizePolicy.setHorizontalStretch(0) 149 | sizePolicy.setVerticalStretch(0) 150 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth()) 151 | self.label.setSizePolicy(sizePolicy) 152 | self.label.setMaximumSize(QtCore.QSize(1080, 960)) 153 | self.label.setSizeIncrement(QtCore.QSize(1, 1)) 154 | self.label.setText("") 155 | self.label.setPixmap(QtGui.QPixmap("img/1.jpg")) 156 | # self.label.setPixmap(QtGui.QPixmap(":/newPrefix/1.jpg")) 157 | self.label.setScaledContents(True) 158 | self.label.setAlignment(QtCore.Qt.AlignCenter) 159 | self.label.setObjectName("label") 160 | 161 | self.verticalLayout_4.addWidget(self.label) 162 | 163 | 164 | self.playButton = QtWidgets.QPushButton() 165 | self.playButton.setEnabled(False) 166 | self.playButton.setIcon(MainWindow.style().standardIcon(QtWidgets.QStyle.SP_MediaPlay)) 167 | 168 | self.positionSlider = QtWidgets.QSlider(QtCore.Qt.Horizontal) 169 | self.positionSlider.setRange(0, 0) 170 | 171 | controlLayout = QtWidgets.QHBoxLayout() 172 | controlLayout.setContentsMargins(0, 0, 0, 0) 173 | controlLayout.addWidget(self.playButton) 174 | controlLayout.addWidget(self.positionSlider) 175 | 176 | self.verticalLayout_4.addLayout(controlLayout) 177 | 178 | self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget) 179 | # self.textBrowser.setReadOnly(True) 180 | self.textBrowser.setObjectName("textBrowser") 181 | self.verticalLayout_4.addWidget(self.textBrowser) 182 | 183 | self.gridLayout.addLayout(self.verticalLayout_4, 0, 1, 1, 1) 184 | MainWindow.setCentralWidget(self.centralwidget) 185 | self.menubar = QtWidgets.QMenuBar(MainWindow) 186 | self.menubar.setGeometry(QtCore.QRect(0, 0, 739, 26)) 187 | self.menubar.setObjectName("menubar") 188 | MainWindow.setMenuBar(self.menubar) 189 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 190 | self.statusbar.setObjectName("statusbar") 191 | MainWindow.setStatusBar(self.statusbar) 192 | 193 | self.retranslateUi(MainWindow) 194 | self.signal(MainWindow) 195 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 196 | 197 | def retranslateUi(self, MainWindow): 198 | _translate = QtCore.QCoreApplication.translate 199 | MainWindow.setWindowTitle(_translate("MainWindow", "说护者识别")) 200 | self.groupBox.setTitle(_translate("MainWindow", "说话人识别")) 201 | self.radioButton.setText(_translate("MainWindow", "NN")) 202 | self.radioButton_2.setText(_translate("MainWindow", "GRU")) 203 | self.radioButton_3.setText(_translate("MainWindow", "LSTM")) 204 | self.groupBox_2.setTitle(_translate("MainWindow", "语种识别")) 205 | self.radioButton_4.setText(_translate("MainWindow", "MFCC")) 206 | self.radioButton_5.setText(_translate("MainWindow", "PLP")) 207 | self.radioButton_6.setText(_translate("MainWindow", "MFCC+PLP")) 208 | self.groupBox_3.setTitle(_translate("MainWindow", "mode")) 209 | self.radioButton_7.setText(_translate("MainWindow", "说话人识别")) 210 | self.radioButton_8.setText(_translate("MainWindow", "语种识别")) 211 | self.pushButton.setText(_translate("MainWindow", "Enroll")) 212 | self.pushButton_2.setText(_translate("MainWindow", "Record")) 213 | self.pushButton_3.setText(_translate("MainWindow", "Test")) 214 | self.textBrowser.setHtml(_translate("MainWindow", 215 | "\n" 216 | "\n")) 219 | 220 | def signal(self, MainWindow): 221 | self.radioButton.clicked.connect(MainWindow.nn_model) 222 | self.radioButton_2.clicked.connect(MainWindow.gru_model) 223 | self.radioButton_3.clicked.connect(MainWindow.lstm_model) 224 | self.radioButton_4.clicked.connect(MainWindow.mfcc_fea) 225 | self.radioButton_5.clicked.connect(MainWindow.plp_fea) 226 | self.radioButton_6.clicked.connect(MainWindow.mfcc_plp_fea) 227 | self.radioButton_7.clicked.connect(MainWindow.speaker_model) 228 | self.radioButton_8.clicked.connect(MainWindow.language_model) 229 | 230 | self.pushButton.clicked.connect(MainWindow.enroll) 231 | self.pushButton_2.clicked.connect(MainWindow.records) 232 | self.pushButton_3.clicked.connect(MainWindow.test) 233 | 234 | # self.playButton.clicked.connect(MainWindow.play) 235 | # self.positionSlider.sliderMoved.connect(MainWindow.setPosition) 236 | 237 | def speaker_model(self): 238 | self.textBrowser.clear() 239 | self.textBrowser.append('Speaker model') 240 | self.state_speaker = True 241 | self.state_language = False 242 | 243 | def language_model(self): 244 | self.textBrowser.clear() 245 | self.textBrowser.append('Language model') 246 | self.state_speaker = False 247 | self.state_language = True 248 | 249 | def nn_model(self): 250 | self._load_speaker_model(model_type='nn') 251 | 252 | def gru_model(self): 253 | self._load_speaker_model(model_type='gru') 254 | 255 | def lstm_model(self): 256 | self._load_speaker_model(model_type='lstm') 257 | 258 | def _load_speaker_model(self, model_type='lstm'): 259 | if self.state_speaker: 260 | self.textBrowser.clear() 261 | self.textBrowser.append('Loading {} d-vector model'.format(model_type)) 262 | self.model_type = model_type 263 | self.model = load_model('feature/d_vector/d_vector_{}.h5'.format(model_type)) 264 | self.textBrowser.append('finished!') 265 | 266 | def mfcc_fea(self): 267 | self._load_language_model(feature_type='MFCC') 268 | 269 | def plp_fea(self): 270 | self._load_language_model(feature_type='PLP') 271 | 272 | def mfcc_plp_fea(self): 273 | self._load_language_model(feature_type='MFCC_PLP') 274 | 275 | def _load_language_model(self, feature_type='MFCC'): 276 | if self.state_language: 277 | self.textBrowser.clear() 278 | self.textBrowser.append('Loading GMM-UBM {} model'.format(feature_type)) 279 | self.feature_type = feature_type 280 | with open("feature/language/GMM_8_" + feature_type + "_model.pkl", 'rb') as f: 281 | self.GMM = pkl.load(f) 282 | with open("feature/language/UBM_8_" + feature_type + "_model.pkl", 'rb') as f: 283 | self.UBM = pkl.load(f) 284 | 285 | self.textBrowser.append('finished!') 286 | 287 | def test(self): 288 | if self.state_language: 289 | # 语种识别 290 | self._GMM_test() 291 | elif self.state_speaker: 292 | # 说话人识别 293 | self._dvector_test() 294 | 295 | def records(self): 296 | seconds, ok = QtWidgets.QInputDialog.getText(self, "records", 297 | "please input how much seconds:", QtWidgets.QLineEdit.Normal) 298 | self.textBrowser.append('{}s'.format(seconds)) 299 | record(seconds=int(seconds)) 300 | self.textBrowser.append('{} seconds record has completed.'.format(seconds)) 301 | sample, audio = read(filename='test.wav') 302 | # print("the length of audio:", str(len(audio))) 303 | testone = [] 304 | testone_feature = [] 305 | sample = 16000 306 | for i in range(len(audio) // sample): 307 | testone.append(audio[i * sample:(i + 1) * sample]) 308 | if self.state_language: 309 | for i in tqdm(range(len(testone))): 310 | try: 311 | _feature = None 312 | if self.feature_type == 'MFCC': 313 | _feature = mfcc(testone[i])[0] 314 | elif self.feature_type == 'PLP': 315 | _feature = plp(testone[i])[0] 316 | elif self.feature_type == 'MFCC_PLP': 317 | _feature1 = mfcc(testone[i])[0] 318 | _feature2 = plp(testone[i])[0] 319 | _feature = np.hstack((_feature1, _feature2)) 320 | 321 | _feature = preprocessing.scale(_feature) 322 | except ValueError: 323 | continue 324 | testone_feature.append(_feature) 325 | os.remove('test.wav') 326 | self.gmm_feature = testone_feature 327 | elif self.state_speaker: 328 | for i in tqdm(range(len(testone))): 329 | try: 330 | # fs=8000会出错 331 | _feature = mfcc(testone[i])[0] 332 | except : 333 | continue 334 | testone_feature.append(_feature) 335 | self.d_vector_feature = testone_feature 336 | 337 | def _GMM_test(self): 338 | testone = self.gmm_feature 339 | pred = np.zeros((len(testone), len(self.GMM))) 340 | for i in range(len(self.GMM)): 341 | for j in range(len(testone)): 342 | pred[j, i] = self.GMM[i].score(testone[j]) - self.UBM.score(testone[j]) 343 | 344 | prob = np.exp(pred.max(axis=1))/np.exp(pred).sum(axis=1) 345 | prob_str = [] 346 | for p in prob: 347 | prob_str.append('{:.2%}'.format(p)) 348 | print(prob) 349 | pred = pred.argmax(axis=1) 350 | res = [] 351 | for i in range(pred.shape[0]): 352 | if pred[i]==0: 353 | res.append('Chinese') 354 | elif pred[i]==1: 355 | res.append('English') 356 | else: 357 | res.append('Japanese') 358 | self.textBrowser.append('The result is: ') 359 | self.textBrowser.append(' '.join(res)) 360 | self.textBrowser.append(' '.join(prob_str)) 361 | 362 | def _dvector_test(self): 363 | # d_vector_feature是一个列表,每一个元素存储一秒语音的特征。 364 | with open('feature/d_vector/d_vector.pkl', 'rb') as f: 365 | d_vector = pkl.load(f) 366 | 367 | pred = [] 368 | target = np.array(self.d_vector_feature) 369 | for i in range(len(self.d_vector_feature)): 370 | # 根据对应的模型,调整输入格式 371 | if self.model_type=='lstm': 372 | pass 373 | elif self.model_type=='nn': 374 | target = target.reshape(target.shape[0], -1) 375 | else: 376 | target = target[:,:,:,np.newaxis] 377 | 378 | target = self.model.predict(target) 379 | 380 | # TODO 后续增加概率计算 381 | prob = [] 382 | for i in range(target.shape[0]): 383 | min_distance = 1 384 | target_name = None 385 | distance_list = [] 386 | for name in d_vector.keys(): 387 | distance_list.append(cosine(target[i,:], d_vector[name])) 388 | if min_distance > distance_list[-1]: 389 | min_distance = distance_list[-1] 390 | target_name = name 391 | pred.append(target_name) 392 | distance = -np.array(distance_list) 393 | prob.append('{:.2%}'.format(np.exp(distance.max())/np.exp(distance).sum())) 394 | self.textBrowser.append('The result is :') 395 | self.textBrowser.append(' '.join(pred)) 396 | self.textBrowser.append(' '.join(prob)) 397 | 398 | def enroll(self): 399 | """ 400 | 注册一个陌生人到库中,以字典形式保存 401 | :param X_train: 样本语音 402 | :param name: 该样本语音的人名,唯一标识,不可重复。 403 | :param model_name: 使用模型的名字,nn,lstm,gru 404 | :return: none 405 | """ 406 | name, ok = QtWidgets.QInputDialog.getText(self, "Enroll", 407 | "please input your name:", QtWidgets.QLineEdit.Normal) 408 | self.textBrowser.append('your name is {}'.format(name)) 409 | self.records() 410 | try: 411 | with open('feature/d_vector/d_vector.pkl', 'rb') as f: 412 | d_vector = pkl.load(f) 413 | except: 414 | d_vector = {} 415 | 416 | target = np.array(self.d_vector_feature) 417 | for i in range(len(self.d_vector_feature)): 418 | # 根据对应的模型,调整输入格式 419 | if self.model_type=='lstm': 420 | pass 421 | elif self.model_type=='nn': 422 | target = target.reshape(target.shape[0], -1) 423 | else: 424 | target = target[:,:,:,np.newaxis] 425 | 426 | target = self.model.predict(target) 427 | print(target.shape) 428 | if name in d_vector: 429 | self.textBrowser.append('your name is already exist') 430 | d_vector[name] = (d_vector[name] + target.mean(axis=0))/2 431 | else: 432 | d_vector[name] = target.mean(axis=0) 433 | 434 | with open('feature/d_vector/d_vector.pkl', 'wb') as f: 435 | pkl.dump(d_vector, f) 436 | 437 | self.textBrowser.append('Finished') -------------------------------------------------------------------------------- /VAD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/3/13 13:10 4 | # @Author : chuyu zhang 5 | # @File : VAD.py 6 | # @Software: PyCharm 7 | 8 | import math 9 | import numpy as np 10 | from scipy.io import loadmat 11 | from utils.tools import wave_read, read, plot_confusion_matrix 12 | import glob 13 | import time 14 | import matplotlib.pyplot as plt 15 | from sklearn.metrics import confusion_matrix 16 | import seaborn as sns 17 | from sklearn import metrics 18 | from bayes_opt import BayesianOptimization 19 | import pickle as pkl 20 | from scipy import signal 21 | # 计算每一帧的过零率 22 | 23 | frameSize = 256 24 | overlap = 128 25 | 26 | # 分帧处理函数 27 | # 不加窗 28 | def enframe(wavData): 29 | """ 30 | frame the wav data, according to frameSize and overlap 31 | :param wavData: the input wav data, ndarray 32 | :return:frameData, shape 33 | """ 34 | # coef = 0.97 # 预加重系数 35 | wlen = wavData.shape[0] 36 | step = frameSize - overlap 37 | frameNum:int = math.ceil(wlen / step) 38 | frameData = np.zeros((frameSize, frameNum)) 39 | 40 | # hamwin = np.hamming(frameSize) 41 | 42 | for i in range(frameNum): 43 | singleFrame = wavData[np.arange(i * step, min(i * step + frameSize, wlen))] 44 | # b, a = signal.butter(8, 1, 'lowpass') 45 | # filtedData = signal.filtfilt(b, a, data) 46 | # singleFrame = np.append(singleFrame[0], singleFrame[:-1] - coef * singleFrame[1:]) # 预加重 47 | frameData[:len(singleFrame), i] = singleFrame.reshape(-1, 1)[:, 0] 48 | # frameData[:, i] = hamwin * frameData[:, i] # 加窗 49 | 50 | return frameData 51 | 52 | 53 | def ZCR(frameData): 54 | frameNum = frameData.shape[1] 55 | frameSize = frameData.shape[0] 56 | zcr = np.zeros((frameNum, 1)) 57 | 58 | for i in range(frameNum): 59 | singleFrame = frameData[:, i] 60 | temp = singleFrame[:frameSize-1] * singleFrame[1:frameSize] 61 | temp = np.sign(temp) 62 | zcr[i] = np.sum(temp<0) 63 | 64 | return zcr 65 | 66 | # 计算每一帧能量 67 | def energy(frameData): 68 | frameNum = frameData.shape[1] 69 | 70 | frame_energy = np.zeros((frameNum, 1)) 71 | 72 | for i in range(frameNum): 73 | single_frame = frameData[:, i] 74 | frame_energy[i] = sum(single_frame * single_frame) 75 | 76 | return frame_energy 77 | 78 | 79 | def stSpectralEntropy(X, n_short_blocks=10, eps=1e-8): 80 | """Computes the spectral entropy""" 81 | L = len(X) # number of frame samples 82 | Eol = np.sum(X ** 2) # total spectral energy 83 | 84 | sub_win_len = int(np.floor(L / n_short_blocks)) # length of sub-frame 85 | if L != sub_win_len * n_short_blocks: 86 | X = X[0:sub_win_len * n_short_blocks] 87 | 88 | sub_wins = X.reshape(sub_win_len, n_short_blocks, order='F').copy() # define sub-frames (using matrix reshape) 89 | s = np.sum(sub_wins ** 2, axis=0) / (Eol + eps) # compute spectral sub-energies 90 | En = -np.sum(s*np.log2(s + eps)) # compute spectral entropy 91 | 92 | return En 93 | 94 | 95 | def spectrum_entropy(frameData): 96 | frameNum = frameData.shape[1] 97 | 98 | frame_spectrum_entropy = np.zeros((frameNum, 1)) 99 | 100 | for i in range(frameNum): 101 | X = np.fft.fft(frameData[:, i]) 102 | X = np.abs(X) 103 | frame_spectrum_entropy[i] = stSpectralEntropy(X[:int(frameSize/2)]) 104 | 105 | return frame_spectrum_entropy 106 | 107 | 108 | def feature(waveData): 109 | # print("feature extract !") 110 | start = time.time() 111 | power = energy(waveData) 112 | zcr = ZCR(waveData) * (power>0.1) 113 | end = time.time() 114 | spectrumentropy = spectrum_entropy(waveData) 115 | 116 | print('feature extract completed, time feature spend {}s, frequency domain spend {}s'. 117 | format(end-start, time.time() - end)) 118 | 119 | return zcr, power, spectrumentropy 120 | 121 | 122 | # framesize为帧长,overlap为帧移 123 | def wavdata(wavfile): 124 | f = wave_read(wavfile) 125 | params = f.getparams() 126 | nchannels, sampwidth, framerate, nframes = params[:4] 127 | strData = f.readframes(nframes) # 读取音频,字符串格式 128 | # print(type(strData)) 129 | waveData = np.fromstring(strData, dtype=np.int16) 130 | # print(waveData.shape) 131 | waveData = waveData/(max(abs(waveData))) 132 | return enframe(waveData) 133 | 134 | 135 | # 首先判断能量,如果能量低于ampl,则认为是噪音(静音),如果能量高于amph则认为是语音,如果能量处于两者之前则认为是清音。 136 | def VAD_detection(zcr, power, zcr_gate=35, ampl=0.3, amph=12): 137 | # 最短语音帧数 138 | min_len = 16 139 | # 两段语音间的最短间隔 140 | min_distance = 21 141 | # 标记量,status:0为静音状态,1为清音状态,2为浊音状态 142 | status = 0 143 | # speech = 0 144 | start = 0 145 | end = 0 146 | last_end = -1 147 | 148 | res = np.zeros((zcr.shape[0], 1)) 149 | 150 | for i in range(zcr.shape[0]): 151 | if power[i] > amph: 152 | # 此处是浊音状态,记录end即可 153 | if status != 1: 154 | start = i 155 | 156 | end = i 157 | status = 1 158 | # print(start - end) 159 | elif end - start + 1 > min_len: 160 | 161 | while(power[start] > ampl or zcr[start] > zcr_gate): 162 | start -= 1 163 | 164 | start += 1 165 | 166 | while(power[end] > ampl or zcr[end] > zcr_gate): 167 | end += 1 168 | if end == power.shape[0]: 169 | break 170 | 171 | end -= 1 172 | if last_end > 0 and start - last_end < min_distance: 173 | res[last_end : end + 1] = 1 174 | last_end = end 175 | else: 176 | res[start: end + 1] = 1 177 | 178 | start = 0 179 | end = 0 180 | status = 0 181 | 182 | return res 183 | 184 | 185 | def VAD_frequency(spectrum): 186 | return np.where(spectrum>0.4, 0, 1) 187 | 188 | 189 | def optimize(X, y): 190 | zcr, power, spectrumentropy = feature(X) 191 | """ 192 | sns.distplot(zcr) 193 | plt.show() 194 | sns.distplot(power) 195 | plt.show() 196 | """ 197 | params ={ 198 | 'zcr_gate': (20, 40), 199 | 'ampl': (0.3, 4), 200 | 'amph': (5, 15) 201 | } 202 | y = y.reshape(1, -1) 203 | 204 | def cv(zcr_gate, ampl, amph): 205 | res = VAD_detection(zcr, power, zcr_gate=zcr_gate, amph=amph, ampl=ampl) 206 | # print((res==0).sum()/res.shape[0]) 207 | res = res.reshape(1, -1) 208 | # metrics.precision_score(y[0], res[0]) 209 | # accuracy = (y == res).sum() / y.shape[0] 210 | return metrics.f1_score(y[0], res[0]) 211 | 212 | BO = BayesianOptimization(cv, params) 213 | 214 | start_time = time.time() 215 | BO.maximize(n_iter=30) 216 | end_time = time.time() 217 | print("Final result:{}, spend {}s".format(BO.max, end_time - start_time)) 218 | best_params = BO.max['params'] 219 | 220 | return best_params 221 | 222 | # 处理mat文件,统计一个帧数中静音和语音的数量,给这个帧数一个label,具体规则后续完善 223 | def label(mat_file): 224 | mat = loadmat(mat_file) 225 | y_label = mat['y_label'] 226 | y_label = enframe(y_label) 227 | 228 | return np.where(y_label.sum(axis=0) > 0, 1, 0) 229 | 230 | 231 | def main(wav, mat): 232 | start = time.time() 233 | print(wav.split('\\')[-1]) 234 | data = wavdata(wav) 235 | y_label = label(mat) 236 | y_label = y_label.reshape(-1, 1) 237 | zcr, power, spectrumentropy = feature(data) 238 | 239 | s1 = time.time() 240 | res1 = VAD_detection(zcr, power) 241 | s2 = time.time() 242 | # res1 = res1.reshape(1, -1) 243 | res2 = VAD_frequency(spectrumentropy) 244 | end = time.time() 245 | # res2 = res2.reshape(1, -1) 246 | td = metrics.f1_score(y_label, res1) 247 | fd = metrics.f1_score(y_label, res2) 248 | np.set_printoptions(precision=2) 249 | 250 | plot_confusion_matrix(y_label.tolist(), res1.tolist(), classes=['silence', 'speech'], normalize=False) 251 | plt.savefig(wav.split('\\')[-1].split('.')[0] + '.png') 252 | plt.show() 253 | end = time.time() 254 | print('time domain res:{:.2%}, spend {}s, frequency domain res:{:.2%}, spend {}s, ' 255 | 'spend {}s totally'.format(td, s2-s1, fd, end-s2, end - start)) 256 | 257 | 258 | if __name__=='__main__': 259 | wavfile = glob.glob(r'dataset\VAD\*.wav') 260 | matfile = glob.glob(r'dataset\VAD\*.mat') 261 | 262 | for wav, mat in zip(wavfile, matfile): 263 | main(wav, mat) -------------------------------------------------------------------------------- /d_vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/15 22:56 4 | # @Author : chuyu zhang 5 | # @File : d_vector.py 6 | # @Software: PyCharm 7 | 8 | import os 9 | import pickle as pkl 10 | import numpy as np 11 | from utils.tools import read, get_time 12 | from tqdm import tqdm 13 | from sklearn import preprocessing 14 | from sklearn.model_selection import train_test_split 15 | from scipy.spatial.distance import cosine 16 | from keras.models import Model,Sequential,load_model 17 | 18 | from sidekit.frontend.features import plp,mfcc 19 | 20 | from keras.layers import Dense, Activation, Dropout, Input, GRU, LSTM, Flatten,Convolution2D, MaxPooling2D,Convolution1D 21 | from keras.optimizers import Adam 22 | from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint,CSVLogger 23 | from keras import regularizers 24 | from keras.layers.core import Reshape,Masking,Lambda,Permute 25 | import keras.backend as K 26 | from keras.layers.wrappers import TimeDistributed 27 | import keras 28 | 29 | class Data_gen: 30 | # 生成数据 31 | def __init__(self): 32 | pass 33 | 34 | def _load(self): 35 | """ 36 | load audio file. 37 | :param path: the dir to audio file 38 | :return: x type:list,each element is an audio, y type:list,it is the label of x 39 | """ 40 | start_time = get_time() 41 | path = self.path 42 | print("Loading data...") 43 | speaker_list = os.listdir(path) 44 | y = [] 45 | x = [] 46 | for speaker in tqdm(speaker_list): 47 | path1 = os.path.join(path, speaker) 48 | for _dir in os.listdir(path1): 49 | path2 = os.path.join(path1, _dir) 50 | for _wav in os.listdir(path2): 51 | self.sample_rate, audio = read(os.path.join(path2, _wav)) 52 | y.append(speaker) 53 | # sample rate is 16000, you can down sample it to 8000, but the result will be bad. 54 | x.append(audio) 55 | 56 | print("Complete! Spend {:.2f}s".format(get_time(start_time))) 57 | return x, y 58 | 59 | def extract_feature(self, feature_type='MFCC', datatype='dev'): 60 | """ 61 | extract feature from x 62 | :param x: type list, each element is audio 63 | :param y: type list, each element is label of audio in x 64 | :param filepath: the path to save feature 65 | :param is_train: if true, generate train_data(type dict, key is lable, value is feature), 66 | if false, just extract feature from x 67 | :return: 68 | """ 69 | start_time = get_time() 70 | if not os.path.exists('feature'): 71 | os.mkdir('feature') 72 | 73 | if not os.path.exists('feature/{}_{}_feature.pkl'.format(datatype, feature_type)): 74 | x, y = self._load() 75 | print("Extract {} feature...".format(feature_type)) 76 | feature = [] 77 | label = [] 78 | new_x = [] 79 | new_y = [] 80 | for i in range(len(x)): 81 | for j in range(x[i].shape[0]//self.sample_rate): 82 | new_x.append(x[i][j*self.sample_rate:(j+1)*self.sample_rate]) 83 | new_y.append(y[i]) 84 | 85 | x = new_x 86 | y = new_y 87 | for i in tqdm(range(len(x))): 88 | # 这里MFCC和PLP默认是16000Hz,注意修改 89 | # mfcc 25ms窗长,10ms重叠 90 | if feature_type == 'MFCC': 91 | _feature = mfcc(x[i], fs=self.sample_rate)[0] 92 | elif feature_type == 'PLP': 93 | _feature = plp(x[i], fs=self.sample_rate)[0] 94 | else: 95 | raise NameError 96 | # 特征出了问题,存在一些无穷大,导致整个网络的梯度爆炸了,需要特殊处理才行 97 | if np.isnan(_feature).sum()>0: 98 | continue 99 | # _feature = np.concatenate([_feature,self.delta(_feature)],axis=1) 100 | # _feature = preprocessing.scale(_feature) 101 | # _feature = preprocessing.StandardScaler().fit_transform(_feature) 102 | # 每2*num为一个输入,并且重叠num 103 | feature.append(_feature) 104 | label.append(y[i]) 105 | 106 | print(len(feature), feature[0].shape) 107 | self.save(feature, '{}_{}_feature'.format(datatype, feature_type)) 108 | self.save(label, '{}_{}_label'.format(datatype, feature_type)) 109 | 110 | else: 111 | feature = self.load('{}_{}_feature'.format(datatype, feature_type)) 112 | label = self.load('{}_{}_label'.format(datatype, feature_type)) 113 | 114 | print("Complete! Spend {:.2f}s".format(get_time(start_time))) 115 | return feature, label 116 | 117 | def load_data(self, path='dataset/ASR_GMM_big', reshape=True, test_size=0.3,datatype='dev'): 118 | self.path = path 119 | feature, label = data_gen.extract_feature(datatype=datatype) 120 | feature = np.array(feature) 121 | # 由于是全连接,故需要reshape,如果是卷积或者rnn系列,就不需要reshape 122 | if reshape: 123 | feature = feature.reshape(feature.shape[0], -1) 124 | 125 | label = np.array(label).reshape(-1, 1) 126 | self.enc = preprocessing.OneHotEncoder() 127 | label = self.enc.fit_transform(label).toarray() 128 | 129 | X_train, X_val, y_train, y_val = train_test_split(feature, label, shuffle=True, test_size=test_size, 130 | random_state=2019) 131 | return X_train, X_val, y_train, y_val 132 | 133 | @staticmethod 134 | def save(data, file_name): 135 | with open('feature/{}.pkl'.format(file_name), 'wb') as f: 136 | pkl.dump(data, f) 137 | 138 | @staticmethod 139 | def load(file_name): 140 | with open('feature/{}.pkl'.format(file_name), 'rb') as f: 141 | return pkl.load(f) 142 | 143 | @staticmethod 144 | def delta(feat, N=2): 145 | """Compute delta features from a feature vector sequence. 146 | :param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector. 147 | :param N: For each frame, calculate delta features based on preceding and following N frames 148 | :returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector. 149 | """ 150 | if N < 1: 151 | raise ValueError('N must be an integer >= 1') 152 | NUMFRAMES = len(feat) 153 | denominator = 2 * sum([i ** 2 for i in range(1, N + 1)]) 154 | delta_feat = np.empty_like(feat) 155 | # padded version of feat 156 | padded = np.pad(feat, ((N, N), (0, 0)), mode='edge') 157 | for t in range(NUMFRAMES): 158 | # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1] 159 | delta_feat[t] = np.dot(np.arange(-N, N + 1), padded[t: t + 2 * N + 1]) / denominator 160 | return delta_feat 161 | 162 | 163 | class nn_model: 164 | # TODO 用一个更大的数据集训练一个背景模型,然后利用这个背景模型直接得到新数据的d-vector 165 | def __init__(self, n_class=40): 166 | self.n_class = n_class 167 | 168 | def inference(self, X_train, Y_train, X_val, Y_val): 169 | #需要修改input_shape等一些参数 170 | print("Training model") 171 | model = Sequential() 172 | 173 | model.add(Dense(256, input_shape=(X_train.shape[1],), name="dense1")) 174 | model.add(Activation('relu', name="activation1")) 175 | model.add(Dropout(rate=0, name="drop1")) 176 | 177 | model.add(Dense(256, name="dense2")) 178 | model.add(Activation('relu', name="activation2")) 179 | model.add(Dropout(rate=0, name="drop2")) 180 | 181 | model.add(Dense(256, name="dense3")) 182 | model.add(Activation('relu', name="activation3")) 183 | model.add(Dropout(rate=0.5, name="drop3")) 184 | 185 | model.add(Dense(256, name="dense4")) 186 | 187 | modelInput = Input(shape=(X_train.shape[1],)) 188 | features = model(modelInput) 189 | spkModel = Model(inputs=modelInput, outputs=features) 190 | 191 | model1 = Activation('relu')(features) 192 | model1 = Dropout(rate=0.5)(model1) 193 | 194 | model1 = Dense(self.n_class, activation='softmax')(model1) 195 | 196 | spk = Model(inputs=modelInput, outputs=model1) 197 | 198 | sgd = Adam(lr=1e-4) 199 | # early_stopping = EarlyStopping(monitor='val_loss', patience=4) 200 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-7) 201 | csv_logger = CSVLogger('feature/d_vector/nn_training.log') 202 | 203 | spk.compile(loss='categorical_crossentropy',optimizer=sgd, metrics=['accuracy']) 204 | 205 | spk.fit(X_train, Y_train, batch_size = 128, epochs=50, validation_data = (X_val, Y_val), 206 | callbacks=[reduce_lr, csv_logger]) 207 | 208 | if not os.path.exists('feature/d_vector'): 209 | os.mkdir('feature/d_vector') 210 | spkModel.save('feature/d_vector/d_vector_nn.h5') 211 | 212 | 213 | def inference_gru(self, X_train, Y_train, X_val, Y_val): 214 | model = Sequential() 215 | 216 | model.add(Convolution2D(64, (5, 5), 217 | padding='same', 218 | strides=(2, 2), 219 | input_shape=(X_train.shape[1], X_train.shape[2], 1), name="cov1", 220 | data_format="channels_last", 221 | kernel_regularizer=keras.regularizers.l2())) 222 | 223 | # 将输入的维度按照给定模式进行重排 224 | # model.add(Permute((2, 1, 3), name='permute')) 225 | # 该包装器可以把一个层应用到输入的每一个时间步上,GRU需要 226 | model.add(TimeDistributed(Flatten(), name='timedistrib')) 227 | 228 | # 三层GRU 229 | model.add(GRU(units=1024, return_sequences=True, name="gru1")) 230 | model.add(GRU(units=1024, return_sequences=True, name="gru2")) 231 | model.add(GRU(units=1024, return_sequences=True, name="gru3")) 232 | 233 | # temporal average 234 | def temporalAverage(x): 235 | return K.mean(x, axis=1) 236 | 237 | model.add(Lambda(temporalAverage, name="temporal_average")) 238 | 239 | # affine 240 | model.add(Dense(units=512, name="dense1")) 241 | 242 | # length normalization 243 | def lengthNormalization(x): 244 | return K.l2_normalize(x, axis=-1) 245 | 246 | model.add(Lambda(lengthNormalization, name="ln")) 247 | 248 | modelInput = Input(shape=(X_train.shape[1], X_train.shape[2], 1)) 249 | features = model(modelInput) 250 | spkModel = Model(inputs=modelInput, outputs=features) 251 | 252 | model1 = Dense(self.n_class, activation='softmax',name="dense2")(features) 253 | 254 | spk = Model(inputs=modelInput, outputs=model1) 255 | 256 | sgd = Adam(lr=1e-4) 257 | 258 | spk.compile(loss='categorical_crossentropy', 259 | optimizer=sgd, metrics=['accuracy']) 260 | 261 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-7) 262 | csv_logger = CSVLogger('feature/d_vector/gru_training.log') 263 | 264 | spk.fit(X_train, Y_train, batch_size = 128, epochs=50, validation_data = (X_val, Y_val), 265 | callbacks=[reduce_lr, csv_logger]) 266 | 267 | if not os.path.exists('feature/d_vector'): 268 | os.mkdir('feature/d_vector') 269 | spkModel.save('feature/d_vector/d_vector_gru.h5') 270 | 271 | def inference_lstm(self, X_train, Y_train, X_val, Y_val): 272 | model = Sequential() 273 | 274 | model.add(LSTM(128, input_shape=(X_train.shape[1],X_train.shape[2]))) 275 | modelInput = Input(shape=(X_train.shape[1],X_train.shape[2])) 276 | features = model(modelInput) 277 | spkModel = Model(inputs=modelInput, outputs=features) 278 | model1 = Dense(self.n_class, activation='softmax',name="dense1")(features) 279 | spk = Model(inputs=modelInput, outputs=model1) 280 | 281 | sgd = Adam(lr=1e-4) 282 | 283 | spk.compile(loss='categorical_crossentropy', 284 | optimizer=sgd, metrics=['accuracy']) 285 | 286 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-7) 287 | csv_logger = CSVLogger('feature/d_vector/lstm_training.log') 288 | 289 | spk.fit(X_train, Y_train, batch_size = 128, epochs=50, validation_data = (X_val, Y_val), 290 | callbacks=[reduce_lr, csv_logger]) 291 | 292 | if not os.path.exists('feature/d_vector'): 293 | os.mkdir('feature/d_vector') 294 | spkModel.save('feature/d_vector/d_vector_lstm.h5') 295 | 296 | def test(self, X_train, Y_train, X_val, Y_val, model_name='nn'): 297 | spkModel = load_model('feature/d_vector/d_vector_{}.h5'.format(model_name)) 298 | print(X_train.shape) 299 | X_train = spkModel.predict(X_train) 300 | X_val = spkModel.predict(X_val) 301 | # num为测试集中人数 302 | num = Y_train.shape[1] 303 | # 对同一个人的d-vector取平均,得到avg,作为这个人的模板储存起来。 304 | with open('feature/d_vector/d_vector_lstm.pkl', 'wb') as f: 305 | pkl.dump(X_train, f) 306 | 307 | with open('feature/d_vector/d_vector_lstm_y.pkl', 'wb') as f: 308 | pkl.dump(Y_train, f) 309 | 310 | avg = np.zeros((num, X_train.shape[1])) 311 | print(X_train.shape[1]) 312 | for i in range(num): 313 | avg[i,:] = X_train[np.argmax(Y_train, axis=1)==i].mean(axis=0) 314 | 315 | distance = np.zeros((X_val.shape[0], num)) 316 | for i in range(X_val.shape[0]): 317 | for j in range(num): 318 | distance[i, j] = cosine(X_val[i], avg[j]) 319 | acc = (np.argmax(Y_val, axis=1)==np.argmin(distance, axis=1)).sum()/X_val.shape[0] 320 | return acc 321 | 322 | def enroll(self, X_train, name, model_name='lstm'): 323 | """ 324 | 注册一个陌生人到库中,以字典形式保存 325 | :param X_train: 样本语音 326 | :param name: 该样本语音的人名,唯一标识,不可重复。 327 | :param model_name: 使用模型的名字,nn,lstm,gru 328 | :return: none 329 | """ 330 | spkModel = load_model('feature/d_vector/d_vector_{}.h5'.format(model_name)) 331 | X_train = spkModel.predict(X_train) 332 | avg = X_train.mean(axis=0) 333 | try: 334 | with open('feature/d_vector/d_vector.pkl', 'rb') as f: 335 | d_vector = pkl.load(f) 336 | except: 337 | d_vector = {} 338 | 339 | if name in d_vector: 340 | print("样本已经存在!!") 341 | d_vector[name] = avg 342 | 343 | with open('feature/d_vector/d_vector.pkl', 'wb') as f: 344 | pkl.dump(d_vector, f) 345 | 346 | def eval(self, target, model_name='lstm'): 347 | spkModel = load_model('feature/d_vector/d_vector_{}.h5'.format(model_name)) 348 | target = spkModel.predict(target) 349 | with open('feature/d_vector/d_vector.pkl', 'rb') as f: 350 | d_vector = pkl.load(f) 351 | 352 | min_distance = 1 353 | target_name = None 354 | distance_list = [] 355 | for name in d_vector.keys(): 356 | distance_list.append(cosine(target, d_vector[name])) 357 | if min_distance > distance_list[-1]: 358 | min_distance = distance_list[-1] 359 | target_name = name 360 | 361 | return target_name 362 | 363 | 364 | if __name__=="__main__": 365 | data_gen = Data_gen() 366 | train_bm = False 367 | model = nn_model() 368 | if train_bm: 369 | # 训练背景模型 370 | X_train, X_val, y_train, y_val = data_gen.load_data(reshape=False) 371 | # X_train = X_train[:, :, :,np.newaxis] 372 | # X_val = X_val[:, :, :,np.newaxis] 373 | model.inference_lstm(X_train, y_train, X_val, y_val) 374 | 375 | # X_train, X_val, y_train, y_val = data_gen.load_data(test_size=0.3,datatype='test',reshape=False) 376 | with open('feature/d_vector/MFCC_feature.pkl', 'rb') as f: 377 | feature = pkl.load(f) 378 | 379 | with open('feature/d_vector/MFCC_label.pkl', 'rb') as f: 380 | label = pkl.load(f) 381 | feature = np.array(feature) 382 | label = np.array(label).reshape(-1, 1) 383 | enc = preprocessing.OneHotEncoder() 384 | label = enc.fit_transform(label).toarray() 385 | 386 | X_train, X_val, y_train, y_val = train_test_split(feature, label, shuffle=True, test_size=0.1, 387 | random_state=2019) 388 | start_time = get_time() 389 | acc = model.test(X_train[:, :, :,np.newaxis], y_train, X_val[:, :, :,np.newaxis], y_val, model_name='lstm_conv') 390 | print(get_time(start_time)) 391 | print(acc) -------------------------------------------------------------------------------- /demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/demo.mp4 -------------------------------------------------------------------------------- /report/First.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/First.pdf -------------------------------------------------------------------------------- /report/GMM_UBM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/GMM_UBM.pdf -------------------------------------------------------------------------------- /report/final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/final.pdf -------------------------------------------------------------------------------- /report/mfcc报告.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/mfcc报告.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dtw 2 | fastdtw 3 | tqdm 4 | librosa 5 | sidekit 6 | tensorflow 7 | keras 8 | numpy 9 | scipy 10 | pyqt5 11 | sklearn -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/2/26 15:46 4 | # @Author : chuyu zhang 5 | # @File : __init__.py.py 6 | # @Software: PyCharm -------------------------------------------------------------------------------- /utils/processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/3/23 22:18 4 | # @Author : chuyu zhang 5 | # @File : processing.py 6 | # @Software: PyCharm 7 | 8 | import numpy as np 9 | import math 10 | from scipy import signal 11 | from scipy.io import wavfile 12 | # from utils.tools import read,play 13 | from scipy.fftpack.realtransforms import dct 14 | from scipy.fftpack import fft 15 | import matplotlib.pyplot as plt 16 | 17 | eps = 1e-8 18 | ## 语音的预处理函数 19 | def enframe(wavData, frameSize=400, step=160): 20 | """ 21 | frame the wav data, according to frameSize and overlap 22 | :param wavData: the input wav data, ndarray 23 | :return:frameData, shape 24 | """ 25 | coef = 0.97 26 | wlen = wavData.shape[0] 27 | frameNum = math.ceil(wlen / step) 28 | frameData = np.zeros((frameSize, frameNum)) 29 | 30 | window = signal.windows.hamming(frameSize) 31 | 32 | for i in range(frameNum): 33 | singleFrame = wavData[i * step : min(i * step + frameSize, wlen)] 34 | # singleFrame[1:] = singleFrame[:-1] - coef * singleFrame[1:] 35 | frameData[:len(singleFrame), i] = singleFrame 36 | frameData[:, i] = window*frameData[:, i] 37 | 38 | return frameData 39 | 40 | 41 | # frequency domain feature 42 | def mfccInitFilterBanks(fs, nfft): 43 | """ 44 | Computes the triangular filterbank for MFCC computation 45 | (used in the stFeatureExtraction function before the stMFCC function call) 46 | This function is taken from the scikits.talkbox library (MIT Licence): 47 | https://pypi.python.org/pypi/scikits.talkbox 48 | """ 49 | # filter bank params: 50 | lowfreq = 133.33 51 | linsc = 200/3. 52 | logsc = 1.0711703 53 | numLinFiltTotal = 13 54 | numLogFilt = 27 55 | 56 | if fs < 8000: 57 | nlogfil = 5 58 | 59 | # Total number of filters 60 | nFiltTotal = numLinFiltTotal + numLogFilt 61 | 62 | # Compute frequency points of the triangle: 63 | freqs = np.zeros(nFiltTotal+2) 64 | freqs[:numLinFiltTotal] = lowfreq + np.arange(numLinFiltTotal) * linsc 65 | freqs[numLinFiltTotal:] = freqs[numLinFiltTotal-1] * logsc ** np.arange(1, numLogFilt + 3) 66 | heights = 2./(freqs[2:] - freqs[0:-2]) 67 | 68 | # Compute filterbank coeff (in fft domain, in bins) 69 | fbank = np.zeros((nFiltTotal, nfft)) 70 | nfreqs = np.arange(nfft) / (1. * nfft) * fs 71 | 72 | for i in range(nFiltTotal): 73 | lowTrFreq = freqs[i] 74 | cenTrFreq = freqs[i+1] 75 | highTrFreq = freqs[i+2] 76 | 77 | lid = np.arange(np.floor(lowTrFreq * nfft / fs) + 1, 78 | np.floor(cenTrFreq * nfft / fs) + 1, 79 | dtype=np.int) 80 | lslope = heights[i] / (cenTrFreq - lowTrFreq) 81 | rid = np.arange(np.floor(cenTrFreq * nfft / fs) + 1, 82 | np.floor(highTrFreq * nfft / fs) + 1, 83 | dtype=np.int) 84 | rslope = heights[i] / (highTrFreq - cenTrFreq) 85 | fbank[i][lid] = lslope * (nfreqs[lid] - lowTrFreq) 86 | fbank[i][rid] = rslope * (highTrFreq - nfreqs[rid]) 87 | 88 | return fbank, freqs 89 | 90 | 91 | def stMFCC(X, fbank, n_mfcc_feats): 92 | """ 93 | Computes the MFCCs of a frame, given the fft mag 94 | ARGUMENTS: 95 | X: fft magnitude abs(FFT) 96 | fbank: filter bank (see mfccInitFilterBanks) 97 | RETURN 98 | ceps: MFCCs (13 element vector) 99 | Note: MFCC calculation is, in general, taken from the 100 | scikits.talkbox library (MIT Licence), 101 | # with a small number of modifications to make it more 102 | compact and suitable for the pyAudioAnalysis Lib 103 | """ 104 | 105 | mspec = np.log10(np.dot(X, fbank.T)+eps) 106 | ceps = dct(mspec, type=2, norm='ortho', axis=-1)[:n_mfcc_feats] 107 | return ceps 108 | 109 | 110 | def MFCC(raw_signal, fs=8000, frameSize=512, step=256): 111 | """ 112 | extract mfcc feature 113 | :param raw_signal: the original audio signal 114 | :param fs: sample frequency 115 | :param frameSize:the size of each frame 116 | :param step: 117 | :return: a series of mfcc feature of each frame and flatten to (num, ) 118 | """ 119 | # Signal normalization 120 | 121 | """ 122 | raw_signal = np.double(raw_signal) 123 | 124 | raw_signal = raw_signal / (2.0 ** 15) 125 | DC = raw_signal.mean() 126 | MAX = (np.abs(raw_signal)).max() 127 | raw_signal = (raw_signal - DC) / (MAX + eps) 128 | """ 129 | nFFT = int(frameSize) 130 | [fbank, freqs] = mfccInitFilterBanks(fs, nFFT) 131 | n_mfcc_feats = 13 132 | 133 | signal = enframe(raw_signal, frameSize, step) 134 | feature = [] 135 | for frame in range(signal.shape[1]): 136 | x = signal[:, frame] 137 | X = abs(fft(x)) # get fft magnitude 138 | X = X[0:nFFT] # normalize fft 139 | X = X / len(X) 140 | feature.append(stMFCC(X, fbank, n_mfcc_feats)) 141 | 142 | feature = np.array(feature) 143 | # print(feature.shape) 144 | return feature 145 | 146 | 147 | def test(): 148 | path = '../dataset/ASR/zcy/zcy1.wav' 149 | frameSize = 400 150 | nFFT = int(frameSize/2) 151 | fs, audio = wavfile.read(path) 152 | audio = audio[:, 0] 153 | 154 | [fbank, freqs] = mfccInitFilterBanks(fs, nFFT) 155 | n_mfcc_feats = 13 156 | x = audio[4000:4000+frameSize] 157 | X = abs(fft(x)) # get fft magnitude 158 | X = X[0:nFFT] # normalize fft 159 | X = X / len(X) 160 | feature = stMFCC(X, fbank, n_mfcc_feats) 161 | plt.figure() 162 | plt.subplot(121) 163 | plt.plot(X) 164 | plt.subplot(122) 165 | plt.plot(feature) 166 | plt.show() 167 | 168 | 169 | if __name__=='__main__': 170 | test() 171 | 172 | """ 173 | path = '../dataset/ASR/zcy/zcy1.wav' 174 | play(path) 175 | samprate, data = read(path) 176 | framedata = enframe(data[:,0]) 177 | print(framedata.shape) 178 | """ 179 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/2/26 15:46 4 | # @Author : chuyu zhang 5 | # @File : read.py 6 | # @Software: PyCharm 7 | 8 | import wave 9 | from scipy.io import wavfile 10 | from scipy.fftpack import fft, ifft 11 | # import math 12 | # from scipy import signal 13 | import numpy as np 14 | from pyaudio import PyAudio, paInt16 15 | import time 16 | from sklearn.metrics import confusion_matrix 17 | import matplotlib.pyplot as plt 18 | # import sounddevice as sd 19 | import simpleaudio as sa 20 | 21 | from utils.processing import enframe 22 | 23 | """ 24 | framerate=8000 25 | NUM_SAMPLES=2000 26 | channels=1 27 | sampwidth=2 28 | TIME=2 29 | """ 30 | 31 | def get_time(start_time=None): 32 | if start_time == None: 33 | return time.time() 34 | else: 35 | return time.time() - start_time 36 | 37 | ## 开发计划,后续添加,play(data),data是read函数的输出 38 | # read .wav file through wave, return a wave object 39 | def wave_read(filename='test.wav'): 40 | audio = wave.open(filename, mode='rb') 41 | return audio 42 | 43 | 44 | # read .wav file through scipy 45 | def read(filename='test.wav'): 46 | sampling_freq, audio = wavfile.read(filename) 47 | return sampling_freq, audio 48 | 49 | 50 | # 根据给定的参数保存.wav文件 51 | def save_wave_file(filename, data, channels=1, sampwidth=2, framerate=8000): 52 | '''save the data to the wavfile''' 53 | wf=wave.open(filename,'wb') 54 | wf.setnchannels(channels)#声道 55 | wf.setsampwidth(sampwidth)#采样字节 1 or 2 56 | wf.setframerate(framerate)#采样频率 8000 or 16000 57 | wf.writeframes(b"".join(data)) 58 | wf.close() 59 | 60 | 61 | # num_samples这个参数的意义?? 62 | def record(filename="test.wav",seconds=10, 63 | framerate=16000, format=paInt16, channels=1, num_samples=2000): 64 | p=PyAudio() 65 | stream=p.open(format = format, channels=channels, rate=framerate, 66 | input=True, frames_per_buffer=num_samples) 67 | my_buf=[] 68 | # 控制录音时间 69 | print("start the recording !") 70 | start = time.time() 71 | while time.time() - start < seconds: 72 | # 一次性录音采样字节大小 73 | string_audio_data = stream.read(num_samples) 74 | my_buf.append(string_audio_data) 75 | 76 | save_wave_file(filename, my_buf) 77 | stream.close() 78 | print("{} seconds record has completed.".format(seconds)) 79 | 80 | 81 | def play(audio=None, sampling_freq=8000, filename=None, chunk=1024): 82 | start_time = get_time() 83 | if filename==None: 84 | # pass 85 | play_obj = sa.play_buffer(audio, 1, 2, sampling_freq) 86 | play_obj.wait_done() 87 | 88 | else: 89 | wf=wave.open(filename,'rb') 90 | p=PyAudio() 91 | stream=p.open(format=p.get_format_from_width(wf.getsampwidth()), 92 | channels=wf.getnchannels(),rate=wf.getframerate(), 93 | output=True) 94 | while True: 95 | data=wf.readframes(chunk) 96 | # char b is absolutely necessary. It represents that the str is byte. 97 | # For more detail, please refer to 3,4 98 | if data == b"": 99 | break 100 | stream.write(data) 101 | 102 | stream.stop_stream() 103 | stream.close() 104 | p.terminate() 105 | print("{} seconds".format(get_time(start_time))) 106 | 107 | 108 | def playback(filename='test.wav', silent=False): 109 | rate, audio = read(filename) 110 | new_filename = 'reverse_' + filename 111 | wavfile.write(new_filename, rate, audio[::-1]) 112 | if not silent: 113 | play(filename=new_filename) 114 | print("complete!") 115 | 116 | 117 | def change_rate(filename='test.wav', new_rate=4000, silent=False): 118 | rate, audio = read(filename) 119 | print("the original frequent rate is {}".format(rate)) 120 | new_filename = str(new_rate) + "_" + filename 121 | wavfile.write(new_filename, new_rate, audio) 122 | if not silent: 123 | play(filename=new_filename) 124 | print("complete !") 125 | 126 | def change_volume(filename='test.wav', volume_rate=1, silent=False): 127 | rate, audio = read(filename) 128 | print("change volume to {}".format(volume_rate)) 129 | new_filename = str(volume_rate) + "_" + filename 130 | new_audio = (audio*volume_rate).astype('int16') 131 | # print(audio.dtype, new_audio.dtype) 132 | wavfile.write(new_filename, rate, new_audio) 133 | if not silent: 134 | play(filename=new_filename) 135 | print("complete !") 136 | 137 | # 定义合成音调 138 | def Synthetic_tone(freq, duration=2, amp=1000, sampling_freq=44100): 139 | # 建立时间轴 140 | # scaling_factor = pow(2, 15) - 1 # 转换为16位整型数 141 | t = np.linspace(0, duration, duration * sampling_freq) 142 | # 构建音频信号 143 | audio = amp * np.sin(2 * np.pi * freq * t) 144 | return audio.astype(np.int16) 145 | 146 | 147 | def simple_music(tone='A', duration=2, amplitude=10000, sampling_freq=44100): 148 | tone_freq_map = {'A': 440, 'Asharp': 466, 'B': 494, 'C': 523, 'Csharp': 554, 149 | 'D': 587, 'Dsharp': 622, 'E': 659, 'F': 698, 'Fsharp': 740, 150 | 'G': 784, 'Gsharp': 831} 151 | 152 | synthesized_tone = Synthetic_tone(tone_freq_map[tone], duration, amplitude, sampling_freq) 153 | wavfile.write('{}.wav'.format(tone), sampling_freq, synthesized_tone) 154 | play('{}.wav'.format(tone)) 155 | 156 | 157 | def plot_confusion_matrix(y_true, y_pred, classes, 158 | normalize=False, 159 | title=None, 160 | cmap=plt.cm.Blues): 161 | """ 162 | This function prints and plots the confusion matrix. 163 | Normalization can be applied by setting `normalize=True`. 164 | """ 165 | if not title: 166 | if normalize: 167 | title = 'Normalized confusion matrix' 168 | else: 169 | title = 'Confusion matrix, without normalization' 170 | 171 | # Compute confusion matrix 172 | cm = confusion_matrix(y_true, y_pred) 173 | # Only use the labels that appear in the data 174 | if normalize: 175 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 176 | print("Normalized confusion matrix") 177 | else: 178 | print('Confusion matrix, without normalization') 179 | 180 | # print(cm) 181 | 182 | fig, ax = plt.subplots() 183 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 184 | ax.figure.colorbar(im, ax=ax) 185 | # We want to show all ticks... 186 | ax.set(xticks=np.arange(cm.shape[1]), 187 | yticks=np.arange(cm.shape[0]), 188 | # ... and label them with the respective list entries 189 | xticklabels=classes, yticklabels=classes, 190 | title=title, 191 | ylabel='True label', 192 | xlabel='Predicted label') 193 | 194 | # Rotate the tick labels and set their alignment. 195 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 196 | rotation_mode="anchor") 197 | 198 | # Loop over data dimensions and create text annotations. 199 | fmt = '.2f' if normalize else 'd' 200 | # thresh = cm.max() / 2. 201 | for i in range(cm.shape[0]): 202 | for j in range(cm.shape[1]): 203 | ax.text(j, i, format(cm[i, j], fmt), 204 | ha="center", va="center", 205 | color="black") 206 | fig.tight_layout() 207 | return ax 208 | 209 | 210 | def test(filename): 211 | framerate, audio = read(filename) 212 | play(audio, framerate) 213 | # print(audio.shape) 214 | audio_frame = enframe(audio[:,0], frameSize=512, step=256) 215 | audio_frame_new = np.zeros_like(audio_frame) 216 | for frame in range(audio_frame.shape[1]): 217 | # print(audio_frame[:, frame]) 218 | audio_fft = fft(audio_frame[:, frame]) 219 | # audio_fft_abs = abs(audio_fft) 220 | # angle = audio_fft.real/audio_fft_abs 221 | audio_fft_new = np.sqrt(0.8)*audio_fft 222 | audio_new = ifft(audio_fft_new) 223 | # print(audio_new) 224 | audio_frame_new[:, frame] = audio_new.real 225 | print((audio_new-audio_frame[:, frame]).sum()) 226 | # break 227 | audio_new = audio_frame_new[:256, :].flatten('F')[:audio.shape[0]].reshape(-1, 1) 228 | audio_new = np.concatenate([audio_new, audio_new], axis=1) 229 | save_wave_file('zero_trans.wav', audio_new) 230 | print(audio_new.shape) 231 | # print(audio_new.astype(audio.dtype)) 232 | # audio_fft_abs = np.abs(audio_fft) 233 | # angle = audio_fft.real()/audio_fft_abs 234 | # audio_fft_abs_new = audio_fft_abs*0.8 235 | # au 236 | play(audio_new.astype(audio.dtype), framerate) 237 | plt.figure() 238 | plt.plot(audio[:,0]) 239 | plt.figure() 240 | plt.plot(audio_new[:,0]) 241 | plt.show() 242 | 243 | 244 | 245 | 246 | if __name__=="__main__": 247 | test(filename='../dataset/ASR/test/zcy/zcy1.wav') 248 | """ 249 | wave_read = wave.open('../dataset/ASR/test/zcy/zcy1.wav', 'rb') 250 | audio_data = wave_read.readframes(wave_read.getnframes()) 251 | num_channels = wave_read.getnchannels() 252 | bytes_per_sample = wave_read.getsampwidth() 253 | sample_rate = wave_read.getframerate() 254 | print(type(audio_data)) 255 | play_obj = sa.play_buffer(audio_data, num_channels, bytes_per_sample, sample_rate) 256 | """ 257 | 258 | # record() 259 | # playback(filename='test.wav') 260 | # rate, audio = read(filename='01.wav') 261 | # print(rate, audio) 262 | # print(audio[::-1]) 263 | # change_volume(volume_rate=1) 264 | """ 265 | duration = 4 266 | music = 0.9*Synthetic_tone(freq=440, duration=duration) + \ 267 | 0.75*Synthetic_tone(freq=880, duration=duration) 268 | wavfile.write('music.wav', 44100, music.astype('int16')) 269 | play(filename='music.wav') 270 | 271 | """ 272 | # change_rate('../dataset/ASR/train/hyy/hyy1.wav', new_rate=8000) 273 | # simple_music(tone='Gsharp') 274 | # framerate, audio = read('../dataset/ASR/train/hyy/hyy1.wav') 275 | # downsample = audio[range(0, audio.shape[0], 2), 0] 276 | # save_wave_file('test.wav', downsample) 277 | # play(filename='test.wav') 278 | # play(downsample, 8000) 279 | -------------------------------------------------------------------------------- /展示.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/展示.pptx --------------------------------------------------------------------------------