├── Final_GUI.py
├── GMM_UBM.py
├── GUI.py
├── MFCC_DTW.py
├── README.md
├── UI
├── GMM_UBM_GUI.py
├── GMM_UBM_GUI.ui
├── __init__.py
├── final.py
├── final.ui
└── tmp.py
├── VAD.py
├── d_vector.py
├── demo.mp4
├── report
├── First.pdf
├── GMM_UBM.pdf
├── final.pdf
└── mfcc报告.pdf
├── requirements.txt
├── utils
├── __init__.py
├── processing.py
└── tools.py
└── 展示.pptx
/Final_GUI.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/5/30 10:12
4 | # @Author : chuyu zhang
5 | # @File : Final_GUI.py
6 | # @Software: PyCharm
7 |
8 | import sys
9 | from PyQt5 import QtWidgets
10 | from UI.tmp import Ui_MainWindow
11 |
12 | class MyPyQT_Form(QtWidgets.QMainWindow,Ui_MainWindow):
13 | def __init__(self):
14 | super(MyPyQT_Form,self).__init__()
15 | self.setupUi(self)
16 |
17 |
18 | if __name__ == '__main__':
19 | app = QtWidgets.QApplication(sys.argv)
20 | my_pyqt_form = MyPyQT_Form()
21 | my_pyqt_form.show()
22 | sys.exit(app.exec_())
23 |
--------------------------------------------------------------------------------
/GMM_UBM.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/4/12 22:24
4 | # @Author : chuyu zhang
5 | # @File : GMM_UBM.py
6 | # @Software: PyCharm
7 |
8 | import os
9 | from utils.tools import read, get_time
10 | from tqdm import tqdm
11 |
12 | # from utils.processing import MFCC
13 | import python_speech_features as psf
14 | import numpy as np
15 | import pickle as pkl
16 | from sklearn.mixture import GaussianMixture
17 | from sklearn.model_selection import train_test_split
18 | from sklearn import preprocessing
19 |
20 | from sidekit.frontend.features import plp,mfcc
21 | label_encoder = {}
22 |
23 |
24 | def load_data(path='dataset/ASR_GMM'):
25 | """
26 | load audio file.
27 | :param path: the dir to audio file
28 | :return: x type:list,each element is an audio, y type:list,it is the label of x
29 | """
30 | start_time = get_time()
31 | print("Loading data...")
32 | speaker_list = os.listdir(path)
33 | y = []
34 | x = []
35 | num = 0
36 | for speaker in tqdm(speaker_list):
37 | # encoder the speaker to num
38 | label_encoder[speaker] = num
39 | path1 = os.path.join(path, speaker)
40 | for _dir in os.listdir(path1):
41 | path2 = os.path.join(path1, _dir)
42 | for _wav in os.listdir(path2):
43 | samplerate, audio = read(os.path.join(path2, _wav))
44 | y.append(num)
45 | # sample rate is 16000, you can down sample it to 8000, but the result will be bad.
46 | x.append(audio)
47 |
48 | num += 1
49 | print("Complete! Spend {:.2f}s".format(get_time(start_time)))
50 | return x,y
51 |
52 |
53 | def delta(feat, N=2):
54 | """Compute delta features from a feature vector sequence.
55 | :param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector.
56 | :param N: For each frame, calculate delta features based on preceding and following N frames
57 | :returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector.
58 | """
59 | if N < 1:
60 | raise ValueError('N must be an integer >= 1')
61 | NUMFRAMES = len(feat)
62 | denominator = 2 * sum([i**2 for i in range(1, N+1)])
63 | delta_feat = np.empty_like(feat)
64 | # padded version of feat
65 | padded = np.pad(feat, ((N, N), (0, 0)), mode='edge')
66 | for t in range(NUMFRAMES):
67 | # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1]
68 | delta_feat[t] = np.dot(np.arange(-N, N+1), padded[t : t+2*N+1]) / denominator
69 | return delta_feat
70 |
71 |
72 | def extract_feature(x, y, is_train=False, feature_type='MFCC'):
73 | """
74 | extract feature from x
75 | :param x: type list, each element is audio
76 | :param y: type list, each element is label of audio in x
77 | :param filepath: the path to save feature
78 | :param is_train: if true, generate train_data(type dict, key is lable, value is feature),
79 | if false, just extract feature from x
80 | :return:
81 | """
82 | start_time = get_time()
83 | print("Extract {} feature...".format(feature_type))
84 | feature = []
85 | train_data = {}
86 | for i in tqdm(range(len(x))):
87 | # extract mfcc feature based on psf, you can look more detail on psf's website.
88 | if feature_type=='MFCC':
89 | _feature = mfcc(x[i])
90 | mfcc_delta = delta(_feature)
91 | _feature = np.hstack((_feature, mfcc_delta))
92 |
93 | _feature = preprocessing.scale(_feature)
94 | elif feature_type=='PLP':
95 | _feature = plp(x[i])
96 | mfcc_delta = delta(_feature)
97 | _feature = np.hstack((_feature, mfcc_delta))
98 |
99 | _feature = preprocessing.scale(_feature)
100 | else:
101 | raise NameError
102 |
103 | # append _feature to feature
104 | feature.append(_feature)
105 |
106 | if is_train:
107 | if y[i] in train_data:
108 | train_data[y[i]] = np.vstack((train_data[y[i]], _feature))
109 | else:
110 | train_data[y[i]] = _feature
111 |
112 | print("Complete! Spend {:.2f}s".format(get_time(start_time)))
113 |
114 |
115 | if is_train:
116 | return train_data, feature, y
117 | else:
118 | return feature, y
119 |
120 |
121 | def load_extract(test_size=0.3):
122 | # combination load_data and extract_feature
123 | x, y = load_data()
124 | # train test split
125 | x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0)
126 | # extract feature from train
127 | train_data, x_train, y_train = extract_feature(x=x_train, y=y_train, is_train=True)
128 | # extract feature from test
129 | x_test, y_test = extract_feature(x=x_test,y=y_test)
130 |
131 | return train_data, x_train, y_train, x_test, y_test
132 |
133 |
134 | def GMM(train, x_train, y_train, x_test, y_test, n_components=16, model=False):
135 | print("Train GMM-UBM model !")
136 | # split x,y to train and test
137 | start_time = get_time()
138 | # if model is True, it will load model from Model/GMM_MFCC_model.pkl.
139 | # if False,it will train and save model
140 |
141 | if model:
142 | print("load model from file...")
143 | with open("Model/GMM_MFCC_model.pkl", 'rb') as f:
144 | GMM = pkl.load(f)
145 | with open("Model/UBM_MFCC_model.pkl", 'rb') as f:
146 | UBM = pkl.load(f)
147 | else:
148 | # speaker_list = os.listdir('dataset/ASR_GMM')
149 | GMM = []
150 | ubm_train = None
151 | flag = False
152 | # UBM = []
153 | print("Train GMM!")
154 | for speaker in tqdm(label_encoder.values()):
155 | # print(type(speaker))
156 | # speaker = label_encoder[speaker]
157 | # GMM based on speaker
158 | gmm = GaussianMixture(n_components = n_components, covariance_type='diag')
159 | gmm.fit(train[speaker])
160 | GMM.append(gmm)
161 | if flag:
162 | ubm_train = np.vstack((ubm_train, train[speaker]))
163 | else:
164 | ubm_train = train[speaker]
165 | flag = True
166 |
167 | # UBM based on background
168 | print("Train UBM!")
169 | UBM = GaussianMixture(n_components = n_components, covariance_type='diag')
170 | UBM.fit(ubm_train)
171 | # UBM.append(gmm)
172 |
173 | if not os.path.exists('Model'):
174 | os.mkdir("Model")
175 |
176 | with open("Model/GMM_MFCC_model.pkl", 'wb') as f:
177 | pkl.dump(GMM, f)
178 | with open("Model/UBM_MFCC_model.pkl", 'wb') as f:
179 | pkl.dump(UBM, f)
180 |
181 | # train accuracy
182 | valid = np.zeros((len(x_train), len(GMM)))
183 | for i in range(len(GMM)):
184 | for j in range(len(x_train)):
185 | valid[j, i] = GMM[i].score(x_train[j]) - UBM.score(x_train[j])
186 |
187 | valid = valid.argmax(axis=1)
188 | acc_train = (valid==np.array(y_train)).sum()/len(x_train)
189 |
190 | # test accuracy
191 | pred = np.zeros((len(x_test), len(GMM)))
192 | for i in range(len(GMM)):
193 | for j in range(len(x_test)):
194 | pred[j, i] = GMM[i].score(x_test[j]) - UBM.score(x_test[j])
195 |
196 | pred = pred.argmax(axis=1)
197 | acc = (pred==np.array(y_test)).sum()/len(x_test)
198 |
199 | print("spend {:.2f}s, train acc {:.2%}, test acc {:.2%}".format(get_time(start_time), acc_train,acc))
200 |
201 |
202 | def main():
203 | train_data, mfcc_x_train, mfcc_y_train, mfcc_x_test, mfcc_y_test = load_extract()
204 | GMM(train_data, mfcc_x_train, mfcc_y_train, mfcc_x_test, mfcc_y_test, model=False)
205 |
206 |
207 | if __name__=='__main__':
208 | main()
--------------------------------------------------------------------------------
/GUI.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/4/20 22:49
4 | # @Author : chuyu zhang
5 | # @File : GUI.py
6 | # @Software: PyCharm
7 |
8 | # 继承至界面文件的主窗口类
9 | import sys
10 | from PyQt5 import QtWidgets
11 | from UI.GMM_UBM_GUI import Ui_Form
12 |
13 | class MyPyQT_Form(QtWidgets.QWidget,Ui_Form):
14 | def __init__(self):
15 | super(MyPyQT_Form,self).__init__()
16 | self.setupUi(self)
17 |
18 |
19 | if __name__ == '__main__':
20 | app = QtWidgets.QApplication(sys.argv)
21 | my_pyqt_form = MyPyQT_Form()
22 | my_pyqt_form.show()
23 | sys.exit(app.exec_())
24 |
--------------------------------------------------------------------------------
/MFCC_DTW.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/3/22 22:08
4 | # @Author : chuyu zhang
5 | # @File : MFCC_DTW.py
6 | # @Software: PyCharm
7 |
8 | import os
9 | import random
10 | from utils.tools import read, get_time,plot_confusion_matrix
11 | from utils.processing import enframe, MFCC
12 | import numpy as np
13 | from scipy.fftpack import fft
14 |
15 | import librosa
16 | # dtw is accurate than fastdtw, but it is slower, I will test the speed and acc later
17 | from scipy.spatial.distance import euclidean
18 | from dtw import dtw,accelerated_dtw
19 | from fastdtw import fastdtw
20 |
21 | import matplotlib.pyplot as plt
22 | # import seaborn as sns
23 | from tqdm import tqdm
24 |
25 |
26 | eps = 1e-8
27 |
28 | def MFCC_lib(raw_signal, n_mfcc=13):
29 | feature = librosa.feature.mfcc(raw_signal.astype('float32'), n_mfcc=n_mfcc, sr=8000)
30 | # print(feature.T.shape)
31 | return feature.T.flatten()
32 |
33 | def _MFCC(raw_signal):
34 | """
35 | extract mfcc feature
36 | :param raw_signal: the original audio signal
37 | :param fs: sample frequency
38 | :param frameSize:the size of each frame
39 | :param step:
40 | :return: a series of mfcc feature of each frame and flatten to (num, )
41 | """
42 | # Signal normalization
43 |
44 | """
45 | raw_signal = np.double(raw_signal)
46 |
47 | raw_signal = raw_signal / (2.0 ** 15)
48 | DC = raw_signal.mean()
49 | MAX = (np.abs(raw_signal)).max()
50 | raw_signal = (raw_signal - DC) / (MAX + eps)
51 | """
52 | feature = MFCC(raw_signal, fs=8000, frameSize=512, step=256)
53 | # print(feature.shape)
54 | return feature.flatten()
55 |
56 |
57 | def distance_dtw(sample_x, sample_y, show=False, dtw_method=1, dist=euclidean):
58 | """
59 | calculate the distance between sample_x and sample_y using dtw
60 | :param sample_x: ndarray, mfcc feature for each frame
61 | :param sample_y: the same as sample_x
62 | :param show: bool, if true, show the path
63 | :param dtw_method: 1:accelerated_dtw, 2:fastdtw
64 | :return: the euclidean distance
65 | """
66 | # euclidean_norm = lambda x, y: np.abs(x - y)euclidean
67 | #
68 | #
69 | if dtw_method==2:
70 | d, path = fastdtw(sample_x, sample_y, dist=dist)
71 | elif dtw_method==1:
72 | d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(sample_x, sample_y, dist='euclidean')
73 | if show:
74 | plt.imshow(acc_cost_matrix.T, origin='lower', cmap='gray', interpolation='nearest')
75 | plt.plot(path[0], path[1], 'w')
76 | plt.show()
77 |
78 | return d
79 |
80 |
81 | def distance_train(data):
82 | """
83 | calculate the distance of all data
84 | :param data: input data, list, mfcc feature of all audio
85 | :return: the distance matrix
86 | """
87 | start_time = get_time()
88 | distance = np.zeros((len(data), len(data)))
89 | for index, sample_x in enumerate(data):
90 | col = index + 1
91 | for sample_y in data[col:]:
92 | distance[index, col] = distance_dtw(sample_x, sample_y)
93 | distance[col, index] = distance[index, col]
94 | col += 1
95 | print('cost {}s'.format(get_time(start_time)))
96 | return distance
97 |
98 | def distance_test(x_test, x_train, show=False):
99 | """
100 | calculate the distance between x_test(one sample) and x_train(many sample)
101 | :param x_test: a sample
102 | :param x_train: the whole train dataset
103 | :return: distance based on dtw
104 | """
105 | distance = np.zeros((1, len(x_train)))
106 | for index in range(len(x_train)):
107 | distance[0, index] = distance_dtw(x_train[index], x_test, show=show)
108 | return distance
109 |
110 |
111 | def sample(x, y, sample_num=2, whole_num=8):
112 | index = random.sample(range(whole_num), sample_num)
113 | sample_x = []
114 | sample_y = []
115 | for i in range(4):
116 | for _index in index:
117 | sample_x.append(x[_index + whole_num*i])
118 | sample_y.append(y[_index + whole_num*i])
119 | return sample_x, sample_y
120 |
121 |
122 | def load_train(path='dataset/ASR/train', mfcc_extract=_MFCC):
123 | """
124 | load data from dataset/ASR/train and generate template
125 | :param path: the path of dataset
126 | :return: x is train data, y_label is the label of x
127 | """
128 | start_time = get_time()
129 | # wav_dir is a list, which include four directory in train.
130 | wav_dir = os.listdir(path)
131 | y_label = []
132 | x = []
133 | print("Generate template according to train set.")
134 | for _dir in tqdm(wav_dir):
135 | _x = []
136 | for _path in os.listdir(os.path.join(path, _dir)):
137 | _, data = read(os.path.join(path, _dir, _path))
138 | # Some audio has two channel, but some audio has one channel.
139 | # so, I add "try except" to deal with such problem.
140 | # downsample the data to 8k
141 | try:
142 | _x.append(mfcc_extract(data[range(0, data.shape[0], 2), 0]))
143 | except:
144 | _x.append(mfcc_extract(data[range(0, data.shape[0], 2)]))
145 | del data
146 | # print(_x[-1].shape)
147 | # generate a template of different speaker.
148 | x.append(generate_template(_x))
149 | y_label.append(_dir)
150 |
151 | print('Loading train data, extract mfcc feature and generate template spend {}s'.format(get_time(start_time)))
152 | return x,y_label
153 |
154 |
155 | def load_test(path='dataset/ASR/test', mfcc_extract=MFCC, template=False):
156 | """
157 | load data from dataset/ASR/test
158 | :param path: the path of dataset
159 | :return: x is train data, y_label is the label of x
160 | """
161 | start_time = get_time()
162 | if template:
163 | # load template directly.
164 | pass
165 | # wav_dir is a list, which include four directory in train.
166 | wav_dir = os.listdir(path)
167 | y_label = []
168 | x = []
169 | # enc = OrdinalEncoder()
170 | for _dir in wav_dir:
171 | for _path in os.listdir(os.path.join(path, _dir)):
172 | _, data = read(os.path.join(path, _dir, _path))
173 | # Some audio has two channel, but some audio has one channel.
174 | # so, I add "try except" to deal with such problem.
175 | # downsample the data to 8k
176 | try:
177 | x.append(mfcc_extract(data[range(0, data.shape[0], 2), 0]))
178 | except:
179 | x.append(mfcc_extract(data[range(0, data.shape[0], 2)]))
180 | del data
181 | y_label.append(_dir)
182 |
183 | print('Loading test data and extract mfcc feature spend {}s'.format(get_time(start_time)))
184 | return x,y_label
185 |
186 |
187 | def generate_template(x):
188 | # max_length is the max length of audio in x.
189 | max_length = -1
190 |
191 | # max_length_index is the index of max length audio.
192 | max_length_index = 0
193 | template = None
194 | for index, _x in enumerate(x):
195 | if _x.shape[0] > max_length:
196 | max_length = _x.shape[0]
197 | max_length_index = index
198 |
199 | template = x[max_length_index]
200 | for index, _x in enumerate(x):
201 | if index != max_length_index:
202 | d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(_x, template, dist='euclidean')
203 | template = (_x[path[0]] + template[path[1]])/2
204 | # the dimension of template will arise after previous step,
205 | # so I will decrease the dimension of template, to keep it to be the same as initial.
206 | pre_road = -1
207 | ind = []
208 | for current_road in path[1]:
209 | if current_road!=pre_road:
210 | ind.append(True)
211 | else:
212 | ind.append(False)
213 | pre_road = current_road
214 |
215 | template = template[ind]
216 |
217 | return template
218 |
219 |
220 | def vote(label):
221 | label = np.array(label)
222 | _dict = {}
223 | for l in label:
224 | if l not in _dict:
225 | _dict[l] = 1
226 | else:
227 | _dict[l] += 1
228 |
229 | return sorted(_dict.items(), key=lambda x: x[1], reverse=True)[0][0]
230 |
231 |
232 | def test(threshold=100):
233 | x_train,y_train = load_train(path='dataset/ASR/train')
234 | # x_train,y_train = sample(x_train, y_train)
235 | x_test,y_test = load_test(path='dataset/ASR/test')
236 | # x_test, y_test = x_train,y_train
237 | y_pred = []
238 | # print(len(x_train))
239 | distances = np.zeros((len(x_test), len(x_train)))
240 | index = 0
241 | for x in tqdm(x_test):
242 | distance = distance_test(x, x_train)
243 | distances[index, :] = distance
244 | # top = np.argsort(distance)
245 | # print(top)
246 | y_pred.append(y_train[np.argmin(distance)])
247 | index += 1
248 | # when I set threshold to 100, the results is very bad, many sample are classified to other,
249 | # so, I decide to give up threshold,
250 | """
251 | if np.min(distance) < threshold:
252 | y_pred.append(y_train[np.argmin(distance)])
253 | else:
254 | y_pred.append('other')
255 | """
256 | y_pred = np.array(y_pred)
257 | y_test = np.array(y_test)
258 | acc = (y_pred==y_test).sum()/y_test.shape[0]
259 | print("accuracy is {:.2%}".format(acc))
260 | # distances = np.concatenate([y_test.reshape(-1,1), distances], axis=1)
261 | # print(y_train)
262 | # np.savetxt('distance_template.csv', X=distances, delimiter=',')
263 | np.savetxt('res.csv', X=(y_pred==y_test), delimiter=',')
264 |
265 | plot_confusion_matrix(y_test, y_pred, classes=y_train)
266 | plt.show()
267 |
268 |
269 | def plot(filename):
270 | _, audio = read(filename)
271 | # 语音图
272 | plt.figure()
273 | plt.plot(audio)
274 | # 频谱图
275 | # mel频谱图
276 | # DTW路径图
277 | pass
278 |
279 |
280 | if __name__=='__main__':
281 | test()
282 | """
283 | x_train,y_train = load_wav(path='dataset/ASR/train')
284 | # distance = distance_dtw(x_train[0], x_train[1])
285 | print(x_train[0].shape)
286 | d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(x_train[0], x_train[1], dist='euclidean')
287 | print(cost_matrix.shape)
288 | print(acc_cost_matrix.shape)
289 | print('*'*50)
290 | print(path[0].shape)
291 | print('*'*50)
292 | print(path[1].shape)
293 | plt.imshow(acc_cost_matrix.T, origin='lower', cmap='gray', interpolation='nearest')
294 | plt.plot(path[0], path[1], 'w')
295 | plt.show()
296 | """
297 |
298 | # np.savetxt('dis.csv', X=distance, delimiter=',')
299 | # print(distance_dtw(x[0], x[1], show=True))
300 | # print(distance_dtw(x[0], x[5], show=True))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # speech_signal_processing
2 |
3 | Any question, you can pull a issue or email me.
4 |
5 | ## Description
6 |
7 | * VAD.py is the first project.
8 |
9 | * MFCC_DTW.py is the second project.
10 |
11 | * GMM_UBM.py is the third project, and GUI.py is the GUI of this project.
12 |
13 | * d_vector.py is final project, and Final_GUI.py is the GUI of this project.
14 |
15 | * feature dir saved model and feature file. you can download it from
16 | [here](https://pan.baidu.com/s/1P65s2OAqqsnnfFOn6XCbSA), code is iwmf.
17 |
18 | * All the reports are in report dir.
19 |
20 | ## Requirement
21 |
22 | python 3.x,windows
23 |
24 | Any other package, run code below
25 | ```
26 | pip install -r requirements.txt
27 | ```
28 | or
29 | ```
30 | pip install dtw librosa fastdtw tqdm sidekit tensorflow keras numpy scipy pyqt sklearn
31 | ```
32 |
33 | *NOTE*:you can use mirror to speed up,refer [blog](https://www.cnblogs.com/microman/p/6107879.html)
34 |
35 | ## dataset
36 | download dataset for d-vector from [voxceleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/).
37 |
38 | ## Experiment Log
39 |
40 | #### MFCC+DTW
41 | | DTW |Time(s)| Acc(%) |
42 | |:---:|:---:|:---:|
43 | |accelerated_dtw|92|83.72|
44 | |accelerated_dtw+pre-emphasis|105|74.42|
45 | |fastdtw|71|60.47|
46 | |fastdtw+pre-emphasis|79|65.12|
47 |
48 | **Summury**:The results of fastdtw is bad than accelerated_dtw, so I suggest you to use accelerated
49 | rather than fastdtw if you prefer more on accuracy.
50 |
51 | ## MFCC
52 |
53 | [blog](https://kleinzcy.github.io/blog/speech%20signal%20processing/%E6%A2%85%E5%B0%94%E5%80%92%E8%B0%B1%E7%B3%BB%E6%95%B0)
54 |
55 | ## GMM
56 |
57 | [blog](https://appliedmachinelearning.blog/2017/11/14/spoken-speaker-identification-based-on-gaussian-mixture-models-python-implementation/)
58 |
59 | [paper](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.117.338&rep=rep1&type=pdf)
60 |
61 | [scikit-learn](https://scikit-learn.org/stable/modules/mixture.html#gmm)
62 |
63 | [SIDEKIT](https://pypi.org/project/SIDEKIT/)
64 |
65 | ## MFCC+GMM
66 |
67 | Please read report for more details.
68 |
69 | ## d-vector
70 |
71 | we train our model on voxceleb dataset, more details, please read report.
72 |
73 | |model|time(s)|train_acc|valid_acc|epoch|test_acc|test_time|
74 | |:---:|:---:|:---:|:----:|:---:|:---:|:---:|
75 | |nn|56s|0.5321|0.4672|50|0.3682|11.92|
76 | |lstm|2906|0.788|0.5472|100|0.4371|49.53|
77 | |gru|2977|0.9385|0.7484|30|0.3766|70.05|
78 |
79 | **inference**:[paper](https://ieeexplore.ieee.org/document/6854363)
80 |
81 | **inference_gru**:[paper](https://arxiv.org/pdf/1705.02304.pdf)
82 |
83 | **inference_lstm**:[paper](https://arxiv.org/abs/1509.08062)
84 |
85 | ## Reference
86 |
87 | 1. [audio-mnist-with-person-detection](https://github.com/yogeshjadhav7/audio-mnist-with-person-detection)
88 |
89 | 2. [dVectorSpeakerRecognition](https://github.com/wangleiai/dVectorSpeakerRecognition)
90 |
91 | 3. [speaker-verification](https://github.com/rajathkmp/speaker-verification)
92 |
93 | 4. [voxceleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/)
94 |
95 |
96 |
--------------------------------------------------------------------------------
/UI/GMM_UBM_GUI.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'GMM_UBM_GUI.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.11.3
6 | #
7 | # WARNING! All changes made in this file will be lost!
8 |
9 | from PyQt5 import QtCore, QtGui, QtWidgets
10 | import pickle as pkl
11 | import os
12 | from playsound import playsound
13 | from GMM_UBM import delta
14 | from utils.tools import record,read
15 | import numpy as np
16 | from sidekit.frontend.features import plp,mfcc
17 | from sklearn import preprocessing
18 | from aip import AipSpeech
19 |
20 | import warnings
21 | warnings.filterwarnings("ignore")
22 |
23 | class Ui_Form(object):
24 | def setupUi(self, Form):
25 | Form.setObjectName("Form")
26 | Form.resize(300, 343)
27 | self.horizontalLayoutWidget = QtWidgets.QWidget(Form)
28 | self.horizontalLayoutWidget.setGeometry(QtCore.QRect(70, 20, 160, 80))
29 | self.horizontalLayoutWidget.setObjectName("horizontalLayoutWidget")
30 | self.horizontalLayout = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget)
31 | self.horizontalLayout.setContentsMargins(0, 0, 0, 0)
32 | self.horizontalLayout.setObjectName("horizontalLayout")
33 | self.radioButton_2 = QtWidgets.QRadioButton(self.horizontalLayoutWidget)
34 | self.radioButton_2.setObjectName("radioButton_2")
35 | self.horizontalLayout.addWidget(self.radioButton_2)
36 | self.radioButton = QtWidgets.QRadioButton(self.horizontalLayoutWidget)
37 | self.radioButton.setObjectName("radioButton")
38 | self.horizontalLayout.addWidget(self.radioButton)
39 | self.pushButton = QtWidgets.QPushButton(Form)
40 | self.pushButton.setGeometry(QtCore.QRect(100, 120, 93, 28))
41 | self.pushButton.setObjectName("pushButton")
42 | self.pushButton_2 = QtWidgets.QPushButton(Form)
43 | self.pushButton_2.setGeometry(QtCore.QRect(100, 180, 93, 28))
44 | self.pushButton_2.setObjectName("pushButton_2")
45 | self.textBrowser = QtWidgets.QTextBrowser(Form)
46 | self.textBrowser.setGeometry(QtCore.QRect(20, 220, 256, 111))
47 | self.textBrowser.setObjectName("textBrowser")
48 |
49 | self.retranslateUi(Form)
50 | self.radioButton_2.clicked.connect(Form.load_plp)
51 | self.radioButton.clicked.connect(Form.load_mfcc)
52 | self.pushButton.clicked.connect(Form.record)
53 | self.pushButton_2.clicked.connect(Form.test)
54 | QtCore.QMetaObject.connectSlotsByName(Form)
55 |
56 | def retranslateUi(self, Form):
57 | _translate = QtCore.QCoreApplication.translate
58 | Form.setWindowTitle(_translate("Form", "Form"))
59 | self.radioButton_2.setText(_translate("Form", "PLP"))
60 | self.radioButton.setText(_translate("Form", "MFCC"))
61 | self.pushButton.setText(_translate("Form", "Record"))
62 | self.pushButton_2.setText(_translate("Form", "Test"))
63 |
64 | def load_plp(self):
65 | self.textBrowser.clear()
66 | self.textBrowser.append('Loading PLP feature model')
67 | self.feature_type = 'PLP'
68 | self.load()
69 |
70 | def load_mfcc(self):
71 | self.textBrowser.clear()
72 | self.textBrowser.append('Loading MFCC feature model')
73 | self.feature_type = 'MFCC'
74 | self.load()
75 |
76 | def load(self):
77 | with open("Model/GMM_{}_model.pkl".format(self.feature_type), 'rb') as f:
78 | self.GMM = pkl.load(f)
79 | with open("Model/UBM_{}_model.pkl".format(self.feature_type), 'rb') as f:
80 | self.UBM = pkl.load(f)
81 |
82 | self.textBrowser.append('complete')
83 |
84 |
85 | def record(self):
86 | self.textBrowser.append('Start the recording !')
87 | record(seconds=3)
88 | self.textBrowser.append('3 seconds record has completed.')
89 | _, audio = read(filename='test.wav')
90 | if self.feature_type=='MFCC':
91 | feature = mfcc(audio)[0]
92 | else:
93 | feature = plp(audio)[0]
94 |
95 | _delta = delta(feature)
96 | feature = np.hstack((feature, _delta))
97 |
98 | feature = preprocessing.scale(feature)
99 | self.feature = feature
100 | os.remove('test.wav')
101 |
102 | def test(self):
103 | prob = np.zeros((1, len(self.GMM)))
104 | for i in range(len(self.GMM)):
105 | prob[0,i] = self.GMM[i].score(self.feature) - self.UBM.score(self.feature)
106 |
107 | res = prob.argmax(axis=1)
108 |
109 | num2name = ['班富景', '郭佳怡', '黄心羿', '居慧敏', '廖楚楚', '刘山', '任蕴菡', '阮煜文', '苏林林', '万之颖',
110 | '陈斌', '陈泓宇', '陈军栋', '蔡晓明', '邓刚刚', '董俊虎', '代旭辉', '高威', '龚兵庆', '姜宇伦',
111 | '靳子涵', '李恩', '罗远哲', '罗伟宇', '李想', '李晓波', '李彦能', '刘乙灼', '刘志航', '李忠亚']
112 |
113 | prob = np.exp(prob)/np.exp(prob).sum(axis=1)
114 | # print(prob)
115 | output = str(num2name[res[0]]) + ':' + str(prob[0,res[0]])
116 | self.textBrowser.append(output)
117 | APP_ID = '11719204'
118 | API_KEY = 'g7SpqGrkSKgTEBti3pfDsprD'
119 | SECRET_KEY = 'Tn5CS7EE26rDH34H8z7GV3p0DYYpsksZ'
120 |
121 | client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
122 | result = client.synthesis('{:.2%}的可能是{}'.format(prob[0,res[0]],num2name[res[0]]), 'zh', 0, {
123 | 'vol': 5,
124 | })
125 | if not isinstance(result, dict):
126 | with open('result.mp3', 'wb') as f:
127 | f.write(result)
128 |
129 | playsound("result.mp3")
130 | os.remove('result.mp3')
131 |
--------------------------------------------------------------------------------
/UI/GMM_UBM_GUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Form
4 |
5 |
6 |
7 | 0
8 | 0
9 | 300
10 | 343
11 |
12 |
13 |
14 | Form
15 |
16 |
17 |
18 |
19 | 70
20 | 20
21 | 160
22 | 80
23 |
24 |
25 |
26 | -
27 |
28 |
29 | PLP
30 |
31 |
32 |
33 | -
34 |
35 |
36 | MFCC
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | 100
46 | 120
47 | 93
48 | 28
49 |
50 |
51 |
52 | Record
53 |
54 |
55 |
56 |
57 |
58 | 100
59 | 180
60 | 93
61 | 28
62 |
63 |
64 |
65 | Test
66 |
67 |
68 |
69 |
70 |
71 | 20
72 | 220
73 | 256
74 | 111
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 | radioButton_2
83 | clicked()
84 | Form
85 | load_plp()
86 |
87 |
88 | 73
89 | 57
90 |
91 |
92 | 45
93 | 63
94 |
95 |
96 |
97 |
98 | radioButton
99 | clicked()
100 | Form
101 | load_mfcc()
102 |
103 |
104 | 211
105 | 65
106 |
107 |
108 | 260
109 | 62
110 |
111 |
112 |
113 |
114 | pushButton
115 | clicked()
116 | Form
117 | record()
118 |
119 |
120 | 190
121 | 142
122 |
123 |
124 | 229
125 | 134
126 |
127 |
128 |
129 |
130 | pushButton_2
131 | clicked()
132 | Form
133 | test()
134 |
135 |
136 | 173
137 | 190
138 |
139 |
140 | 234
141 | 189
142 |
143 |
144 |
145 |
146 |
147 | load_plp()
148 | load_mfcc()
149 | record()
150 | test()
151 |
152 |
153 |
--------------------------------------------------------------------------------
/UI/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/4/20 22:56
4 | # @Author : chuyu zhang
5 | # @File : __init__.py.py
6 | # @Software: PyCharm
--------------------------------------------------------------------------------
/UI/final.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'final.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.11.3
6 | #
7 | # WARNING! All changes made in this file will be lost!
8 |
9 | from PyQt5 import QtCore, QtGui, QtWidgets
10 |
11 | class Ui_MainWindow(object):
12 | def setupUi(self, MainWindow):
13 | MainWindow.setObjectName("MainWindow")
14 | MainWindow.resize(692, 561)
15 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
16 | sizePolicy.setHorizontalStretch(0)
17 | sizePolicy.setVerticalStretch(0)
18 | sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth())
19 | MainWindow.setSizePolicy(sizePolicy)
20 | MainWindow.setFocusPolicy(QtCore.Qt.NoFocus)
21 | MainWindow.setAnimated(True)
22 | self.centralwidget = QtWidgets.QWidget(MainWindow)
23 | self.centralwidget.setObjectName("centralwidget")
24 | self.gridLayout = QtWidgets.QGridLayout(self.centralwidget)
25 | self.gridLayout.setObjectName("gridLayout")
26 | self.verticalLayout_4 = QtWidgets.QVBoxLayout()
27 | self.verticalLayout_4.setContentsMargins(-1, 0, -1, -1)
28 | self.verticalLayout_4.setObjectName("verticalLayout_4")
29 | self.label = QtWidgets.QLabel(self.centralwidget)
30 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
31 | sizePolicy.setHorizontalStretch(0)
32 | sizePolicy.setVerticalStretch(0)
33 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth())
34 | self.label.setSizePolicy(sizePolicy)
35 | self.label.setMaximumSize(QtCore.QSize(480, 320))
36 | self.label.setSizeIncrement(QtCore.QSize(1, 1))
37 | self.label.setText("")
38 | self.label.setPixmap(QtGui.QPixmap(":/newPrefix/1.jpg"))
39 | self.label.setScaledContents(True)
40 | self.label.setAlignment(QtCore.Qt.AlignCenter)
41 | self.label.setObjectName("label")
42 | self.verticalLayout_4.addWidget(self.label)
43 | self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)
44 | font = QtGui.QFont()
45 | font.setFamily("Adobe 宋体 Std L")
46 | font.setPointSize(15)
47 | self.textBrowser_2.setFont(font)
48 | self.textBrowser_2.setReadOnly(False)
49 | self.textBrowser_2.setCursorWidth(4)
50 | self.textBrowser_2.setObjectName("textBrowser_2")
51 | self.verticalLayout_4.addWidget(self.textBrowser_2)
52 | self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)
53 | self.textBrowser_3.setObjectName("textBrowser_3")
54 | self.verticalLayout_4.addWidget(self.textBrowser_3)
55 | self.gridLayout.addLayout(self.verticalLayout_4, 0, 1, 1, 1)
56 | self.verticalLayout_1 = QtWidgets.QVBoxLayout()
57 | self.verticalLayout_1.setSpacing(0)
58 | self.verticalLayout_1.setObjectName("verticalLayout_1")
59 | self.horizontalLayout = QtWidgets.QHBoxLayout()
60 | self.horizontalLayout.setContentsMargins(-1, -1, 0, -1)
61 | self.horizontalLayout.setSpacing(10)
62 | self.horizontalLayout.setObjectName("horizontalLayout")
63 | self.groupBox = QtWidgets.QGroupBox(self.centralwidget)
64 | self.groupBox.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China))
65 | self.groupBox.setObjectName("groupBox")
66 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.groupBox)
67 | self.verticalLayout_3.setObjectName("verticalLayout_3")
68 | self.radioButton = QtWidgets.QRadioButton(self.groupBox)
69 | self.radioButton.setObjectName("radioButton")
70 | self.verticalLayout_3.addWidget(self.radioButton)
71 | self.radioButton_2 = QtWidgets.QRadioButton(self.groupBox)
72 | self.radioButton_2.setObjectName("radioButton_2")
73 | self.verticalLayout_3.addWidget(self.radioButton_2)
74 | self.radioButton_3 = QtWidgets.QRadioButton(self.groupBox)
75 | self.radioButton_3.setObjectName("radioButton_3")
76 | self.verticalLayout_3.addWidget(self.radioButton_3)
77 | self.horizontalLayout.addWidget(self.groupBox)
78 | self.groupBox_2 = QtWidgets.QGroupBox(self.centralwidget)
79 | self.groupBox_2.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China))
80 | self.groupBox_2.setObjectName("groupBox_2")
81 | self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.groupBox_2)
82 | self.verticalLayout_2.setObjectName("verticalLayout_2")
83 | self.radioButton_4 = QtWidgets.QRadioButton(self.groupBox_2)
84 | self.radioButton_4.setObjectName("radioButton_4")
85 | self.verticalLayout_2.addWidget(self.radioButton_4)
86 | self.radioButton_6 = QtWidgets.QRadioButton(self.groupBox_2)
87 | self.radioButton_6.setObjectName("radioButton_6")
88 | self.verticalLayout_2.addWidget(self.radioButton_6)
89 | self.radioButton_5 = QtWidgets.QRadioButton(self.groupBox_2)
90 | self.radioButton_5.setObjectName("radioButton_5")
91 | self.verticalLayout_2.addWidget(self.radioButton_5)
92 | self.horizontalLayout.addWidget(self.groupBox_2)
93 | self.verticalLayout_1.addLayout(self.horizontalLayout)
94 | self.verticalLayout = QtWidgets.QVBoxLayout()
95 | self.verticalLayout.setObjectName("verticalLayout")
96 | self.groupBox_3 = QtWidgets.QGroupBox(self.centralwidget)
97 | self.groupBox_3.setObjectName("groupBox_3")
98 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.groupBox_3)
99 | self.horizontalLayout_2.setContentsMargins(5, 5, 5, 5)
100 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
101 | self.radioButton_8 = QtWidgets.QRadioButton(self.groupBox_3)
102 | self.radioButton_8.setObjectName("radioButton_8")
103 | self.horizontalLayout_2.addWidget(self.radioButton_8)
104 | self.radioButton_7 = QtWidgets.QRadioButton(self.groupBox_3)
105 | self.radioButton_7.setObjectName("radioButton_7")
106 | self.horizontalLayout_2.addWidget(self.radioButton_7)
107 | self.verticalLayout.addWidget(self.groupBox_3)
108 | self.pushButton = QtWidgets.QPushButton(self.centralwidget)
109 | self.pushButton.setObjectName("pushButton")
110 | self.verticalLayout.addWidget(self.pushButton)
111 | self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
112 | self.pushButton_2.setObjectName("pushButton_2")
113 | self.verticalLayout.addWidget(self.pushButton_2)
114 | self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget)
115 | self.pushButton_3.setObjectName("pushButton_3")
116 | self.verticalLayout.addWidget(self.pushButton_3)
117 | self.verticalLayout_1.addLayout(self.verticalLayout)
118 | self.verticalLayout_1.setStretch(0, 1)
119 | self.verticalLayout_1.setStretch(1, 3)
120 | self.gridLayout.addLayout(self.verticalLayout_1, 0, 0, 1, 1)
121 | MainWindow.setCentralWidget(self.centralwidget)
122 | self.menubar = QtWidgets.QMenuBar(MainWindow)
123 | self.menubar.setGeometry(QtCore.QRect(0, 0, 692, 26))
124 | self.menubar.setObjectName("menubar")
125 | MainWindow.setMenuBar(self.menubar)
126 | self.statusbar = QtWidgets.QStatusBar(MainWindow)
127 | self.statusbar.setObjectName("statusbar")
128 | MainWindow.setStatusBar(self.statusbar)
129 |
130 | self.retranslateUi(MainWindow)
131 | QtCore.QMetaObject.connectSlotsByName(MainWindow)
132 |
133 | def retranslateUi(self, MainWindow):
134 | _translate = QtCore.QCoreApplication.translate
135 | MainWindow.setWindowTitle(_translate("MainWindow", "说护者识别"))
136 | self.textBrowser_2.setHtml(_translate("MainWindow", "\n"
137 | "
\n"
140 | "
"))
141 | self.textBrowser_3.setHtml(_translate("MainWindow", "\n"
142 | "\n"
145 | "
"))
146 | self.groupBox.setTitle(_translate("MainWindow", "说话人识别"))
147 | self.radioButton.setText(_translate("MainWindow", "NN"))
148 | self.radioButton_2.setText(_translate("MainWindow", "GRU"))
149 | self.radioButton_3.setText(_translate("MainWindow", "LSTM"))
150 | self.groupBox_2.setTitle(_translate("MainWindow", "语种识别"))
151 | self.radioButton_4.setText(_translate("MainWindow", "MFCC"))
152 | self.radioButton_6.setText(_translate("MainWindow", "PLP"))
153 | self.radioButton_5.setText(_translate("MainWindow", "MFCC+PLP"))
154 | self.groupBox_3.setTitle(_translate("MainWindow", "mode"))
155 | self.radioButton_8.setText(_translate("MainWindow", "说话人识别"))
156 | self.radioButton_7.setText(_translate("MainWindow", "语种识别"))
157 | self.pushButton.setText(_translate("MainWindow", "Enroll"))
158 | self.pushButton_2.setText(_translate("MainWindow", "Test"))
159 | self.pushButton_3.setText(_translate("MainWindow", "Set"))
160 |
161 | import source_rc
162 |
--------------------------------------------------------------------------------
/UI/final.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | MainWindow
4 |
5 |
6 |
7 | 0
8 | 0
9 | 692
10 | 561
11 |
12 |
13 |
14 |
15 | 0
16 | 0
17 |
18 |
19 |
20 | Qt::NoFocus
21 |
22 |
23 | 说护者识别
24 |
25 |
26 | true
27 |
28 |
29 |
30 | -
31 |
32 |
33 | 0
34 |
35 |
-
36 |
37 |
38 |
39 | 0
40 | 0
41 |
42 |
43 |
44 |
45 | 480
46 | 320
47 |
48 |
49 |
50 |
51 | 1
52 | 1
53 |
54 |
55 |
56 |
57 |
58 |
59 | :/newPrefix/1.jpg
60 |
61 |
62 | true
63 |
64 |
65 | Qt::AlignCenter
66 |
67 |
68 |
69 | -
70 |
71 |
72 |
73 | Adobe 宋体 Std L
74 | 15
75 |
76 |
77 |
78 | false
79 |
80 |
81 | <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd">
82 | <html><head><meta name="qrichtext" content="1" /><style type="text/css">
83 | p, li { white-space: pre-wrap; }
84 | </style></head><body style=" font-family:'Adobe 宋体 Std L'; font-size:15pt; font-weight:400; font-style:normal;">
85 | <p style="-qt-paragraph-type:empty; margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px; font-family:'SimSun'; font-size:9pt;"><br /></p></body></html>
86 |
87 |
88 | 4
89 |
90 |
91 |
92 | -
93 |
94 |
95 | <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd">
96 | <html><head><meta name="qrichtext" content="1" /><style type="text/css">
97 | p, li { white-space: pre-wrap; }
98 | </style></head><body style=" font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;">
99 | <p style="-qt-paragraph-type:empty; margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><br /></p></body></html>
100 |
101 |
102 |
103 |
104 |
105 | -
106 |
107 |
108 | 0
109 |
110 |
-
111 |
112 |
113 | 10
114 |
115 |
116 | 0
117 |
118 |
-
119 |
120 |
121 |
122 |
123 |
124 | 说话人识别
125 |
126 |
127 |
-
128 |
129 |
130 | NN
131 |
132 |
133 |
134 | -
135 |
136 |
137 | GRU
138 |
139 |
140 |
141 | -
142 |
143 |
144 | LSTM
145 |
146 |
147 |
148 |
149 |
150 |
151 | -
152 |
153 |
154 |
155 |
156 |
157 | 语种识别
158 |
159 |
160 |
-
161 |
162 |
163 | MFCC
164 |
165 |
166 |
167 | -
168 |
169 |
170 | PLP
171 |
172 |
173 |
174 | -
175 |
176 |
177 | MFCC+PLP
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 | -
187 |
188 |
-
189 |
190 |
191 | mode
192 |
193 |
194 |
195 | 5
196 |
197 |
198 | 5
199 |
200 |
201 | 5
202 |
203 |
204 | 5
205 |
206 |
-
207 |
208 |
209 | 说话人识别
210 |
211 |
212 |
213 | -
214 |
215 |
216 | 语种识别
217 |
218 |
219 |
220 |
221 |
222 |
223 | -
224 |
225 |
226 | Enroll
227 |
228 |
229 |
230 | -
231 |
232 |
233 | Test
234 |
235 |
236 |
237 | -
238 |
239 |
240 | Set
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
--------------------------------------------------------------------------------
/UI/tmp.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'final.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.11.3
6 | #
7 | # WARNING! All changes made in this file will be lost!
8 |
9 | from PyQt5 import QtCore, QtGui, QtWidgets
10 | from PyQt5.QtMultimedia import QMediaContent, QMediaPlayer
11 | from PyQt5.QtMultimediaWidgets import QVideoWidget
12 | from d_vector import nn_model
13 | from keras.models import load_model
14 | import pickle as pkl
15 | import numpy as np
16 | from scipy.spatial.distance import cosine
17 | from sidekit.frontend.features import plp,mfcc
18 | from sklearn import preprocessing
19 | from utils.tools import record,read
20 | from tqdm import tqdm
21 | import os
22 | import python_speech_features as psf
23 | #TODO 测试录音数据
24 | #TODO 实验报告
25 | class Ui_MainWindow(object):
26 | def setupUi(self, MainWindow):
27 | #TODO 播放条
28 | #TODO 上下拉升问题
29 | MainWindow.setObjectName("MainWindow")
30 | MainWindow.resize(739, 583)
31 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
32 | sizePolicy.setHorizontalStretch(0)
33 | sizePolicy.setVerticalStretch(0)
34 | sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth())
35 | MainWindow.setSizePolicy(sizePolicy)
36 | MainWindow.setFocusPolicy(QtCore.Qt.NoFocus)
37 | MainWindow.setAnimated(True)
38 |
39 | self.centralwidget = QtWidgets.QWidget(MainWindow)
40 | self.centralwidget.setObjectName("centralwidget")
41 | self.gridLayout = QtWidgets.QGridLayout(self.centralwidget)
42 | self.gridLayout.setObjectName("gridLayout")
43 |
44 | # verticalLayout_1这是groupbox和pushbutton垂直排列
45 | self.verticalLayout_1 = QtWidgets.QVBoxLayout()
46 | self.verticalLayout_1.setSpacing(40)
47 | self.verticalLayout_1.setObjectName("verticalLayout_1")
48 |
49 | # verticalLayout_5是三个groupbox横向排列
50 | self.verticalLayout_5 = QtWidgets.QVBoxLayout()
51 | self.verticalLayout_5.setObjectName("verticalLayout_5")
52 |
53 | # groupbox_3
54 | self.groupBox_3 = QtWidgets.QGroupBox(self.centralwidget)
55 | self.groupBox_3.setObjectName("groupBox_3")
56 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.groupBox_3)
57 | self.horizontalLayout_2.setContentsMargins(5, 5, 5, 5)
58 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
59 | self.radioButton_7 = QtWidgets.QRadioButton(self.groupBox_3)
60 | self.radioButton_7.setObjectName("radioButton_7")
61 | self.horizontalLayout_2.addWidget(self.radioButton_7)
62 | self.radioButton_8 = QtWidgets.QRadioButton(self.groupBox_3)
63 | self.radioButton_8.setObjectName("radioButton_8")
64 | self.horizontalLayout_2.addWidget(self.radioButton_8)
65 |
66 | self.verticalLayout_5.addWidget(self.groupBox_3)
67 |
68 | # horizontalLayout这是两个groupbox横向排列
69 | self.horizontalLayout = QtWidgets.QHBoxLayout()
70 | self.horizontalLayout.setContentsMargins(-1, -1, 0, -1)
71 | self.horizontalLayout.setSpacing(10)
72 | self.horizontalLayout.setObjectName("horizontalLayout")
73 |
74 | # groupbox
75 | self.groupBox = QtWidgets.QGroupBox(self.centralwidget)
76 | self.groupBox.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China))
77 | self.groupBox.setObjectName("groupBox")
78 |
79 | # verticalLayout_3这是groupbox内垂直排列
80 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.groupBox)
81 | self.verticalLayout_3.setObjectName("verticalLayout_3")
82 | self.verticalLayout_3.setSpacing(10)
83 |
84 | self.radioButton = QtWidgets.QRadioButton(self.groupBox)
85 | self.radioButton.setObjectName("radioButton")
86 | self.verticalLayout_3.addWidget(self.radioButton)
87 | self.radioButton_2 = QtWidgets.QRadioButton(self.groupBox)
88 | self.radioButton_2.setObjectName("radioButton_2")
89 | self.verticalLayout_3.addWidget(self.radioButton_2)
90 | self.radioButton_3 = QtWidgets.QRadioButton(self.groupBox)
91 | self.radioButton_3.setObjectName("radioButton_3")
92 | self.verticalLayout_3.addWidget(self.radioButton_3)
93 | self.horizontalLayout.addWidget(self.groupBox)
94 |
95 | # groupbox_2
96 | self.groupBox_2 = QtWidgets.QGroupBox(self.centralwidget)
97 | self.groupBox_2.setLocale(QtCore.QLocale(QtCore.QLocale.Chinese, QtCore.QLocale.China))
98 | self.groupBox_2.setObjectName("groupBox_2")
99 |
100 | # verticalLayout_2这是groupbox_2内垂直排列
101 | self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.groupBox_2)
102 | self.verticalLayout_2.setObjectName("verticalLayout_2")
103 | self.verticalLayout_2.setSpacing(10)
104 |
105 | self.radioButton_4 = QtWidgets.QRadioButton(self.groupBox_2)
106 | self.radioButton_4.setObjectName("radioButton_4")
107 | self.verticalLayout_2.addWidget(self.radioButton_4)
108 | self.radioButton_5 = QtWidgets.QRadioButton(self.groupBox_2)
109 | self.radioButton_5.setObjectName("radioButton_5")
110 | self.verticalLayout_2.addWidget(self.radioButton_5)
111 | self.radioButton_6 = QtWidgets.QRadioButton(self.groupBox_2)
112 | self.radioButton_6.setObjectName("radioButton_6")
113 | self.verticalLayout_2.addWidget(self.radioButton_6)
114 | self.horizontalLayout.addWidget(self.groupBox_2)
115 |
116 | self.verticalLayout_5.addLayout(self.horizontalLayout)
117 |
118 | self.verticalLayout_1.addLayout(self.verticalLayout_5)
119 |
120 | # verticalLayout这是三个pushbutton垂直排列
121 | self.verticalLayout = QtWidgets.QVBoxLayout()
122 | self.verticalLayout.setObjectName("verticalLayout")
123 | self.verticalLayout.setSpacing(50)
124 |
125 | self.pushButton = QtWidgets.QPushButton(self.centralwidget)
126 | self.pushButton.setObjectName("pushButton")
127 | self.verticalLayout.addWidget(self.pushButton)
128 | self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
129 | self.pushButton_2.setObjectName("pushButton_2")
130 | self.verticalLayout.addWidget(self.pushButton_2)
131 | self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget)
132 | self.pushButton_3.setObjectName("pushButton_3")
133 | self.verticalLayout.addWidget(self.pushButton_3)
134 |
135 | self.verticalLayout_1.addLayout(self.verticalLayout)
136 |
137 | self.verticalLayout_1.setStretch(1, 4)
138 |
139 | self.gridLayout.addLayout(self.verticalLayout_1, 0, 0, 1, 1)
140 |
141 | # verticalLayout_4这是三个文本框的垂直排列
142 | self.verticalLayout_4 = QtWidgets.QVBoxLayout()
143 | self.verticalLayout_4.setContentsMargins(-1, 0, -1, -1)
144 | self.verticalLayout_4.setObjectName("verticalLayout_4")
145 |
146 | self.label = QtWidgets.QLabel(self.centralwidget)
147 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
148 | sizePolicy.setHorizontalStretch(0)
149 | sizePolicy.setVerticalStretch(0)
150 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth())
151 | self.label.setSizePolicy(sizePolicy)
152 | self.label.setMaximumSize(QtCore.QSize(1080, 960))
153 | self.label.setSizeIncrement(QtCore.QSize(1, 1))
154 | self.label.setText("")
155 | self.label.setPixmap(QtGui.QPixmap("img/1.jpg"))
156 | # self.label.setPixmap(QtGui.QPixmap(":/newPrefix/1.jpg"))
157 | self.label.setScaledContents(True)
158 | self.label.setAlignment(QtCore.Qt.AlignCenter)
159 | self.label.setObjectName("label")
160 |
161 | self.verticalLayout_4.addWidget(self.label)
162 |
163 |
164 | self.playButton = QtWidgets.QPushButton()
165 | self.playButton.setEnabled(False)
166 | self.playButton.setIcon(MainWindow.style().standardIcon(QtWidgets.QStyle.SP_MediaPlay))
167 |
168 | self.positionSlider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
169 | self.positionSlider.setRange(0, 0)
170 |
171 | controlLayout = QtWidgets.QHBoxLayout()
172 | controlLayout.setContentsMargins(0, 0, 0, 0)
173 | controlLayout.addWidget(self.playButton)
174 | controlLayout.addWidget(self.positionSlider)
175 |
176 | self.verticalLayout_4.addLayout(controlLayout)
177 |
178 | self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)
179 | # self.textBrowser.setReadOnly(True)
180 | self.textBrowser.setObjectName("textBrowser")
181 | self.verticalLayout_4.addWidget(self.textBrowser)
182 |
183 | self.gridLayout.addLayout(self.verticalLayout_4, 0, 1, 1, 1)
184 | MainWindow.setCentralWidget(self.centralwidget)
185 | self.menubar = QtWidgets.QMenuBar(MainWindow)
186 | self.menubar.setGeometry(QtCore.QRect(0, 0, 739, 26))
187 | self.menubar.setObjectName("menubar")
188 | MainWindow.setMenuBar(self.menubar)
189 | self.statusbar = QtWidgets.QStatusBar(MainWindow)
190 | self.statusbar.setObjectName("statusbar")
191 | MainWindow.setStatusBar(self.statusbar)
192 |
193 | self.retranslateUi(MainWindow)
194 | self.signal(MainWindow)
195 | QtCore.QMetaObject.connectSlotsByName(MainWindow)
196 |
197 | def retranslateUi(self, MainWindow):
198 | _translate = QtCore.QCoreApplication.translate
199 | MainWindow.setWindowTitle(_translate("MainWindow", "说护者识别"))
200 | self.groupBox.setTitle(_translate("MainWindow", "说话人识别"))
201 | self.radioButton.setText(_translate("MainWindow", "NN"))
202 | self.radioButton_2.setText(_translate("MainWindow", "GRU"))
203 | self.radioButton_3.setText(_translate("MainWindow", "LSTM"))
204 | self.groupBox_2.setTitle(_translate("MainWindow", "语种识别"))
205 | self.radioButton_4.setText(_translate("MainWindow", "MFCC"))
206 | self.radioButton_5.setText(_translate("MainWindow", "PLP"))
207 | self.radioButton_6.setText(_translate("MainWindow", "MFCC+PLP"))
208 | self.groupBox_3.setTitle(_translate("MainWindow", "mode"))
209 | self.radioButton_7.setText(_translate("MainWindow", "说话人识别"))
210 | self.radioButton_8.setText(_translate("MainWindow", "语种识别"))
211 | self.pushButton.setText(_translate("MainWindow", "Enroll"))
212 | self.pushButton_2.setText(_translate("MainWindow", "Record"))
213 | self.pushButton_3.setText(_translate("MainWindow", "Test"))
214 | self.textBrowser.setHtml(_translate("MainWindow",
215 | "\n"
216 | "\n"))
219 |
220 | def signal(self, MainWindow):
221 | self.radioButton.clicked.connect(MainWindow.nn_model)
222 | self.radioButton_2.clicked.connect(MainWindow.gru_model)
223 | self.radioButton_3.clicked.connect(MainWindow.lstm_model)
224 | self.radioButton_4.clicked.connect(MainWindow.mfcc_fea)
225 | self.radioButton_5.clicked.connect(MainWindow.plp_fea)
226 | self.radioButton_6.clicked.connect(MainWindow.mfcc_plp_fea)
227 | self.radioButton_7.clicked.connect(MainWindow.speaker_model)
228 | self.radioButton_8.clicked.connect(MainWindow.language_model)
229 |
230 | self.pushButton.clicked.connect(MainWindow.enroll)
231 | self.pushButton_2.clicked.connect(MainWindow.records)
232 | self.pushButton_3.clicked.connect(MainWindow.test)
233 |
234 | # self.playButton.clicked.connect(MainWindow.play)
235 | # self.positionSlider.sliderMoved.connect(MainWindow.setPosition)
236 |
237 | def speaker_model(self):
238 | self.textBrowser.clear()
239 | self.textBrowser.append('Speaker model')
240 | self.state_speaker = True
241 | self.state_language = False
242 |
243 | def language_model(self):
244 | self.textBrowser.clear()
245 | self.textBrowser.append('Language model')
246 | self.state_speaker = False
247 | self.state_language = True
248 |
249 | def nn_model(self):
250 | self._load_speaker_model(model_type='nn')
251 |
252 | def gru_model(self):
253 | self._load_speaker_model(model_type='gru')
254 |
255 | def lstm_model(self):
256 | self._load_speaker_model(model_type='lstm')
257 |
258 | def _load_speaker_model(self, model_type='lstm'):
259 | if self.state_speaker:
260 | self.textBrowser.clear()
261 | self.textBrowser.append('Loading {} d-vector model'.format(model_type))
262 | self.model_type = model_type
263 | self.model = load_model('feature/d_vector/d_vector_{}.h5'.format(model_type))
264 | self.textBrowser.append('finished!')
265 |
266 | def mfcc_fea(self):
267 | self._load_language_model(feature_type='MFCC')
268 |
269 | def plp_fea(self):
270 | self._load_language_model(feature_type='PLP')
271 |
272 | def mfcc_plp_fea(self):
273 | self._load_language_model(feature_type='MFCC_PLP')
274 |
275 | def _load_language_model(self, feature_type='MFCC'):
276 | if self.state_language:
277 | self.textBrowser.clear()
278 | self.textBrowser.append('Loading GMM-UBM {} model'.format(feature_type))
279 | self.feature_type = feature_type
280 | with open("feature/language/GMM_8_" + feature_type + "_model.pkl", 'rb') as f:
281 | self.GMM = pkl.load(f)
282 | with open("feature/language/UBM_8_" + feature_type + "_model.pkl", 'rb') as f:
283 | self.UBM = pkl.load(f)
284 |
285 | self.textBrowser.append('finished!')
286 |
287 | def test(self):
288 | if self.state_language:
289 | # 语种识别
290 | self._GMM_test()
291 | elif self.state_speaker:
292 | # 说话人识别
293 | self._dvector_test()
294 |
295 | def records(self):
296 | seconds, ok = QtWidgets.QInputDialog.getText(self, "records",
297 | "please input how much seconds:", QtWidgets.QLineEdit.Normal)
298 | self.textBrowser.append('{}s'.format(seconds))
299 | record(seconds=int(seconds))
300 | self.textBrowser.append('{} seconds record has completed.'.format(seconds))
301 | sample, audio = read(filename='test.wav')
302 | # print("the length of audio:", str(len(audio)))
303 | testone = []
304 | testone_feature = []
305 | sample = 16000
306 | for i in range(len(audio) // sample):
307 | testone.append(audio[i * sample:(i + 1) * sample])
308 | if self.state_language:
309 | for i in tqdm(range(len(testone))):
310 | try:
311 | _feature = None
312 | if self.feature_type == 'MFCC':
313 | _feature = mfcc(testone[i])[0]
314 | elif self.feature_type == 'PLP':
315 | _feature = plp(testone[i])[0]
316 | elif self.feature_type == 'MFCC_PLP':
317 | _feature1 = mfcc(testone[i])[0]
318 | _feature2 = plp(testone[i])[0]
319 | _feature = np.hstack((_feature1, _feature2))
320 |
321 | _feature = preprocessing.scale(_feature)
322 | except ValueError:
323 | continue
324 | testone_feature.append(_feature)
325 | os.remove('test.wav')
326 | self.gmm_feature = testone_feature
327 | elif self.state_speaker:
328 | for i in tqdm(range(len(testone))):
329 | try:
330 | # fs=8000会出错
331 | _feature = mfcc(testone[i])[0]
332 | except :
333 | continue
334 | testone_feature.append(_feature)
335 | self.d_vector_feature = testone_feature
336 |
337 | def _GMM_test(self):
338 | testone = self.gmm_feature
339 | pred = np.zeros((len(testone), len(self.GMM)))
340 | for i in range(len(self.GMM)):
341 | for j in range(len(testone)):
342 | pred[j, i] = self.GMM[i].score(testone[j]) - self.UBM.score(testone[j])
343 |
344 | prob = np.exp(pred.max(axis=1))/np.exp(pred).sum(axis=1)
345 | prob_str = []
346 | for p in prob:
347 | prob_str.append('{:.2%}'.format(p))
348 | print(prob)
349 | pred = pred.argmax(axis=1)
350 | res = []
351 | for i in range(pred.shape[0]):
352 | if pred[i]==0:
353 | res.append('Chinese')
354 | elif pred[i]==1:
355 | res.append('English')
356 | else:
357 | res.append('Japanese')
358 | self.textBrowser.append('The result is: ')
359 | self.textBrowser.append(' '.join(res))
360 | self.textBrowser.append(' '.join(prob_str))
361 |
362 | def _dvector_test(self):
363 | # d_vector_feature是一个列表,每一个元素存储一秒语音的特征。
364 | with open('feature/d_vector/d_vector.pkl', 'rb') as f:
365 | d_vector = pkl.load(f)
366 |
367 | pred = []
368 | target = np.array(self.d_vector_feature)
369 | for i in range(len(self.d_vector_feature)):
370 | # 根据对应的模型,调整输入格式
371 | if self.model_type=='lstm':
372 | pass
373 | elif self.model_type=='nn':
374 | target = target.reshape(target.shape[0], -1)
375 | else:
376 | target = target[:,:,:,np.newaxis]
377 |
378 | target = self.model.predict(target)
379 |
380 | # TODO 后续增加概率计算
381 | prob = []
382 | for i in range(target.shape[0]):
383 | min_distance = 1
384 | target_name = None
385 | distance_list = []
386 | for name in d_vector.keys():
387 | distance_list.append(cosine(target[i,:], d_vector[name]))
388 | if min_distance > distance_list[-1]:
389 | min_distance = distance_list[-1]
390 | target_name = name
391 | pred.append(target_name)
392 | distance = -np.array(distance_list)
393 | prob.append('{:.2%}'.format(np.exp(distance.max())/np.exp(distance).sum()))
394 | self.textBrowser.append('The result is :')
395 | self.textBrowser.append(' '.join(pred))
396 | self.textBrowser.append(' '.join(prob))
397 |
398 | def enroll(self):
399 | """
400 | 注册一个陌生人到库中,以字典形式保存
401 | :param X_train: 样本语音
402 | :param name: 该样本语音的人名,唯一标识,不可重复。
403 | :param model_name: 使用模型的名字,nn,lstm,gru
404 | :return: none
405 | """
406 | name, ok = QtWidgets.QInputDialog.getText(self, "Enroll",
407 | "please input your name:", QtWidgets.QLineEdit.Normal)
408 | self.textBrowser.append('your name is {}'.format(name))
409 | self.records()
410 | try:
411 | with open('feature/d_vector/d_vector.pkl', 'rb') as f:
412 | d_vector = pkl.load(f)
413 | except:
414 | d_vector = {}
415 |
416 | target = np.array(self.d_vector_feature)
417 | for i in range(len(self.d_vector_feature)):
418 | # 根据对应的模型,调整输入格式
419 | if self.model_type=='lstm':
420 | pass
421 | elif self.model_type=='nn':
422 | target = target.reshape(target.shape[0], -1)
423 | else:
424 | target = target[:,:,:,np.newaxis]
425 |
426 | target = self.model.predict(target)
427 | print(target.shape)
428 | if name in d_vector:
429 | self.textBrowser.append('your name is already exist')
430 | d_vector[name] = (d_vector[name] + target.mean(axis=0))/2
431 | else:
432 | d_vector[name] = target.mean(axis=0)
433 |
434 | with open('feature/d_vector/d_vector.pkl', 'wb') as f:
435 | pkl.dump(d_vector, f)
436 |
437 | self.textBrowser.append('Finished')
--------------------------------------------------------------------------------
/VAD.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/3/13 13:10
4 | # @Author : chuyu zhang
5 | # @File : VAD.py
6 | # @Software: PyCharm
7 |
8 | import math
9 | import numpy as np
10 | from scipy.io import loadmat
11 | from utils.tools import wave_read, read, plot_confusion_matrix
12 | import glob
13 | import time
14 | import matplotlib.pyplot as plt
15 | from sklearn.metrics import confusion_matrix
16 | import seaborn as sns
17 | from sklearn import metrics
18 | from bayes_opt import BayesianOptimization
19 | import pickle as pkl
20 | from scipy import signal
21 | # 计算每一帧的过零率
22 |
23 | frameSize = 256
24 | overlap = 128
25 |
26 | # 分帧处理函数
27 | # 不加窗
28 | def enframe(wavData):
29 | """
30 | frame the wav data, according to frameSize and overlap
31 | :param wavData: the input wav data, ndarray
32 | :return:frameData, shape
33 | """
34 | # coef = 0.97 # 预加重系数
35 | wlen = wavData.shape[0]
36 | step = frameSize - overlap
37 | frameNum:int = math.ceil(wlen / step)
38 | frameData = np.zeros((frameSize, frameNum))
39 |
40 | # hamwin = np.hamming(frameSize)
41 |
42 | for i in range(frameNum):
43 | singleFrame = wavData[np.arange(i * step, min(i * step + frameSize, wlen))]
44 | # b, a = signal.butter(8, 1, 'lowpass')
45 | # filtedData = signal.filtfilt(b, a, data)
46 | # singleFrame = np.append(singleFrame[0], singleFrame[:-1] - coef * singleFrame[1:]) # 预加重
47 | frameData[:len(singleFrame), i] = singleFrame.reshape(-1, 1)[:, 0]
48 | # frameData[:, i] = hamwin * frameData[:, i] # 加窗
49 |
50 | return frameData
51 |
52 |
53 | def ZCR(frameData):
54 | frameNum = frameData.shape[1]
55 | frameSize = frameData.shape[0]
56 | zcr = np.zeros((frameNum, 1))
57 |
58 | for i in range(frameNum):
59 | singleFrame = frameData[:, i]
60 | temp = singleFrame[:frameSize-1] * singleFrame[1:frameSize]
61 | temp = np.sign(temp)
62 | zcr[i] = np.sum(temp<0)
63 |
64 | return zcr
65 |
66 | # 计算每一帧能量
67 | def energy(frameData):
68 | frameNum = frameData.shape[1]
69 |
70 | frame_energy = np.zeros((frameNum, 1))
71 |
72 | for i in range(frameNum):
73 | single_frame = frameData[:, i]
74 | frame_energy[i] = sum(single_frame * single_frame)
75 |
76 | return frame_energy
77 |
78 |
79 | def stSpectralEntropy(X, n_short_blocks=10, eps=1e-8):
80 | """Computes the spectral entropy"""
81 | L = len(X) # number of frame samples
82 | Eol = np.sum(X ** 2) # total spectral energy
83 |
84 | sub_win_len = int(np.floor(L / n_short_blocks)) # length of sub-frame
85 | if L != sub_win_len * n_short_blocks:
86 | X = X[0:sub_win_len * n_short_blocks]
87 |
88 | sub_wins = X.reshape(sub_win_len, n_short_blocks, order='F').copy() # define sub-frames (using matrix reshape)
89 | s = np.sum(sub_wins ** 2, axis=0) / (Eol + eps) # compute spectral sub-energies
90 | En = -np.sum(s*np.log2(s + eps)) # compute spectral entropy
91 |
92 | return En
93 |
94 |
95 | def spectrum_entropy(frameData):
96 | frameNum = frameData.shape[1]
97 |
98 | frame_spectrum_entropy = np.zeros((frameNum, 1))
99 |
100 | for i in range(frameNum):
101 | X = np.fft.fft(frameData[:, i])
102 | X = np.abs(X)
103 | frame_spectrum_entropy[i] = stSpectralEntropy(X[:int(frameSize/2)])
104 |
105 | return frame_spectrum_entropy
106 |
107 |
108 | def feature(waveData):
109 | # print("feature extract !")
110 | start = time.time()
111 | power = energy(waveData)
112 | zcr = ZCR(waveData) * (power>0.1)
113 | end = time.time()
114 | spectrumentropy = spectrum_entropy(waveData)
115 |
116 | print('feature extract completed, time feature spend {}s, frequency domain spend {}s'.
117 | format(end-start, time.time() - end))
118 |
119 | return zcr, power, spectrumentropy
120 |
121 |
122 | # framesize为帧长,overlap为帧移
123 | def wavdata(wavfile):
124 | f = wave_read(wavfile)
125 | params = f.getparams()
126 | nchannels, sampwidth, framerate, nframes = params[:4]
127 | strData = f.readframes(nframes) # 读取音频,字符串格式
128 | # print(type(strData))
129 | waveData = np.fromstring(strData, dtype=np.int16)
130 | # print(waveData.shape)
131 | waveData = waveData/(max(abs(waveData)))
132 | return enframe(waveData)
133 |
134 |
135 | # 首先判断能量,如果能量低于ampl,则认为是噪音(静音),如果能量高于amph则认为是语音,如果能量处于两者之前则认为是清音。
136 | def VAD_detection(zcr, power, zcr_gate=35, ampl=0.3, amph=12):
137 | # 最短语音帧数
138 | min_len = 16
139 | # 两段语音间的最短间隔
140 | min_distance = 21
141 | # 标记量,status:0为静音状态,1为清音状态,2为浊音状态
142 | status = 0
143 | # speech = 0
144 | start = 0
145 | end = 0
146 | last_end = -1
147 |
148 | res = np.zeros((zcr.shape[0], 1))
149 |
150 | for i in range(zcr.shape[0]):
151 | if power[i] > amph:
152 | # 此处是浊音状态,记录end即可
153 | if status != 1:
154 | start = i
155 |
156 | end = i
157 | status = 1
158 | # print(start - end)
159 | elif end - start + 1 > min_len:
160 |
161 | while(power[start] > ampl or zcr[start] > zcr_gate):
162 | start -= 1
163 |
164 | start += 1
165 |
166 | while(power[end] > ampl or zcr[end] > zcr_gate):
167 | end += 1
168 | if end == power.shape[0]:
169 | break
170 |
171 | end -= 1
172 | if last_end > 0 and start - last_end < min_distance:
173 | res[last_end : end + 1] = 1
174 | last_end = end
175 | else:
176 | res[start: end + 1] = 1
177 |
178 | start = 0
179 | end = 0
180 | status = 0
181 |
182 | return res
183 |
184 |
185 | def VAD_frequency(spectrum):
186 | return np.where(spectrum>0.4, 0, 1)
187 |
188 |
189 | def optimize(X, y):
190 | zcr, power, spectrumentropy = feature(X)
191 | """
192 | sns.distplot(zcr)
193 | plt.show()
194 | sns.distplot(power)
195 | plt.show()
196 | """
197 | params ={
198 | 'zcr_gate': (20, 40),
199 | 'ampl': (0.3, 4),
200 | 'amph': (5, 15)
201 | }
202 | y = y.reshape(1, -1)
203 |
204 | def cv(zcr_gate, ampl, amph):
205 | res = VAD_detection(zcr, power, zcr_gate=zcr_gate, amph=amph, ampl=ampl)
206 | # print((res==0).sum()/res.shape[0])
207 | res = res.reshape(1, -1)
208 | # metrics.precision_score(y[0], res[0])
209 | # accuracy = (y == res).sum() / y.shape[0]
210 | return metrics.f1_score(y[0], res[0])
211 |
212 | BO = BayesianOptimization(cv, params)
213 |
214 | start_time = time.time()
215 | BO.maximize(n_iter=30)
216 | end_time = time.time()
217 | print("Final result:{}, spend {}s".format(BO.max, end_time - start_time))
218 | best_params = BO.max['params']
219 |
220 | return best_params
221 |
222 | # 处理mat文件,统计一个帧数中静音和语音的数量,给这个帧数一个label,具体规则后续完善
223 | def label(mat_file):
224 | mat = loadmat(mat_file)
225 | y_label = mat['y_label']
226 | y_label = enframe(y_label)
227 |
228 | return np.where(y_label.sum(axis=0) > 0, 1, 0)
229 |
230 |
231 | def main(wav, mat):
232 | start = time.time()
233 | print(wav.split('\\')[-1])
234 | data = wavdata(wav)
235 | y_label = label(mat)
236 | y_label = y_label.reshape(-1, 1)
237 | zcr, power, spectrumentropy = feature(data)
238 |
239 | s1 = time.time()
240 | res1 = VAD_detection(zcr, power)
241 | s2 = time.time()
242 | # res1 = res1.reshape(1, -1)
243 | res2 = VAD_frequency(spectrumentropy)
244 | end = time.time()
245 | # res2 = res2.reshape(1, -1)
246 | td = metrics.f1_score(y_label, res1)
247 | fd = metrics.f1_score(y_label, res2)
248 | np.set_printoptions(precision=2)
249 |
250 | plot_confusion_matrix(y_label.tolist(), res1.tolist(), classes=['silence', 'speech'], normalize=False)
251 | plt.savefig(wav.split('\\')[-1].split('.')[0] + '.png')
252 | plt.show()
253 | end = time.time()
254 | print('time domain res:{:.2%}, spend {}s, frequency domain res:{:.2%}, spend {}s, '
255 | 'spend {}s totally'.format(td, s2-s1, fd, end-s2, end - start))
256 |
257 |
258 | if __name__=='__main__':
259 | wavfile = glob.glob(r'dataset\VAD\*.wav')
260 | matfile = glob.glob(r'dataset\VAD\*.mat')
261 |
262 | for wav, mat in zip(wavfile, matfile):
263 | main(wav, mat)
--------------------------------------------------------------------------------
/d_vector.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/5/15 22:56
4 | # @Author : chuyu zhang
5 | # @File : d_vector.py
6 | # @Software: PyCharm
7 |
8 | import os
9 | import pickle as pkl
10 | import numpy as np
11 | from utils.tools import read, get_time
12 | from tqdm import tqdm
13 | from sklearn import preprocessing
14 | from sklearn.model_selection import train_test_split
15 | from scipy.spatial.distance import cosine
16 | from keras.models import Model,Sequential,load_model
17 |
18 | from sidekit.frontend.features import plp,mfcc
19 |
20 | from keras.layers import Dense, Activation, Dropout, Input, GRU, LSTM, Flatten,Convolution2D, MaxPooling2D,Convolution1D
21 | from keras.optimizers import Adam
22 | from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint,CSVLogger
23 | from keras import regularizers
24 | from keras.layers.core import Reshape,Masking,Lambda,Permute
25 | import keras.backend as K
26 | from keras.layers.wrappers import TimeDistributed
27 | import keras
28 |
29 | class Data_gen:
30 | # 生成数据
31 | def __init__(self):
32 | pass
33 |
34 | def _load(self):
35 | """
36 | load audio file.
37 | :param path: the dir to audio file
38 | :return: x type:list,each element is an audio, y type:list,it is the label of x
39 | """
40 | start_time = get_time()
41 | path = self.path
42 | print("Loading data...")
43 | speaker_list = os.listdir(path)
44 | y = []
45 | x = []
46 | for speaker in tqdm(speaker_list):
47 | path1 = os.path.join(path, speaker)
48 | for _dir in os.listdir(path1):
49 | path2 = os.path.join(path1, _dir)
50 | for _wav in os.listdir(path2):
51 | self.sample_rate, audio = read(os.path.join(path2, _wav))
52 | y.append(speaker)
53 | # sample rate is 16000, you can down sample it to 8000, but the result will be bad.
54 | x.append(audio)
55 |
56 | print("Complete! Spend {:.2f}s".format(get_time(start_time)))
57 | return x, y
58 |
59 | def extract_feature(self, feature_type='MFCC', datatype='dev'):
60 | """
61 | extract feature from x
62 | :param x: type list, each element is audio
63 | :param y: type list, each element is label of audio in x
64 | :param filepath: the path to save feature
65 | :param is_train: if true, generate train_data(type dict, key is lable, value is feature),
66 | if false, just extract feature from x
67 | :return:
68 | """
69 | start_time = get_time()
70 | if not os.path.exists('feature'):
71 | os.mkdir('feature')
72 |
73 | if not os.path.exists('feature/{}_{}_feature.pkl'.format(datatype, feature_type)):
74 | x, y = self._load()
75 | print("Extract {} feature...".format(feature_type))
76 | feature = []
77 | label = []
78 | new_x = []
79 | new_y = []
80 | for i in range(len(x)):
81 | for j in range(x[i].shape[0]//self.sample_rate):
82 | new_x.append(x[i][j*self.sample_rate:(j+1)*self.sample_rate])
83 | new_y.append(y[i])
84 |
85 | x = new_x
86 | y = new_y
87 | for i in tqdm(range(len(x))):
88 | # 这里MFCC和PLP默认是16000Hz,注意修改
89 | # mfcc 25ms窗长,10ms重叠
90 | if feature_type == 'MFCC':
91 | _feature = mfcc(x[i], fs=self.sample_rate)[0]
92 | elif feature_type == 'PLP':
93 | _feature = plp(x[i], fs=self.sample_rate)[0]
94 | else:
95 | raise NameError
96 | # 特征出了问题,存在一些无穷大,导致整个网络的梯度爆炸了,需要特殊处理才行
97 | if np.isnan(_feature).sum()>0:
98 | continue
99 | # _feature = np.concatenate([_feature,self.delta(_feature)],axis=1)
100 | # _feature = preprocessing.scale(_feature)
101 | # _feature = preprocessing.StandardScaler().fit_transform(_feature)
102 | # 每2*num为一个输入,并且重叠num
103 | feature.append(_feature)
104 | label.append(y[i])
105 |
106 | print(len(feature), feature[0].shape)
107 | self.save(feature, '{}_{}_feature'.format(datatype, feature_type))
108 | self.save(label, '{}_{}_label'.format(datatype, feature_type))
109 |
110 | else:
111 | feature = self.load('{}_{}_feature'.format(datatype, feature_type))
112 | label = self.load('{}_{}_label'.format(datatype, feature_type))
113 |
114 | print("Complete! Spend {:.2f}s".format(get_time(start_time)))
115 | return feature, label
116 |
117 | def load_data(self, path='dataset/ASR_GMM_big', reshape=True, test_size=0.3,datatype='dev'):
118 | self.path = path
119 | feature, label = data_gen.extract_feature(datatype=datatype)
120 | feature = np.array(feature)
121 | # 由于是全连接,故需要reshape,如果是卷积或者rnn系列,就不需要reshape
122 | if reshape:
123 | feature = feature.reshape(feature.shape[0], -1)
124 |
125 | label = np.array(label).reshape(-1, 1)
126 | self.enc = preprocessing.OneHotEncoder()
127 | label = self.enc.fit_transform(label).toarray()
128 |
129 | X_train, X_val, y_train, y_val = train_test_split(feature, label, shuffle=True, test_size=test_size,
130 | random_state=2019)
131 | return X_train, X_val, y_train, y_val
132 |
133 | @staticmethod
134 | def save(data, file_name):
135 | with open('feature/{}.pkl'.format(file_name), 'wb') as f:
136 | pkl.dump(data, f)
137 |
138 | @staticmethod
139 | def load(file_name):
140 | with open('feature/{}.pkl'.format(file_name), 'rb') as f:
141 | return pkl.load(f)
142 |
143 | @staticmethod
144 | def delta(feat, N=2):
145 | """Compute delta features from a feature vector sequence.
146 | :param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector.
147 | :param N: For each frame, calculate delta features based on preceding and following N frames
148 | :returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector.
149 | """
150 | if N < 1:
151 | raise ValueError('N must be an integer >= 1')
152 | NUMFRAMES = len(feat)
153 | denominator = 2 * sum([i ** 2 for i in range(1, N + 1)])
154 | delta_feat = np.empty_like(feat)
155 | # padded version of feat
156 | padded = np.pad(feat, ((N, N), (0, 0)), mode='edge')
157 | for t in range(NUMFRAMES):
158 | # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1]
159 | delta_feat[t] = np.dot(np.arange(-N, N + 1), padded[t: t + 2 * N + 1]) / denominator
160 | return delta_feat
161 |
162 |
163 | class nn_model:
164 | # TODO 用一个更大的数据集训练一个背景模型,然后利用这个背景模型直接得到新数据的d-vector
165 | def __init__(self, n_class=40):
166 | self.n_class = n_class
167 |
168 | def inference(self, X_train, Y_train, X_val, Y_val):
169 | #需要修改input_shape等一些参数
170 | print("Training model")
171 | model = Sequential()
172 |
173 | model.add(Dense(256, input_shape=(X_train.shape[1],), name="dense1"))
174 | model.add(Activation('relu', name="activation1"))
175 | model.add(Dropout(rate=0, name="drop1"))
176 |
177 | model.add(Dense(256, name="dense2"))
178 | model.add(Activation('relu', name="activation2"))
179 | model.add(Dropout(rate=0, name="drop2"))
180 |
181 | model.add(Dense(256, name="dense3"))
182 | model.add(Activation('relu', name="activation3"))
183 | model.add(Dropout(rate=0.5, name="drop3"))
184 |
185 | model.add(Dense(256, name="dense4"))
186 |
187 | modelInput = Input(shape=(X_train.shape[1],))
188 | features = model(modelInput)
189 | spkModel = Model(inputs=modelInput, outputs=features)
190 |
191 | model1 = Activation('relu')(features)
192 | model1 = Dropout(rate=0.5)(model1)
193 |
194 | model1 = Dense(self.n_class, activation='softmax')(model1)
195 |
196 | spk = Model(inputs=modelInput, outputs=model1)
197 |
198 | sgd = Adam(lr=1e-4)
199 | # early_stopping = EarlyStopping(monitor='val_loss', patience=4)
200 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-7)
201 | csv_logger = CSVLogger('feature/d_vector/nn_training.log')
202 |
203 | spk.compile(loss='categorical_crossentropy',optimizer=sgd, metrics=['accuracy'])
204 |
205 | spk.fit(X_train, Y_train, batch_size = 128, epochs=50, validation_data = (X_val, Y_val),
206 | callbacks=[reduce_lr, csv_logger])
207 |
208 | if not os.path.exists('feature/d_vector'):
209 | os.mkdir('feature/d_vector')
210 | spkModel.save('feature/d_vector/d_vector_nn.h5')
211 |
212 |
213 | def inference_gru(self, X_train, Y_train, X_val, Y_val):
214 | model = Sequential()
215 |
216 | model.add(Convolution2D(64, (5, 5),
217 | padding='same',
218 | strides=(2, 2),
219 | input_shape=(X_train.shape[1], X_train.shape[2], 1), name="cov1",
220 | data_format="channels_last",
221 | kernel_regularizer=keras.regularizers.l2()))
222 |
223 | # 将输入的维度按照给定模式进行重排
224 | # model.add(Permute((2, 1, 3), name='permute'))
225 | # 该包装器可以把一个层应用到输入的每一个时间步上,GRU需要
226 | model.add(TimeDistributed(Flatten(), name='timedistrib'))
227 |
228 | # 三层GRU
229 | model.add(GRU(units=1024, return_sequences=True, name="gru1"))
230 | model.add(GRU(units=1024, return_sequences=True, name="gru2"))
231 | model.add(GRU(units=1024, return_sequences=True, name="gru3"))
232 |
233 | # temporal average
234 | def temporalAverage(x):
235 | return K.mean(x, axis=1)
236 |
237 | model.add(Lambda(temporalAverage, name="temporal_average"))
238 |
239 | # affine
240 | model.add(Dense(units=512, name="dense1"))
241 |
242 | # length normalization
243 | def lengthNormalization(x):
244 | return K.l2_normalize(x, axis=-1)
245 |
246 | model.add(Lambda(lengthNormalization, name="ln"))
247 |
248 | modelInput = Input(shape=(X_train.shape[1], X_train.shape[2], 1))
249 | features = model(modelInput)
250 | spkModel = Model(inputs=modelInput, outputs=features)
251 |
252 | model1 = Dense(self.n_class, activation='softmax',name="dense2")(features)
253 |
254 | spk = Model(inputs=modelInput, outputs=model1)
255 |
256 | sgd = Adam(lr=1e-4)
257 |
258 | spk.compile(loss='categorical_crossentropy',
259 | optimizer=sgd, metrics=['accuracy'])
260 |
261 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-7)
262 | csv_logger = CSVLogger('feature/d_vector/gru_training.log')
263 |
264 | spk.fit(X_train, Y_train, batch_size = 128, epochs=50, validation_data = (X_val, Y_val),
265 | callbacks=[reduce_lr, csv_logger])
266 |
267 | if not os.path.exists('feature/d_vector'):
268 | os.mkdir('feature/d_vector')
269 | spkModel.save('feature/d_vector/d_vector_gru.h5')
270 |
271 | def inference_lstm(self, X_train, Y_train, X_val, Y_val):
272 | model = Sequential()
273 |
274 | model.add(LSTM(128, input_shape=(X_train.shape[1],X_train.shape[2])))
275 | modelInput = Input(shape=(X_train.shape[1],X_train.shape[2]))
276 | features = model(modelInput)
277 | spkModel = Model(inputs=modelInput, outputs=features)
278 | model1 = Dense(self.n_class, activation='softmax',name="dense1")(features)
279 | spk = Model(inputs=modelInput, outputs=model1)
280 |
281 | sgd = Adam(lr=1e-4)
282 |
283 | spk.compile(loss='categorical_crossentropy',
284 | optimizer=sgd, metrics=['accuracy'])
285 |
286 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-7)
287 | csv_logger = CSVLogger('feature/d_vector/lstm_training.log')
288 |
289 | spk.fit(X_train, Y_train, batch_size = 128, epochs=50, validation_data = (X_val, Y_val),
290 | callbacks=[reduce_lr, csv_logger])
291 |
292 | if not os.path.exists('feature/d_vector'):
293 | os.mkdir('feature/d_vector')
294 | spkModel.save('feature/d_vector/d_vector_lstm.h5')
295 |
296 | def test(self, X_train, Y_train, X_val, Y_val, model_name='nn'):
297 | spkModel = load_model('feature/d_vector/d_vector_{}.h5'.format(model_name))
298 | print(X_train.shape)
299 | X_train = spkModel.predict(X_train)
300 | X_val = spkModel.predict(X_val)
301 | # num为测试集中人数
302 | num = Y_train.shape[1]
303 | # 对同一个人的d-vector取平均,得到avg,作为这个人的模板储存起来。
304 | with open('feature/d_vector/d_vector_lstm.pkl', 'wb') as f:
305 | pkl.dump(X_train, f)
306 |
307 | with open('feature/d_vector/d_vector_lstm_y.pkl', 'wb') as f:
308 | pkl.dump(Y_train, f)
309 |
310 | avg = np.zeros((num, X_train.shape[1]))
311 | print(X_train.shape[1])
312 | for i in range(num):
313 | avg[i,:] = X_train[np.argmax(Y_train, axis=1)==i].mean(axis=0)
314 |
315 | distance = np.zeros((X_val.shape[0], num))
316 | for i in range(X_val.shape[0]):
317 | for j in range(num):
318 | distance[i, j] = cosine(X_val[i], avg[j])
319 | acc = (np.argmax(Y_val, axis=1)==np.argmin(distance, axis=1)).sum()/X_val.shape[0]
320 | return acc
321 |
322 | def enroll(self, X_train, name, model_name='lstm'):
323 | """
324 | 注册一个陌生人到库中,以字典形式保存
325 | :param X_train: 样本语音
326 | :param name: 该样本语音的人名,唯一标识,不可重复。
327 | :param model_name: 使用模型的名字,nn,lstm,gru
328 | :return: none
329 | """
330 | spkModel = load_model('feature/d_vector/d_vector_{}.h5'.format(model_name))
331 | X_train = spkModel.predict(X_train)
332 | avg = X_train.mean(axis=0)
333 | try:
334 | with open('feature/d_vector/d_vector.pkl', 'rb') as f:
335 | d_vector = pkl.load(f)
336 | except:
337 | d_vector = {}
338 |
339 | if name in d_vector:
340 | print("样本已经存在!!")
341 | d_vector[name] = avg
342 |
343 | with open('feature/d_vector/d_vector.pkl', 'wb') as f:
344 | pkl.dump(d_vector, f)
345 |
346 | def eval(self, target, model_name='lstm'):
347 | spkModel = load_model('feature/d_vector/d_vector_{}.h5'.format(model_name))
348 | target = spkModel.predict(target)
349 | with open('feature/d_vector/d_vector.pkl', 'rb') as f:
350 | d_vector = pkl.load(f)
351 |
352 | min_distance = 1
353 | target_name = None
354 | distance_list = []
355 | for name in d_vector.keys():
356 | distance_list.append(cosine(target, d_vector[name]))
357 | if min_distance > distance_list[-1]:
358 | min_distance = distance_list[-1]
359 | target_name = name
360 |
361 | return target_name
362 |
363 |
364 | if __name__=="__main__":
365 | data_gen = Data_gen()
366 | train_bm = False
367 | model = nn_model()
368 | if train_bm:
369 | # 训练背景模型
370 | X_train, X_val, y_train, y_val = data_gen.load_data(reshape=False)
371 | # X_train = X_train[:, :, :,np.newaxis]
372 | # X_val = X_val[:, :, :,np.newaxis]
373 | model.inference_lstm(X_train, y_train, X_val, y_val)
374 |
375 | # X_train, X_val, y_train, y_val = data_gen.load_data(test_size=0.3,datatype='test',reshape=False)
376 | with open('feature/d_vector/MFCC_feature.pkl', 'rb') as f:
377 | feature = pkl.load(f)
378 |
379 | with open('feature/d_vector/MFCC_label.pkl', 'rb') as f:
380 | label = pkl.load(f)
381 | feature = np.array(feature)
382 | label = np.array(label).reshape(-1, 1)
383 | enc = preprocessing.OneHotEncoder()
384 | label = enc.fit_transform(label).toarray()
385 |
386 | X_train, X_val, y_train, y_val = train_test_split(feature, label, shuffle=True, test_size=0.1,
387 | random_state=2019)
388 | start_time = get_time()
389 | acc = model.test(X_train[:, :, :,np.newaxis], y_train, X_val[:, :, :,np.newaxis], y_val, model_name='lstm_conv')
390 | print(get_time(start_time))
391 | print(acc)
--------------------------------------------------------------------------------
/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/demo.mp4
--------------------------------------------------------------------------------
/report/First.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/First.pdf
--------------------------------------------------------------------------------
/report/GMM_UBM.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/GMM_UBM.pdf
--------------------------------------------------------------------------------
/report/final.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/final.pdf
--------------------------------------------------------------------------------
/report/mfcc报告.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/report/mfcc报告.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dtw
2 | fastdtw
3 | tqdm
4 | librosa
5 | sidekit
6 | tensorflow
7 | keras
8 | numpy
9 | scipy
10 | pyqt5
11 | sklearn
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/2/26 15:46
4 | # @Author : chuyu zhang
5 | # @File : __init__.py.py
6 | # @Software: PyCharm
--------------------------------------------------------------------------------
/utils/processing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/3/23 22:18
4 | # @Author : chuyu zhang
5 | # @File : processing.py
6 | # @Software: PyCharm
7 |
8 | import numpy as np
9 | import math
10 | from scipy import signal
11 | from scipy.io import wavfile
12 | # from utils.tools import read,play
13 | from scipy.fftpack.realtransforms import dct
14 | from scipy.fftpack import fft
15 | import matplotlib.pyplot as plt
16 |
17 | eps = 1e-8
18 | ## 语音的预处理函数
19 | def enframe(wavData, frameSize=400, step=160):
20 | """
21 | frame the wav data, according to frameSize and overlap
22 | :param wavData: the input wav data, ndarray
23 | :return:frameData, shape
24 | """
25 | coef = 0.97
26 | wlen = wavData.shape[0]
27 | frameNum = math.ceil(wlen / step)
28 | frameData = np.zeros((frameSize, frameNum))
29 |
30 | window = signal.windows.hamming(frameSize)
31 |
32 | for i in range(frameNum):
33 | singleFrame = wavData[i * step : min(i * step + frameSize, wlen)]
34 | # singleFrame[1:] = singleFrame[:-1] - coef * singleFrame[1:]
35 | frameData[:len(singleFrame), i] = singleFrame
36 | frameData[:, i] = window*frameData[:, i]
37 |
38 | return frameData
39 |
40 |
41 | # frequency domain feature
42 | def mfccInitFilterBanks(fs, nfft):
43 | """
44 | Computes the triangular filterbank for MFCC computation
45 | (used in the stFeatureExtraction function before the stMFCC function call)
46 | This function is taken from the scikits.talkbox library (MIT Licence):
47 | https://pypi.python.org/pypi/scikits.talkbox
48 | """
49 | # filter bank params:
50 | lowfreq = 133.33
51 | linsc = 200/3.
52 | logsc = 1.0711703
53 | numLinFiltTotal = 13
54 | numLogFilt = 27
55 |
56 | if fs < 8000:
57 | nlogfil = 5
58 |
59 | # Total number of filters
60 | nFiltTotal = numLinFiltTotal + numLogFilt
61 |
62 | # Compute frequency points of the triangle:
63 | freqs = np.zeros(nFiltTotal+2)
64 | freqs[:numLinFiltTotal] = lowfreq + np.arange(numLinFiltTotal) * linsc
65 | freqs[numLinFiltTotal:] = freqs[numLinFiltTotal-1] * logsc ** np.arange(1, numLogFilt + 3)
66 | heights = 2./(freqs[2:] - freqs[0:-2])
67 |
68 | # Compute filterbank coeff (in fft domain, in bins)
69 | fbank = np.zeros((nFiltTotal, nfft))
70 | nfreqs = np.arange(nfft) / (1. * nfft) * fs
71 |
72 | for i in range(nFiltTotal):
73 | lowTrFreq = freqs[i]
74 | cenTrFreq = freqs[i+1]
75 | highTrFreq = freqs[i+2]
76 |
77 | lid = np.arange(np.floor(lowTrFreq * nfft / fs) + 1,
78 | np.floor(cenTrFreq * nfft / fs) + 1,
79 | dtype=np.int)
80 | lslope = heights[i] / (cenTrFreq - lowTrFreq)
81 | rid = np.arange(np.floor(cenTrFreq * nfft / fs) + 1,
82 | np.floor(highTrFreq * nfft / fs) + 1,
83 | dtype=np.int)
84 | rslope = heights[i] / (highTrFreq - cenTrFreq)
85 | fbank[i][lid] = lslope * (nfreqs[lid] - lowTrFreq)
86 | fbank[i][rid] = rslope * (highTrFreq - nfreqs[rid])
87 |
88 | return fbank, freqs
89 |
90 |
91 | def stMFCC(X, fbank, n_mfcc_feats):
92 | """
93 | Computes the MFCCs of a frame, given the fft mag
94 | ARGUMENTS:
95 | X: fft magnitude abs(FFT)
96 | fbank: filter bank (see mfccInitFilterBanks)
97 | RETURN
98 | ceps: MFCCs (13 element vector)
99 | Note: MFCC calculation is, in general, taken from the
100 | scikits.talkbox library (MIT Licence),
101 | # with a small number of modifications to make it more
102 | compact and suitable for the pyAudioAnalysis Lib
103 | """
104 |
105 | mspec = np.log10(np.dot(X, fbank.T)+eps)
106 | ceps = dct(mspec, type=2, norm='ortho', axis=-1)[:n_mfcc_feats]
107 | return ceps
108 |
109 |
110 | def MFCC(raw_signal, fs=8000, frameSize=512, step=256):
111 | """
112 | extract mfcc feature
113 | :param raw_signal: the original audio signal
114 | :param fs: sample frequency
115 | :param frameSize:the size of each frame
116 | :param step:
117 | :return: a series of mfcc feature of each frame and flatten to (num, )
118 | """
119 | # Signal normalization
120 |
121 | """
122 | raw_signal = np.double(raw_signal)
123 |
124 | raw_signal = raw_signal / (2.0 ** 15)
125 | DC = raw_signal.mean()
126 | MAX = (np.abs(raw_signal)).max()
127 | raw_signal = (raw_signal - DC) / (MAX + eps)
128 | """
129 | nFFT = int(frameSize)
130 | [fbank, freqs] = mfccInitFilterBanks(fs, nFFT)
131 | n_mfcc_feats = 13
132 |
133 | signal = enframe(raw_signal, frameSize, step)
134 | feature = []
135 | for frame in range(signal.shape[1]):
136 | x = signal[:, frame]
137 | X = abs(fft(x)) # get fft magnitude
138 | X = X[0:nFFT] # normalize fft
139 | X = X / len(X)
140 | feature.append(stMFCC(X, fbank, n_mfcc_feats))
141 |
142 | feature = np.array(feature)
143 | # print(feature.shape)
144 | return feature
145 |
146 |
147 | def test():
148 | path = '../dataset/ASR/zcy/zcy1.wav'
149 | frameSize = 400
150 | nFFT = int(frameSize/2)
151 | fs, audio = wavfile.read(path)
152 | audio = audio[:, 0]
153 |
154 | [fbank, freqs] = mfccInitFilterBanks(fs, nFFT)
155 | n_mfcc_feats = 13
156 | x = audio[4000:4000+frameSize]
157 | X = abs(fft(x)) # get fft magnitude
158 | X = X[0:nFFT] # normalize fft
159 | X = X / len(X)
160 | feature = stMFCC(X, fbank, n_mfcc_feats)
161 | plt.figure()
162 | plt.subplot(121)
163 | plt.plot(X)
164 | plt.subplot(122)
165 | plt.plot(feature)
166 | plt.show()
167 |
168 |
169 | if __name__=='__main__':
170 | test()
171 |
172 | """
173 | path = '../dataset/ASR/zcy/zcy1.wav'
174 | play(path)
175 | samprate, data = read(path)
176 | framedata = enframe(data[:,0])
177 | print(framedata.shape)
178 | """
179 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/2/26 15:46
4 | # @Author : chuyu zhang
5 | # @File : read.py
6 | # @Software: PyCharm
7 |
8 | import wave
9 | from scipy.io import wavfile
10 | from scipy.fftpack import fft, ifft
11 | # import math
12 | # from scipy import signal
13 | import numpy as np
14 | from pyaudio import PyAudio, paInt16
15 | import time
16 | from sklearn.metrics import confusion_matrix
17 | import matplotlib.pyplot as plt
18 | # import sounddevice as sd
19 | import simpleaudio as sa
20 |
21 | from utils.processing import enframe
22 |
23 | """
24 | framerate=8000
25 | NUM_SAMPLES=2000
26 | channels=1
27 | sampwidth=2
28 | TIME=2
29 | """
30 |
31 | def get_time(start_time=None):
32 | if start_time == None:
33 | return time.time()
34 | else:
35 | return time.time() - start_time
36 |
37 | ## 开发计划,后续添加,play(data),data是read函数的输出
38 | # read .wav file through wave, return a wave object
39 | def wave_read(filename='test.wav'):
40 | audio = wave.open(filename, mode='rb')
41 | return audio
42 |
43 |
44 | # read .wav file through scipy
45 | def read(filename='test.wav'):
46 | sampling_freq, audio = wavfile.read(filename)
47 | return sampling_freq, audio
48 |
49 |
50 | # 根据给定的参数保存.wav文件
51 | def save_wave_file(filename, data, channels=1, sampwidth=2, framerate=8000):
52 | '''save the data to the wavfile'''
53 | wf=wave.open(filename,'wb')
54 | wf.setnchannels(channels)#声道
55 | wf.setsampwidth(sampwidth)#采样字节 1 or 2
56 | wf.setframerate(framerate)#采样频率 8000 or 16000
57 | wf.writeframes(b"".join(data))
58 | wf.close()
59 |
60 |
61 | # num_samples这个参数的意义??
62 | def record(filename="test.wav",seconds=10,
63 | framerate=16000, format=paInt16, channels=1, num_samples=2000):
64 | p=PyAudio()
65 | stream=p.open(format = format, channels=channels, rate=framerate,
66 | input=True, frames_per_buffer=num_samples)
67 | my_buf=[]
68 | # 控制录音时间
69 | print("start the recording !")
70 | start = time.time()
71 | while time.time() - start < seconds:
72 | # 一次性录音采样字节大小
73 | string_audio_data = stream.read(num_samples)
74 | my_buf.append(string_audio_data)
75 |
76 | save_wave_file(filename, my_buf)
77 | stream.close()
78 | print("{} seconds record has completed.".format(seconds))
79 |
80 |
81 | def play(audio=None, sampling_freq=8000, filename=None, chunk=1024):
82 | start_time = get_time()
83 | if filename==None:
84 | # pass
85 | play_obj = sa.play_buffer(audio, 1, 2, sampling_freq)
86 | play_obj.wait_done()
87 |
88 | else:
89 | wf=wave.open(filename,'rb')
90 | p=PyAudio()
91 | stream=p.open(format=p.get_format_from_width(wf.getsampwidth()),
92 | channels=wf.getnchannels(),rate=wf.getframerate(),
93 | output=True)
94 | while True:
95 | data=wf.readframes(chunk)
96 | # char b is absolutely necessary. It represents that the str is byte.
97 | # For more detail, please refer to 3,4
98 | if data == b"":
99 | break
100 | stream.write(data)
101 |
102 | stream.stop_stream()
103 | stream.close()
104 | p.terminate()
105 | print("{} seconds".format(get_time(start_time)))
106 |
107 |
108 | def playback(filename='test.wav', silent=False):
109 | rate, audio = read(filename)
110 | new_filename = 'reverse_' + filename
111 | wavfile.write(new_filename, rate, audio[::-1])
112 | if not silent:
113 | play(filename=new_filename)
114 | print("complete!")
115 |
116 |
117 | def change_rate(filename='test.wav', new_rate=4000, silent=False):
118 | rate, audio = read(filename)
119 | print("the original frequent rate is {}".format(rate))
120 | new_filename = str(new_rate) + "_" + filename
121 | wavfile.write(new_filename, new_rate, audio)
122 | if not silent:
123 | play(filename=new_filename)
124 | print("complete !")
125 |
126 | def change_volume(filename='test.wav', volume_rate=1, silent=False):
127 | rate, audio = read(filename)
128 | print("change volume to {}".format(volume_rate))
129 | new_filename = str(volume_rate) + "_" + filename
130 | new_audio = (audio*volume_rate).astype('int16')
131 | # print(audio.dtype, new_audio.dtype)
132 | wavfile.write(new_filename, rate, new_audio)
133 | if not silent:
134 | play(filename=new_filename)
135 | print("complete !")
136 |
137 | # 定义合成音调
138 | def Synthetic_tone(freq, duration=2, amp=1000, sampling_freq=44100):
139 | # 建立时间轴
140 | # scaling_factor = pow(2, 15) - 1 # 转换为16位整型数
141 | t = np.linspace(0, duration, duration * sampling_freq)
142 | # 构建音频信号
143 | audio = amp * np.sin(2 * np.pi * freq * t)
144 | return audio.astype(np.int16)
145 |
146 |
147 | def simple_music(tone='A', duration=2, amplitude=10000, sampling_freq=44100):
148 | tone_freq_map = {'A': 440, 'Asharp': 466, 'B': 494, 'C': 523, 'Csharp': 554,
149 | 'D': 587, 'Dsharp': 622, 'E': 659, 'F': 698, 'Fsharp': 740,
150 | 'G': 784, 'Gsharp': 831}
151 |
152 | synthesized_tone = Synthetic_tone(tone_freq_map[tone], duration, amplitude, sampling_freq)
153 | wavfile.write('{}.wav'.format(tone), sampling_freq, synthesized_tone)
154 | play('{}.wav'.format(tone))
155 |
156 |
157 | def plot_confusion_matrix(y_true, y_pred, classes,
158 | normalize=False,
159 | title=None,
160 | cmap=plt.cm.Blues):
161 | """
162 | This function prints and plots the confusion matrix.
163 | Normalization can be applied by setting `normalize=True`.
164 | """
165 | if not title:
166 | if normalize:
167 | title = 'Normalized confusion matrix'
168 | else:
169 | title = 'Confusion matrix, without normalization'
170 |
171 | # Compute confusion matrix
172 | cm = confusion_matrix(y_true, y_pred)
173 | # Only use the labels that appear in the data
174 | if normalize:
175 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
176 | print("Normalized confusion matrix")
177 | else:
178 | print('Confusion matrix, without normalization')
179 |
180 | # print(cm)
181 |
182 | fig, ax = plt.subplots()
183 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
184 | ax.figure.colorbar(im, ax=ax)
185 | # We want to show all ticks...
186 | ax.set(xticks=np.arange(cm.shape[1]),
187 | yticks=np.arange(cm.shape[0]),
188 | # ... and label them with the respective list entries
189 | xticklabels=classes, yticklabels=classes,
190 | title=title,
191 | ylabel='True label',
192 | xlabel='Predicted label')
193 |
194 | # Rotate the tick labels and set their alignment.
195 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
196 | rotation_mode="anchor")
197 |
198 | # Loop over data dimensions and create text annotations.
199 | fmt = '.2f' if normalize else 'd'
200 | # thresh = cm.max() / 2.
201 | for i in range(cm.shape[0]):
202 | for j in range(cm.shape[1]):
203 | ax.text(j, i, format(cm[i, j], fmt),
204 | ha="center", va="center",
205 | color="black")
206 | fig.tight_layout()
207 | return ax
208 |
209 |
210 | def test(filename):
211 | framerate, audio = read(filename)
212 | play(audio, framerate)
213 | # print(audio.shape)
214 | audio_frame = enframe(audio[:,0], frameSize=512, step=256)
215 | audio_frame_new = np.zeros_like(audio_frame)
216 | for frame in range(audio_frame.shape[1]):
217 | # print(audio_frame[:, frame])
218 | audio_fft = fft(audio_frame[:, frame])
219 | # audio_fft_abs = abs(audio_fft)
220 | # angle = audio_fft.real/audio_fft_abs
221 | audio_fft_new = np.sqrt(0.8)*audio_fft
222 | audio_new = ifft(audio_fft_new)
223 | # print(audio_new)
224 | audio_frame_new[:, frame] = audio_new.real
225 | print((audio_new-audio_frame[:, frame]).sum())
226 | # break
227 | audio_new = audio_frame_new[:256, :].flatten('F')[:audio.shape[0]].reshape(-1, 1)
228 | audio_new = np.concatenate([audio_new, audio_new], axis=1)
229 | save_wave_file('zero_trans.wav', audio_new)
230 | print(audio_new.shape)
231 | # print(audio_new.astype(audio.dtype))
232 | # audio_fft_abs = np.abs(audio_fft)
233 | # angle = audio_fft.real()/audio_fft_abs
234 | # audio_fft_abs_new = audio_fft_abs*0.8
235 | # au
236 | play(audio_new.astype(audio.dtype), framerate)
237 | plt.figure()
238 | plt.plot(audio[:,0])
239 | plt.figure()
240 | plt.plot(audio_new[:,0])
241 | plt.show()
242 |
243 |
244 |
245 |
246 | if __name__=="__main__":
247 | test(filename='../dataset/ASR/test/zcy/zcy1.wav')
248 | """
249 | wave_read = wave.open('../dataset/ASR/test/zcy/zcy1.wav', 'rb')
250 | audio_data = wave_read.readframes(wave_read.getnframes())
251 | num_channels = wave_read.getnchannels()
252 | bytes_per_sample = wave_read.getsampwidth()
253 | sample_rate = wave_read.getframerate()
254 | print(type(audio_data))
255 | play_obj = sa.play_buffer(audio_data, num_channels, bytes_per_sample, sample_rate)
256 | """
257 |
258 | # record()
259 | # playback(filename='test.wav')
260 | # rate, audio = read(filename='01.wav')
261 | # print(rate, audio)
262 | # print(audio[::-1])
263 | # change_volume(volume_rate=1)
264 | """
265 | duration = 4
266 | music = 0.9*Synthetic_tone(freq=440, duration=duration) + \
267 | 0.75*Synthetic_tone(freq=880, duration=duration)
268 | wavfile.write('music.wav', 44100, music.astype('int16'))
269 | play(filename='music.wav')
270 |
271 | """
272 | # change_rate('../dataset/ASR/train/hyy/hyy1.wav', new_rate=8000)
273 | # simple_music(tone='Gsharp')
274 | # framerate, audio = read('../dataset/ASR/train/hyy/hyy1.wav')
275 | # downsample = audio[range(0, audio.shape[0], 2), 0]
276 | # save_wave_file('test.wav', downsample)
277 | # play(filename='test.wav')
278 | # play(downsample, 8000)
279 |
--------------------------------------------------------------------------------
/展示.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinzcy/speech_signal_processing/9d197cd3f1d9215cf57e992701b1529d46f242ef/展示.pptx
--------------------------------------------------------------------------------