├── 0_george_0.wav ├── Images ├── classification_report.png ├── formula.png ├── jan-huber-SqR_XkrwwPk-unsplash.jpg ├── mel_filter.png ├── mel_formula.png ├── morlet.png ├── morlet_image.png ├── squished.png ├── stretched.png ├── title.png └── wavelet_icon.png ├── README.md ├── recordings.zip ├── requirements.txt ├── speaker_classifier.h5 ├── speaker_mean_std.pkl ├── testing_raw_audio.npz ├── training_raw_audio.npz └── wavelet_tutorial.py /0_george_0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/0_george_0.wav -------------------------------------------------------------------------------- /Images/classification_report.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/classification_report.png -------------------------------------------------------------------------------- /Images/formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/formula.png -------------------------------------------------------------------------------- /Images/jan-huber-SqR_XkrwwPk-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/jan-huber-SqR_XkrwwPk-unsplash.jpg -------------------------------------------------------------------------------- /Images/mel_filter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/mel_filter.png -------------------------------------------------------------------------------- /Images/mel_formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/mel_formula.png -------------------------------------------------------------------------------- /Images/morlet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/morlet.png -------------------------------------------------------------------------------- /Images/morlet_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/morlet_image.png -------------------------------------------------------------------------------- /Images/squished.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/squished.png -------------------------------------------------------------------------------- /Images/stretched.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/stretched.png -------------------------------------------------------------------------------- /Images/title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/title.png -------------------------------------------------------------------------------- /Images/wavelet_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/Images/wavelet_icon.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audio Classification using Wavelet Transform and Deep Learning 2 | A step-by-step tutorial to classify audio signals using continuous wavelet transform (CWT) as features. 3 | 4 | 5 | - ## Steps to use this repository: 6 | 7 | - Create a virtual environment by using the command: ```virtualenv venv``` 8 | - Activate the environment: ```source venv/bin/activate``` 9 | - Install the requirements.txt file by typing: ```pip install -r requirements.txt``` 10 | - Extract the recordings.zip file 11 | 12 | - ## Files Description 13 | 14 | - recordings.zip: The contains recordings from the Free Spoken Digit Dataset (FSDD). You can also find this data [here](https://github.com/Jakobovski/free-spoken-digit-dataset). 15 | - training_raw_audio.npz: We are only classifying 3 speakers here: george, jackson, and lucas. All the training data from these 3 speakers is in this numpy zip file. 16 | - testing_raw_audio.npz: We are only classifying 3 speakers here: george, jackson, and lucas. All the testing data from these 3 speakers is in this numpy zip file. 17 | - requirements.txt: It contains the required libraries. 18 | 19 | classification_report 20 | 21 | title 22 | -------------------------------------------------------------------------------- /recordings.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/recordings.zip -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | appdirs==1.4.4 3 | appnope==0.1.2 4 | astor==0.8.1 5 | astunparse==1.6.3 6 | audiomentations==0.18.0 7 | audioread==2.1.9 8 | backcall==0.2.0 9 | cached-property==1.5.2 10 | cachetools==4.2.0 11 | certifi==2020.12.5 12 | cffi==1.14.5 13 | chardet==4.0.0 14 | cycler==0.10.0 15 | decorator==4.4.2 16 | docopt==0.6.2 17 | filelock==3.0.12 18 | flatbuffers==1.12 19 | gast==0.3.3 20 | gdown==3.12.2 21 | google-auth==1.24.0 22 | google-auth-oauthlib==0.4.2 23 | google-pasta==0.2.0 24 | googledrivedownloader==0.4 25 | grpcio==1.32.0 26 | h5py==2.10.0 27 | idna==2.10 28 | importlib-metadata==3.4.0 29 | ipykernel==5.5.0 30 | ipython==7.22.0 31 | ipython-genutils==0.2.0 32 | jedi==0.18.0 33 | joblib==1.0.0 34 | jupyter-client==6.1.12 35 | jupyter-core==4.7.1 36 | Keras==2.4.3 37 | Keras-Applications==1.0.8 38 | Keras-Preprocessing==1.1.2 39 | kiwisolver==1.3.1 40 | librosa==0.8.0 41 | llvmlite==0.35.0 42 | Markdown==3.3.3 43 | matplotlib==3.3.3 44 | mock==4.0.3 45 | numba==0.52.0 46 | numpy==1.19.5 47 | oauthlib==3.1.0 48 | opencv-python==4.5.1.48 49 | opt-einsum==3.3.0 50 | packaging==20.9 51 | pandas==1.2.1 52 | parso==0.8.1 53 | pexpect==4.8.0 54 | pickleshare==0.7.5 55 | Pillow==8.1.0 56 | pipreqs==0.4.10 57 | plotly==4.14.3 58 | pooch==1.3.0 59 | prompt-toolkit==3.0.18 60 | protobuf==3.14.0 61 | ptyprocess==0.7.0 62 | pyasn1==0.4.8 63 | pyasn1-modules==0.2.8 64 | pycparser==2.20 65 | pydot==1.4.1 66 | pydub==0.25.1 67 | Pygments==2.8.1 68 | pyparsing==2.4.7 69 | PySocks==1.7.1 70 | python-dateutil==2.8.1 71 | pytz==2021.1 72 | PyYAML==5.3.1 73 | pyzmq==22.0.3 74 | requests==2.25.1 75 | requests-oauthlib==1.3.0 76 | resampy==0.2.2 77 | retrying==1.3.3 78 | rsa==4.7 79 | scikit-image==0.18.3 80 | scikit-learn==0.24.0 81 | scipy==1.6.0 82 | seaborn==0.11.1 83 | six==1.15.0 84 | sklearn==0.0 85 | SoundFile==0.10.3.post1 86 | tensorboard==2.4.1 87 | tensorboard-plugin-wit==1.7.0 88 | tensorflow==2.4.1 89 | tensorflow-estimator==2.4.0 90 | termcolor==1.1.0 91 | threadpoolctl==2.1.0 92 | tornado==6.1 93 | tqdm==4.56.0 94 | traitlets==5.0.5 95 | typing-extensions==3.7.4.3 96 | urllib3==1.26.2 97 | visualkeras==0.0.1 98 | wcwidth==0.2.5 99 | Werkzeug==1.0.1 100 | wrapt==1.12.1 101 | yarg==0.1.9 102 | zipp==3.4.0 -------------------------------------------------------------------------------- /speaker_classifier.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/speaker_classifier.h5 -------------------------------------------------------------------------------- /speaker_mean_std.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/speaker_mean_std.pkl -------------------------------------------------------------------------------- /testing_raw_audio.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/testing_raw_audio.npz -------------------------------------------------------------------------------- /training_raw_audio.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaDutt/Audio-Classification-Using-Wavelet-Transform/fb27563c7cf4a1eb442f80497fc9e5208e5b6df1/training_raw_audio.npz -------------------------------------------------------------------------------- /wavelet_tutorial.py: -------------------------------------------------------------------------------- 1 | 2 | # Import libraries 3 | import os, sys, cv2, matplotlib.pyplot as plt, numpy as np, pandas as pd, pickle 4 | import random 5 | from random import seed, random, randint, sample 6 | 7 | import tensorflow.keras as keras 8 | from keras import backend as K 9 | from keras.models import Model, load_model, Sequential 10 | from keras.callbacks import ModelCheckpoint 11 | from keras.layers import Input, Dense, GlobalMaxPool1D, Activation, MaxPool1D, Conv1D, Flatten, BatchNormalization 12 | from keras.regularizers import l2 13 | from keras.utils.vis_utils import plot_model 14 | from tensorflow.keras.utils import to_categorical 15 | from sklearn.model_selection import train_test_split 16 | from tensorflow.keras.optimizers import Adam 17 | from tensorflow.keras.regularizers import l2 18 | 19 | import librosa 20 | import librosa.display 21 | from sklearn.metrics import confusion_matrix, classification_report 22 | from sklearn.preprocessing import normalize 23 | from mpl_toolkits.mplot3d import Axes3D 24 | from skimage.transform import resize 25 | from scipy.signal import hilbert, chirp 26 | from sklearn.preprocessing import MinMaxScaler 27 | from librosa.filters import mel 28 | import pywt 29 | import scipy 30 | from tqdm import tqdm 31 | from sklearn.model_selection import StratifiedKFold 32 | 33 | 34 | 35 | # Step 1: Read the audio files and split into train/ test data 36 | 37 | 38 | # The data in the current directory inside the doler "recordings". 39 | dir = os.getcwd() + "/recordings/" 40 | 41 | # Read audio files from the directory. For this tutorial, we will only classify 3 speakers: george, jackson, and lucas. 42 | # Audio files have this format : {digit}_{speaker}_{speaker_filenumber}.wav 43 | 44 | audio = [] # List to store audio np arrays 45 | y = [] # List to store the target class labels 46 | 47 | for root, dirs, files in os.walk(dir, topdown=False): 48 | for name in files: 49 | 50 | if name.find(".wav") != -1 : # Check if the file has a .wav extension 51 | if name.find("george") != -1 or name.find("jackson") != -1 or name.find("lucas") != -1 : # Check if the speaker is george, jackson, and lucas. 52 | fullname = os.path.join(root, name) 53 | audio.append(fullname) # Append the np array to the list. 54 | if name.find("george") != -1 : 55 | y.append(0) 56 | elif name.find("jackson") != -1 : 57 | y.append(1) 58 | else : 59 | y.append(2) 60 | 61 | # Write the audio data in a npz file so that we don't have to read the audio files again. We can load the data from npz file. Also, the npz format is very space efficient. 62 | audio_train, audio_test, y_train, y_test = train_test_split(audio, y, test_size=0.3) 63 | # np.savez_compressed(os.getcwd()+"/training_raw_audio", a=audio_train, b=y_train) 64 | # np.savez_compressed(os.getcwd()+"/testing_raw_audio", a=audio_test, b=y_train) 65 | 66 | print("Finished writing to npz file...") 67 | 68 | # Print the class distribution 69 | print("Training Data class distribution: ", np.unique(y_train, return_counts=True)) 70 | print("Testing Data class distribution: ", np.unique(y_test, return_counts=True)) 71 | 72 | 73 | 74 | # Load the data from the .npz file 75 | train_data = np.load(os.getcwd()+"/training_raw_audio.npz", allow_pickle=True) 76 | audio_train = train_data['a'] 77 | y_train = train_data['b'] 78 | 79 | test_data = np.load(os.getcwd()+"/testing_raw_audio.npz", allow_pickle=True) 80 | audio_test = test_data['a'] 81 | y_test = test_data['b'] 82 | 83 | 84 | 85 | 86 | ''' 87 | Step 2: Write a function to compute continuous wavelet transform features of each audio sample 88 | Human Voice Frequency Range: 89 | - The human ear can hear between 20 and 20,000 Hz (20 kHz) but it is most sensitive to everything that happens between 250 and 5,000 Hz. 90 | - The voiced speech of a typical adult male will have a fundamental frequency from 85 to 180 Hz, and that of a typical adult female from 165 to 255 Hz. 91 | - For a child’s voice, average fundamental frequency is 300Hz. 92 | - Consonants take up space between 2kHz and 5kHz. 93 | - Vowel Sounds are prominent between 500Hz and 2kHz. 94 | 95 | We will keep frequencies only between 80 Hz and 5KHz. 96 | We will split each audio into frames of length 800. 97 | ''' 98 | 99 | def compute_wavelet_features(X) : 100 | 101 | # Define a few parameters 102 | wavelet = 'morl' # wavelet type: morlet 103 | sr = 8000 # sampling frequency: 8KHz 104 | widths = np.arange(1, 256) # scales for morlet wavelet 105 | dt = 1/sr # timestep difference 106 | 107 | frequencies = pywt.scale2frequency(wavelet, widths) / dt # Get frequencies corresponding to scales 108 | 109 | # Create a filter to select frequencies between 80Hz and 5KHz 110 | upper = ([x for x in range(len(widths)) if frequencies[x] > 1000])[-1] 111 | lower = ([x for x in range(len(widths)) if frequencies[x] < 80])[0] 112 | widths = widths[upper:lower] # Select scales in this frequency range 113 | 114 | # Compute continuous wavelet transform of the audio numpy array 115 | wavelet_coeffs, freqs = pywt.cwt(X, widths, wavelet = wavelet, sampling_period=dt) 116 | # print(wavelet_coeffs.shape) 117 | # sys.exit(1) 118 | 119 | # Split the coefficients into frames of length 800 120 | start = 0 121 | end = wavelet_coeffs.shape[1] 122 | frames = [] 123 | frame_size = 400 124 | count = 0 125 | 126 | while start+frame_size <= end-1 : 127 | 128 | f = wavelet_coeffs[:,start:start+frame_size] 129 | 130 | # Total samples in a frame will not be a multiple of 800 everytime. If the last frame length is less than 800, we can skip it. 131 | assert f.shape[1] == frame_size # assert frame lengths are equal to the frame_size parameter 132 | 133 | frames.append(f) 134 | start += frame_size 135 | 136 | 137 | # Convert frames to numpy array 138 | frames = np.array(frames) 139 | frames = frames.reshape((len(frames), wavelet_coeffs.shape[0], frame_size)) 140 | 141 | return frames 142 | 143 | 144 | 145 | # Step 3: Compute continuous wavelet transform of training and testing data using the function in Step 3 146 | 147 | ### Compute Training data features. We have each sample into frames of length 400 148 | 149 | indices = [] 150 | WaveletFeatTrain = [] # Store wavelet features 151 | WaveletYTrain = [] # Store class labels corresponding to wavelet features from an audio sample 152 | uniq_id = [] 153 | count = 0 154 | 155 | for i in range(3) : 156 | 157 | ind, = np.where(y_train == i) 158 | seed(i) 159 | ind = ind.tolist() 160 | ind = sample(ind, 100) 161 | audio_samples = audio_train[ind] 162 | num_rand_samp = 100 163 | 164 | for j in tqdm(range(len(audio_samples))) : 165 | 166 | # print("i ", i, " j ", j, "/", len(audio_samples)) 167 | curr_sample = audio_samples[j] 168 | seq, _ = librosa.load(curr_sample) 169 | F = compute_wavelet_features(seq) 170 | F = F.astype(np.float16) 171 | 172 | # Generate target labels corresponding to the frames of each sample 173 | indices = np.arange(0, len(F), 1) 174 | indices = indices.tolist() 175 | indices = sample(indices, min(num_rand_samp, len(indices))) 176 | F = F[indices] 177 | uniq_id += [count] * len(F) 178 | WaveletYTrain += [i] * len(F) 179 | 180 | if count == 0 : 181 | WaveletFeatTrain = F 182 | else : 183 | WaveletFeatTrain = np.concatenate((WaveletFeatTrain, F), axis=0) 184 | 185 | count += 1 186 | 187 | 188 | 189 | print("X: ", WaveletFeatTrain.shape) 190 | 191 | WaveletYTrain = np.array(WaveletYTrain) # Convert to numpy array 192 | uniq_id = np.array(uniq_id) 193 | print("Y: ", WaveletYTrain.shape, " unique: ", np.unique(WaveletYTrain, return_counts=True)) 194 | # Write all features to a .npz file 195 | np.savez_compressed(os.getcwd()+"/training_features", a=WaveletFeatTrain, b=WaveletYTrain, c=uniq_id) 196 | 197 | 198 | 199 | ### Compute Testing data features 200 | 201 | WaveletFeatTest = [] # Store wavelet features. We have each sample into frames of length 400 202 | WaveletYTest = [] # Store class labels corresponding to wavelet features from an audio sample 203 | uniq_id = [] 204 | 205 | for i in tqdm(range(len(audio_test))) : 206 | 207 | curr_sample = audio_test[i] 208 | seq, _ = librosa.load(curr_sample) 209 | curr_target = y_test[i] 210 | F = compute_wavelet_features(seq) 211 | 212 | # Generate target labels corresponding to the frames of each sample 213 | WaveletYTest += [curr_target] * len(F) 214 | uniq_id += [i] * len(F) 215 | 216 | if i == 0 : 217 | WaveletFeatTest = F 218 | else : 219 | WaveletFeatTest = np.concatenate((WaveletFeatTest, F), axis=0) 220 | 221 | WaveletYTest = np.array(WaveletYTest) # Convert to numpy array 222 | uniq_id = np.array(uniq_id) 223 | print("X: ", WaveletFeatTest.shape, " y: ", WaveletYTest.shape) 224 | 225 | WaveletFeatTest = WaveletFeatTest.astype(np.float16) 226 | 227 | # Write all features to a .npz file 228 | np.savez_compressed(os.getcwd()+"/testing_features", a=WaveletFeatTest, b=WaveletYTest, c=uniq_id) 229 | 230 | 231 | 232 | # Step 4: Build a deep learning model 233 | def create_model(row, col) : 234 | 235 | n_filters = 32 236 | filter_width = 3 237 | dilation_rates = [2**i for i in range(6)] * 2 238 | 239 | # define an input history series and pass it through a stack of dilated causal convolution blocks 240 | history_seq = Input(shape=(row, col)) 241 | x = history_seq 242 | 243 | skips = [] 244 | count = 0 245 | # x = GaussianNoise(0.01)(x) 246 | for dilation_rate in dilation_rates: 247 | 248 | # preprocessing - equivalent to time-distributed dense 249 | 250 | # filter 251 | x = Conv1D(filters=n_filters, 252 | kernel_size=filter_width, 253 | padding='causal', 254 | dilation_rate=dilation_rate, kernel_regularizer=l2(0.001), bias_regularizer=l2(0.001))(x) 255 | 256 | x = BatchNormalization()(x) 257 | x = Activation('relu')(x) 258 | 259 | out = Conv1D(16, 3, padding='same', kernel_initializer= 'random_normal', kernel_regularizer=l2(0.001), bias_regularizer=l2(0.001))(x) 260 | out = BatchNormalization()(out) 261 | out = Activation('relu')(out) 262 | out = GlobalMaxPool1D()(out) 263 | 264 | out = Dense(3, kernel_regularizer=l2(0.001), bias_regularizer=l2(0.001))(out) 265 | out = Activation('softmax')(out) 266 | 267 | model = Model(history_seq, out) 268 | 269 | model.compile(loss='categorical_crossentropy', optimizer='adam') 270 | 271 | return model 272 | 273 | model = create_model(400, 76) 274 | print(model.summary()) 275 | 276 | 277 | 278 | 279 | 280 | 281 | # Step 5: Preprocess the data and train the model 282 | #For the neural network, we need the data in format: Num_samples x timesteps x features. But currently the data is in format: Num_samples x features x timesteps. 283 | 284 | 285 | # Load the data 286 | training_data = np.load(os.getcwd()+"/training_features.npz") 287 | X = training_data['a'] 288 | y = training_data['b'] 289 | 290 | X = X.transpose(0,2,1) # Put data in correct format: Num_samples x timesteps x features 291 | y = to_categorical(y) # Convert class labels to categorial vectors 292 | print("X ", X.shape, "y ", y.shape) 293 | 294 | # Standardize the data 295 | mean = X.mean() 296 | std = X.std() 297 | X = (X-mean)/ std 298 | 299 | print("Mean ", mean, " STD ", std, X.mean(), X.std()) 300 | 301 | X = X.astype(np.float32) 302 | 303 | y = y.astype(np.uint8) 304 | 305 | print("Input shapes ", X.shape, y.shape) 306 | 307 | # Write the standard deviation and mean in a pickle file 308 | # f = open(os.getcwd()+'/speaker_mean_std.pkl', 'wb') 309 | # pickle.dump([mean, std, y], f) 310 | # f.close() 311 | 312 | r,c = X[0].shape 313 | 314 | # Split data into training and validation 315 | X1, Xval, y1, yval = train_test_split(X, y, test_size=0.20)#, random_state=int(time.time())) 316 | 317 | # Use 5-fold cross validation 318 | kfold = StratifiedKFold(n_splits=5, shuffle=True) 319 | 320 | count = 0 321 | 322 | # Train the model 323 | # for train, test in kfold.split(X1, np.argmax(y1, axis= -1)): 324 | 325 | # print("K Fold Step ", count) 326 | # model.fit(X1[train], y1[train], validation_data= (X1[test], y1[test]), batch_size= 128, epochs= 80, verbose= 2) 327 | # model.save(os.getcwd()+"/speaker_classifier.h5") 328 | 329 | # count += 1 330 | 331 | # scores = model.evaluate(Xval, yval, verbose=0) 332 | # print("Metrics : ", scores) 333 | 334 | 335 | 336 | 337 | # Step 6: Test the model 338 | model = load_model(os.getcwd()+"/speaker_classifier.h5", compile= False) 339 | 340 | print(model.summary()) 341 | 342 | # Load the standard deviation and mean 343 | f = open(os.getcwd()+'/speaker_mean_std.pkl', 'rb') 344 | mean, std, poss_knnn = pickle.load(f) 345 | f.close() 346 | 347 | testing_data = np.load(os.getcwd()+"/testing_features.npz") 348 | X = testing_data['a'] 349 | y = testing_data['b'] 350 | ind = testing_data['c'] 351 | unq_ind = np.unique(ind) 352 | 353 | X = X.astype(np.float32) 354 | 355 | X = X.transpose(0,2,1) # Put data in correct format: Num_samples x timesteps x features 356 | X = (X-mean)/ std 357 | 358 | # Predict 359 | ypred = model.predict(X) 360 | ypred = np.argmax(ypred, axis=-1) 361 | ypred = ypred.flatten() 362 | 363 | new_pred = [] 364 | new_truth = [] 365 | 366 | # Find unique ids and assign class based on majority vote from all the frames 367 | for i in range(len(unq_ind)) : 368 | curr = unq_ind[i] 369 | indices, = np.where(ind == curr) 370 | t = y[indices] 371 | t = t[0] 372 | 373 | p = ypred[indices] 374 | # p1 = ypred1[indices] 375 | # p2 = ypred2[indices] 376 | # p3 = ypred3[indices] 377 | # p4 = ypred4[indices] 378 | 379 | # all_model_pred = [get_best_candidate(x, t) for x in [p1, p2, p3, p4]] 380 | # un, fr = np.unique(all_model_pred, return_counts=True) 381 | 382 | un, fr = np.unique(p, return_counts=True) 383 | new_pred.append(un[np.argmax(fr)]) 384 | print("Truth ", t, " Pred ", un, fr) 385 | 386 | new_truth.append(t) 387 | 388 | new_truth = np.array(new_truth) 389 | new_pred = np.array(new_pred) 390 | 391 | # Print classification report 392 | rep = classification_report(new_truth, new_pred, target_names=['speaker1', 'speaker2', 'speaker3']) 393 | print(rep) --------------------------------------------------------------------------------