├── .idea ├── TIMITspeech.iml ├── misc.xml ├── modules.xml ├── preferred-vcs.xml └── workspace.xml ├── README.md ├── RNN.py ├── RNN_implementation.py ├── background ├── __init__.py ├── htkbook-3.5.pdf ├── plot_data.py ├── timit_phones.txt └── timit_words.txt ├── environment.yml ├── getResults.py └── tools ├── __init__.py ├── __init__.pyc ├── createMLF.py ├── datasetToPkl.py ├── formatting.py ├── formatting.pyc ├── general_tools.py ├── general_tools.pyc ├── helpFunctions ├── __init__.py ├── copyFilesOfType.py ├── progress_bar.py ├── removeEmptyDirs.py ├── resample.py ├── resampleExperiment.py ├── sa1.wav ├── visualizeMFC.m ├── wavToPng.py └── writeToTxt.py ├── mergeAudioFiles.py ├── phoneme_set.py ├── preprocessWavs.py └── transform.py /.idea/TIMITspeech.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/preferred-vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ApexVCS 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 11 | 12 | 13 | 14 | 15 | 16 | 21 | 22 | 23 | 25 | 26 | 27 | 1506427818625 28 | 32 | 33 | 34 | 35 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | With this repo you can preprocess an audio dataset (modify phoneme classes, resample audio etc), and train LSTM networks for framewise phoneme classification. 2 | You can achieve 82% accuracy on the TIMIT dataset, similar to the results from [Graves et al (2013)](https://arxiv.org/abs/1303.5778), although CTC is not used here. 3 | Instead, the network generates predictions in the middle of each phoneme interval, as specified by the labels. This is to simplify things, but using CTC shouldn't be too much trouble. 4 | 5 | In order to create and train a network on a dataset, you need to: 6 | 7 | 0. install software. I recommend using [Anaconda](https//www.anaconda.com/download/) and a virtual environment. 8 | - create an environment from the provided file: `conda env create -f environment.yml` 9 | 10 | 1. Generate a binary file from the source dataset. It's easiest if you structure your data as in the TIMIT dataset, although that's not really needed. Just make sure that the wav and its corresponding phn file have the same path except for the extension. Otherwise they won't get matched and your labels will be off. 11 | - WAV files, 16kHz sampling rate, folder structure `dataset/TRAIN/speakerName/videoName/`. 12 | Each videoName/ directory contains a `videoName.wav` and `videoName.phn`. 13 | The phn contains the audio sample (@16kHz) numbers where each phoneme starts and ends. 14 | - If your files are in a different format, you can use functions from fixDataset/ to: (use transform.py, with the appropriate arguments, see bottom of file) 15 | - fix wav headers, resample wavs. Store them under `dataRoot/dataset/fixed(nbPhonemes)/` 16 | `transform.py phonemes -i dataRoot/TIMIT/original/ -o dataRoot/TIMIT/fixed` 17 | - fix labelfiles: replace phonemes (eg to use a reduced phoneme set; I used the 39 phonemes from Lee and Hon (1989)). Stored next to fixed wavs, under `root/dataset/fixed(nbPhonemes)/` 18 | - create a MLF file (like from HTK tool, and as used in the TCDTIMIT dataset) 19 | - the scripts should be case-agnostic, but you can convert lower to uppercase and vice versa by running `find . -depth -print0 | xargs -0 rename '$_ = lc $_'` in the root dataset directory (change 'lc' to 'uc to convert to upper case). Repeat until you get no more output. 20 | - Then set variables in datasetToPkl.py (source and target dir, nbMFCCs to use etc), and run the file 21 | - the result is stored as `root/dataset/binary(nbPhonemes)/dataset/dataset_nbPhonemes_ch.pkl`. eg root/TIMIT/binary39/TIMIT/TIMIT_39_ch.pkl 22 | - the mean and std_dev of the train data are stored as `root/dataset/binary_nbPhonemes/dataset_MeanStd.pkl`. It's useful for normalization when evaluating. 23 | 1. Use RNN.py to start training. Its functions are implemented in RNN_tools_lstm.py, but you can set the parameters from RNN.py. 24 | - set location of pkl generated by datasetToPkl.py 25 | - specify number of LSTM layers and number of units per layer 26 | - use bidirectional LSTM layers 27 | - add some dense layers (though it did not improve performance for me) 28 | - learning rate and decay (LR is updated at end of RNN_tools_lstm.py). It's decreased if the performance hasn't improved for some time. 29 | 30 | - it will automatically give the model a name based on the specified parameters. A log file, the model parameters and a pkl file containing training info (accuracy, error etc for each epoch) are stored as well. 31 | The storage location is`root/dataset/results` 32 | 33 | 1. to evaluate a dataset, change the test_dataset variable to whatever you want (TIMIT/TCDTIMIT/combined) 34 | 1. You can generate test datasets with noise (either white noise or simultaneous speakers) of a certain level using mergeAudioFiles.py to create the wavs and testdataToPkl.py to convert that to pkl files. 35 | 1. If this noisy audio is to be used for combinedSR, you need to generate the pkl files a bit differently, using audioToPkl_perVideo.py. That pkl file can then be combined with the images and labels generated by combinedSR/datasetToPkl. You can enable this by setting some parameters in `combinedSR/combinedNN.py` 36 | 37 | On TIMIT, you should get about 82% accuracy using a 2-layer, 256 units/layer bidirectional LSTM network. 38 | You should get about 67% on TCD-TIMIT. 39 | 40 | The TIMIT dataset is non-free and available from [https://catalog.ldc.upenn.edu/LDC93S19](https://catalog.ldc.upenn.edu/LDC93S1). 41 | The TCD-TIMIT dataset is free for research and available from [https://sigmedia.tcd.ie/TCDTIMIT/](https://sigmedia.tcd.ie/TCDTIMIT/). 42 | If you want to use TCD-TIMIT, I recommend to use my repo [TCDTIMITprocessing](https://github.com/matthijsvk/TCDTIMITprocessing) to download, and extract the database. It's quite a nasty job otherwise. You can use `extractTCDTIMITaudio.py` to get the phoneme and wav files. 43 | 44 | If you want to do lipreading or audio-visual speech recognition, check out my other repository [MultimodalSR](https://github.com/matthijsvk/multimodalSR) 45 | -------------------------------------------------------------------------------- /RNN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import warnings 4 | from time import gmtime, strftime 5 | 6 | warnings.simplefilter("ignore", UserWarning) # cuDNN warning 7 | 8 | import logging 9 | import tools.formatting as formatting 10 | 11 | import time 12 | program_start_time = time.time() 13 | 14 | print("\n * Importing libraries...") 15 | from RNN_implementation import * 16 | from pprint import pprint # printing properties of networkToRun objects 17 | from tqdm import tqdm # % done bar 18 | 19 | 20 | ##### SCRIPT META VARIABLES ##### 21 | VERBOSE = True 22 | num_epochs = 20 23 | nbMFCCs = 39 # num of features to use -> see 'utils.py' in convertToPkl under processDatabase 24 | nbPhonemes = 39 # number output neurons 25 | 26 | # Root path: data is stored here, models and results will be stored here as well 27 | # see README.md for info about where you should store your data 28 | root = os.path.expanduser("~"+os.sep+"TCDTIMIT"+os.sep) 29 | 30 | # Choose which datasets to run 31 | datasets = ["TIMIT"] # "TCDTIMIT", combined" 32 | 33 | # for each dataset type as key this dictionary contains as value a list of all the network architectures that need to be trained for this dataset 34 | many_n_hidden_lists = {} 35 | many_n_hidden_lists['TIMIT'] = [[50]]#[8,8], [32], [64], [256], [512],[1024]] 36 | #[8, 8], [32, 32], [64, 64], [256, 256], [512, 512]]#, [1024,1024]] 37 | #[8,8,8],[32,32,32],[64,64,64],[256,256,256],[512,512,512], 38 | #[8,8,8,8],[32,32,32,32],[64,64,64,64],[256,256,256,256],[512,512,512,512]] 39 | #[512,512,512,512], 40 | #[1024,1024,1024],[1024,1024,1024,1024]] 41 | many_n_hidden_lists['TCDTIMIT'] = [[256, 256], [512, 512], [256, 256, 256, 256]] 42 | # combined is TIMIT and TCDTIMIT put together 43 | many_n_hidden_lists['combined'] = [[256, 256]] # [32, 32], [64, 64], [256, 256]]#, [512, 512]] 44 | 45 | ####################### 46 | 47 | bidirectional = True 48 | add_dense_layers = False 49 | 50 | run_test = True # if network exists, just test, don't retrain but just evaluate on test set 51 | autoTrain = False # if network doesn't exist, train a new one automatically 52 | round_params = False 53 | 54 | # Decaying LR: each epoch LR = LR * LR_decay 55 | LR_start = 0.01 # will be divided by 10 if retraining existing model 56 | LR_fin = 0.0000001 57 | # LR_decay = (LR_fin / LR_start) ** (1. / num_epochs) 58 | LR_decay = 0.5 59 | 60 | print("LR_start = %s", str(LR_start)) 61 | print("LR_fin = %s", str(LR_fin)) 62 | print("LR_decay = %s", str(LR_decay)) 63 | 64 | 65 | # quickly generate many networks 66 | def create_network_list(dataset, networkArchs): 67 | network_list = [] 68 | for networkArch in networkArchs: 69 | network_list.append(NetworkToRun(run_type="audio", n_hidden_list=networkArch, 70 | dataset=dataset, test_dataset=dataset, run_test=True)) 71 | return network_list 72 | 73 | def main(): 74 | # Use this if you want only want to start training if the network doesn't exist 75 | network_list = create_network_list(dataset="TIMIT", networkArchs=many_n_hidden_lists['TIMIT']) 76 | 77 | net_runner = NetworkRunner(network_list) 78 | net_runner.get_network_results() 79 | print("\n got all results") 80 | net_runner.exportResultsToExcel() 81 | 82 | 83 | class NetworkToRun: 84 | def __init__(self, 85 | n_hidden_list=(256, 256,), nbMFCCs=39, audio_bidirectional=True, 86 | LR_start=0.001, round_params=False, run_test=False, force_train=False, 87 | run_type='audio', dataset="TCDTIMIT", test_dataset=None, 88 | with_noise=False, noise_types=('white',), ratio_dBs=('0, -3, -5, -10',)): 89 | # Audio 90 | self.n_hidden_list = n_hidden_list # LSTM architecture for audio part 91 | self.nbMFCCs = nbMFCCs 92 | self.audio_bidirectional = audio_bidirectional 93 | 94 | # Others 95 | self.run_type = run_type 96 | self.LR_start = LR_start 97 | self.run_test = run_test 98 | self.force_train = force_train # If False, just test the network outputs when the network already exists. 99 | # If force_train == True, train it anyway before testing 100 | # If True, set the LR_start low enough so you don't move too far out of the objective minimum 101 | 102 | self.dataset = dataset 103 | if test_dataset == None: 104 | self.test_dataset = self.dataset 105 | else: 106 | self.test_dataset = test_dataset 107 | 108 | self.round_params = round_params 109 | self.with_noise = with_noise 110 | self.noise_types = noise_types 111 | self.ratio_dBs = ratio_dBs 112 | 113 | self.model_name, self.nice_name = self.get_model_name() 114 | self.model_path = self.get_model_path() 115 | self.model_path_noNPZ = self.model_path.replace('.npz','') 116 | 117 | # this generates the correct path based on the chosen parameters, and gets the train/val/test data 118 | def load_data(self, noise_type='white', ratio_dB='0'): 119 | dataset = self.dataset 120 | test_dataset = self.test_dataset 121 | with_noise = self.with_noise 122 | 123 | data_dir = os.path.join(root, self.run_type+"SR", dataset, "binary" + str(nbPhonemes), 124 | dataset) # output dir from datasetToPkl.py 125 | data_path = os.path.join(data_dir, dataset + '_' + str(nbMFCCs) + '_ch.pkl') 126 | if run_test: 127 | test_data_dir = os.path.join(root, self.run_type+"SR", test_dataset, "binary" + str(nbPhonemes) + \ 128 | ('_'.join([noise_type, os.sep, "ratio", str(ratio_dB)]) if with_noise else ""),test_dataset) 129 | test_data_path = os.path.join(test_data_dir, test_dataset + '_' + str(nbMFCCs) + '_ch.pkl') 130 | 131 | self.logger.info(' data source: %s', data_path) 132 | 133 | dataset = unpickle(data_path) 134 | x_train, y_train, valid_frames_train, x_val, y_val, valid_frames_val, x_test, y_test, valid_frames_test = dataset 135 | 136 | # if run_test, you can use another dataset than the one used for training for evaluation 137 | if run_test: 138 | self.logger.info(" test data source: %s", test_data_path) 139 | if with_noise: 140 | x_test, y_test, valid_frames_test = unpickle(test_data_path) 141 | else: 142 | _, _, _, _, _, _, x_test, y_test, valid_frames_test = unpickle(test_data_path) 143 | 144 | datasetFiles = [x_train, y_train, valid_frames_train, x_val, y_val, valid_frames_val, x_test, y_test, 145 | valid_frames_test] 146 | # Print some information 147 | debug = False 148 | if debug: 149 | self.logger.info("\n* Data information") 150 | self.logger.info('X train') 151 | self.logger.info(' %s %s', type(x_train), len(x_train)) 152 | self.logger.info(' %s %s', type(x_train[0]), x_train[0].shape) 153 | self.logger.info(' %s %s', type(x_train[0][0]), x_train[0][0].shape) 154 | self.logger.info(' %s', type(x_train[0][0][0])) 155 | 156 | self.logger.info('y train') 157 | self.logger.info(' %s %s', type(y_train), len(y_train)) 158 | self.logger.info(' %s %s', type(y_train[0]), y_train[0].shape) 159 | self.logger.info(' %s %s', type(y_train[0][0]), y_train[0][0].shape) 160 | 161 | self.logger.info('valid_frames train') 162 | self.logger.info(' %s %s', type(valid_frames_train), len(valid_frames_train)) 163 | self.logger.info(' %s %s', type(valid_frames_train[0]), valid_frames_train[0].shape) 164 | self.logger.info(' %s %s', type(valid_frames_train[0][0]), valid_frames_train[0][0].shape) 165 | 166 | return datasetFiles 167 | 168 | # this builds the chosen network architecture, loads network weights and compiles the functions 169 | def setup_network(self, batch_size): 170 | dataset = self.dataset 171 | test_dataset = self.test_dataset 172 | n_hidden_list = self.n_hidden_list 173 | round_params = self.round_params 174 | 175 | store_dir = root + dataset + "/results" 176 | if not os.path.exists(store_dir): os.makedirs(store_dir) 177 | 178 | # log file 179 | fh = self.setupLogging(store_dir) 180 | ############################################################# 181 | 182 | 183 | self.logger.info("\n\n\n\n STARTING NEW TRAINING SESSION AT " + strftime("%Y-%m-%d %H:%M:%S", gmtime())) 184 | 185 | ##### IMPORTING DATA ##### 186 | 187 | self.logger.info(' model target: %s', self.model_name) 188 | 189 | self.logger.info('\n* Building network using batch size: %s...', batch_size) 190 | RNN_network = NeuralNetwork('RNN', None, batch_size=batch_size, 191 | num_features=nbMFCCs, n_hidden_list=n_hidden_list, 192 | num_output_units=nbPhonemes, 193 | bidirectional=bidirectional, addDenseLayers=add_dense_layers, 194 | debug=False, 195 | dataset=dataset, test_dataset=test_dataset, logger=self.logger) 196 | 197 | # print number of parameters 198 | nb_params = lasagne.layers.count_params(RNN_network.network_lout_batch) 199 | self.logger.info(" Number of parameters of this network: %s", nb_params) 200 | 201 | # Try to load stored model 202 | self.logger.info(' Network built. \nTrying to load stored model: %s', self.model_name + '.npz') 203 | success = RNN_network.load_model(self.model_path_noNPZ, round_params=round_params) 204 | 205 | RNN_network.loadPreviousResults(self.model_path_noNPZ) 206 | 207 | ##### COMPILING FUNCTIONS ##### 208 | self.logger.info("\n* Compiling functions ...") 209 | RNN_network.build_functions(train=True, debug=False) 210 | 211 | return RNN_network, success, fh 212 | 213 | def setupLogging(self, store_dir): 214 | self.logger = logging.getLogger(self.model_name) 215 | self.logger.setLevel(logging.DEBUG) 216 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s ' 217 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False)) 218 | # formatter2 = logging.Formatter( 219 | # '%(asctime)s - %(name)-5s - %(levelname)-10s - (%(filename)s:%(lineno)d): %(message)s') 220 | 221 | # create console handler with a higher log level 222 | ch = logging.StreamHandler() 223 | ch.setLevel(logging.DEBUG) 224 | ch.setFormatter(formatter) 225 | self.logger.addHandler(ch) 226 | 227 | logFile = store_dir + os.sep + self.model_name + '.log' 228 | if os.path.exists(logFile): 229 | self.fh = logging.FileHandler(logFile) # append to existing log 230 | else: 231 | self.fh = logging.FileHandler(logFile, 'w') # create new logFile 232 | self.fh.setLevel(logging.DEBUG) 233 | self.fh.setFormatter(formatter) 234 | self.logger.addHandler(self.fh) 235 | 236 | def get_model_name(self): 237 | n_hidden_list = self.n_hidden_list 238 | dataset = self.dataset 239 | model_name = str(len(n_hidden_list)) + "_LSTMLayer" + '_'.join([str(layer) for layer in n_hidden_list]) \ 240 | + "_nbMFCC" + str(nbMFCCs) + ("_bidirectional" if bidirectional else "_unidirectional") + \ 241 | ("_withDenseLayers" if add_dense_layers else "") + "_" + dataset 242 | 243 | nice_name = "Audio:" + ' '.join( 244 | ["LSTM", str(n_hidden_list[0]), "/", str(len(n_hidden_list))]) 245 | 246 | return model_name, nice_name 247 | 248 | def get_model_path(self): 249 | model_name, nice_name = self.get_model_name() 250 | model_path = os.path.join(root, self.run_type + 'SR', self.dataset, 'results', model_name + '.npz') 251 | return model_path 252 | 253 | # this takes the prepared data, built network and some parameters, and trains/evaluates the network 254 | def executeNetwork(self, RNN_network, load_params_success, batch_size, datasetFiles, 255 | noiseType='white', ratio_dB=0, fh=None): 256 | with_noise = self.with_noise 257 | run_test = self.run_test 258 | 259 | LR = LR_start 260 | if load_params_success == 0: LR = LR_start / 10.0 261 | 262 | ##### TRAINING ##### 263 | self.logger.info("\n* Training ...") 264 | results = RNN_network.train(datasetFiles, self.model_path_noNPZ, num_epochs=num_epochs, 265 | batch_size=batch_size, LR_start=LR, LR_decay=LR_decay, 266 | compute_confusion=True, justTest=run_test, debug=False, 267 | withNoise=with_noise, noiseType=noiseType, ratio_dB=ratio_dB) 268 | 269 | self.fh.close() 270 | self.logger.removeHandler(self.fh) 271 | 272 | return results 273 | 274 | # estimate a good batchsize based on the size of the network 275 | def getBatchSizes(self): 276 | n_hidden_list = self.n_hidden_list 277 | if n_hidden_list[0] > 128: 278 | batch_sizes = [64, 32, 16, 8, 4] 279 | elif n_hidden_list[0] > 64: 280 | batch_sizes = [128, 64, 32, 16, 8, 4] 281 | else: 282 | batch_sizes = [256, 128, 64, 32, 16, 8, 4] 283 | return batch_sizes 284 | 285 | # run the network at the maximum batch size 286 | def runNetwork(self): 287 | 288 | batch_sizes = self.getBatchSizes() 289 | results = [0, 0, 0] 290 | for batch_size in batch_sizes: 291 | try: 292 | RNN_network, load_params_success, fh = self.setup_network(batch_size) 293 | datasetFiles = self.load_data() 294 | results = self.executeNetwork(RNN_network, load_params_success, batch_size=batch_size, 295 | datasetFiles=datasetFiles, fh=fh) 296 | break 297 | except: 298 | print('caught this error: ' + traceback.format_exc()); 299 | self.logger.info("batch size too large; trying again with lower batch size") 300 | pass # just try again with the next batch_size 301 | 302 | return results 303 | 304 | # run a network for all noise types 305 | def testAudio(self, batch_size, fh, load_params_success, RNN_network, 306 | with_noise=False, noiseTypes=('white',), ratio_dBs=(0, -3, -5, -10,)): 307 | for noiseType in noiseTypes: 308 | for ratio_dB in ratio_dBs: 309 | datasetFiles = self.load_data() 310 | self.executeNetwork(RNN_network=RNN_network, load_params_success=load_params_success, 311 | batch_size=batch_size, datasetFiles=datasetFiles, 312 | noiseType=noiseType, ratio_dB=ratio_dB, fh=fh) 313 | if not with_noise: # only need to run once, not for all noise types as we're not testing on noisy audio anyway 314 | return 0 315 | return 0 316 | 317 | # try different batch sizes for testing a network 318 | def testNetwork(self, with_noise, noiseTypes, ratio_dBs): 319 | 320 | batch_sizes = self.getBatchSizes() 321 | for batch_size in batch_sizes: 322 | try: 323 | RNN_network, load_params_success, fh = self.setup_network(batch_size) 324 | # evaluate on test dataset for all noise types 325 | self.testAudio(batch_size, fh, load_params_success, RNN_network, 326 | with_noise, noiseTypes, ratio_dBs) 327 | except: 328 | print('caught this error: ' + traceback.format_exc()); 329 | self.logger.info("batch size too large; trying again with lower batch size") 330 | pass # just try again with the next batch_size 331 | 332 | def get_clean_results(self, network_train_info, nice_name, noise_type='white', ratio_dB='0'): 333 | with_noise = self.with_noise 334 | results_type = ("round_params" if self.round_params else "") + ( 335 | "_Noise" + noise_type + "_" + str(ratio_dB) if with_noise else "") 336 | 337 | this_results = {'results_type': results_type} 338 | this_results['values'] = [] 339 | this_results['dataset'] = self.dataset 340 | this_results['test_dataset'] = self.test_dataset 341 | this_results['audio_dataset'] = self.dataset 342 | 343 | # audio networks can be run on TIMIT or combined as well 344 | if self.run_type != 'audio' and self.test_dataset != self.dataset: 345 | test_type = "_" + self.test_dataset 346 | else: 347 | test_type = "" 348 | if self.round_params: 349 | test_type = "_round_params" + test_type 350 | if self.run_type != 'lipreading' and with_noise: 351 | this_results['values'] = [ 352 | network_train_info['final_test_cost_' + noise_type + "_" + "ratio" + str(ratio_dB) + test_type], 353 | network_train_info['final_test_acc_' + noise_type + "_" + "ratio" + str(ratio_dB) + test_type], 354 | network_train_info['final_test_top3_acc_' + noise_type + "_" + "ratio" + str(ratio_dB) + test_type]] 355 | else: 356 | try: 357 | val_acc = max(network_train_info['val_acc']) 358 | except: 359 | try: 360 | val_acc = max(network_train_info['test_acc']) 361 | except: 362 | val_acc = network_train_info['final_test_acc'] 363 | this_results['values'] = [network_train_info['final_test_cost' + test_type], 364 | network_train_info['final_test_acc' + test_type], 365 | network_train_info['final_test_top3_acc' + test_type], val_acc] 366 | this_results['nb_params'] = network_train_info['nb_params'] 367 | this_results['niceName'] = nice_name 368 | 369 | return this_results 370 | 371 | def get_network_train_info(self, save_path,): 372 | save_name = save_path.replace('.npz','') 373 | if os.path.exists(save_path) and os.path.exists(save_name + "_trainInfo.pkl"): 374 | network_train_info = unpickle(save_name + '_trainInfo.pkl') 375 | 376 | if not 'final_test_cost' in network_train_info.keys(): 377 | network_train_info['final_test_cost'] = min(network_train_info['test_cost']) 378 | if not 'final_test_acc' in network_train_info.keys(): 379 | network_train_info['final_test_acc'] = max(network_train_info['test_acc']) 380 | if not 'final_test_top3_acc' in network_train_info.keys(): 381 | network_train_info['final_test_top3_acc'] = max(network_train_info['test_topk_acc']) 382 | return network_train_info 383 | else: 384 | return -1 385 | 386 | 387 | # This class runs many networks, storing the weight files, training data and log in the appropriate location 388 | # It can also just evaluate a network on a test set, or simply load the results from previous runs 389 | # All results are stored in an Excel file. 390 | # Audio networks can also be evaluated on audio data which has been polluted with noise (set the appropriate parameters in the declaration of the networks you want to test) 391 | # It receives a network_list as input, containing NetworkToRun objects that specify all the relevant parameters to define a network and run it. 392 | class NetworkRunner: 393 | def __init__(self, network_list): 394 | self.network_list = network_list 395 | self.results = {} 396 | self.setup_logging(root) 397 | 398 | def setup_logging(self,store_dir): 399 | self.logger = logging.getLogger('audioNetworkRunner') 400 | self.logger.setLevel(logging.DEBUG) 401 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s ' 402 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False)) 403 | # formatter2 = logging.Formatter( 404 | # '%(asctime)s - %(name)-5s - %(levelname)-10s - (%(filename)s:%(lineno)d): %(message)s') 405 | 406 | # create console handler with a higher log level 407 | ch = logging.StreamHandler() 408 | ch.setLevel(logging.DEBUG) 409 | ch.setFormatter(formatter) 410 | self.logger.addHandler(ch) 411 | 412 | logFile = store_dir + os.sep + 'audioNetworkRunner' + '.log' 413 | if os.path.exists(logFile): 414 | self.fh = logging.FileHandler(logFile) # append to existing log 415 | else: 416 | self.fh = logging.FileHandler(logFile, 'w') # create new logFile 417 | self.fh.setLevel(logging.DEBUG) 418 | self.fh.setFormatter(formatter) 419 | self.logger.addHandler(self.fh) 420 | 421 | # this loads the specified results from networks in network_list 422 | # more efficient if we don't have to reload each network file; just retrieve the data you found earlier 423 | def get_network_results(self): 424 | self.results_path = root + 'storedResults' + ".pkl" 425 | try: 426 | prev_results = unpickle(self.results_path) 427 | except: 428 | prev_results = {} 429 | 430 | # get results for networks that were trained before. If a network has run_test=True, run it on the test set 431 | results, to_retrain, failures = self.get_trained_network_results(self.network_list) 432 | 433 | # failures mean that the network does not exist yet. We have to generate and train it. 434 | if len(failures) > 0: 435 | results2 = self.train_networks(failures) 436 | results.update(results2) 437 | 438 | # to_retrain are networks that had forceTrain==True. We have to train them. 439 | if len(to_retrain) > 0: 440 | results3 = self.train_networks(to_retrain) 441 | results.update(results3) 442 | 443 | # update and store the results 444 | prev_results.update(results) 445 | saveToPkl(self.results_path, prev_results) 446 | self.results = prev_results 447 | 448 | # train the networks 449 | def train_networks(self, networks): 450 | results = [] 451 | self.logger.info("Couldn't get results from %s networks...", len(networks)) 452 | for network in networks: 453 | pprint(vars(network)) 454 | if autoTrain or query_yes_no("\nWould you like to train the networks now?\n\n"): 455 | self.logger.info("Running networks...") 456 | 457 | failures = [] 458 | for network in tqdm(networks, total=len(networks)): 459 | print("\n\n\n\n ################################") 460 | print("Training new network...") 461 | print("Network properties: ") 462 | pprint(vars(network)) 463 | try: 464 | network.runNetwork() 465 | except: 466 | print('caught this error: ' + traceback.format_exc()); 467 | pprint(vars(network)) 468 | failures.append(network) 469 | 470 | if len(failures) > 0: 471 | print("Some networks failed to train...") 472 | import pdb; pdb.set_trace() 473 | 474 | # now the networks are trained, we can load their results 475 | results, _, failures = self.get_trained_network_results(networks) 476 | 477 | return results 478 | 479 | # get the stored results from networks that were trained previously 480 | # for networks that have run_test=True, run the network on the test set before getting the results 481 | # for networks that have forceTrain=True, add to 'to_retrain' list 482 | # for networks that fail to give results (most probably because they haven't been trained yet), add to 'failures' list 483 | def get_trained_network_results(self, networks): 484 | results = {} 485 | results['audio'] = {} 486 | results['lipreading'] = {} 487 | results['combined'] = {} 488 | 489 | failures = [] 490 | to_retrain = [] 491 | 492 | for network in tqdm(networks, total=len(networks)): 493 | self.logger.info("\n\n\n\n ################################") 494 | # pprint(vars(network_params)) 495 | try: 496 | # if forced test evaluation, test the network before trying to load the results 497 | if network.run_test == True: 498 | network.testNetwork() 499 | 500 | # if forced training, we don't need to get the results. That will only be done after retraining 501 | if network.force_train == True: 502 | to_retrain.append(network) 503 | # by leaving here we make sure the network is not appended to the failures list 504 | # (that happens when we try to get its stored results while it doesn't exist, and we haven't done that yet here. 505 | continue 506 | 507 | model_name, nice_name = network.get_model_name() 508 | model_path = network.get_model_path() 509 | 510 | self.logger.info("Getting results for %s", model_path) 511 | network_train_info = network.get_network_train_info(model_path) 512 | if network_train_info == -1: 513 | raise IOError("this model doesn't have any stored results") 514 | 515 | if network.with_noise: 516 | for noise_type in network.noise_types: 517 | for ratio_dB in network.ratio_dBs: 518 | this_results = network.get_clean_results(network_train_info=network_train_info, 519 | nice_name=nice_name, 520 | noise_type=noise_type, 521 | ratio_dB=ratio_dB) 522 | 523 | results[network.run_type][model_name] = this_results 524 | else: 525 | this_results = network.get_clean_results(network_train_info=network_train_info, 526 | nice_name=nice_name) 527 | 528 | # eg results['audio']['2Layer_256_256_TIMIT']['values'] = [0.8, 79.5, 92,6] #test cost, acc, top3 acc 529 | results[network.run_type][model_name] = this_results 530 | 531 | except: 532 | self.logger.info('caught this error: ' + traceback.format_exc()); 533 | failures.append(network) 534 | 535 | self.logger.info("\n\nDONE getting stored results from networks") 536 | self.logger.info("####################################################") 537 | 538 | return results, to_retrain, failures 539 | 540 | 541 | def exportResultsToExcel(self): 542 | path =self.results_path 543 | results = self.results 544 | 545 | storePath = path.replace(".pkl", ".xlsx") 546 | import xlsxwriter 547 | workbook = xlsxwriter.Workbook(storePath) 548 | 549 | for run_type in results.keys()[1:]: # audio, lipreading, combined: 550 | worksheet = workbook.add_worksheet(run_type) # one worksheet per run_type, but then everything is spread out... 551 | row = 0 552 | 553 | allNets = results[run_type] 554 | 555 | # get and write the column titles 556 | # get the number of parameters. #for audio, only 1 value. For combined/lipreadin: lots of values in a dictionary 557 | try: 558 | nb_paramNames = allNets.items()[0][1][ 559 | 'nb_params'].keys() # first key-value pair, get the value ([1]), then get names of nbParams (=the keys) 560 | except: 561 | nb_paramNames = ['nb_params'] 562 | startVals = 4 + len(nb_paramNames) # column number of first value 563 | 564 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset'] + nb_paramNames + ['Test Cost', 565 | 'Test Accuracy', 566 | 'Test Top 3 Accuracy', 567 | 'Validation accuracy'] 568 | for i in range(len(colNames)): 569 | worksheet.write(0, i, colNames[i]) 570 | 571 | # write the data for each network 572 | for netName in allNets.keys(): 573 | row += 1 574 | 575 | thisNet = allNets[netName] 576 | # write the path and name 577 | worksheet.write(row, 0, os.path.basename(netName)) # netName) 578 | worksheet.write(row, 1, thisNet['niceName']) 579 | if run_type == 'audio': 580 | worksheet.write(row, 2, thisNet['audio_dataset']) 581 | worksheet.write(row, 3, thisNet['test_dataset']) 582 | else: 583 | worksheet.write(row, 2, thisNet['dataset']) 584 | worksheet.write(row, 3, thisNet['test_dataset']) 585 | 586 | # now write the params 587 | try: 588 | vals = thisNet['nb_params'].values() # vals is list of [test_cost, test_acc, test_top3_acc] 589 | except: 590 | vals = [thisNet['nb_params']] 591 | for i in range(len(vals)): 592 | worksheet.write(row, 4 + i, vals[i]) 593 | 594 | # now write the values 595 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc] 596 | for i in range(len(vals)): 597 | worksheet.write(row, startVals + i, vals[i]) 598 | 599 | workbook.close() 600 | 601 | self.logger.info("Excel file stored in %s", storePath) 602 | self.fh.close() 603 | self.logger.removeHandler(self.fh) 604 | 605 | def exportResultsToExcelManyNoise(self, resultsList, path): 606 | storePath = path.replace(".pkl", ".xlsx") 607 | import xlsxwriter 608 | workbook = xlsxwriter.Workbook(storePath) 609 | 610 | storePath = path.replace(".pkl", ".xlsx") 611 | import xlsxwriter 612 | workbook = xlsxwriter.Workbook(storePath) 613 | 614 | row = 0 615 | 616 | if len(resultsList[0]['audio'].keys()) > 0: thisrun_type = 'audio' 617 | if len(resultsList[0]['lipreading'].keys()) > 0: thisrun_type = 'lipreading' 618 | if len(resultsList[0]['combined'].keys()) > 0: thisrun_type = 'combined' 619 | worksheetAudio = workbook.add_worksheet('audio'); 620 | audioRow = 0 621 | worksheetLipreading = workbook.add_worksheet('lipreading'); 622 | lipreadingRow = 0 623 | worksheetCombined = workbook.add_worksheet('combined'); 624 | combinedRow = 0 625 | 626 | for r in range(len(resultsList)): 627 | results = resultsList[r] 628 | 629 | for run_type in results.keys()[1:]: 630 | if len(results[run_type]) == 0: continue 631 | if run_type == 'audio': worksheet = worksheetAudio; row = audioRow 632 | if run_type == 'lipreading': worksheet = worksheetLipreading; row = lipreadingRow 633 | if run_type == 'combined': worksheet = worksheetCombined; row = combinedRow 634 | 635 | allNets = results[run_type] 636 | 637 | # write the column titles 638 | startVals = 5 639 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset', 'Noise Type', 'Test Cost', 640 | 'Test Accuracy', 'Test Top 3 Accuracy'] 641 | for i in range(len(colNames)): 642 | worksheet.write(0, i, colNames[i]) 643 | 644 | # write the data for each network 645 | for netName in allNets.keys(): 646 | row += 1 647 | 648 | thisNet = allNets[netName] 649 | # write the path and name 650 | worksheet.write(row, 0, os.path.basename(netName)) # netName) 651 | worksheet.write(row, 1, thisNet['niceName']) 652 | worksheet.write(row, 2, thisNet['dataset']) 653 | worksheet.write(row, 3, thisNet['test_dataset']) 654 | worksheet.write(row, 4, thisNet['results_type']) 655 | 656 | # now write the values 657 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc] 658 | for i in range(len(vals)): 659 | worksheet.write(row, startVals + i, vals[i]) 660 | 661 | if run_type == 'audio': audioRow = row 662 | if run_type == 'lipreading': lipreadingRow = row 663 | if run_type == 'combined': combinedRow = row 664 | 665 | row += 1 666 | 667 | workbook.close() 668 | 669 | self.logger.info("Excel file stored in %s", storePath) 670 | 671 | self.fh.close() 672 | self.logger.removeHandler(self.fh) 673 | 674 | 675 | if __name__ == "__main__": 676 | main() 677 | -------------------------------------------------------------------------------- /RNN_implementation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import logging # debug < info < warn < error < critical # from https://docs.python.org/3/howto/logging-cookbook.html 4 | import time 5 | import traceback 6 | 7 | import lasagne 8 | import lasagne.layers as L 9 | import theano 10 | import theano.tensor as T 11 | from tqdm import tqdm 12 | 13 | logger_RNNtools = logging.getLogger('audioSR.tools') 14 | logger_RNNtools.setLevel(logging.DEBUG) 15 | 16 | from tools.general_tools import * 17 | 18 | 19 | class NeuralNetwork: 20 | network = None 21 | training_fn = None 22 | best_param = None 23 | best_error = 100 24 | curr_epoch, best_epoch = 0, 0 25 | X = None 26 | Y = None 27 | 28 | def __init__(self, architecture, data=None, batch_size=1, max_seq_length=1000, num_features=26, 29 | n_hidden_list=(100,), num_output_units=61, 30 | bidirectional=False, addDenseLayers=False, seed=int(time.time()), debug=False, logger=logger_RNNtools, 31 | dataset="", test_dataset=""): 32 | self.num_output_units = num_output_units 33 | self.num_features = num_features 34 | self.batch_size = batch_size 35 | self.max_seq_length = max_seq_length # currently unused 36 | self.epochsNotImproved = 0 # keep track, to know when to stop training 37 | self.updates = {} 38 | # self.network_train_info = [[], [], [], [], []] # train cost, val cost, val acc, test cost, test acc 39 | self.network_train_info = { 40 | 'train_cost': [], 41 | 'val_cost': [], 'val_acc': [], 'val_topk_acc': [], 42 | 'test_cost': [], 'test_acc': [], 'test_topk_acc': [] 43 | } 44 | self.dataset = dataset 45 | self.test_dataset = test_dataset 46 | self.logger= logger 47 | 48 | if architecture == 'RNN': 49 | if data != None: 50 | X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test = data 51 | 52 | X = X_train[:batch_size] 53 | y = y_train[:batch_size] 54 | self.valid_frames = valid_frames_train[:batch_size] 55 | self.masks = generate_masks(X, valid_frames=self.valid_frames, batch_size=len(X)) 56 | 57 | self.X = pad_sequences_X(X) 58 | self.Y = pad_sequences_y(y) 59 | # self.valid_frames = pad_sequences_y(self.valid_frames) 60 | 61 | self.logger.debug('X.shape: %s', self.X.shape) 62 | self.logger.debug('X[0].shape: %s', self.X[0].shape) 63 | self.logger.debug('X[0][0][0].type: %s', type(self.X[0][0][0])) 64 | self.logger.debug('y.shape: %s', self.Y.shape) 65 | self.logger.debug('y[0].shape: %s', self.Y[0].shape) 66 | self.logger.debug('y[0][0].type: %s', type(self.Y[0][0])) 67 | self.logger.debug('masks.shape: %s', self.masks.shape) 68 | self.logger.debug('masks[0].shape: %s', self.masks[0].shape) 69 | self.logger.debug('masks[0][0].type: %s', type(self.masks[0][0])) 70 | 71 | self.logger.info("NUM FEATURES: %s", num_features) 72 | 73 | self.audio_inputs_var = T.tensor3('audio_inputs') 74 | self.audio_masks_var = T.matrix( 75 | 'audio_masks') # set MATRIX, not iMatrix!! Otherwise all mask calculations are done by CPU, and everything will be ~2x slowed down!! Also in general_tools.generate_masks() 76 | self.audio_valid_indices_var = T.imatrix('audio_valid_indices') 77 | self.audio_targets_var = T.imatrix('audio_targets') 78 | 79 | self.build_RNN(n_hidden_list=n_hidden_list, bidirectional=bidirectional, addDenseLayers=addDenseLayers, 80 | seed=seed, debug=debug) 81 | else: 82 | print("ERROR: Invalid argument: The valid architecture arguments are: 'RNN'") 83 | 84 | def build_RNN(self, n_hidden_list=(100,), bidirectional=False, addDenseLayers=False, 85 | seed=int(time.time()), debug=False): 86 | # some inspiration from http://colinraffel.com/talks/hammer2015recurrent.pdf 87 | 88 | # if debug: 89 | # self.self.logger.debug('\nInputs:'); 90 | # self.self.logger.debug(' X.shape: %s', self.X[0].shape) 91 | # self.self.logger.debug(' X[0].shape: %s %s %s \n%s', self.X[0][0].shape, type(self.X[0][0]), 92 | # type(self.X[0][0][0]), self.X[0][0][:5]) 93 | # 94 | # self.self.logger.debug('Targets: '); 95 | # self.self.logger.debug(' Y.shape: %s', self.Y.shape) 96 | # self.self.logger.debug(' Y[0].shape: %s %s %s \n%s', self.Y[0].shape, type(self.Y[0]), type(self.Y[0][0]), 97 | # self.Y[0][:5]) 98 | # self.self.logger.debug('Layers: ') 99 | 100 | # fix these at initialization because it allows for compiler opimizations 101 | num_output_units = self.num_output_units 102 | num_features = self.num_features 103 | batch_size = self.batch_size 104 | 105 | audio_inputs = self.audio_inputs_var 106 | audio_masks = self.audio_masks_var # set MATRIX, not iMatrix!! Otherwise all mask calculations are done by CPU, and everything will be ~2x slowed down!! Also in general_tools.generate_masks() 107 | valid_indices = self.audio_valid_indices_var 108 | 109 | net = {} 110 | # net['l1_in_valid'] = L.InputLayer(shape=(batch_size, None), input_var=valid_indices) 111 | 112 | # shape = (batch_size, batch_max_seq_length, num_features) 113 | net['l1_in'] = L.InputLayer(shape=(batch_size, None, num_features), input_var=audio_inputs) 114 | # We could do this and set all input_vars to None, but that is slower -> fix batch_size and num_features at initialization 115 | # batch_size, n_time_steps, n_features = net['l1_in'].input_var.shape 116 | 117 | # This input will be used to provide the network with masks. 118 | # Masks are matrices of shape (batch_size, n_time_steps); 119 | net['l1_mask'] = L.InputLayer(shape=(batch_size, None), input_var=audio_masks) 120 | 121 | if debug: 122 | get_l_in = L.get_output(net['l1_in']) 123 | l_in_val = get_l_in.eval({net['l1_in'].input_var: self.X}) 124 | # self.self.logger.debug(l_in_val) 125 | self.self.logger.debug(' l_in size: %s', l_in_val.shape); 126 | 127 | get_l_mask = L.get_output(net['l1_mask']) 128 | l_mask_val = get_l_mask.eval({net['l1_mask'].input_var: self.masks}) 129 | # self.self.logger.debug(l_in_val) 130 | self.self.logger.debug(' l_mask size: %s', l_mask_val.shape); 131 | 132 | n_batch, n_time_steps, n_features = net['l1_in'].input_var.shape 133 | self.self.logger.debug(" n_batch: %s | n_time_steps: %s | n_features: %s", n_batch, n_time_steps, 134 | n_features) 135 | 136 | ## LSTM parameters 137 | # All gates have initializers for the input-to-gate and hidden state-to-gate 138 | # weight matrices, the cell-to-gate weight vector, the bias vector, and the nonlinearity. 139 | # The convention is that gates use the standard sigmoid nonlinearity, 140 | # which is the default for the Gate class. 141 | gate_parameters = L.recurrent.Gate( 142 | W_in=lasagne.init.Orthogonal(), W_hid=lasagne.init.Orthogonal(), 143 | b=lasagne.init.Constant(0.)) 144 | cell_parameters = L.recurrent.Gate( 145 | W_in=lasagne.init.Orthogonal(), W_hid=lasagne.init.Orthogonal(), 146 | # Setting W_cell to None denotes that no cell connection will be used. 147 | W_cell=None, b=lasagne.init.Constant(0.), 148 | # By convention, the cell nonlinearity is tanh in an LSTM. 149 | nonlinearity=lasagne.nonlinearities.tanh) 150 | 151 | # generate layers of stacked LSTMs, possibly bidirectional 152 | net['l2_lstm'] = [] 153 | 154 | for i in range(len(n_hidden_list)): 155 | n_hidden = n_hidden_list[i] 156 | 157 | if i == 0: input = net['l1_in'] 158 | else: input = net['l2_lstm'][-1] 159 | 160 | nextForwardLSTMLayer = L.recurrent.LSTMLayer( 161 | input, n_hidden, 162 | # We need to specify a separate input for masks 163 | mask_input=net['l1_mask'], 164 | # Here, we supply the gate parameters for each gate 165 | ingate=gate_parameters, forgetgate=gate_parameters, 166 | cell=cell_parameters, outgate=gate_parameters, 167 | # We'll learn the initialization and use gradient clipping 168 | learn_init=True, grad_clipping=100.) 169 | net['l2_lstm'].append(nextForwardLSTMLayer) 170 | 171 | if bidirectional: 172 | if i == 0: input = net['l1_in'] 173 | else: input = net['l2_lstm'][-2] 174 | # Use backward LSTM 175 | # The "backwards" layer is the same as the first, 176 | # except that the backwards argument is set to True. 177 | nextBackwardLSTMLayer = L.recurrent.LSTMLayer( 178 | input, n_hidden, ingate=gate_parameters, 179 | mask_input=net['l1_mask'], forgetgate=gate_parameters, 180 | cell=cell_parameters, outgate=gate_parameters, 181 | learn_init=True, grad_clipping=100., backwards=True) 182 | net['l2_lstm'].append(nextBackwardLSTMLayer) 183 | 184 | # if debug: 185 | # # Backwards LSTM 186 | # get_l_lstm_back = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var], 187 | # L.get_output(net['l2_lstm'][-1])) 188 | # l_lstmBack_val = get_l_lstm_back(self.X, self.masks) 189 | # self.self.logger.debug(' l_lstm_back size: %s', l_lstmBack_val.shape) 190 | 191 | # We'll combine the forward and backward layer output by summing. 192 | # Merge layers take in lists of layers to merge as input. 193 | # The output of l_sum will be of shape (n_batch, max_n_time_steps, n_features) 194 | net['l2_lstm'].append(L.ElemwiseSumLayer([net['l2_lstm'][-2], net['l2_lstm'][-1]])) 195 | 196 | # we need to convert (batch_size, seq_length, num_features) to (batch_size * seq_length, num_features) because Dense networks can't deal with 2 unknown sizes 197 | net['l3_reshape'] = L.ReshapeLayer(net['l2_lstm'][-1], (-1, n_hidden_list[-1])) 198 | 199 | if debug: 200 | get_l_reshape = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var], 201 | L.get_output(net['l3_reshape'])) 202 | l_reshape_val = get_l_reshape(self.X, self.masks) 203 | self.logger.debug(' l_reshape size: %s', l_reshape_val.shape) 204 | 205 | # Forwards LSTM 206 | get_l_lstm = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var], 207 | L.get_output(net['l2_lstm'][-1])) 208 | l_lstm_val = get_l_lstm(self.X, self.masks) 209 | self.self.logger.debug(' l2_lstm size: %s', l_lstm_val.shape); 210 | 211 | if addDenseLayers: 212 | net['l4_dense'] = L.DenseLayer(net['l3_reshape'], nonlinearity=lasagne.nonlinearities.rectify, 213 | num_units=256) 214 | dropoutLayer = L.DropoutLayer(net['l4_dense'], p=0.3) 215 | net['l5_dense'] = L.DenseLayer(dropoutLayer, nonlinearity=lasagne.nonlinearities.rectify, num_units=64) 216 | # Now we can apply feed-forward layers as usual for classification 217 | net['l6_dense'] = L.DenseLayer(net['l5_dense'], num_units=num_output_units, 218 | nonlinearity=lasagne.nonlinearities.softmax) 219 | else: 220 | # Now we can apply feed-forward layers as usual for classification 221 | net['l6_dense'] = L.DenseLayer(net['l3_reshape'], num_units=num_output_units, 222 | nonlinearity=lasagne.nonlinearities.softmax) 223 | 224 | # # Now, the shape will be (n_batch * n_timesteps, num_output_units). We can then reshape to 225 | # # n_batch to get num_output_units values for each timestep from each sequence 226 | net['l7_out_flattened'] = L.ReshapeLayer(net['l6_dense'], (-1, num_output_units)) 227 | net['l7_out'] = L.ReshapeLayer(net['l6_dense'], (batch_size, -1, num_output_units)) 228 | 229 | # we only want the predictions at for the valid frames, not for every frame. We can do this with theano mask functions for audio-only nets, 230 | # but that's not possible if there are subsequent networks as lasagne needs lasagne layers, not theano functions. 231 | # -> use lasagne slice layer to extract valid predictions 232 | net['l7_out_valid_basic'] = L.SliceLayer(net['l7_out'], indices=valid_indices, axis=1) 233 | net['l7_out_valid'] = L.ReshapeLayer(net['l7_out_valid_basic'], (batch_size, -1, num_output_units)) 234 | net['l7_out_valid_flattened'] = L.ReshapeLayer(net['l7_out_valid_basic'], (-1, num_output_units)) 235 | 236 | if debug: 237 | get_l_out = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var], L.get_output(net['l7_out'])) 238 | l_out = get_l_out(self.X, self.masks) 239 | 240 | # this only works for batch_size == 1 241 | get_l_out_valid = theano.function([audio_inputs, audio_masks, valid_indices], 242 | L.get_output(net['l7_out_valid'])) 243 | try: 244 | l_out_valid = get_l_out_valid(self.X, self.masks, self.valid_frames) 245 | self.self.logger.debug('\n\n\n l_out: %s | l_out_valid: %s', l_out.shape, l_out_valid.shape); 246 | except: 247 | self.self.logger.warning("batchsize not 1, get_valid not working") 248 | 249 | if debug: self.print_network_structure(net) 250 | self.network_lout = net['l7_out_flattened'] 251 | self.network_lout_batch = net['l7_out'] 252 | self.network_lout_valid = net['l7_out_valid'] 253 | self.network_lout_valid_flattened = net['l7_out_valid_flattened'] 254 | 255 | self.network = net 256 | 257 | def print_network_structure(self, net=None): 258 | if net == None: net = self.network 259 | 260 | self.logger.debug("\n PRINTING Network structure: \n %s ", sorted(net.keys())) 261 | for key in sorted(net.keys()): 262 | if 'lstm' in key: 263 | for layer in net['l2_lstm']: 264 | try: 265 | self.logger.debug('Layer: %12s | in: %s | out: %s', key, layer.input_shape, layer.output_shape) 266 | except: 267 | self.logger.debug('Layer: %12s | out: %s', key, layer.output_shape) 268 | else: 269 | try: 270 | self.logger.debug('Layer: %12s | in: %s | out: %s', key, net[key].input_shape, net[key].output_shape) 271 | except: 272 | self.logger.debug('Layer: %12s | out: %s', key, net[key].output_shape) 273 | return 0 274 | 275 | def use_best_param(self): 276 | L.set_all_param_values(self.network, self.best_param) 277 | self.curr_epoch = self.best_epoch 278 | # Remove the network_train_info entries newer than self.best_epoch 279 | del self.network_train_info[0][self.best_epoch:] 280 | del self.network_train_info[1][self.best_epoch:] 281 | del self.network_train_info[2][self.best_epoch:] 282 | 283 | def load_model(self, model_name, round_params=False): 284 | model_path = model_name + '.npz' 285 | if self.network is not None: 286 | try: 287 | #self.logger.info("Loading stored model %s...", model_path) 288 | 289 | # restore network weights 290 | with np.load(model_path) as f: 291 | param_values = [f['arr_%d' % i] for i in range(len(f.files))][0] 292 | if round_params: 293 | self.logger.info("ROUND PARAMS") 294 | param_values = self.roundParams(param_values) 295 | L.set_all_param_values(self.network_lout, param_values) 296 | 297 | self.logger.info("Loading parameters successful.") 298 | return 0 299 | 300 | except IOError as e: 301 | print(os.strerror(e.errno)) 302 | #self.logger.warning('Model: {} not found. No weights loaded'.format(model_path)) 303 | return -1 304 | else: 305 | raise IOError('You must build the network before loading the weights.') 306 | return -1 307 | 308 | def roundParams(self, param_values): 309 | # round by converting to float16 and back to flort32 310 | # up to 12 bit should be possible without accuracy loss, but I don't know how to do this 311 | for i in range(len(param_values)): 312 | param_values[i] = param_values[i].astype(np.float16) 313 | param_values[i] = param_values[i].astype(np.float32) 314 | 315 | return param_values 316 | 317 | def save_model(self, model_name): 318 | if not os.path.exists(os.path.dirname(model_name)): 319 | os.makedirs(os.path.dirname(model_name)) 320 | np.savez(model_name + '.npz', self.best_param) 321 | 322 | # also restore the updates variables to continue training. LR should also be saved and restored... 323 | # updates_vals = [p.get_value() for p in self.best_updates.keys()] 324 | # np.savez(model_name + '_updates.npz', updates_vals) 325 | 326 | # TODO use combinedSR.py for a working version 327 | def create_confusion(self, X, y, debug=False): 328 | argmax_fn = self.training_fn[1] 329 | 330 | y_pred = [] 331 | for X_obs in X: 332 | for x in argmax_fn(X_obs): 333 | for j in x: 334 | y_pred.append(j) 335 | 336 | y_actu = [] 337 | for Y in y: 338 | for y in Y: 339 | y_actu.append(y) 340 | 341 | conf_img = np.zeros([61, 61]) 342 | assert (len(y_pred) == len(y_actu)) 343 | 344 | for i in range(len(y_pred)): 345 | row_idx = y_actu[i] 346 | col_idx = y_pred[i] 347 | conf_img[row_idx, col_idx] += 1 348 | 349 | return conf_img, y_pred, y_actu 350 | 351 | def build_functions(self, train=False, debug=False): 352 | 353 | # LSTM in lasagne: see https://github.com/craffel/Lasagne-tutorial/blob/master/examples/recurrent.py 354 | # and also http://colinraffel.com/talks/hammer2015recurrent.pdf 355 | target_var = self.audio_targets_var # T.imatrix('audio_targets') 356 | 357 | if debug: self.print_network_structure() 358 | 359 | network_output = L.get_output(self.network_lout_batch) 360 | network_output_flattened = L.get_output(self.network_lout) # (batch_size * batch_max_seq_length, nb_phonemes) 361 | 362 | # compare targets with highest output probability. Take maximum of all probs (3rd axis (index 2) of output: 363 | # 1=batch_size (input files), 2 = time_seq (frames), 3 = n_features (phonemes) 364 | # network_output.shape = (len(X), 39) -> (nb_inputs, nb_classes) 365 | predictions = (T.argmax(network_output, axis=2)) 366 | 367 | if debug: 368 | self.predictions_fn = theano.function([self.audio_inputs_var, self.audio_masks_var], predictions, 369 | name='predictions_fn') 370 | 371 | predicted = self.predictions_fn(self.X, self.masks) 372 | self.logger.debug('predictions_fn(X).shape: %s', predicted.shape) 373 | # self.logger.debug('predictions_fn(X)[0], value: %s', predicted[0]) 374 | 375 | self.output_fn = theano.function([self.audio_inputs_var, self.audio_masks_var], network_output, 376 | name='output_fn') 377 | n_out = self.output_fn(self.X, self.masks) 378 | self.logger.debug('network_output.shape: \t%s', n_out.shape); 379 | # self.logger.debug('network_output[0]: \n%s', n_out[0]); 380 | 381 | # # Function to determine the number of correct classifications 382 | # which video, and which frames in the video 383 | valid_indices_example, valid_indices_seqNr = self.audio_masks_var.nonzero() 384 | valid_indices_fn = theano.function([self.audio_masks_var], [valid_indices_example, valid_indices_seqNr], 385 | name='valid_indices_fn') 386 | 387 | # this gets a FLATTENED array of all the valid predictions of all examples of this batch (so not one row per example) 388 | # if you want to get the valid predictions per example, you need to use the valid_frames list (it tells you the number of valid frames per wav, so where to split this valid_predictions array) 389 | # of course this is trivial for batch_size_audio = 1, as all valid_predictions will belong to the one input wav 390 | valid_predictions = predictions[valid_indices_example, valid_indices_seqNr] 391 | valid_targets = target_var[valid_indices_example, valid_indices_seqNr] 392 | self.valid_targets_fn = theano.function([self.audio_masks_var, target_var], valid_targets, 393 | name='valid_targets_fn') 394 | self.valid_predictions_fn = theano.function([self.audio_inputs_var, self.audio_masks_var], valid_predictions, 395 | name='valid_predictions_fn') 396 | 397 | # get valid network output 398 | valid_network_output = network_output[valid_indices_example, valid_indices_seqNr] 399 | if debug: 400 | self.valid_network_output_fn = theano.function([self.audio_inputs_var, self.audio_masks_var], 401 | valid_network_output) 402 | 403 | # Functions for computing cost and training 404 | top1_acc = T.mean(lasagne.objectives.categorical_accuracy(valid_network_output, valid_targets, top_k=1)) 405 | self.top1_acc_fn = theano.function( 406 | [self.audio_inputs_var, self.audio_masks_var, self.audio_targets_var], top1_acc) 407 | top3_acc = T.mean(lasagne.objectives.categorical_accuracy(valid_network_output, valid_targets, top_k=3)) 408 | self.top3_acc_fn = theano.function( 409 | [self.audio_inputs_var, self.audio_masks_var, self.audio_targets_var], top3_acc) 410 | 411 | 412 | if debug: 413 | try: 414 | # only works with batch_size == 1 415 | # valid_preds2 = self.valid_predictions2_fn(self.X, self.masks, self.valid_frames) 416 | # self.logger.debug("all valid predictions of this batch: ") 417 | # self.logger.debug('valid_preds2.shape: %s', valid_preds2.shape) 418 | # self.logger.debug('valid_preds2, value: \n%s', valid_preds2) 419 | 420 | # valid_out = self.valid_network_fn(self.X, self.masks, self.valid_frames) 421 | # self.logger.debug('valid_out.shape: %s', valid_out.shape) 422 | # # self.logger.debug('valid_out, value: \n%s', valid_out) 423 | 424 | valid_example, valid_seqNr = valid_indices_fn(self.masks) 425 | self.logger.debug('valid_inds(masks).shape: %s', valid_example.shape) 426 | 427 | valid_output = self.valid_network_output_fn(self.X, self.masks) 428 | self.logger.debug("all valid outputs of this batch: ") 429 | self.logger.debug('valid_output.shape: %s', valid_output.shape) 430 | 431 | valid_preds = self.valid_predictions_fn(self.X, self.masks) 432 | self.logger.debug("all valid predictions of this batch: ") 433 | self.logger.debug('valid_preds.shape: %s', valid_preds.shape) 434 | self.logger.debug('valid_preds, value: \n%s', valid_preds) 435 | 436 | valid_targs = self.valid_targets_fn(self.masks, self.Y) 437 | self.logger.debug('valid_targets.shape: %s', valid_targs.shape) 438 | self.logger.debug('valid_targets, value: \n%s', valid_targs) 439 | 440 | top1 = self.top1_acc_fn(self.X, self.masks, self.Y) 441 | self.logger.debug("top 1 accuracy: %s", top1 * 100.0) 442 | 443 | top3 = self.top3_acc_fn(self.X, self.masks, self.Y) 444 | self.logger.debug("top 3 accuracy: %s", top3 * 100.0) 445 | 446 | except Exception as error: 447 | print('caught this error: ' + traceback.format_exc()); 448 | import pdb; 449 | pdb.set_trace() 450 | # pdb.set_trace() 451 | 452 | ## from https://groups.google.com/forum/#!topic/lasagne-users/os0j3f_Th5Q 453 | # Pad your vector of labels and then mask the cost: 454 | # It's important to pad the label vectors with something valid such as zeros, 455 | # since they will still have to give valid costs that can be multiplied by the mask. 456 | # The shape of predictions, targets and mask should match: 457 | # (predictions as (batch_size*max_seq_len, n_features), the other two as (batch_size*max_seq_len,)) -> we need to get the flattened output of the network for this 458 | 459 | 460 | # this works, using theano masks 461 | cost_pointwise = lasagne.objectives.categorical_crossentropy(network_output_flattened, target_var.flatten()) 462 | cost = lasagne.objectives.aggregate(cost_pointwise, self.audio_masks_var.flatten()) 463 | weight_decay = 1e-5 464 | weightsl2 = lasagne.regularization.regularize_network_params(self.network_lout, lasagne.regularization.l2) 465 | cost += weight_decay * weightsl2 466 | 467 | self.validate_fn = theano.function([self.audio_inputs_var, self.audio_masks_var, 468 | self.audio_targets_var], 469 | [cost, top1_acc, top3_acc], name='validate_fn') 470 | self.cost_pointwise_fn = theano.function([self.audio_inputs_var, self.audio_masks_var, target_var], 471 | cost_pointwise, name='cost_pointwise_fn') 472 | 473 | if debug: 474 | self.logger.debug('cost pointwise: %s', self.cost_pointwise_fn(self.X, self.masks, self.Y)) 475 | 476 | try: 477 | evaluate_cost = self.validate_fn(self.X, self.masks, self.Y) 478 | except: 479 | print('caught this error: ' + traceback.format_exc()); 480 | pdb.set_trace() 481 | self.logger.debug('cost: {:.3f}'.format(float(evaluate_cost[0]))) 482 | self.logger.debug('accuracy: {:.3f}'.format(float(evaluate_cost[1] * 100.0))) 483 | self.logger.debug('top 3 accuracy: {:.3f}'.format(float(evaluate_cost[2] * 100.0))) 484 | 485 | # pdb.set_trace() 486 | 487 | if train: 488 | LR = T.scalar('LR', dtype=theano.config.floatX) 489 | # Retrieve all trainable parameters from the network 490 | all_params = L.get_all_params(self.network_lout, trainable=True) 491 | self.updates = lasagne.updates.adam(loss_or_grads=cost, params=all_params, learning_rate=LR) 492 | self.train_fn = theano.function([self.audio_inputs_var, self.audio_masks_var, 493 | target_var, LR], 494 | [cost, top1_acc, top3_acc], updates=self.updates, name='train_fn') 495 | 496 | def shuffle(X, y, valid_frames): 497 | 498 | chunk_size = len(X) 499 | shuffled_range = range(chunk_size) 500 | 501 | X_buffer = np.copy(X[0:chunk_size]) 502 | y_buffer = np.copy(y[0:chunk_size]) 503 | valid_frames_buffer = np.copy(valid_frames[0:chunk_size]) 504 | 505 | np.random.shuffle(shuffled_range) 506 | 507 | for i in range(chunk_size): 508 | X_buffer[i] = X[shuffled_range[i]] 509 | y_buffer[i] = y[shuffled_range[i]] 510 | valid_frames_buffer[i] = valid_frames[shuffled_range[i]] 511 | 512 | X[0: chunk_size] = X_buffer 513 | y[0: chunk_size] = y_buffer 514 | valid_frames[0: chunk_size] = valid_frames_buffer 515 | 516 | return X, y, valid_frames 517 | 518 | # This function trains the model a full epoch (on the whole dataset) 519 | def run_epoch(self, X, y, valid_frames, get_predictions=False, LR=None, batch_size=-1): 520 | if batch_size == -1: batch_size = self.batch_size 521 | 522 | cost = 0; 523 | accuracy = 0; 524 | top3_accuracy = 0 525 | nb_batches = len(X) / batch_size 526 | 527 | predictions = [] # only used if get_predictions = True 528 | for i in tqdm(range(nb_batches), total=nb_batches): 529 | batch_X = X[i * batch_size:(i + 1) * batch_size] 530 | batch_y = y[i * batch_size:(i + 1) * batch_size] 531 | batch_valid_frames = valid_frames[i * batch_size:(i + 1) * batch_size] 532 | batch_masks = generate_masks(batch_X, valid_frames=batch_valid_frames, batch_size=batch_size) 533 | # now pad inputs and target to maxLen 534 | batch_X = pad_sequences_X(batch_X) 535 | batch_y = pad_sequences_y(batch_y) 536 | # batch_valid_frames = pad_sequences_y(batch_valid_frames) 537 | # print("batch_X.shape: ", batch_X.shape) 538 | # print("batch_y.shape: ", batch_y.shape) 539 | # import pdb;pdb.set_trace() 540 | if LR != None: 541 | cst, acc, top3_acc = self.train_fn(batch_X, batch_masks, batch_y, LR) # training 542 | else: 543 | cst, acc, top3_acc = self.validate_fn(batch_X, batch_masks, batch_y) # validation 544 | cost += cst; 545 | accuracy += acc; 546 | top3_accuracy += top3_acc 547 | 548 | if get_predictions: 549 | prediction = self.predictions_fn(batch_X, batch_masks) 550 | # prediction = np.reshape(prediction, (nb_inputs, -1)) #only needed if predictions_fn is the flattened and not the batched version (see RNN_implementation.py) 551 | prediction = list(prediction) 552 | predictions = predictions + prediction 553 | # # some tests of valid predictions functions (this works :) ) 554 | valid_predictions = self.valid_predictions_fn(batch_X, batch_masks) 555 | self.self.logger.debug("valid predictions: ", valid_predictions.shape) 556 | 557 | # # get valid predictions for video 0 558 | # self.get_validPredictions_video(valid_predictions, valid_frames, 0) 559 | # # and the targets for video 0 560 | # targets[0][valid_frames[0]] 561 | # 562 | cost /= nb_batches; 563 | accuracy /= nb_batches; 564 | top3_accuracy /= nb_batches 565 | if get_predictions: 566 | return cost, accuracy * 100.0, top3_accuracy * 100.0, predictions 567 | return cost, accuracy * 100.0, top3_accuracy * 100.0 568 | 569 | def train(self, dataset, save_name='Best_model', num_epochs=100, batch_size=1, LR_start=1e-4, LR_decay=1, 570 | compute_confusion=False, justTest=False, debug=False, roundParams=False, 571 | withNoise=False, noiseType='white', ratio_dB=0): 572 | 573 | X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test = dataset 574 | 575 | confusion_matrices = [] 576 | 577 | # try to load performance metrics of stored model 578 | best_val_acc, test_acc, old_train_info = self.loadPreviousResults( 579 | save_name) # stores old_train_info into self.network_train_info 580 | 581 | self.logger.info("Initial best Val acc: %s", best_val_acc) 582 | self.logger.info("Initial best test acc: %s\n", test_acc) 583 | self.best_val_acc = best_val_acc 584 | 585 | self.logger.info("Pass over Test Set") 586 | test_cost, test_acc, test_topk_acc = self.run_epoch(X=X_test, y=y_test, 587 | valid_frames=valid_frames_test) 588 | self.logger.info("Test cost:\t\t{:.6f} ".format(test_cost)) 589 | self.logger.info("Test accuracy:\t\t{:.6f} %".format(test_acc)) 590 | self.logger.info("Test Top 3 accuracy:\t{:.6f} %".format(test_topk_acc)) 591 | 592 | self.network_train_info['nb_params'] = lasagne.layers.count_params(self.network_lout_batch) 593 | if justTest: 594 | if os.path.exists(save_name + ".npz"): 595 | self.saveFinalResults(noiseType, ratio_dB, roundParams, save_name, test_acc, test_cost, 596 | test_topk_acc, withNoise) 597 | return 0 598 | # else do nothing and train anyway 599 | else: 600 | self.network_train_info['test_cost'].append(test_cost) 601 | self.network_train_info['test_acc'].append(test_acc) 602 | self.network_train_info['test_topk_acc'].append(test_topk_acc) 603 | 604 | self.logger.info("\n* Starting training...") 605 | LR = LR_start 606 | self.best_cost = 100 607 | for epoch in range(num_epochs): 608 | self.curr_epoch += 1 609 | epoch_time = time.time() 610 | self.logger.info("\n\nCURRENT EPOCH: %s", self.curr_epoch) 611 | 612 | self.logger.info("Pass over Training Set") 613 | train_cost, train_acc, train_topk_acc = self.run_epoch(X=X_train, y=y_train, 614 | valid_frames=valid_frames_train, LR=LR) 615 | 616 | self.logger.info("Pass over Validation Set") 617 | val_cost, val_acc, val_topk_acc = self.run_epoch(X=X_val, y=y_val, valid_frames=valid_frames_val) 618 | 619 | # Print epoch summary 620 | self.logger.info("Epoch {} of {} took {:.3f}s.".format( 621 | epoch + 1, num_epochs, time.time() - epoch_time)) 622 | self.logger.info("Learning Rate:\t\t{:.6f} %".format(LR)) 623 | self.logger.info("Training cost:\t{:.6f}".format(train_cost)) 624 | self.logger.info("Validation Top 3 accuracy:\t{:.6f} %".format(val_topk_acc)) 625 | 626 | self.logger.info("Validation cost:\t{:.6f} ".format(val_cost)) 627 | self.logger.info("Validation accuracy:\t\t{:.6f} %".format(val_acc)) 628 | self.logger.info("Validation Top 3 accuracy:\t{:.6f} %".format(val_topk_acc)) 629 | 630 | # better model, so save parameters 631 | if val_acc > self.best_val_acc: 632 | # only reset if significant improvement 633 | if val_acc - self.best_val_acc > 0.2: 634 | self.epochsNotImproved = 0 635 | # store new parameters 636 | self.best_cost = val_cost 637 | self.best_val_acc = val_acc 638 | self.best_epoch = self.curr_epoch 639 | self.best_param = L.get_all_param_values(self.network_lout) 640 | self.best_updates = [p.get_value() for p in self.updates.keys()] 641 | self.logger.info("New best model found!") 642 | if save_name is not None: 643 | self.logger.info("Model saved as " + save_name) 644 | self.save_model(save_name) 645 | 646 | self.logger.info("Pass over Test Set") 647 | test_cost, test_acc, test_topk_acc = self.run_epoch(X=X_test, y=y_test, 648 | valid_frames=valid_frames_test) 649 | self.logger.info("Test cost:\t\t{:.6f} ".format(test_cost)) 650 | self.logger.info("Test accuracy:\t\t{:.6f} %".format(test_acc)) 651 | self.logger.info("Test Top 3 accuracy:\t{:.6f} %".format(test_topk_acc)) 652 | 653 | # save the training info 654 | self.network_train_info['train_cost'].append(train_cost) 655 | self.network_train_info['val_cost'].append(val_cost) 656 | self.network_train_info['val_acc'].append(val_acc) 657 | self.network_train_info['val_topk_acc'].append(val_topk_acc) 658 | self.network_train_info['test_cost'].append(test_cost) 659 | self.network_train_info['test_acc'].append(test_acc) 660 | self.network_train_info['test_topk_acc'].append(test_topk_acc) 661 | 662 | saveToPkl(save_name + '_trainInfo.pkl', self.network_train_info) 663 | self.logger.info("Train info written to:\t %s", save_name + '_trainInfo.pkl') 664 | 665 | if compute_confusion: 666 | print("does not work here") 667 | 668 | # update LR, see if we can stop training 669 | LR = self.updateLR(LR, LR_decay) 670 | 671 | if self.epochsNotImproved >= 3: 672 | logging.warning("\n\nNo more improvements, stopping training...") 673 | self.logger.info("Pass over Test Set") 674 | test_cost, test_acc, test_topk_acc = self.run_epoch(X=X_test, y=y_test, 675 | valid_frames=valid_frames_test) 676 | self.logger.info("Test cost:\t\t{:.6f} ".format(test_cost)) 677 | self.logger.info("Test accuracy:\t\t{:.6f} %".format(test_acc)) 678 | self.logger.info("Test Top 3 accuracy:\t{:.6f} %".format(test_topk_acc)) 679 | 680 | self.network_train_info['test_cost'][-1] = test_cost 681 | self.network_train_info['test_acc'][-1] = test_acc 682 | self.network_train_info['test_topk_acc'][-1] = test_topk_acc 683 | 684 | self.saveFinalResults(noiseType, ratio_dB, roundParams, save_name, test_acc, test_cost, 685 | test_topk_acc, withNoise) 686 | 687 | return test_cost, test_acc, test_topk_acc 688 | return test_cost, test_acc, test_topk_acc 689 | 690 | 691 | def saveFinalResults(self, noiseType, ratio_dB, roundParams, save_name, test_acc, test_cost, test_topk_acc, 692 | withNoise): 693 | if self.test_dataset != self.dataset: 694 | testType = "_" + self.test_dataset 695 | else: 696 | testType = "" 697 | if roundParams: 698 | testType = "_roundParams" + testType 699 | 700 | if withNoise: 701 | self.network_train_info[ 702 | 'final_test_cost_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType] = test_cost 703 | self.network_train_info['final_test_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType] = test_acc 704 | self.network_train_info[ 705 | 'final_test_top3_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType] = test_topk_acc 706 | else: 707 | self.network_train_info['final_test_cost' + testType] = test_cost 708 | self.network_train_info['final_test_acc' + testType] = test_acc 709 | self.network_train_info['final_test_top3_acc' + testType] = test_topk_acc 710 | 711 | saveToPkl(save_name + '_trainInfo.pkl', self.network_train_info) 712 | self.logger.info("Train info written to:\t %s", save_name + '_trainInfo.pkl') 713 | 714 | def get_validPredictions_video(self, valid_predictions, valid_frames, videoIndexInBatch): 715 | # get indices of the valid frames for each video, using the valid_frames 716 | nbValidsPerVideo = [len(el) for el in valid_frames] 717 | 718 | # each el is the sum of the els before. -> for example video 3, you need valid_predictions from indices[2] (inclusive) till indices[3] (not inclusive) 719 | indices = [0] + [np.sum(nbValidsPerVideo[:i + 1]) for i in range(len(nbValidsPerVideo))] 720 | 721 | # make a 2D list. Each el of the list is a list with the valid frames per video. 722 | videoPreds = [range(indices[videoIndex], indices[videoIndex + 1]) for videoIndex in range( 723 | len(valid_frames))] 724 | # assert len(videoPreds) == len(inputs) == len(valid_frames) 725 | 726 | # now you can get the frames for a specific video: 727 | return valid_predictions[videoPreds[videoIndexInBatch]] 728 | 729 | def loadPreviousResults(self, save_name): 730 | # try to load performance metrics of stored model 731 | best_val_acc = 0 732 | test_topk_acc = 0 733 | test_cost = 0 734 | test_acc = 0 735 | old_train_info = {} 736 | try: 737 | if os.path.exists(save_name + "_trainInfo.pkl"): 738 | old_train_info = unpickle(save_name + '_trainInfo.pkl') 739 | best_val_acc = max(old_train_info['val_acc']) 740 | test_cost = min(old_train_info['test_cost']) 741 | test_acc = max(old_train_info['test_acc']) 742 | self.network_train_info = old_train_info 743 | try: 744 | test_topk_acc = max(old_train_info['test_topk_acc']) 745 | except: 746 | pass 747 | except: 748 | pass 749 | return best_val_acc, test_acc, old_train_info 750 | 751 | def updateLR(self, LR, LR_decay): 752 | this_acc = self.network_train_info['val_acc'][-1] 753 | this_cost = self.network_train_info['val_cost'][-1] 754 | try: 755 | last_acc = self.network_train_info['val_acc'][-2] 756 | last_cost = self.network_train_info['val_cost'][-2] 757 | except: 758 | last_acc = -10 759 | last_cost = 10 * this_cost # first time it will fail because there is only 1 result stored 760 | 761 | # only reduce LR if not much improvment anymore 762 | if this_cost / float(last_cost) >= 0.98 or this_acc - last_acc < 0.2: 763 | self.logger.info(" Error not much reduced: %s vs %s. Reducing LR: %s", this_cost, last_cost, LR * LR_decay) 764 | self.epochsNotImproved += 1 765 | return LR * LR_decay 766 | else: 767 | self.epochsNotImproved = max(self.epochsNotImproved - 1, 0) # reduce by 1, minimum 0 768 | return LR 769 | -------------------------------------------------------------------------------- /background/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/background/__init__.py -------------------------------------------------------------------------------- /background/htkbook-3.5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/background/htkbook-3.5.pdf -------------------------------------------------------------------------------- /background/plot_data.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from general_tools import * 4 | 5 | store_dir = os.path.expanduser('~/TCDTIMIT/audioSR/TIMIT/binary39/TIMIT') 6 | dataset_path = os.path.join(store_dir, 'TIMIT_39_ch.pkl') 7 | 8 | X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test = load_dataset( 9 | dataset_path) 10 | 11 | plt.figure(1) 12 | plt.title('Preprocessed data visualization') 13 | for i in range(1, 5): 14 | plt.subplot(2, 2, i) 15 | plt.axis('off') 16 | plt.imshow(X_train[i].T) 17 | plt.imshow(np.log(X_train[i].T)) 18 | print(X_train[i].shape) 19 | 20 | plt.tight_layout() 21 | plt.show() 22 | -------------------------------------------------------------------------------- /background/timit_phones.txt: -------------------------------------------------------------------------------- 1 | iy 2 | ih 3 | eh 4 | ae 5 | ah 6 | uw 7 | uh 8 | aa 9 | ey 10 | ay 11 | oy 12 | aw 13 | ow 14 | l 15 | r 16 | y 17 | w 18 | er 19 | m 20 | n 21 | ng 22 | ch 23 | jh 24 | dh 25 | b 26 | d 27 | dx 28 | g 29 | p 30 | t 31 | k 32 | z 33 | v 34 | f 35 | th 36 | s 37 | sh 38 | hh 39 | sil -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DeepLearning 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - libgpuarray=0.7.1=vc9_0 7 | - pygpu=0.7.1=np113py27_0 8 | - _license=1.1=py27_1 9 | - alabaster=0.7.10=py27_0 10 | - anaconda-client=1.6.3=py27_0 11 | - anaconda=custom=py27_0 12 | - anaconda-navigator=1.6.2=py27_0 13 | - anaconda-project=0.6.0=py27_0 14 | - asn1crypto=0.22.0=py27_0 15 | - astroid=1.4.9=py27_0 16 | - astropy=2.0.1=np113py27_0 17 | - babel=2.4.0=py27_0 18 | - backports=1.0=py27_0 19 | - backports_abc=0.5=py27_0 20 | - beautifulsoup4=4.6.0=py27_0 21 | - bitarray=0.8.1=py27_1 22 | - blaze=0.10.1=py27_0 23 | - bleach=1.5.0=py27_0 24 | - bokeh=0.12.5=py27_1 25 | - boto=2.46.1=py27_0 26 | - bottleneck=1.2.1=np113py27_0 27 | - bzip2=1.0.6=vc9_3 28 | - cdecimal=2.3=py27_2 29 | - certifi=2016.2.28=py27_0 30 | - cffi=1.10.0=py27_0 31 | - chardet=3.0.3=py27_0 32 | - click=6.7=py27_0 33 | - cloudpickle=0.2.2=py27_0 34 | - clyent=1.2.2=py27_0 35 | - colorama=0.3.9=py27_0 36 | - comtypes=1.1.2=py27_0 37 | - configparser=3.5.0=py27_0 38 | - console_shortcut=0.1.1=py27_1 39 | - contextlib2=0.5.5=py27_0 40 | - cryptography=1.8.1=py27_0 41 | - cudatoolkit=8.0=1 42 | - curl=7.52.1=vc9_0 43 | - cycler=0.10.0=py27_0 44 | - cython=0.25.2=py27_0 45 | - cytoolz=0.8.2=py27_0 46 | - dask=0.14.3=py27_1 47 | - datashape=0.5.4=py27_0 48 | - decorator=4.0.11=py27_0 49 | - distributed=1.16.3=py27_0 50 | - docutils=0.13.1=py27_0 51 | - entrypoints=0.2.2=py27_1 52 | - enum34=1.1.6=py27_0 53 | - et_xmlfile=1.0.1=py27_0 54 | - fastcache=1.0.2=py27_1 55 | - flask=0.12.2=py27_0 56 | - flask-cors=3.0.2=py27_0 57 | - freetype=2.5.5=vc9_2 58 | - funcsigs=1.0.2=py27_0 59 | - functools32=3.2.3.2=py27_0 60 | - futures=3.1.1=py27_0 61 | - get_terminal_size=1.0.0=py27_0 62 | - gevent=1.2.1=py27_0 63 | - greenlet=0.4.12=py27_0 64 | - grin=1.2.1=py27_3 65 | - h5py=2.7.0=np113py27_0 66 | - hdf5=1.8.15.1=vc9_4 67 | - heapdict=1.0.0=py27_1 68 | - html5lib=0.999=py27_0 69 | - icu=57.1=vc9_0 70 | - idna=2.5=py27_0 71 | - imagesize=0.7.1=py27_0 72 | - ipaddress=1.0.18=py27_0 73 | - ipykernel=4.6.1=py27_0 74 | - ipython=5.3.0=py27_0 75 | - ipython_genutils=0.2.0=py27_0 76 | - ipywidgets=6.0.0=py27_0 77 | - isort=4.2.5=py27_0 78 | - itsdangerous=0.24=py27_0 79 | - jdcal=1.3=py27_0 80 | - jedi=0.10.2=py27_2 81 | - jinja2=2.9.6=py27_0 82 | - jpeg=9b=vc9_0 83 | - jsonschema=2.6.0=py27_0 84 | - jupyter=1.0.0=py27_3 85 | - jupyter_client=5.0.1=py27_0 86 | - jupyter_console=5.1.0=py27_0 87 | - jupyter_core=4.3.0=py27_0 88 | - lasagne=0.1=py27_0 89 | - lazy-object-proxy=1.2.2=py27_0 90 | - libpng=1.6.27=vc9_0 91 | - libpython=2.0=py27_0 92 | - libtiff=4.0.6=vc9_3 93 | - llvmlite=0.20.0=py27_0 94 | - locket=0.2.0=py27_1 95 | - lxml=3.7.3=py27_0 96 | - m2w64-binutils=2.25.1=5 97 | - m2w64-bzip2=1.0.6=6 98 | - m2w64-crt-git=5.0.0.4636.2595836=2 99 | - m2w64-gcc=5.3.0=6 100 | - m2w64-gcc-ada=5.3.0=6 101 | - m2w64-gcc-fortran=5.3.0=6 102 | - m2w64-gcc-libgfortran=5.3.0=6 103 | - m2w64-gcc-libs=5.3.0=7 104 | - m2w64-gcc-libs-core=5.3.0=7 105 | - m2w64-gcc-objc=5.3.0=6 106 | - m2w64-gmp=6.1.0=2 107 | - m2w64-headers-git=5.0.0.4636.c0ad18a=2 108 | - m2w64-isl=0.16.1=2 109 | - m2w64-libiconv=1.14=6 110 | - m2w64-libmangle-git=5.0.0.4509.2e5a9a2=2 111 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 112 | - m2w64-make=4.1.2351.a80a8b8=2 113 | - m2w64-mpc=1.0.3=3 114 | - m2w64-mpfr=3.1.4=4 115 | - m2w64-pkg-config=0.29.1=2 116 | - m2w64-toolchain=5.3.0=7 117 | - m2w64-tools-git=5.0.0.4592.90b8472=2 118 | - m2w64-windows-default-manifest=6.4=3 119 | - m2w64-winpthreads-git=5.0.0.4634.697f757=2 120 | - m2w64-zlib=1.2.8=10 121 | - mako=1.0.6=py27_0 122 | - markupsafe=0.23=py27_2 123 | - matplotlib=2.0.2=np113py27_0 124 | - menuinst=1.4.7=py27_0 125 | - mistune=0.7.4=py27_0 126 | - mkl=2017.0.1=0 127 | - mkl-service=1.1.2=py27_3 128 | - mpmath=0.19=py27_1 129 | - msgpack-python=0.4.8=py27_0 130 | - msys2-conda-epoch=20160418=1 131 | - multipledispatch=0.4.9=py27_0 132 | - navigator-updater=0.1.0=py27_0 133 | - nbconvert=5.1.1=py27_0 134 | - nbformat=4.3.0=py27_0 135 | - networkx=1.11=py27_0 136 | - nltk=3.2.3=py27_0 137 | - nose=1.3.7=py27_1 138 | - nose-parameterized=0.6.0=py27_0 139 | - notebook=5.0.0=py27_0 140 | - numba=0.35.0=np113py27_0 141 | - numexpr=2.6.2=np113py27_0 142 | - numpy=1.13.1=py27_0 143 | - numpydoc=0.6.0=py27_0 144 | - odo=0.5.0=py27_1 145 | - olefile=0.44=py27_0 146 | - openpyxl=2.4.7=py27_0 147 | - openssl=1.0.2l=vc9_0 148 | - packaging=16.8=py27_0 149 | - pandas=0.20.3=py27_0 150 | - pandocfilters=1.4.1=py27_0 151 | - partd=0.3.8=py27_0 152 | - path.py=10.3.1=py27_0 153 | - pathlib2=2.2.1=py27_0 154 | - patsy=0.4.1=py27_0 155 | - pep8=1.7.0=py27_0 156 | - pickleshare=0.7.4=py27_0 157 | - pillow=4.1.1=py27_0 158 | - pip=9.0.1=py27_1 159 | - ply=3.10=py27_0 160 | - progressbar=2.3=py27_0 161 | - prompt_toolkit=1.0.14=py27_0 162 | - psutil=5.2.2=py27_0 163 | - py=1.4.33=py27_0 164 | - pycosat=0.6.2=py27_0 165 | - pycparser=2.17=py27_0 166 | - pycrypto=2.6.1=py27_6 167 | - pyculib=1.0.2=np113py27_2 168 | - pyculib_sorting=1.0.0=8 169 | - pycurl=7.43.0=py27_2 170 | - pydot-ng=1.0.0.15=py27_0 171 | - pyflakes=1.5.0=py27_0 172 | - pygments=2.2.0=py27_0 173 | - pylint=1.6.4=py27_1 174 | - pyodbc=4.0.16=py27_0 175 | - pyopenssl=17.0.0=py27_0 176 | - pyparsing=2.1.4=py27_0 177 | - pyqt=5.6.0=py27_2 178 | - pytables=3.2.2=np113py27_4 179 | - pytest=3.0.7=py27_0 180 | - python=2.7.13=1 181 | - python-dateutil=2.6.0=py27_0 182 | - pytz=2017.2=py27_0 183 | - pywavelets=0.5.2=np113py27_0 184 | - pywin32=220=py27_2 185 | - pyyaml=3.12=py27_0 186 | - pyzmq=16.0.2=py27_0 187 | - qt=5.6.2=vc9_4 188 | - qtawesome=0.4.4=py27_0 189 | - qtconsole=4.3.0=py27_0 190 | - qtpy=1.2.1=py27_0 191 | - requests=2.14.2=py27_0 192 | - rope=0.9.4=py27_1 193 | - ruamel_yaml=0.11.14=py27_1 194 | - scandir=1.5=py27_0 195 | - scikit-image=0.13.0=np113py27_0 196 | - scikit-learn=0.19.0=np113py27_0 197 | - scipy=0.19.1=np113py27_0 198 | - seaborn=0.7.1=py27_0 199 | - setuptools=36.4.0=py27_1 200 | - simplegeneric=0.8.1=py27_1 201 | - singledispatch=3.4.0.3=py27_0 202 | - sip=4.18=py27_0 203 | - six=1.10.0=py27_0 204 | - snowballstemmer=1.2.1=py27_0 205 | - sortedcollections=0.5.3=py27_0 206 | - sortedcontainers=1.5.7=py27_0 207 | - sphinx=1.6.3=py27_0 208 | - sphinxcontrib=1.0=py27_0 209 | - sphinxcontrib-websupport=1.0.1=py27_0 210 | - spyder=3.1.4=py27_0 211 | - sqlalchemy=1.1.9=py27_0 212 | - ssl_match_hostname=3.4.0.2=py27_1 213 | - statsmodels=0.8.0=np113py27_0 214 | - subprocess32=3.2.7=py27_0 215 | - sympy=1.0=py27_0 216 | - tblib=1.3.2=py27_0 217 | - testpath=0.3=py27_0 218 | - theano=0.9.0=py27_0 219 | - tk=8.5.18=vc9_0 220 | - toolz=0.8.2=py27_0 221 | - tornado=4.5.1=py27_0 222 | - traitlets=4.3.2=py27_0 223 | - typing=3.6.2=py27_0 224 | - unicodecsv=0.14.1=py27_0 225 | - vs2008_runtime=9.00.30729.5054=0 226 | - vs2015_runtime=14.0.25123=0 227 | - wcwidth=0.1.7=py27_0 228 | - werkzeug=0.12.2=py27_0 229 | - wheel=0.29.0=py27_0 230 | - widgetsnbextension=2.0.0=py27_0 231 | - win_unicode_console=0.5=py27_0 232 | - wincertstore=0.2=py27_0 233 | - wrapt=1.10.10=py27_0 234 | - xlrd=1.0.0=py27_0 235 | - xlsxwriter=0.9.6=py27_0 236 | - xlwings=0.10.4=py27_0 237 | - xlwt=1.2.0=py27_0 238 | - zict=0.1.2=py27_0 239 | - zlib=1.2.8=vc9_3 240 | - pip: 241 | - backports-abc==0.5 242 | - backports.shutil-get-terminal-size==1.0.0 243 | - backports.ssl-match-hostname==3.4.0.2 244 | - et-xmlfile==1.0.1 245 | - gpustat==0.3.2 246 | - ipython-genutils==0.2.0 247 | - jupyter-client==5.0.1 248 | - jupyter-console==5.1.0 249 | - jupyter-core==4.3.0 250 | - keras==2.0.8 251 | - lmdb==0.93 252 | - msgpack-numpy==0.4.1 253 | - prompt-toolkit==1.0.14 254 | - pypet==0.3.0 255 | - python-speech-features==0.6 256 | - tables==3.2.2 257 | - tabulate==0.7.7 258 | - tensorpack==0.5.0 259 | - termcolor==1.1.0 260 | - tqdm==4.17.1 261 | - win-unicode-console==0.5 262 | prefix: C:\Users\matthijsv\AppData\Local\Continuum\Anaconda2\envs\DeepLearning 263 | 264 | -------------------------------------------------------------------------------- /getResults.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import warnings 4 | from pprint import pprint # printing properties of networkToRun objects 5 | from time import gmtime, strftime 6 | 7 | # pprint(vars(a)) 8 | 9 | warnings.simplefilter("ignore", UserWarning) # cuDNN warning 10 | 11 | import logging 12 | import formatting 13 | from tqdm import tqdm 14 | 15 | logger_combined = logging.getLogger('combined') 16 | logger_combined.setLevel(logging.DEBUG) 17 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s ' 18 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False)) 19 | 20 | # create console handler with a higher log level 21 | ch = logging.StreamHandler() 22 | ch.setLevel(logging.DEBUG) 23 | ch.setFormatter(formatter) 24 | logger_combined.addHandler(ch) 25 | 26 | # File logger: see below META VARIABLES 27 | 28 | import time 29 | 30 | program_start_time = time.time() 31 | 32 | print("\n * Importing libraries...") 33 | from combinedNN_tools import * 34 | from general_tools import * 35 | import traceback 36 | 37 | ###################### Script settings ####################################### 38 | root = os.path.expanduser('~/TCDTIMIT/') 39 | resultsPath = root + 'combinedSR/TCDTIMIT/results/allEvalResults.pkl' 40 | 41 | logToFile = True; 42 | overwriteResults = False 43 | 44 | # if you wish to force retrain of networks, set justTest to False, forceTrain in main() to True, and overwriteSubnets to True. 45 | # if True, and justTest=False, even if a network exists it will continue training. If False, it will just be evaluated 46 | forceTrain = False 47 | 48 | # JustTest: If True, mainGetResults just runs over the trained networks. If a network doesn't exist, it's skipped 49 | # If False, ask user to train networks, then start training networks that don't exist. 50 | justTest = False 51 | 52 | getConfusionMatrix = True # if True, stores confusionMatrix where the .npz and train_info.pkl are stored 53 | # use this for testing with reduced precision. It converts the network weights to float16, then back to float32 for execution. 54 | # This amounts to rounding. Performance should hardly be impacted. 55 | ROUND_PARAMS = False 56 | 57 | # use this to TEST trained networks on the test dataset with noise added. 58 | # This data is generated using audioSR/fixDataset/mergeAudiofiles.py + audioToPkl_perVideo.py and combinedSR/dataToPkl_lipspeakers.py 59 | # It is loaded in in combinedNN_tools/finalEvaluation (also just before training in 'train'. You could also generate noisy training data and train on that, but I haven't tried that 60 | withNoise = False 61 | noiseTypes = ['white', 'voices'] 62 | ratio_dBs = [0, -3, -5, -10] 63 | 64 | 65 | ###################### Script code ####################################### 66 | 67 | # quickly generate many networks 68 | def createNetworkList(dataset, networkArchs): 69 | networkList = [] 70 | for networkArchi in networkArchs: 71 | networkList.append(networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=networkArch, 72 | audio_dataset=dataset, test_dataset=dataset)) 73 | return networkList 74 | 75 | networkList = [ 76 | # # # # # ### AUDIO ### -> see audioSR/RNN.py, there it can run in batch mode which is much faster 77 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64], 78 | # audio_dataset="TCDTIMIT", test_dataset="TCDTIMIT"),#,forceTrain=True), 79 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], 80 | # audio_dataset="combined", test_dataset="TCDTIMIT"),#, forceTrain=True) 81 | networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], 82 | audio_dataset="TCDTIMIT", test_dataset="TCDTIMIT"), # , forceTrain=True) 83 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512,512], 84 | # audio_dataset="TCDTIMIT", test_dataset="TCDTIMIT"),#,forceTrain=True), 85 | 86 | # run TCDTIMIT-trained network on lipspeakers, and on the real TCDTIMIT test set (volunteers) 87 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], 88 | # audio_dataset="TCDTIMIT", test_dataset="TIMIT"), # , forceTrain=True) % 66.75 / 89.19 89 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], 90 | # audio_dataset="TCDTIMIT", test_dataset="TCDTIMITvolunteers"), # ,forceTrain=True), 91 | 92 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64], 93 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True), 94 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], 95 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True), 96 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512, 512], 97 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True), 98 | 99 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], 100 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True), 101 | 102 | 103 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8], audio_dataset="TIMIT", test_dataset="TIMIT"), 104 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32], audio_dataset="TIMIT", test_dataset="TIMIT"), 105 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64], audio_dataset="TIMIT", test_dataset="TIMIT"), 106 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256], dataset="TIMIT", audio_dataset="TIMIT", test_dataset="TIMIT"), 107 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512], audio_dataset="TIMIT", test_dataset="TIMIT"), 108 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024], audio_dataset="TIMIT", test_dataset="TIMIT"), 109 | # 110 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8, 8], audio_dataset="TIMIT", test_dataset="TIMIT"), 111 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32, 32], audio_dataset="TIMIT", test_dataset="TIMIT"), 112 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT"), 113 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], audio_dataset="TIMIT", test_dataset="TIMIT"), 114 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512, 512], audio_dataset="TIMIT", test_dataset="TIMIT"), 115 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024,1024], audio_dataset="TIMIT", test_dataset="TIMIT"), 116 | # 117 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8,8,8], audio_dataset="TIMIT", test_dataset="TIMIT"), 118 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32,32,32], audio_dataset="TIMIT", test_dataset="TIMIT"), 119 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64,64], audio_dataset="TIMIT", test_dataset="TIMIT"), 120 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256,256,256], audio_dataset="TIMIT", test_dataset="TIMIT"), 121 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512,512,512], audio_dataset="TIMIT", test_dataset="TIMIT"), 122 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024, 1024, 1024], audio_dataset="TIMIT", test_dataset="TIMIT"), 123 | # 124 | # 125 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8,8,8,8], audio_dataset="TIMIT", test_dataset="TIMIT"), 126 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32,32,32,32], audio_dataset="TIMIT", test_dataset="TIMIT"), 127 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64,64,64], audio_dataset="TIMIT", test_dataset="TIMIT"), 128 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256,256,256,256], audio_dataset="TIMIT", test_dataset="TIMIT"), 129 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512,512,512,512], audio_dataset="TIMIT", test_dataset="TIMIT"), 130 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024,1024,1024,1024], audio_dataset="TIMIT", test_dataset="TIMIT"), 131 | 132 | # #get the MFCC results 133 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT",nbMFCCs=13), 134 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT",nbMFCCs=26), 135 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT",nbMFCCs=39), 136 | ] 137 | 138 | def main(): 139 | 140 | # Use this if you want only want to start training if the network doesn't exist 141 | if withNoise: 142 | allResults = [] 143 | for noiseType in noiseTypes: 144 | for ratio_dB in ratio_dBs: 145 | results, resultsPath = mainGetResults(networkList, withNoise, noiseType, ratio_dB) 146 | allResults.append(results) 147 | # print(allResults) 148 | # import pdb;pdb.set_trace() 149 | 150 | allNoisePath = root + 'resultsNoisy.pkl' 151 | exportResultsToExcelManyNoise(allResults, allNoisePath) 152 | else: 153 | results, resultsPath = mainGetResults(networkList) 154 | print("\n got all results") 155 | exportResultsToExcel(results, resultsPath) 156 | 157 | # Use this if you want to force run the network on train sets. If justTest==True, it will only evaluate performance on the test set 158 | runManyNetworks(networkList, withNoise=withNoise) 159 | 160 | class networkToRun: 161 | def __init__(self, 162 | AUDIO_LSTM_HIDDEN_LIST=[256, 256], audio_dataset="TCDTIMIT", nbMFCCs=39, audio_bidirectional=True, 163 | LR_start=0.001, 164 | forceTrain=False, runType='audio', 165 | dataset="TCDTIMIT", test_dataset=None): 166 | # Audio 167 | self.AUDIO_LSTM_HIDDEN_LIST = AUDIO_LSTM_HIDDEN_LIST # LSTM architecture for audio part 168 | self.audio_dataset = audio_dataset # training here only works for TCDTIMIT at the moment; for that go to audioSR/RNN.py. This variable is used to get the stored results from that python script 169 | self.nbMFCCs = nbMFCCs 170 | self.audio_bidirectional = audio_bidirectional 171 | 172 | # Others 173 | self.runType = runType 174 | self.LR_start = LR_start 175 | self.forceTrain = forceTrain # If False, just test the network outputs when the network already exists. 176 | # If forceTrain == True, train it anyway before testing 177 | # If True, set the LR_start low enough so you don't move too far out of the objective minimum 178 | 179 | self.dataset = dataset 180 | if test_dataset == None: self.test_dataset = self.dataset 181 | else: self.test_dataset = test_dataset 182 | 183 | 184 | 185 | 186 | def exportResultsToExcel(results, path): 187 | storePath = path.replace(".pkl", ".xlsx") 188 | import xlsxwriter 189 | workbook = xlsxwriter.Workbook(storePath) 190 | 191 | for runType in results.keys()[1:]: # audio, lipreading, combined: 192 | worksheet = workbook.add_worksheet(runType) # one worksheet per runType, but then everything is spread out... 193 | row = 0 194 | 195 | allNets = results[runType] 196 | 197 | # get and write the column titles 198 | 199 | # get the number of parameters. #for audio, only 1 value. For combined/lipreadin: lots of values in a dictionary 200 | try: 201 | nb_paramNames = allNets.items()[0][1][ 202 | 'nb_params'].keys() # first key-value pair, get the value ([1]), then get names of nbParams (=the keys) 203 | except: 204 | nb_paramNames = ['nb_params'] 205 | startVals = 4 + len(nb_paramNames) # column number of first value 206 | 207 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset'] + nb_paramNames + ['Test Cost', 208 | 'Test Accuracy', 209 | 'Test Top 3 Accuracy', 210 | 'Validation accuracy'] 211 | for i in range(len(colNames)): 212 | worksheet.write(0, i, colNames[i]) 213 | 214 | # write the data for each network 215 | for netName in allNets.keys(): 216 | row += 1 217 | 218 | thisNet = allNets[netName] 219 | # write the path and name 220 | worksheet.write(row, 0, os.path.basename(netName)) # netName) 221 | worksheet.write(row, 1, thisNet['niceName']) 222 | if runType == 'audio': 223 | worksheet.write(row, 2, thisNet['audio_dataset']) 224 | worksheet.write(row, 3, thisNet['test_dataset']) 225 | else: 226 | worksheet.write(row, 2, thisNet['dataset']) 227 | worksheet.write(row, 3, thisNet['test_dataset']) 228 | 229 | # now write the params 230 | try: 231 | vals = thisNet['nb_params'].values() # vals is list of [test_cost, test_acc, test_top3_acc] 232 | except: 233 | vals = [thisNet['nb_params']] 234 | for i in range(len(vals)): 235 | worksheet.write(row, 4 + i, vals[i]) 236 | 237 | # now write the values 238 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc] 239 | for i in range(len(vals)): 240 | worksheet.write(row, startVals + i, vals[i]) 241 | 242 | workbook.close() 243 | 244 | logger_combined.info("Excel file stored in %s", storePath) 245 | 246 | def exportResultsToExcelManyNoise(resultsList, path): 247 | storePath = path.replace(".pkl", ".xlsx") 248 | import xlsxwriter 249 | workbook = xlsxwriter.Workbook(storePath) 250 | 251 | storePath = path.replace(".pkl", ".xlsx") 252 | import xlsxwriter 253 | workbook = xlsxwriter.Workbook(storePath) 254 | 255 | row = 0 256 | 257 | if len(resultsList[0]['audio'].keys()) > 0: thisRunType = 'audio' 258 | if len(resultsList[0]['lipreading'].keys()) > 0: thisRunType = 'lipreading' 259 | if len(resultsList[0]['combined'].keys()) > 0: thisRunType = 'combined' 260 | worksheetAudio = workbook.add_worksheet('audio'); 261 | audioRow = 0 262 | worksheetLipreading = workbook.add_worksheet('lipreading'); 263 | lipreadingRow = 0 264 | worksheetCombined = workbook.add_worksheet('combined'); 265 | combinedRow = 0 266 | 267 | for r in range(len(resultsList)): 268 | results = resultsList[r] 269 | noiseType = results['resultsType'] 270 | 271 | for runType in results.keys()[1:]: 272 | if len(results[runType]) == 0: continue 273 | if runType == 'audio': worksheet = worksheetAudio; row = audioRow 274 | if runType == 'lipreading': worksheet = worksheetLipreading; row = lipreadingRow 275 | if runType == 'combined': worksheet = worksheetCombined; row = combinedRow 276 | 277 | allNets = results[runType] 278 | 279 | # write the column titles 280 | startVals = 5 281 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset', 'Noise Type', 'Test Cost', 282 | 'Test Accuracy', 'Test Top 3 Accuracy'] 283 | for i in range(len(colNames)): 284 | worksheet.write(0, i, colNames[i]) 285 | 286 | # write the data for each network 287 | for netName in allNets.keys(): 288 | row += 1 289 | 290 | thisNet = allNets[netName] 291 | # write the path and name 292 | worksheet.write(row, 0, os.path.basename(netName)) # netName) 293 | worksheet.write(row, 1, thisNet['niceName']) 294 | worksheet.write(row, 2, thisNet['dataset']) 295 | worksheet.write(row, 3, thisNet['test_dataset']) 296 | worksheet.write(row, 4, noiseType) 297 | 298 | # now write the values 299 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc] 300 | for i in range(len(vals)): 301 | worksheet.write(row, startVals + i, vals[i]) 302 | 303 | if runType == 'audio': audioRow = row 304 | if runType == 'lipreading': lipreadingRow = row 305 | if runType == 'combined': combinedRow = row 306 | 307 | row += 1 308 | 309 | workbook.close() 310 | 311 | logger_combined.info("Excel file stored in %s", storePath) 312 | 313 | def getManyNetworkResults(networks, resultsType="unknownResults", roundParams=False, withNoise=False, 314 | noiseType='white', ratio_dB=0): 315 | results = {'resultsType': resultsType} 316 | results['audio'] = {} 317 | results['lipreading'] = {} 318 | results['combined'] = {} 319 | 320 | failures = [] 321 | 322 | for networkParams in tqdm(networks, total=len(networks)): 323 | logger_combined.info("\n\n\n\n ################################") 324 | logger_combined.info("Getting results from network...") 325 | logger_combined.info("Network properties: ") 326 | # pprint(vars(networkParams)) 327 | try: 328 | if networkParams.forceTrain == True: 329 | runManyNetworks([networkParams]) 330 | thisResults = {} 331 | thisResults['values'] = [] 332 | thisResults['dataset'] = networkParams.dataset 333 | thisResults['test_dataset'] = networkParams.test_dataset 334 | thisResults['audio_dataset'] = networkParams.audio_dataset 335 | 336 | model_name, nice_name = getModelName(networkParams.AUDIO_LSTM_HIDDEN_LIST, networkParams.dataset) 337 | 338 | logger_combined.info("Getting results for %s", model_name + '.npz') 339 | network_train_info = getNetworkResults(model_name) 340 | if network_train_info == -1: 341 | raise IOError("this model doesn't have any stored results") 342 | # import pdb;pdb.set_trace() 343 | 344 | # audio networks can be run on TIMIT or combined as well 345 | if networkParams.runType != 'audio' and networkParams.test_dataset != networkParams.dataset: 346 | testType = "_" + networkParams.test_dataset 347 | else: 348 | testType = "" 349 | 350 | if roundParams: 351 | testType = "_roundParams" + testType 352 | 353 | if networkParams.runType != 'lipreading' and withNoise: 354 | thisResults['values'] = [ 355 | network_train_info['final_test_cost_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType], 356 | network_train_info['final_test_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType], 357 | network_train_info['final_test_top3_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType]] 358 | else: 359 | try: 360 | val_acc = max(network_train_info['val_acc']) 361 | except: 362 | try: 363 | val_acc = max(network_train_info['test_acc']) 364 | except: 365 | val_acc = network_train_info['final_test_acc'] 366 | thisResults['values'] = [network_train_info['final_test_cost' + testType], 367 | network_train_info['final_test_acc' + testType], 368 | network_train_info['final_test_top3_acc' + testType], val_acc] 369 | 370 | thisResults['nb_params'] = network_train_info['nb_params'] 371 | thisResults['niceName'] = nice_name 372 | 373 | # eg results['audio']['2Layer_256_256_TIMIT'] = [0.8, 79.5, 92,6] #test cost, test acc, test top3 acc 374 | results[networkParams.runType][model_name] = thisResults 375 | 376 | except: 377 | logger_combined.info('caught this error: ' + traceback.format_exc()); 378 | # import pdb;pdb.set_trace() 379 | failures.append(networkParams) 380 | 381 | logger_combined.info("\n\nDONE getting stored results from networks") 382 | logger_combined.info("####################################################") 383 | 384 | if len(failures) > 0: 385 | logger_combined.info("Couldn't get %s results from %s networks...", resultsType, len(failures)) 386 | for failure in failures: 387 | pprint(vars(failure)) 388 | if autoTrain or query_yes_no("\nWould you like to evalute the networks now?\n\n"): 389 | logger_combined.info("Running networks...") 390 | runManyNetworks(failures, withNoise=withNoise, noiseType=noiseType, ratio_dB=ratio_dB) 391 | mainGetResults(failures, withNoise=withNoise, noiseType=noiseType, ratio_dB=ratio_dB) 392 | 393 | logger_combined.info("Done training.\n\n") 394 | # import pdb; pdb.set_trace() 395 | return results 396 | 397 | 398 | def getNetworkResults(save_name, logger=logger_combined): # copy-pasted from loadPreviousResults 399 | if os.path.exists(save_name + ".npz") and os.path.exists(save_name + "_trainInfo.pkl"): 400 | old_train_info = unpickle(save_name + '_trainInfo.pkl') 401 | # import pdb;pdb.set_trace() 402 | if type(old_train_info) == dict: # normal case 403 | network_train_info = old_train_info # load old train info so it won't get lost on retrain 404 | 405 | if not 'final_test_cost' in network_train_info.keys(): 406 | network_train_info['final_test_cost'] = min(network_train_info['test_cost']) 407 | if not 'final_test_acc' in network_train_info.keys(): 408 | network_train_info['final_test_acc'] = max(network_train_info['test_acc']) 409 | if not 'final_test_top3_acc' in network_train_info.keys(): 410 | network_train_info['final_test_top3_acc'] = max(network_train_info['test_topk_acc']) 411 | else: 412 | logger.warning("old trainInfo found, but wrong format: %s", save_name + "_trainInfo.pkl") 413 | # do nothing 414 | else: 415 | return -1 416 | return network_train_info 417 | 418 | 419 | # networks is a list of dictionaries, where each dictionary contains the needed parameters for training 420 | def runManyNetworks(networks, withNoise=False, noiseType='white', ratio_dB=0): 421 | results = {} 422 | failures = [] 423 | if justTest: 424 | logger_combined.warning("\n\n!!!!!!!!! WARNING !!!!!!!!!! \n justTest = True") 425 | if not query_yes_no("\nAre you sure you want to continue?\n\n"): 426 | return -1 427 | for network in tqdm(networks, total=len(networks)): 428 | print("\n\n\n\n ################################") 429 | print("Training new network...") 430 | print("Network properties: ") 431 | pprint(vars(network)) 432 | try: 433 | model_save, test_results = runNetwork(AUDIO_LSTM_HIDDEN_LIST=network.AUDIO_LSTM_HIDDEN_LIST, 434 | audio_features=network.audio_features, 435 | audio_bidirectional=network.audio_bidirectional, 436 | CNN_NETWORK=network.CNN_NETWORK, 437 | cnn_features=network.cnn_features, 438 | LIP_RNN_HIDDEN_LIST=network.LIP_RNN_HIDDEN_LIST, 439 | lipRNN_bidirectional=network.lipRNN_bidirectional, 440 | lipRNN_features=network.lipRNN_features, 441 | DENSE_HIDDEN_LIST=network.DENSE_HIDDEN_LIST, 442 | combinationType=network.combinationType, 443 | dataset=network.dataset, datasetType=network.datasetType, 444 | test_dataset=network.test_dataset, 445 | addNoisyAudio=network.addNoisyAudio, 446 | runType=network.runType, 447 | LR_start=network.LR_start, 448 | allowSubnetTraining=network.allowSubnetTraining, 449 | forceTrain=network.forceTrain, 450 | overwriteSubnets=network.overwriteSubnets, 451 | audio_dataset=network.audio_dataset, 452 | withNoise=withNoise, noiseType=noiseType, ratio_dB=ratio_dB) 453 | print(model_save) 454 | name = model_save + ("_Noise" + noiseType + "_" + str(ratio_dB) if withNoise else "") 455 | results[name] = test_results # should be test_cost, test_acc, test_topk_acc 456 | 457 | except: 458 | print('caught this error: ' + traceback.format_exc()); 459 | # import pdb; pdb.set_trace() 460 | 461 | failures.append(network) 462 | print("#########################################################") 463 | print("\n\n\n DONE running all networks") 464 | 465 | if len(failures) > 0: 466 | print("Some networks failed to run...") 467 | # import pdb;pdb.set_trace() 468 | return results 469 | 470 | 471 | 472 | if __name__ == "__main__": 473 | main() 474 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/__init__.py -------------------------------------------------------------------------------- /tools/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/__init__.pyc -------------------------------------------------------------------------------- /tools/createMLF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import helpFunctions.writeToTxt as wrTxt 5 | 6 | """ 7 | License: WTFPL http://www.wtfpl.net 8 | Copyright: Gabriel Synnaeve 2013 9 | """ 10 | 11 | doc = """ 12 | Usage: 13 | python createMLF.py [$folder_path] 14 | 15 | Will create "${folder}.mlf" and "labels.txt" files in $folder_path. 16 | 17 | If you run it only on the training folder, all the phones that you will 18 | encounter in the test should be present in training so that the "labels" 19 | corresponds. 20 | """ 21 | 22 | 23 | def process(folder): 24 | folder = folder.rstrip('/') 25 | countPhonemes = {} 26 | master_label_fname = folder + '/' + folder.split('/')[-1] + '.mlf' 27 | labels_fpath = folder + '/labels.txt' 28 | master_label_file = open(master_label_fname, 'w') 29 | master_label_file.write("#!MLF!#\n") 30 | 31 | for d, ds, fs in os.walk(folder): 32 | for fname in fs: 33 | fullname = d.rstrip('/') + '/' + fname 34 | print("Processing: ", fullname) 35 | extension = fname[-4:] 36 | 37 | phones = [] 38 | if extension.lower() == '.phn': 39 | master_label_file.write('"' + fullname + '"\n') 40 | for line in open(fullname): 41 | master_label_file.write(line) 42 | phones.append(line.split()[2]) 43 | for tmp_phn in phones: 44 | countPhonemes[tmp_phn] = countPhonemes.get(tmp_phn, 0) + 1 45 | master_label_file.write('\n.\n') 46 | 47 | master_label_file.close() 48 | print("written MLF file", master_label_fname) 49 | 50 | wrTxt.writeToTxt(sorted(countPhonemes.items()), labels_fpath) 51 | print("written labels", labels_fpath) 52 | 53 | print("phones counts:", countPhonemes) 54 | print("number of phones:", len(countPhonemes)) 55 | 56 | 57 | if __name__ == '__main__': 58 | if len(sys.argv) > 1: 59 | if '--help' in sys.argv: 60 | print(doc) 61 | sys.exit(0) 62 | l = filter(lambda x: not '--' in x[0:2], sys.argv) 63 | print(l) 64 | foldername = '.' 65 | if len(l) > 1: 66 | foldername = l[1] 67 | print(foldername) 68 | process(foldername) 69 | else: 70 | process('.') # default 71 | -------------------------------------------------------------------------------- /tools/datasetToPkl.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import timeit; 4 | 5 | program_start_time = timeit.default_timer() 6 | import random 7 | 8 | random.seed(int(timeit.default_timer())) 9 | 10 | from phoneme_set import phoneme_set_39_list 11 | import formatting, preprocessWavs 12 | import general_tools 13 | 14 | import logging 15 | 16 | # prepare logging (File logger: see below META VARIABLES) 17 | logger = logging.getLogger('PrepTCDTIMIT') 18 | logger.setLevel(logging.DEBUG) 19 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s ' 20 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False)) 21 | 22 | # create console handler with a higher log level 23 | ch = logging.StreamHandler() 24 | ch.setLevel(logging.DEBUG) 25 | ch.setFormatter(formatter) 26 | logger.addHandler(ch) 27 | 28 | 29 | 30 | ##### META VARIABLES ##### 31 | DEBUG = False 32 | debug_size = 50 33 | 34 | # TODO: MODIFY THESE PARAMETERS for other nbPhonemes of mfccTypes. Save location is updated automatically. 35 | nbMFCCs = 39 # 13 => just mfcc (13 features). 26 => also derivative (26 features). 39 => also 2nd derivative (39 features) 36 | 37 | # set phoneme type and get dictionary with phoneme-number mappings 38 | nbPhonemes = 39 39 | phoneme_set_list = phoneme_set_39_list # import list of phonemes, 40 | values = [i for i in range(0, len(phoneme_set_list))] 41 | phoneme_classes = dict(zip(phoneme_set_list, values)) 42 | 43 | ##### DATA Settings ##### 44 | 45 | FRAC_VAL = 0.1 # fraction of training data to be used for validation 46 | 47 | root = os.path.expanduser("~/TCDTIMIT/audioSR/") # (keep the trailing slash) 48 | dataPreSplit = True # some datasets have a pre-defined TEST set (eg TIMIT). Set the dataset name below 49 | 50 | if dataPreSplit: 51 | dataset = "TIMIT" # eg TIMIT. You can also manually split up TCDTIMIT according to train/test split in Harte, N.; Gillen, E., "TCD-TIMIT: An Audio-Visual Corpus of Continuous Speech," doi: 10.1109/TMM.2015.2407694 52 | ## eg TIMIT ## 53 | dataRootDir = root + dataset + "/fixed" + str(nbPhonemes) + os.sep + dataset 54 | train_source_path = os.path.join(dataRootDir, 'TRAIN') 55 | test_source_path = os.path.join(dataRootDir, 'TEST') 56 | outputDir = root + dataset + "/binary2" + str(nbPhonemes) + os.sep + dataset 57 | 58 | else: 59 | ## just a bunch of wav and phn files, not split up in train and test -> need to create the split yourself. 60 | dataset = "TCDTIMIT" 61 | dataRootDir = root + dataset + "/fixed" + str(nbPhonemes) + "_nonSplit" + os.sep + dataset 62 | outputDir = root + dataset + "/binary" + str(nbPhonemes) + os.sep + os.path.basename(dataRootDir) 63 | # TOTAL = TRAINING + TEST = TRAIN + VALIDATION + TEST 64 | FRAC_TEST = 0.1 65 | FRAC_TRAINING = 1 - FRAC_TEST # val set will be FRAC_TRAINING * FRAC_VAL = 9% of the data. FRAC_TRAIN is 90 - 9 = 81%, test = 10 66 | 67 | 68 | ##### Everything below is calculated automatically ########## 69 | ############################################################# 70 | 71 | # store path 72 | target = os.path.join(outputDir, os.path.basename(dataRootDir) + '_' + str(nbMFCCs) + '_ch'); 73 | target_path = target + '.pkl' 74 | if not os.path.exists(outputDir): 75 | os.makedirs(outputDir) 76 | 77 | # Already exists, ask if overwrite 78 | if (os.path.exists(target_path)): 79 | if (not general_tools.query_yes_no(target_path + " exists. Overwrite?", "no")): 80 | raise Exception("Not Overwriting") 81 | 82 | # set log file 83 | logFile = outputDir + os.sep + os.path.basename(target) + '.log' 84 | fh = logging.FileHandler(logFile, 'w') # create new logFile 85 | fh.setLevel(logging.DEBUG) 86 | fh.setFormatter(formatter) 87 | logger.addHandler(fh) 88 | 89 | if DEBUG: 90 | logger.info('DEBUG mode: \tACTIVE, only a small dataset will be preprocessed') 91 | target_path = target + '_DEBUG.pkl' 92 | else: 93 | logger.info('DEBUG mode: \tDEACTIVE') 94 | debug_size = None 95 | 96 | ##### The PREPROCESSING itself ##### 97 | logger.info('Preprocessing data ...') 98 | 99 | 100 | # FIRST, gather the WAV and PHN files, generate MFCCs, extract labels to make inputs and targets for the network 101 | # for a dataset containing no TRAIN/TEST subdivision, just a bunch of wavs -> choose training set yourself 102 | def processDataset(FRAC_TRAINING, data_source_path, logger=None): 103 | logger.info(' Data: %s ', data_source_path) 104 | X_all, y_all, valid_frames_all = preprocessWavs.preprocess_dataset(source_path=data_source_path, nbMFCCs=nbMFCCs, 105 | logger=logger, debug=debug_size) 106 | assert len(X_all) == len(y_all) == len(valid_frames_all) 107 | 108 | logger.info(' Loading data complete.') 109 | logger.debug('Type and shape/len of X_all') 110 | logger.debug('type(X_all): {}'.format(type(X_all))) 111 | logger.debug('type(X_all[0]): {}'.format(type(X_all[0]))) 112 | logger.debug('type(X_all[0][0]): {}'.format(type(X_all[0][0]))) 113 | logger.debug('type(X_all[0][0][0]): {}'.format(type(X_all[0][0][0]))) 114 | logger.info('Creating Validation index ...') 115 | 116 | total_size = len(X_all) # TOTAL = TRAINING + TEST = TRAIN + VAL + TEST 117 | total_training_size = int(math.ceil(FRAC_TRAINING * total_size)) # TRAINING = TRAIN + VAL 118 | test_size = total_size - total_training_size 119 | 120 | # split off a 'test' dataset 121 | test_idx = random.sample(range(0, total_training_size), test_size) 122 | test_idx = [int(i) for i in test_idx] 123 | # ensure that the testidation set isn't empty 124 | if DEBUG: 125 | test_idx[0] = 0 126 | test_idx[1] = 1 127 | logger.info('Separating test and training set ...') 128 | X_training = [] 129 | y_training = [] 130 | valid_frames_training = [] 131 | X_test = [] 132 | y_test = [] 133 | valid_frames_test = [] 134 | for i in range(len(X_all)): 135 | if i in test_idx: 136 | X_test.append(X_all[i]) 137 | y_test.append(y_all[i]) 138 | valid_frames_test.append(valid_frames_all[i]) 139 | else: 140 | X_training.append(X_all[i]) 141 | y_training.append(y_all[i]) 142 | valid_frames_training.append(valid_frames_all[i]) 143 | 144 | assert len(X_test) == test_size 145 | assert len(X_training) == total_training_size 146 | 147 | return X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test 148 | 149 | 150 | def processDatasetSplit(train_source_path, test_source_path, logger=None): 151 | logger.info(' Training data: %s ', train_source_path) 152 | X_training, y_training, valid_frames_training = preprocessWavs.preprocess_dataset(source_path=train_source_path, 153 | logger=logger, 154 | nbMFCCs=nbMFCCs, debug=debug_size) 155 | logger.info(' Test data: %s', test_source_path) 156 | X_test, y_test, valid_frames_test = preprocessWavs.preprocess_dataset(source_path=test_source_path, logger=logger, 157 | nbMFCCs=nbMFCCs, debug=debug_size) 158 | return X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test 159 | 160 | 161 | if dataPreSplit: 162 | X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test = \ 163 | processDatasetSplit(train_source_path, test_source_path, logger) 164 | else: 165 | X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test = \ 166 | processDataset(FRAC_TRAINING, dataRootDir, logger) 167 | 168 | # SECOND, split off a 'validation' set from the training set. The remainder is the 'train' set 169 | total_training_size = len(X_training) 170 | val_size = int(math.ceil(total_training_size * FRAC_VAL)) 171 | train_size = total_training_size - val_size 172 | val_idx = random.sample(range(0, total_training_size), val_size) # choose random indices to be validation data 173 | val_idx = [int(i) for i in val_idx] 174 | 175 | logger.info('Length of training') 176 | logger.info(" train X: %s", len(X_training)) 177 | 178 | # ensure that the validation set isn't empty 179 | if DEBUG: 180 | val_idx[0] = 0 181 | val_idx[1] = 1 182 | 183 | logger.info('Separating training set into validation and train ...') 184 | X_train = [] 185 | y_train = [] 186 | valid_frames_train = [] 187 | X_val = [] 188 | y_val = [] 189 | valid_frames_val = [] 190 | for i in range(len(X_training)): 191 | if i in val_idx: 192 | X_val.append(X_training[i]) 193 | y_val.append(y_training[i]) 194 | valid_frames_val.append(valid_frames_training[i]) 195 | else: 196 | X_train.append(X_training[i]) 197 | y_train.append(y_training[i]) 198 | valid_frames_train.append(valid_frames_training[i]) 199 | assert len(X_val) == val_size 200 | 201 | # Print some information 202 | logger.info('Length of train, val, test') 203 | logger.info(" train X: %s", len(X_train)) 204 | logger.info(" train y: %s", len(y_train)) 205 | logger.info(" train valid_frames: %s", len(valid_frames_train)) 206 | 207 | logger.info(" val X: %s", len(X_val)) 208 | logger.info(" val y: %s", len(y_val)) 209 | logger.info(" val valid_frames: %s", len(valid_frames_val)) 210 | 211 | logger.info(" test X: %s", len(X_test)) 212 | logger.info(" test y: %s", len(y_test)) 213 | logger.info(" test valid_frames: %s", len(valid_frames_test)) 214 | 215 | ### NORMALIZE data ### 216 | logger.info('Normalizing data ...') 217 | logger.info(' Each channel mean=0, sd=1 ...') 218 | 219 | mean_val, std_val, _ = preprocessWavs.calc_norm_param(X_train) 220 | 221 | X_train = preprocessWavs.normalize(X_train, mean_val, std_val) 222 | X_val = preprocessWavs.normalize(X_val, mean_val, std_val) 223 | X_test = preprocessWavs.normalize(X_test, mean_val, std_val) 224 | 225 | logger.debug('X train') 226 | logger.debug(' %s %s', type(X_train), len(X_train)) 227 | logger.debug(' %s %s', type(X_train[0]), X_train[0].shape) 228 | logger.debug(' %s %s', type(X_train[0][0]), X_train[0][0].shape) 229 | logger.debug(' %s %s', type(X_train[0][0][0]), X_train[0][0].shape) 230 | logger.debug('y train') 231 | logger.debug(' %s %s', type(y_train), len(y_train)) 232 | logger.debug(' %s %s', type(y_train[0]), y_train[0].shape) 233 | logger.debug(' %s %s', type(y_train[0][0]), y_train[0][0].shape) 234 | 235 | # make sure we're working with float32 236 | X_data_type = 'float32' 237 | X_train = preprocessWavs.set_type(X_train, X_data_type) 238 | X_val = preprocessWavs.set_type(X_val, X_data_type) 239 | X_test = preprocessWavs.set_type(X_test, X_data_type) 240 | 241 | y_data_type = 'int32' 242 | y_train = preprocessWavs.set_type(y_train, y_data_type) 243 | y_val = preprocessWavs.set_type(y_val, y_data_type) 244 | y_test = preprocessWavs.set_type(y_test, y_data_type) 245 | 246 | valid_frames_data_type = 'int32' 247 | valid_frames_train = preprocessWavs.set_type(valid_frames_train, valid_frames_data_type) 248 | valid_frames_val = preprocessWavs.set_type(valid_frames_val, valid_frames_data_type) 249 | valid_frames_test = preprocessWavs.set_type(valid_frames_test, valid_frames_data_type) 250 | 251 | # print some more to check that cast succeeded 252 | logger.debug('X train') 253 | logger.debug(' %s %s', type(X_train), len(X_train)) 254 | logger.debug(' %s %s', type(X_train[0]), X_train[0].shape) 255 | logger.debug(' %s %s', type(X_train[0][0]), X_train[0][0].shape) 256 | logger.debug(' %s %s', type(X_train[0][0][0]), X_train[0][0].shape) 257 | logger.debug('y train') 258 | logger.debug(' %s %s', type(y_train), len(y_train)) 259 | logger.debug(' %s %s', type(y_train[0]), y_train[0].shape) 260 | logger.debug(' %s %s', type(y_train[0][0]), y_train[0][0].shape) 261 | 262 | ### STORE DATA ### 263 | logger.info('Saving data to %s', target_path) 264 | dataList = [X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test] 265 | general_tools.saveToPkl(target_path, dataList) 266 | 267 | # these can be used to evaluate new data, so you don't have to load the whole dataset just to normalize 268 | meanStd_path = os.path.dirname(outputDir) + os.sep + os.path.basename(dataRootDir) + "MeanStd.pkl" 269 | logger.info('Saving Mean and Std_val to %s', meanStd_path) 270 | dataList = [mean_val, std_val] 271 | general_tools.saveToPkl(meanStd_path, dataList) 272 | 273 | logger.info('Preprocessing complete!') 274 | logger.info('Total time: {:.3f}'.format(timeit.default_timer() - program_start_time)) 275 | -------------------------------------------------------------------------------- /tools/formatting.py: -------------------------------------------------------------------------------- 1 | # from http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output?rq=1 2 | import logging 3 | import os 4 | 5 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) 6 | 7 | # The background is set with 40 plus the number of the color, and the foreground with 30 8 | 9 | # These are the sequences need to get colored ouput 10 | RESET_SEQ = "\033[0m" 11 | COLOR_SEQ = "\033[1;%dm" 12 | BOLD_SEQ = "\033[1m" 13 | 14 | 15 | def formatter_message(message, use_color=True): 16 | if use_color: 17 | message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ) 18 | else: 19 | message = message.replace("$RESET", "").replace("$BOLD", "") 20 | return message 21 | 22 | 23 | COLORS = { 24 | 'WARNING': YELLOW, 25 | 'INFO': WHITE, 26 | 'DEBUG': BLUE, 27 | 'CRITICAL': YELLOW, 28 | 'ERROR': RED 29 | } 30 | 31 | 32 | class ColoredFormatter(logging.Formatter): 33 | def __init__(self, msg, use_color=True): 34 | logging.Formatter.__init__(self, msg) 35 | self.use_color = use_color 36 | 37 | def format(self, record): 38 | levelname = record.levelname 39 | if self.use_color and levelname in COLORS: 40 | levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ 41 | record.levelname = levelname_color 42 | return logging.Formatter.format(self, record) 43 | 44 | 45 | # Custom logger class with multiple destinations 46 | class ColoredLogger(logging.Logger): 47 | FORMAT = '%(asctime)s - (%(filename)s:%(lineno)d) | %(message)s' # "[$BOLD%(name)-5s$RESET][%(levelname)-10s]($BOLD%(filename)s$RESET:%(lineno)d) %(message)s " 48 | COLOR_FORMAT = formatter_message(FORMAT, True) 49 | 50 | def __init__(self, name): 51 | logging.Logger.__init__(self, name, logging.WARNING) 52 | 53 | color_formatter = ColoredFormatter(self.COLOR_FORMAT) 54 | 55 | console = logging.StreamHandler() 56 | console.setFormatter(color_formatter) 57 | console.setLevel(logging.INFO) 58 | 59 | self.addHandler(console) 60 | return 61 | 62 | def addFileHandler(self, output_dir='.', log_name="logger.log"): 63 | fileHandler = logging.FileHandler(os.path.join(output_dir, log_name), "w", encoding=None, delay=True) 64 | fileHandler.setFormatter(ColoredFormatter(self.COLOR_FORMAT)) 65 | fileHandler.setLevel(logging.DEBUG) 66 | self.addHandler(fileHandler) 67 | -------------------------------------------------------------------------------- /tools/formatting.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/formatting.pyc -------------------------------------------------------------------------------- /tools/general_tools.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | from six.moves import cPickle 7 | 8 | logger_GeneralTools = logging.getLogger('audioSR.generalTools') 9 | logger_GeneralTools.setLevel(logging.ERROR) 10 | 11 | 12 | def path_reader(filename): 13 | with open(filename) as f: 14 | path_list = f.read().splitlines() 15 | return path_list 16 | 17 | 18 | def unpickle(file_path): 19 | with open(file_path, 'rb') as cPickle_file: 20 | a = cPickle.load(cPickle_file) 21 | return a 22 | 23 | 24 | # find all files of a type under a directory, recursive 25 | def load_wavPhn(rootDir): 26 | wavs = loadWavs(rootDir) 27 | phns = loadPhns(rootDir) 28 | return wavs, phns 29 | 30 | 31 | def loadWavs(rootDir): 32 | wav_files = [] 33 | for dirpath, dirs, files in os.walk(rootDir): 34 | for f in files: 35 | if (f.lower().endswith(".wav")): 36 | wav_files.append(os.path.join(dirpath, f)) 37 | return sorted(wav_files) 38 | 39 | 40 | def loadPhns(rootDir): 41 | phn_files = [] 42 | for dirpath, dirs, files in os.walk(rootDir): 43 | for f in files: 44 | if (f.lower().endswith(".phn")): 45 | phn_files.append(os.path.join(dirpath, f)) 46 | return sorted(phn_files) 47 | 48 | 49 | def pad_sequences_X(sequences, maxlen=None, padding='post', truncating='post', value=0.): 50 | """ 51 | Pad each sequence to the same length: 52 | the length of the longuest sequence. 53 | If maxlen is provided, any sequence longer 54 | than maxlen is truncated to maxlen. Truncation happens off either the beginning (default) or 55 | the end of the sequence. 56 | Supports post-padding and pre-padding (default). 57 | """ 58 | lengths = [len(s) for s in sequences] 59 | 60 | nb_samples = len(sequences) 61 | if maxlen is None: 62 | maxlen = np.max(lengths) 63 | 64 | # try-except to distinguish between X and y 65 | datatype = type(sequences[0][0][0]); 66 | logger_GeneralTools.debug('X data: %s, %s, %s', type(sequences[0][0]), sequences[0][0].shape, sequences[0][0]) 67 | 68 | xSize = nb_samples; 69 | ySize = maxlen; 70 | zSize = sequences[0].shape[1]; 71 | # sequences = [[np.reshape(subsequence, (subsequence.shape[0], 1)) for subsequence in sequence] for sequence in sequences] 72 | 73 | logger_GeneralTools.debug('new dimensions: %s, %s, %s', xSize, ySize, zSize) 74 | logger_GeneralTools.debug('intermediate matrix, estimated_size: %s', 75 | xSize * ySize * zSize * np.dtype(datatype).itemsize) 76 | 77 | x = (np.ones((xSize, ySize, zSize)) * value).astype(datatype) 78 | 79 | for idx, s in enumerate(sequences): 80 | if truncating == 'pre': 81 | trunc = s[-maxlen:] 82 | elif truncating == 'post': 83 | trunc = s[:maxlen] 84 | else: 85 | raise ValueError("Truncating type '%s' not understood" % padding) 86 | 87 | if padding == 'post': 88 | x[idx, :len(trunc), :] = trunc 89 | elif padding == 'pre': 90 | x[idx, -len(trunc):, :] = np.array(trunc, dtype='float32') 91 | else: 92 | raise ValueError("Padding type '%s' not understood" % padding) 93 | 94 | return x 95 | 96 | 97 | def pad_sequences_y(sequences, maxlen=None, padding='post', truncating='post', value=0.): 98 | """ 99 | Pad each sequence to the same length: 100 | the length of the longuest sequence. 101 | If maxlen is provided, any sequence longer 102 | than maxlen is truncated to maxlen. Truncation happens off either the beginning (default) or 103 | the end of the sequence. 104 | Supports post-padding and pre-padding (default). 105 | """ 106 | lengths = [len(s) for s in sequences] 107 | 108 | nb_samples = len(sequences) 109 | if maxlen is None: maxlen = np.max(lengths) 110 | 111 | datatype = type(sequences[0][0]); 112 | logger_GeneralTools.debug('Y data: %s, %s, %s', type(sequences[0]), sequences[0].shape, sequences[0]) 113 | 114 | xSize = nb_samples; 115 | ySize = maxlen; 116 | # sequences = [np.reshape(sequence, (sequence.shape[0], 1)) for sequence in sequences] 117 | 118 | logger_GeneralTools.debug('new dimensions: %s, %s', xSize, ySize) 119 | logger_GeneralTools.debug('intermediate matrix, estimated_size: %s', xSize * ySize * np.dtype(datatype).itemsize) 120 | 121 | y = (np.ones((xSize, ySize)) * value).astype(datatype) 122 | 123 | for idx, s in enumerate(sequences): 124 | if truncating == 'pre': 125 | trunc = s[-maxlen:] 126 | elif truncating == 'post': 127 | trunc = s[:maxlen] 128 | else: 129 | raise ValueError("Truncating type '%s' not understood" % padding) 130 | 131 | if padding == 'post': 132 | y[idx, :len(trunc)] = trunc 133 | elif padding == 'pre': 134 | y[idx, -len(trunc):] = np.array(trunc, dtype='float32') 135 | else: 136 | raise ValueError("Padding type '%s' not understood" % padding) 137 | 138 | return y 139 | 140 | 141 | def generate_masks(inputs, valid_frames=None, batch_size=1, max_length=1000, 142 | logger=logger_GeneralTools): # inputs = X. valid_frames = list of frames when we need to extract the phoneme 143 | ## all recurrent layers in lasagne accept a separate mask input which has shape 144 | # (batch_size, n_time_steps), which is populated such that mask[i, j] = 1 when j <= (length of sequence i) and mask[i, j] = 0 when j > (length 145 | # of sequence i). When no mask is provided, it is assumed that all sequences in the minibatch are of length n_time_steps. 146 | logger.debug("* Data information") 147 | logger.debug('%s %s', type(inputs), len(inputs)) 148 | logger.debug('%s %s', type(inputs[0]), inputs[0].shape) 149 | logger.debug('%s %s', type(inputs[0][0]), inputs[0][0].shape) 150 | logger.debug('%s', type(inputs[0][0][0])) 151 | 152 | # max_input_length = max([len(inputs[i]) for i in range(len(inputs))]) 153 | max_length = max([len(inputs[i]) for i in range(len(inputs))]) 154 | input_dim = len(inputs[0][0]) 155 | 156 | logger.debug("max_seq_len: %d", max_length) 157 | logger.debug("input_dim: %d", input_dim) 158 | 159 | # X = np.zeros((batch_size, max_input_length, input_dim)) 160 | input_mask = np.zeros((batch_size, max_length), dtype='float32') 161 | 162 | for example_id in range(len(inputs)): 163 | try: 164 | if valid_frames != None: 165 | # Sometimes phonemes are so close to each other that all are mapped to last frame -> gives error 166 | if valid_frames[example_id][-1] >= max_length: valid_frames[example_id][-1] = max_length - 1 167 | if valid_frames[example_id][-2] >= max_length: valid_frames[example_id][-1] = max_length - 1 168 | if valid_frames[example_id][-3] >= max_length: valid_frames[example_id][-1] = max_length - 1 169 | 170 | input_mask[example_id, valid_frames[example_id]] = 1 171 | else: 172 | logger.warning("NO VALID FRAMES SPECIFIED!!!") 173 | # raise Exception("NO VALID FRAMES SPECIFIED!!!") 174 | 175 | logger.debug('%d', example_id) 176 | curr_seq_len = len(inputs[example_id]) 177 | logger.debug('%d', curr_seq_len) 178 | input_mask[example_id, :curr_seq_len] = 1 179 | except Exception as e: 180 | print("Couldn't do it: %s" % e) 181 | import pdb; 182 | pdb.set_trace() 183 | 184 | return input_mask 185 | 186 | 187 | def query_yes_no(question, default="yes"): 188 | """Ask a yes/no question via raw_input() and return their answer. 189 | 190 | "question" is a string that is presented to the user. 191 | "default" is the presumed answer if the user just hits . 192 | It must be "yes" (the default), "no" or None (meaning 193 | an answer is required of the user). 194 | 195 | The "answer" return value is True for "yes" or False for "no". 196 | """ 197 | valid = { 198 | "yes": True, "y": True, "ye": True, 199 | "no": False, "n": False 200 | } 201 | if default is None: 202 | prompt = " [y/n] " 203 | elif default == "yes": 204 | prompt = " [Y/n] " 205 | elif default == "no": 206 | prompt = " [y/N] " 207 | else: 208 | raise ValueError("invalid default answer: '%s'" % default) 209 | 210 | while True: 211 | sys.stdout.write(question + prompt) 212 | choice = raw_input().lower() 213 | if default is not None and choice == '': 214 | return valid[default] 215 | elif choice in valid: 216 | return valid[choice] 217 | else: 218 | sys.stdout.write("Please respond with 'yes' or 'no' " 219 | "(or 'y' or 'n').\n") 220 | 221 | 222 | def saveToPkl(target_path, data): # data can be list or dictionary 223 | if not os.path.exists(os.path.dirname(target_path)): 224 | os.makedirs(os.path.dirname(target_path)) 225 | with open(target_path, 'wb') as cPickle_file: 226 | cPickle.dump( 227 | data, 228 | cPickle_file, 229 | protocol=cPickle.HIGHEST_PROTOCOL) 230 | return 0 231 | 232 | 233 | def depth(path): 234 | return path.count(os.sep) 235 | 236 | 237 | # stuff for getting relative paths between two directories 238 | def pathsplit(p, rest=[]): 239 | (h, t) = os.path.split(p) 240 | if len(h) < 1: return [t] + rest 241 | if len(t) < 1: return [h] + rest 242 | return pathsplit(h, [t] + rest) 243 | 244 | 245 | def commonpath(l1, l2, common=[]): 246 | if len(l1) < 1: return (common, l1, l2) 247 | if len(l2) < 1: return (common, l1, l2) 248 | if l1[0] != l2[0]: return (common, l1, l2) 249 | return commonpath(l1[1:], l2[1:], common + [l1[0]]) 250 | 251 | 252 | # p1 = main path, p2= the one you want to get the relative path of 253 | def relpath(p1, p2): 254 | (common, l1, l2) = commonpath(pathsplit(p1), pathsplit(p2)) 255 | p = [] 256 | if len(l1) > 0: 257 | p = ['../' * len(l1)] 258 | p = p + l2 259 | return os.path.join(*p) 260 | 261 | 262 | # need this to traverse directories, find depth 263 | def directories(root): 264 | dirList = [] 265 | for path, folders, files in os.walk(root): 266 | for name in folders: 267 | dirList.append(os.path.join(path, name)) 268 | return sorted(dirList) 269 | 270 | 271 | def tryint(s): 272 | try: 273 | return int(s) 274 | except ValueError: 275 | return s 276 | 277 | 278 | def alphanum_key(s): 279 | return [tryint(c) for c in re.split('([0-9]+)', s)] 280 | 281 | 282 | def sort_nicely(l): 283 | return sorted(l, key=alphanum_key) 284 | 285 | 286 | def set_type(X, type): 287 | for i in range(len(X)): 288 | X[i] = X[i].astype(type) 289 | return X 290 | -------------------------------------------------------------------------------- /tools/general_tools.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/general_tools.pyc -------------------------------------------------------------------------------- /tools/helpFunctions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/helpFunctions/__init__.py -------------------------------------------------------------------------------- /tools/helpFunctions/copyFilesOfType.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import os.path 5 | import re 6 | import shutil 7 | import sys 8 | 9 | protocolPattern = re.compile('r&apos^\w+://&apos') 10 | 11 | 12 | def pathsplit(path): 13 | """ This version, in contrast to the original version, permits trailing 14 | slashes in the pathname (in the event that it is a directory). 15 | It also uses no recursion """ 16 | return path.split(os.path.sep) 17 | 18 | 19 | def commonpath(l1, l2, common=[]): 20 | if len(l1) < 1: return (common, l1, l2) 21 | if len(l2) < 1: return (common, l1, l2) 22 | if l1[0] != l2[0]: return (common, l1, l2) 23 | return commonpath(l1[1:], l2[1:], common + [l1[0]]) 24 | 25 | 26 | def relpath(p1, p2): 27 | (common, l1, l2) = commonpath(pathsplit(p1), pathsplit(p2)) 28 | p = [] 29 | if len(l1) > 0: 30 | p = ['../' * len(l1)] 31 | p = p + l2 32 | return os.path.join(*p) 33 | 34 | 35 | def isabs(string): 36 | if protocolPattern.match(string): return 1 37 | return os.path.isabs(string) 38 | 39 | 40 | def rel2abs(path, base=os.curdir): 41 | if isabs(path): return path 42 | retval = os.path.join(base, path) 43 | return os.path.abspath(retval) 44 | 45 | 46 | def abs2rel(path, base=os.curdir): # return a relative path from base to path. 47 | if protocolPattern.match(path): return path 48 | base = rel2abs(base) 49 | path = rel2abs(path) # redundant - should already be absolute 50 | return relpath(base, path) 51 | 52 | 53 | def test(p1, p2): 54 | print("from", p1, "to", p2, " -> ", 55 | relpath(p1, p2)) # this is what I need. p1 = AbsDirPath; p2 = AbsfilePath; out = filepathRelToDir 56 | print("from", p1, "to", p2, " -> ", rel2abs(p1, p2)) 57 | print("from", p1, "to", p2, " -> ", abs2rel(p1, p2)) 58 | 59 | 60 | def query_yes_no(question, default="yes"): 61 | """Ask a yes/no question via raw_input() and return their answer. 62 | 63 | "question" is a string that is presented to the user. 64 | "default" is the presumed answer if the user just hits . 65 | It must be "yes" (the default), "no" or None (meaning 66 | an answer is required of the user). 67 | 68 | The "answer" return value is True for "yes" or False for "no". 69 | """ 70 | valid = { 71 | "yes": True, "y": True, "ye": True, 72 | "no": False, "n": False 73 | } 74 | if default is None: 75 | prompt = " [y/n] " 76 | elif default == "yes": 77 | prompt = " [Y/n] " 78 | elif default == "no": 79 | prompt = " [y/N] " 80 | else: 81 | raise ValueError("invalid default answer: '%s'" % default) 82 | 83 | while True: 84 | sys.stdout.write(question + prompt) 85 | choice = raw_input().lower() 86 | if default is not None and choice == '': 87 | return valid[default] 88 | elif choice in valid: 89 | return valid[choice] 90 | else: 91 | sys.stdout.write("Please respond with 'yes' or 'no' " 92 | "(or 'y' or 'n').\n") 93 | 94 | 95 | def copyFilesOfType(srcDir, dstDir, extension, interactive=False): 96 | print("Source Dir: %s, Destination Dir: %s, Extension: %s" % (srcDir, dstDir, extension)) 97 | 98 | src = [] 99 | dest = [] 100 | for root, dirs, files in os.walk(srcDir): 101 | for file_ in files: 102 | # print(file_) 103 | if file_.lower().endswith(extension): 104 | srcPath = os.path.join(root, file_) 105 | relSrcPath = relpath(srcDir, srcPath).lstrip("../") 106 | # print(relSrcPath) 107 | destPath = os.path.join(dstDir, relSrcPath) 108 | # print("copying from : %s to \t\t %s" % (srcPath, destPath)) 109 | src.append(srcPath) 110 | dest.append(destPath) 111 | 112 | print("Example: copying ", src[0], "to:", dest[0]) 113 | print(len(src), " files will be copied in total") 114 | 115 | if (interactive and (not query_yes_no("Are you sure you want to peform these operations?", "yes"))): 116 | print("Not doing a thing.") 117 | else: 118 | for i in range(len(src)): 119 | if (not os.path.exists(os.path.dirname(dest[i]))): 120 | os.makedirs(os.path.dirname(dest[i])) 121 | shutil.copy(src[i], dest[i]) 122 | print("Done.") 123 | 124 | return 0 125 | 126 | 127 | if __name__ == '__main__': 128 | if __name__ == '__main__': 129 | srcDir = sys.argv[1] 130 | dstDir = sys.argv[2] 131 | type = sys.argv[3] 132 | copyFilesOfType(srcDir, dstDir, type, interactive=True) 133 | 134 | # example usage: python copyFilesOfType.py ~/Documents/Dropbox/_MyDocs/_ku_leuven/Master_2/Thesis/convNets/code /media/matthijs/TOSHIBA_EXT/TCDTIMIT/zzzNPZmodels ".npz" 135 | -------------------------------------------------------------------------------- /tools/helpFunctions/progress_bar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | # def update_progress(amtDone): 5 | # sys.stdout.write("\rProgress: [{0:50s}] {1:.1f}%".format('#' * int(amtDone * 50), amtDone * 100)) 6 | # sys.stdout.flush() 7 | 8 | def show_progress(frac_done, bar_length=20): 9 | # for i in range(end_val): 10 | hashes = '#' * int(round(frac_done * bar_length)) 11 | spaces = ' ' * (bar_length - len(hashes)) 12 | sys.stdout.write("\rProgress: [{0}] {1}% ".format(hashes + spaces, int(round(frac_done * 100)))) 13 | sys.stdout.flush() 14 | 15 | 16 | if __name__ == '__main__': 17 | show_progress(0.8) 18 | -------------------------------------------------------------------------------- /tools/helpFunctions/removeEmptyDirs.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ''' 3 | Module to remove empty folders recursively. Can be used as standalone script or be imported into existing script. 4 | ''' 5 | 6 | import os 7 | import sys 8 | 9 | 10 | def removeEmptyFolders(path, removeRoot=True): 11 | 'Function to remove empty folders' 12 | if not os.path.isdir(path): 13 | return 14 | 15 | # remove empty subfolders 16 | files = os.listdir(path) 17 | if len(files): 18 | for f in files: 19 | fullpath = os.path.join(path, f) 20 | if os.path.isdir(fullpath): 21 | removeEmptyFolders(fullpath) 22 | 23 | # if folder empty, delete it 24 | files = os.listdir(path) 25 | if len(files) == 0 and removeRoot: 26 | print "Removing empty folder:", path 27 | os.rmdir(path) 28 | 29 | 30 | def usageString(): 31 | 'Return usage string to be output in error cases' 32 | return 'Usage: %s directory [removeRoot]' % sys.argv[0] 33 | 34 | 35 | if __name__ == "__main__": 36 | removeRoot = False 37 | 38 | if len(sys.argv) < 1: 39 | print "Not enough arguments" 40 | sys.exit(usageString()) 41 | 42 | if not os.path.isdir(sys.argv[1]): 43 | print "No such directory %s" % sys.argv[1] 44 | sys.exit(usageString()) 45 | 46 | if len(sys.argv) == 2 and sys.argv[2] == "True": 47 | removeRoot = True 48 | 49 | removeEmptyFolders(sys.argv[1], removeRoot) 50 | -------------------------------------------------------------------------------- /tools/helpFunctions/resample.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Apr 7, 2011 3 | by Uri Nieto 4 | uri@urinieto.com 5 | 6 | Sample Rate converter from 48kHz to 44.1kHz. 7 | 8 | USAGE: 9 | $>python resample.py -i input.wav [-o output.wav -q [0.0-1.0]] 10 | 11 | EXAMPLES: 12 | $>python resample.py -i onades.wav 13 | $>python resample.py -i onades.wav -o out.wav 14 | $>python resample.py -i onades-mono.wav -q 0.8 15 | $>python resample.py -i onades.wav -o out3.wav -q 0.5 16 | 17 | DESCRIPTION: 18 | The input has to be a WAV file sampled at 48 kHz/sec 19 | with a resolution of 16 bits/sample. It can have n>0 20 | number of channels (i.e. 1=mono, 2=stereo, ...). 21 | 22 | The output will be a WAV file sampled at 44.1 kHz/sec 23 | with a resolution of 16 bits/sample, and the same 24 | number of channels as the input. 25 | 26 | A quality parameter q can be provided (>0.0 to 1.0), and 27 | it will modify the number of zero crossings of the filter, 28 | making the output quality best when q = 1.0 and very bad 29 | as q tends to 0.0 30 | 31 | The sample rate factor is: 32 | 33 | 44100 147 34 | ------- = ----- 35 | 48000 160 36 | 37 | To do the conversion, we upsample by 147, low pass filter, 38 | and downsample by 160 (in this order). This is done by 39 | using an efficient polyphase filter bank with resampling 40 | algorithm proposed by Vaidyanathan in [2]. 41 | 42 | The low pass filter is an impulse response windowed 43 | by a Kaiser window to have a better filtering 44 | (around -60dB in the rejection band) [1]. 45 | 46 | As a comparison between the Kaiser Window and the 47 | Rectangular Window, this algorithm plotted the following 48 | images included with this package: 49 | 50 | KaiserIR.png 51 | KaiserFR.png 52 | RectIR.png 53 | RectFR.png 54 | 55 | The images show the Impulse Responses and the Frequency 56 | Responses of the Kaiser and Rectangular Windows. As it can be 57 | clearly seen, the Kaiser window has a gain of around -60dB in the 58 | rejection band, whereas the Rect window has a gain of around -20dB 59 | and smoothly decreasing to -40dB. Thus, the Kaiser window 60 | method is rejecting the aliasing much better than the Rect window. 61 | 62 | The Filter Design is performed in the function: 63 | designFIR() 64 | 65 | The Upsampling, Filtering, and Downsampling is 66 | performed in the function: 67 | upSampleFilterDownSample() 68 | 69 | Also included in the package are two wav files sampled at 48kHz 70 | with 16bits/sample resolution. One is stereo and the other mono: 71 | 72 | onades.wav 73 | onades-mono.wav 74 | 75 | NOTES: 76 | You need numpy and scipy installed to run this script. 77 | You can find them here: 78 | http://numpy.scipy.org/ 79 | 80 | You may want to have matplotlib as well if you want to 81 | print the plots by yourself (commented right now) 82 | 83 | This code would be much faster on C or C++, but my decision 84 | on using Python was to make the code more readable (yet useful) 85 | rather than focusing on performance. 86 | 87 | @author: uri 88 | 89 | REFERENCES: 90 | [1]: Smith, J.O., "Spectral Audio Signal Processing", 91 | W3K Publishing, March 2010 92 | 93 | [2] Vaidyanathan, P.P., "Multirate Systems and Filter Banks", 94 | Prentice Hall, 1993. 95 | 96 | COPYRIGHT NOTES: 97 | 98 | Copyright (C) 2011, Uri Nieto 99 | 100 | This program is free software: you can redistribute it and/or modify 101 | it under the terms of the GNU General Public License as published by 102 | the Free Software Foundation, either version 3 of the License, or 103 | any later version. 104 | 105 | This program is distributed in the hope that it will be useful, 106 | but WITHOUT ANY WARRANTY; without even the implied warranty of 107 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 108 | GNU General Public License for more details. 109 | 110 | You should have received a copy of the GNU General Public License 111 | along with this program. If not, see . 112 | 113 | ''' 114 | 115 | import sys 116 | import time as tm 117 | from decimal import Decimal 118 | from fractions import Fraction 119 | from shutil import copyfile 120 | 121 | import numpy as np 122 | from scipy.io import wavfile 123 | 124 | # import matplotlib.pyplot as plt #Uncomment to plot 125 | 126 | 127 | ''' 128 | h = flipTranspose(h, L, cpp) 129 | ... 130 | Desc: Flips and Transposes the impulse response h, dividing 131 | it into L different phases with cpp coefficients per phase. 132 | Following as described in [2]. 133 | ... 134 | h: Impulse response to flip and Transpose 135 | L: Upsampling Factor 136 | cpp: Coeficients per Phase 137 | return hh: h flipped and transposed following the descritpion 138 | ''' 139 | 140 | 141 | def flipTranspose(h, L, cpp): 142 | # Get the impulse response size 143 | N = len(h) 144 | 145 | # Init the output to 0 146 | hh = np.zeros(N) 147 | 148 | # Flip and Transpose: 149 | for i in range(L): 150 | hh[cpp - 1 + i * cpp:-N - 1 + i * cpp:-1] = h[i:cpp * L:L] 151 | 152 | return hh 153 | 154 | 155 | ''' 156 | h = upSampleFilterDownSample(x, h, L, M) 157 | ... 158 | Desc: Upsamples the input x by L, filters it out using h, and 159 | downsamples it by M. 160 | 161 | The algorithm is based on the "efficient polyphase filter bank 162 | with resampling" found on page 129 of the book [2] (Figure 4.3-8d). 163 | 164 | ... 165 | x: input signal 166 | h: impulse response (assumes it has the correct cut-off freq) 167 | L: Upsampling Factor 168 | M: Downsampling Factor 169 | returns y: output signal (x upsampled, filtered, and downsampled) 170 | ''' 171 | 172 | 173 | def upSampleFilterDownSample(x, h, L, M, printing=False): 174 | # Number of samples to convert 175 | N = len(x) 176 | 177 | # Compute the number of coefficients per phase 178 | cpp = len(h) / L 179 | 180 | # Flip and Transpose the impulse response 181 | h = flipTranspose(h, L, cpp) 182 | 183 | # Check number of channels 184 | if (np.shape(np.shape(x)) == (2,)): 185 | nchan = np.shape(x)[1] 186 | y = np.zeros(int((np.ceil(N * L / float(M)), nchan))) 187 | else: 188 | nchan = 1 189 | y = np.zeros(int(np.ceil(N * L / float(M)))) 190 | 191 | # Init the output index 192 | y_i = 0 193 | 194 | # Init the phase index 195 | phase_i = 0 196 | 197 | # Init the main loop index 198 | i = 0 199 | 200 | # Main Loop 201 | while i < N: 202 | 203 | # Print % completed 204 | if (printing and (i % 30000 == 0)): 205 | print("%.2f %% completed" % float(100 * i / float(len(x)))) 206 | 207 | # Compute the filter index 208 | h_i = phase_i * cpp 209 | 210 | # Compute the input index 211 | x_i = i - cpp + 1; 212 | 213 | # Update impulse index if needed (offset) 214 | if x_i < 0: 215 | h_i -= x_i 216 | x_i = 0 217 | 218 | # Compute the current output sample 219 | rang = i - x_i + 1 220 | if nchan == 1: 221 | y[y_i] = np.sum(x[x_i:x_i + rang] * h[h_i:h_i + rang]) 222 | else: 223 | for c in range(nchan): 224 | y[y_i, c] = np.sum(x[x_i:x_i + rang, c] * h[h_i:h_i + rang]) 225 | 226 | # Add the downsampling factor to the phase index 227 | phase_i += M 228 | 229 | # Compute the increment for the index of x with the new phase 230 | x_incr = phase_i / int(L) 231 | 232 | # Update phase index 233 | phase_i %= L 234 | 235 | # Update the main loop index 236 | i += x_incr 237 | 238 | # Update the output index 239 | y_i += 1 240 | 241 | return y 242 | 243 | 244 | ''' 245 | h = impulse(M, L) 246 | ... 247 | M: Impulse Response Size 248 | T: Sampling Period 249 | returns h: The impulse response 250 | ''' 251 | 252 | 253 | def impulse(M, T): 254 | # Create time array 255 | n = np.arange(-(M - 1) / 2, (M - 1) / 2 + 1) 256 | 257 | # Compute the impulse response using the sinc function 258 | h = (1 / T) * np.sinc((1 / T) * n) 259 | 260 | return h 261 | 262 | 263 | ''' 264 | b = bessel(x) 265 | ... 266 | Desc: Zero-order modified Bessel function of the first kind, with 267 | approximation using the Maclaurin series, as described in [1] 268 | ... 269 | x: Input sample 270 | b: Zero-order modified Bessel function of the first kind 271 | ''' 272 | 273 | 274 | def bessel(x): 275 | return np.power(np.exp(x), 2); 276 | 277 | 278 | ''' 279 | k = kaiser(M, beta) 280 | ... 281 | Desc: Generates an M length Kaiser window with the 282 | specified beta parameter. Following instructions in [1] 283 | ... 284 | M: Number of samples of the window 285 | beta: Beta parameter of the Kaiser Window 286 | k: array(M,1) containing the Kaiser window with the specified beta 287 | ''' 288 | 289 | 290 | def kaiser(M, beta): 291 | # Init Kaiser Window 292 | k = np.zeros(M) 293 | 294 | # Compute each sample of the Kaiser Window 295 | i = 0 296 | for n in np.arange(-(M - 1) / 2, (M - 1) / 2 + 1): 297 | samp = beta * np.sqrt(1 - np.power((n / (M / 2.0)), 2)) 298 | samp = bessel(samp) / float(bessel(beta)) 299 | k[i] = samp 300 | i = i + 1 301 | 302 | return k 303 | 304 | 305 | ''' 306 | h = designFIR(N, L, M) 307 | ... 308 | Desc: Designs a low pass filter to perform the conversion of 309 | sampling frequencies given the upsampling and downsampling factors. 310 | It uses the Kaiser window to better filter out aliasing. 311 | ... 312 | N: Maximum size of the Impulse Response of the FIR 313 | L: Upsampling Factor 314 | M: Downsampling Factor 315 | returns h: Impulse Response of the FIR 316 | ''' 317 | 318 | 319 | def designFIR(N, L, M): 320 | # Get the impulse response with the right Sampling Period 321 | h0 = impulse(N, float(M)) 322 | 323 | # Compute a Kaiser Window 324 | alpha = 2.5 # Alpha factor for the Kaiser Window 325 | k = kaiser(N, alpha * np.pi) 326 | 327 | # Window the impulse response with the Kaiser window 328 | h = h0 * k 329 | 330 | # Filter Gain 331 | h = h * L 332 | 333 | # Reduce window by removing almost 0 values to improve filtering 334 | for i in range(len(h)): 335 | if abs(h[i]) > 1e-3: 336 | for j in range(i, 0, -1): 337 | if abs(h[j]) < 1e-7: 338 | h = h[j:len(h) - j] 339 | break 340 | break 341 | 342 | ''' 343 | # For plotting purposes: 344 | N = len(h) 345 | Hww = fft(h, N) 346 | Hww = Hww[0:N/2.0] 347 | Hwwdb = 20.0*np.log10(np.abs(Hww)) 348 | Hw = fft(h0, N) 349 | Hw = Hw[0:N/2.0] 350 | Hwdb = 20.0*np.log10(np.abs(Hw)) 351 | 352 | plt.figure(1) 353 | plt.plot(h) 354 | plt.title('Kaiser Windowed Impulse Response of the low pass filter') 355 | plt.show() 356 | 357 | plt.figure(2) 358 | plt.plot(h0) 359 | plt.title('Rect Windowd Impulse Response of the low pass filter') 360 | plt.show() 361 | 362 | #print np.shape(np.arange(0, N/2.0-1)/float(N)), np.shape(Hwwdb) 363 | 364 | plt.figure(3) 365 | plt.plot(np.arange(0, N/2.0-1)/float(N),Hwwdb) 366 | plt.xlabel('Normalized Frequency'); 367 | plt.ylabel('Magnitude (dB)'); 368 | plt.title('Amplitude Response, Kaiser window with beta = ' + str(2.5*np.pi)); 369 | plt.show() 370 | 371 | plt.figure(4) 372 | plt.plot(np.arange(0, N/2.0-1)/float(N),Hwdb) 373 | plt.xlabel('Normalized Frequency'); 374 | plt.ylabel('Magnitude (dB)'); 375 | plt.title('Amplitude Response using a Rect Window'); 376 | plt.show() 377 | ''' 378 | 379 | return h 380 | 381 | 382 | ''' 383 | dieWithUsage() 384 | ... 385 | Desc: Stops program and prints usage 386 | ''' 387 | 388 | 389 | def dieWithUsage(): 390 | usage = """ 391 | USAGE: $>python resample.py -i input.wav [-o output.wav -q (0.0-1.0]] 392 | 393 | input.wav has to be sampled at 48kHz with 16bits/sample 394 | If no output file is provided, file will be written in "./output.wav" 395 | If no quality param is provided, 1.0 (max) will be used 396 | 397 | Description: Converts sampling frequency of input.wav from 48kHz to 16kHz and writes it into output.wav 398 | """ 399 | print 400 | usage 401 | sys.exit(1) 402 | 403 | 404 | ''' 405 | Main 406 | ... 407 | Desc: Reads the input wave file, designs the filter, 408 | upsamples, filters, and downsamples the input, and 409 | finally writes it to the output wave file. 410 | ''' 411 | 412 | 413 | def resampleWAV(inFile, outFile="output.wav", out_fr=16000.0, q=0.0): 414 | # Parse arguments 415 | inPath, outPath, out_fr, q = inFile, outFile, out_fr, q 416 | 417 | # Read input wave 418 | in_fr, in_data = wavfile.read(inPath) 419 | in_nbits = in_data.dtype 420 | 421 | # Time it 422 | start_time = tm.time() 423 | 424 | # Set output wave parameters 425 | out_nbits = in_nbits 426 | 427 | frac = Fraction(Decimal(str(float(in_fr) / out_fr))).limit_denominator(1000) 428 | 429 | if (float(frac) < 1.0): 430 | print("input file smaller sampling rate than output..") 431 | print("input: ", in_fr, "Hz. Output: ", out_fr, "Hz") 432 | 433 | elif (float(frac) == 1.0): 434 | try: 435 | copyfile(inPath, outPath) 436 | print("Input file", inPath, "already at correct sampling rate; just copying") 437 | except Exception, e: 438 | print(e.args) 439 | else: 440 | print("some weird error", inPath, outPath) 441 | return -1 442 | else: 443 | L = frac.denominator # Upsampling Factor 444 | M = frac.numerator # Downsampling Factor 445 | Nz = int(25 * q) # Max Number of Zero Crossings (depending on quality) 446 | h = designFIR(Nz * L, L, M) 447 | 448 | # Upsample, Filter, and Downsample 449 | out_data = upSampleFilterDownSample(in_data, h, L, M, printing=False) # control progression output 450 | 451 | # Make sure the output is 16 bits 452 | out_data = out_data.astype(out_nbits) 453 | 454 | # Write the output wave 455 | wavfile.write(outPath, out_fr, out_data) 456 | 457 | # Print Results 458 | duration = float(tm.time() - start_time) 459 | print( 460 | "File", outPath, " was successfully written. Resampled from ", in_fr, "to", out_fr, "in", duration, 461 | " seconds") 462 | return 0 463 | -------------------------------------------------------------------------------- /tools/helpFunctions/resampleExperiment.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | ########## RESAMPLING ############### 3 | ###################################### 4 | from resample import * 5 | from wavToPng import * 6 | 7 | inFile = "sa1.wav" 8 | outFile44 = "sa1_out44.wav" 9 | outFile20 = "sa1_out20.wav" 10 | outFile16 = "sa1_out16.wav" 11 | outFile8 = "sa1_out8.wav" 12 | outFile4 = "sa1_out4.wav" 13 | outFile16_from44 = "sa1_out16_from44.wav" 14 | 15 | resampleWAV(inFile, outFile44, out_fr=44100, q=1.0) 16 | resampleWAV(inFile, outFile20, out_fr=20000, q=1.0) 17 | resampleWAV(inFile, outFile16, out_fr=16000, q=1.0) 18 | resampleWAV(inFile, outFile16, out_fr=16000, q=1.0) 19 | resampleWAV(inFile, outFile8, out_fr=8000, q=1.0) 20 | resampleWAV(inFile, outFile4, out_fr=4000, q=1.0) 21 | resampleWAV(outFile44, outFile16_from44, out_fr=16000, q=1.0) 22 | 23 | wavToPng("sa1.wav") 24 | 25 | ###################################### 26 | ############## Fractions ############# 27 | ###################################### 28 | # from fractions import Fraction 29 | # from decimal import Decimal 30 | # print Fraction(Decimal('1.4')) 31 | # 32 | # a= Fraction(Decimal(str(48000/44100.0))).limit_denominator(1000) 33 | # print a.numerator 34 | # print a.denominator 35 | # print(type(a.numerator)) 36 | -------------------------------------------------------------------------------- /tools/helpFunctions/sa1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/helpFunctions/sa1.wav -------------------------------------------------------------------------------- /tools/helpFunctions/visualizeMFC.m: -------------------------------------------------------------------------------- 1 | % EXAMPLE Simple demo of the MFCC function usage. 2 | % 3 | % This script is a step by step walk-through of computation of the 4 | % mel frequency cepstral coefficients (MFCCs) from a speech signal 5 | % using the MFCC routine. 6 | % 7 | % See also MFCC, COMPARE. 8 | 9 | % Author: Kamil Wojcicki, September 2011 10 | % see https://nl.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab 11 | % installation of toolbox: 12 | % Download zip, extract to Matlab bin/MATLAB/R2016b/toolbox 13 | % Then click 'Home', 'Set Path'; Add the 'toolbox/mfcc' folder; 14 | % Click 'Save'. 15 | 16 | format short g 17 | 18 | % Clean-up MATLAB's environment 19 | clear all; close all; clc; 20 | 21 | % Define variables 22 | Tw = 25; % analysis frame duration (ms) 23 | Ts = 10; % analysis frame shift (ms) 24 | alpha = 0.97; % preemphasis coefficient 25 | M = 20; % number of filterbank channels 26 | C = 12; % number of cepstral coefficients 27 | L = 22; % cepstral sine lifter parameter 28 | LF = 300; % lower frequency limit (Hz) 29 | HF = 3700; % upper frequency limit (Hz) 30 | wav_file = 'si650.wav'; % input audio filename 31 | mfc_file = 'si650.mfc'; 32 | 33 | % Read MFC file for comparison (see bottom) 34 | htkmfc = readhtk(mfc_file); 35 | 36 | % Read speech samples, sampling rate and precision from file 37 | [ speech, fs] = audioread( wav_file ); 38 | info = audioinfo(wav_file); 39 | nbits = info.BitsPerSample 40 | 41 | % Feature extraction (feature vectors as columns) 42 | [ MFCCs, FBEs, frames ] = ... 43 | mfcc( speech, fs, Tw, Ts, alpha, @hamming, [LF HF], M, C+1, L ); 44 | 45 | disp(MFCCs) 46 | disp(frames) 47 | 48 | % Generate data needed for plotting 49 | [ Nw, NF ] = size( frames ); % frame length and number of frames 50 | time_frames = [0:NF-1]*Ts*0.001+0.5*Nw/fs; % time vector (s) for frames 51 | time = [ 0:length(speech)-1 ]/fs; % time vector (s) for signal samples 52 | logFBEs = 20*log10( FBEs ); % compute log FBEs for plotting 53 | logFBEs_floor = max(logFBEs(:))-50; % get logFBE floor 50 dB below max 54 | logFBEs( logFBEs /20 53 | while sound2.rms > targetRMS + min_acc: 54 | sound2 -= min_acc / 20.0 55 | 56 | # print(sound1.rms, targetRMS, sound2.rms) 57 | 58 | combined = sound1.overlay(sound2, loop=True) 59 | 60 | combined.export(out_path, format='wav') 61 | 62 | 63 | def generateBadAudio(outType, srcDir, dstDir, ratio_dB): 64 | # copy phoneme files 65 | copyFilesOfType(srcDir, dstDir, ".phn") 66 | 67 | # copy merged wav files 68 | noiseFile = createNoiseFile(ratio_dB) 69 | src_wavs = loadWavs(srcDir) 70 | for i in tqdm(range(len(src_wavs))): 71 | relSrcPath = relpath(srcDir, src_wavs[i]).lstrip("../") 72 | # print(relSrcPath) 73 | destPath = os.path.join(dstDir, relSrcPath) 74 | if outType == 'voices': 75 | # index of voice to merge 76 | j = random.randint(0, len(src_wavs) - 1) 77 | mergeAudioFiles(src_wavs[i], src_wavs[j], destPath, ratio_dB) 78 | else: 79 | mergeAudioFiles(src_wavs[i], noiseFile, destPath, ratio_dB) 80 | 81 | 82 | import random 83 | 84 | 85 | def createNoiseFile(ratio_dB, noise_path='noise.wav'): 86 | rate = 16000 87 | noise = np.random.normal(0, 1, rate * 3) # generate 3 seconds of white noise 88 | wav.write(noise_path, rate, noise) 89 | 90 | # change the volume to be ~ the TIMIT volume 91 | sound = AudioSegment.from_file(noise_path); 92 | loud2 = sound.dBFS 93 | while sound.dBFS > -30 + ratio_dB: 94 | sound -= 1 95 | 96 | sound.export(noise_path, format='wav') 97 | return os.path.abspath(noise_path) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /tools/phoneme_set.py: -------------------------------------------------------------------------------- 1 | # using the 39 phone set proposed in (Lee & Hon, 1989) 2 | # Table 3. Mapping from 61 classes to 39 classes, as proposed by Lee and Hon, (Lee & Hon, 3 | # 1989). The phones in the left column are folded into the labels of the right column. The 4 | # remaining phones are left intact. 5 | import logging 6 | 7 | logger_phonemeSet = logging.getLogger('phonemeSet') 8 | logger_phonemeSet.setLevel(logging.ERROR) 9 | 10 | phoneme_set_61_39 = { 11 | 'ao': 'aa', # 1 12 | 'ax': 'ah', # 2 13 | 'ax-h': 'ah', 14 | 'axr': 'er', # 3 15 | 'hv': 'hh', # 4 16 | 'ix': 'ih', # 5 17 | 'el': 'l', # 6 18 | 'em': 'm', # 6 19 | 'en': 'n', # 7 20 | 'nx': 'n', 21 | 'eng': 'ng', # 8 22 | 'zh': 'sh', # 9 23 | "ux": "uw", # 10 24 | "pcl": "sil", # 11 25 | "tcl": "sil", 26 | "kcl": "sil", 27 | "qcl": "sil", 28 | "bcl": "sil", 29 | "dcl": "sil", 30 | "gcl": "sil", 31 | "h#": "sil", 32 | "#h": "sil", 33 | "pau": "sil", 34 | "epi": "sil", 35 | "q": "sil", 36 | } 37 | 38 | # from https://www.researchgate.net/publication/275055833_TCD-TIMIT_An_audio-visual_corpus_of_continuous_speech 39 | phoneme_set_39_list = [ 40 | 'iy', 'ih', 'eh', 'ae', 'ah', 'uw', 'uh', 'aa', 'ey', 'ay', 'oy', 'aw', 'ow', # 13 phns 41 | 'l', 'r', 'y', 'w', 'er', 'm', 'n', 'ng', 'ch', 'jh', 'dh', 'b', 'd', 'dx', # 14 phns 42 | 'g', 'p', 't', 'k', 'z', 'v', 'f', 'th', 's', 'sh', 'hh', 'sil' # 12 pns 43 | ] 44 | values = [i for i in range(0, len(phoneme_set_39_list))] 45 | phoneme_set_39 = dict(zip(phoneme_set_39_list, values)) 46 | classToPhoneme39 = dict((v, k) for k, v in phoneme_set_39.iteritems()) 47 | 48 | # from http://www.intechopen.com/books/speech-technologies/phoneme-recognition-on-the-timit-database, page 5 49 | phoneme_set_61_list = [ 50 | 'iy', 'ih', 'eh', 'ey', 'ae', 'aa', 'aw', 'ay', 'ah', 'ao', 'oy', 'ow', 'uh', 'uw', 'ux', 'er', 'ax', 'ix', 'axr', 51 | 'ax-h', 'jh', 52 | 'ch', 'b', 'd', 'g', 'p', 't', 'k', 'dx', 's', 'sh', 'z', 'zh', 'f', 'th', 'v', 'dh', 'm', 'n', 'ng', 'em', 'nx', 53 | 'en', 'eng', 'l', 'r', 'w', 'y', 'hh', 'hv', 'el', 'bcl', 'dcl', 'gcl', 'pcl', 'tcl', 'kcl', 'q', 'pau', 'epi', 54 | 'h#', 55 | ] 56 | values = [i for i in range(0, len(phoneme_set_61_list))] 57 | phoneme_set_61 = dict(zip(phoneme_set_61_list, values)) 58 | 59 | 60 | def convertPredictions(predictions, phoneme_list=classToPhoneme39, valid_frames=None, outputType="phonemes"): 61 | # b is straight conversion to phoneme chars 62 | predictedPhonemes = [phoneme_list[predictedClass] for predictedClass in predictions] 63 | 64 | # c is reduced set of b: duplicates following each other are removed until only 1 is left 65 | reducedPhonemes = [] 66 | for j in range(len(predictedPhonemes) - 1): 67 | if predictedPhonemes[j] != predictedPhonemes[j + 1]: 68 | reducedPhonemes.append(predictedPhonemes[j]) 69 | 70 | # get only the outputs for valid phrames 71 | validPredictions = [predictedPhonemes[frame] for frame in valid_frames] 72 | 73 | # return class outputs 74 | if outputType != "phonemes": 75 | predictedPhonemes = [phoneme_set_39[phoneme] for phoneme in predictedPhonemes] 76 | reducedPhonemes = [phoneme_set_39[phoneme] for phoneme in reducedPhonemes] 77 | validPredictions = [phoneme_set_39[phoneme] for phoneme in validPredictions] 78 | 79 | return predictedPhonemes, reducedPhonemes, validPredictions 80 | -------------------------------------------------------------------------------- /tools/preprocessWavs.py: -------------------------------------------------------------------------------- 1 | import timeit; 2 | 3 | import numpy as np 4 | import scipy.io.wavfile as wav 5 | from tqdm import tqdm 6 | 7 | program_start_time = timeit.default_timer() 8 | import pdb 9 | import python_speech_features 10 | 11 | from phoneme_set import phoneme_set_39_list 12 | import transform 13 | 14 | nbPhonemes = 39 15 | phoneme_set_list = phoneme_set_39_list # import list of phonemes, 16 | # convert to dictionary with number mappings (see phoneme_set.py) 17 | values = [i for i in range(0, len(phoneme_set_list))] 18 | phoneme_classes = dict(zip(phoneme_set_list, values)) 19 | 20 | 21 | ## Functions ## 22 | def get_total_duration(file): 23 | """Get the length of the phoneme file, i.e. the 'time stamp' of the last phoneme""" 24 | for line in reversed(list(open(file))): 25 | [_, val, _] = line.split() 26 | return int(val) 27 | 28 | 29 | def create_mfcc(method, filename, type=2): 30 | """Perform standard preprocessing, as described by Alex Graves (2012) 31 | http://www.cs.toronto.edu/~graves/preprint.pdf 32 | Output consists of 12 MFCC and 1 energy, as well as the first derivative of these. 33 | [1 energy, 12 MFCC, 1 diff(energy), 12 diff(MFCC) 34 | 35 | method is a dummy input!!""" 36 | 37 | (rate, sample) = wav.read(filename) 38 | 39 | mfcc = python_speech_features.mfcc(sample, rate, winlen=0.025, winstep=0.01, numcep=13, nfilt=26, 40 | preemph=0.97, appendEnergy=True) 41 | out = mfcc 42 | if type > 13: 43 | derivative = np.zeros(mfcc.shape) 44 | for i in range(1, mfcc.shape[0] - 1): 45 | derivative[i, :] = mfcc[i + 1, :] - mfcc[i - 1, :] 46 | 47 | mfcc_derivative = np.concatenate((mfcc, derivative), axis=1) 48 | out = mfcc_derivative 49 | if type > 26: 50 | derivative2 = np.zeros(derivative.shape) 51 | for i in range(1, derivative.shape[0] - 1): 52 | derivative2[i, :] = derivative[i + 1, :] - derivative[i - 1, :] 53 | 54 | out = np.concatenate((mfcc, derivative, derivative2), axis=1) 55 | if type > 39: 56 | derivative3 = np.zeros(derivative2.shape) 57 | for i in range(1, derivative2.shape[0] - 1): 58 | derivative3[i, :] = derivative2[i + 1, :] - derivative2[i - 1, :] 59 | 60 | out = np.concatenate((mfcc, derivative, derivative2, derivative3), axis=1) 61 | 62 | return out, out.shape[0] 63 | 64 | 65 | def calc_norm_param(X): 66 | """Assumes X to be a list of arrays (of differing sizes)""" 67 | total_len = 0 68 | mean_val = np.zeros(X[0].shape[1]) 69 | std_val = np.zeros(X[0].shape[1]) 70 | for obs in X: 71 | obs_len = obs.shape[0] 72 | mean_val += np.mean(obs, axis=0) * obs_len 73 | std_val += np.std(obs, axis=0) * obs_len 74 | total_len += obs_len 75 | 76 | mean_val /= total_len 77 | std_val /= total_len 78 | 79 | return mean_val, std_val, total_len 80 | 81 | 82 | def normalize(X, mean_val, std_val): 83 | for i in range(len(X)): 84 | X[i] = (X[i] - mean_val) / std_val 85 | return X 86 | 87 | 88 | def set_type(X, type): 89 | for i in range(len(X)): 90 | X[i] = X[i].astype(type) 91 | return X 92 | 93 | 94 | def preprocess_dataset(source_path, nbMFCCs=39, logger=None, debug=None, verbose=False): 95 | """Preprocess data, ignoring compressed files and files starting with 'SA'""" 96 | X = [] 97 | y = [] 98 | valid_frames = [] 99 | 100 | # source_path is the root dir of all the wav/phn files 101 | wav_files = transform.loadWavs(source_path) 102 | label_files = transform.loadPhns(source_path) 103 | 104 | logger.debug("Found %d WAV files" % len(wav_files)) 105 | logger.debug("Found %d PHN files" % len(label_files)) 106 | assert len(wav_files) == len(label_files) 107 | assert len(wav_files) != 0 108 | 109 | processed = 0 110 | for i in tqdm(range(len(wav_files))): 111 | phn_name = str(label_files[i]) 112 | wav_name = str(wav_files[i]) 113 | 114 | if (wav_name.startswith("SA")): # specific for TIMIT: these files contain strong dialects; don't use them 115 | continue 116 | 117 | # Get MFCC of the WAV 118 | X_val, total_frames = create_mfcc('DUMMY', wav_name, 119 | nbMFCCs) # get 3 levels: 0th, 1st and 2nd derivative (=> 3*13 = 39 coefficients) 120 | total_frames = int(total_frames) 121 | 122 | X.append(X_val) 123 | 124 | # Get phonemes and valid frame numbers out of .phn files 125 | total_duration = get_total_duration(phn_name) 126 | fr = open(phn_name) 127 | 128 | # some .PHN files don't start at 0. Set default phoneme to silence (expected at the end of phoneme_set_list) 129 | y_vals = np.zeros(total_frames) + phoneme_classes[phoneme_set_list[-1]] 130 | valid_frames_vals = [] 131 | 132 | for line in fr: 133 | [start_time, end_time, phoneme] = line.rstrip('\n').split() 134 | start_time = int(start_time) 135 | end_time = int(end_time) 136 | start_ind = int(np.round(start_time / (total_duration / total_frames))) 137 | end_ind = int(np.round(end_time / (total_duration / total_frames))) 138 | 139 | valid_ind = int((start_ind + end_ind) / 2) 140 | valid_frames_vals.append(valid_ind) 141 | 142 | phoneme_num = phoneme_classes[phoneme] 143 | # check that phoneme is found in dict 144 | if (phoneme_num == -1): 145 | logger.error("In file: %s, phoneme not found: %s", phn_name, phoneme) 146 | pdb.set_trace() 147 | y_vals[start_ind:end_ind] = phoneme_num 148 | 149 | if verbose: 150 | logger.debug('%s', (total_frames / float(total_duration))) 151 | logger.debug('TIME start: %s end: %s, phoneme: %s, class: %s', start_time, end_time, phoneme, 152 | phoneme_num) 153 | logger.debug('FRAME start: %s end: %s, phoneme: %s, class: %s', start_ind, end_ind, phoneme, 154 | phoneme_num) 155 | fr.close() 156 | 157 | # append the target array to our y 158 | y.append(y_vals.astype('int32')) 159 | 160 | # append the valid_frames array to our valid_frames 161 | valid_frames_vals = np.array(valid_frames_vals) 162 | valid_frames.append(valid_frames_vals.astype('int32')) 163 | 164 | if verbose: 165 | logger.debug('(%s) create_target_vector: %s', i, phn_name[:-4]) 166 | logger.debug('type(X_val): \t\t %s', type(X_val)) 167 | logger.debug('X_val.shape: \t\t %s', X_val.shape) 168 | logger.debug('type(X_val[0][0]):\t %s', type(X_val[0][0])) 169 | 170 | logger.debug('type(y_val): \t\t %s', type(y_vals)) 171 | logger.debug('y_val.shape: \t\t %s', y_vals.shape) 172 | logger.debug('type(y_val[0]):\t %s', type(y_vals[0])) 173 | logger.debug('y_val: \t\t %s', (y_vals)) 174 | 175 | processed += 1 176 | if debug != None and processed >= debug: 177 | break 178 | 179 | return X, y, valid_frames 180 | 181 | 182 | def preprocess_unlabeled_dataset(source_path, nbMFCCs=39, verbose=False, logger=None): # TODO 183 | wav_files = transform.loadWavs(source_path) 184 | logger.debug("Found %d WAV files" % len(wav_files)) 185 | assert len(wav_files) != 0 186 | 187 | X = [] 188 | for i in tqdm(range(len(wav_files))): 189 | wav_name = str(wav_files[i]) 190 | X_val, total_frames = create_mfcc('DUMMY', wav_name, nbMFCCs) 191 | X.append(X_val) 192 | 193 | if verbose: 194 | logger.debug('type(X_val): \t\t %s', type(X_val)) 195 | logger.debug('X_val.shape: \t\t %s', X_val.shape) 196 | logger.debug('type(X_val[0][0]):\t %s', type(X_val[0][0])) 197 | return X 198 | -------------------------------------------------------------------------------- /tools/transform.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import subprocess 5 | import sys 6 | 7 | from tqdm import tqdm 8 | 9 | from helpFunctions import resample 10 | from helpFunctions.writeToTxt import writeToTxt 11 | from phoneme_set import phoneme_set_39, phoneme_set_61_39 12 | 13 | debug = False 14 | 15 | 16 | # input dir should be the dir just above 'TRAIN' and 'TEST'. It expects 17 | # example usage: python transform.py phonemes -i /home/matthijs/TCDTIMIT/TIMIT/original/TIMIT/ -o /home/matthijs/TCDTIMIT/TIMIT/processed 18 | 19 | 20 | ##### Load Data ###### 21 | # find all files of a type under a directory, recursive 22 | def load_wavPhn(rootDir): 23 | wavs = loadWavs(rootDir) 24 | phns = loadPhns(rootDir) 25 | return wavs, phns 26 | 27 | 28 | def loadWavs(rootDir): 29 | wav_files = [] 30 | for dirpath, dirs, files in os.walk(rootDir): 31 | for f in files: 32 | if (f.lower().endswith(".wav")): 33 | wav_files.append(os.path.join(dirpath, f)) 34 | return sorted(wav_files) 35 | 36 | 37 | def loadPhns(rootDir): 38 | phn_files = [] 39 | for dirpath, dirs, files in os.walk(rootDir): 40 | for f in files: 41 | if (f.lower().endswith(".phn")): 42 | phn_files.append(os.path.join(dirpath, f)) 43 | return sorted(phn_files) 44 | 45 | 46 | # generates for example: dstDir/TIMIT/TRAIN/DR2/MTAT1/SX59.PHN' 47 | # from srcPath = someDir/TIMIT/TRAIN/DR2/MTAT1/SX239.PHN') 48 | def getDestPath(srcPath, dstDir): 49 | filename = os.path.basename(srcPath) 50 | 51 | speakerPath = os.path.dirname(srcPath) 52 | speaker = os.path.basename(speakerPath) 53 | 54 | regionPath = os.path.dirname(speakerPath) 55 | region = os.path.basename(regionPath) 56 | 57 | setPath = os.path.dirname(regionPath) 58 | set = os.path.basename(setPath) 59 | 60 | timitPath = os.path.dirname(setPath) 61 | timit = os.path.basename(timitPath) 62 | 63 | dstPath = os.path.join(dstDir, timit, set, region, speaker, filename) 64 | return dstPath 65 | 66 | 67 | ### TRANSFORM FUNCTIONS ### 68 | # create a wav file with NIST headers 69 | def transformWav(wav_file, dstPath): 70 | output_dir = os.path.dirname(dstPath) 71 | if not os.path.exists(output_dir): 72 | os.makedirs(output_dir) 73 | 74 | if not os.path.exists(dstPath): 75 | command = ['mplayer', 76 | '-quiet', 77 | '-vo', 'null', 78 | '-vc', 'dummy', 79 | '-ao', 'pcm:waveheader:file=' + dstPath, 80 | wav_file] 81 | 82 | # actually run the command, only show stderror on terminal, close the processes (don't wait for user input) 83 | FNULL = open(os.devnull, 'w') 84 | p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True) # stdout=subprocess.PIPE 85 | 86 | # TODO this line is commented out to enable parallel file processing; uncomment if you need to access the file directly after creation 87 | # subprocess.Popen.wait(p) # wait for completion 88 | return 1 89 | else: 90 | return 0 91 | 92 | # generate new .phn file with mapped phonemes (from 61, to 39 -> see dictionary in phoneme_set.py) 93 | def transformPhn(phn_file, dstPath): 94 | output_dir = os.path.dirname(dstPath) 95 | if not os.path.exists(output_dir): 96 | os.makedirs(output_dir) 97 | 98 | if not os.path.exists(dstPath): 99 | # extract label from phn 100 | phn_labels = [] 101 | with open(phn_file, 'rb') as csvfile: 102 | phn_reader = csv.reader(csvfile, delimiter=' ') 103 | for row in phn_reader: 104 | start, stop, label = row[0], row[1], row[2] 105 | 106 | if label not in phoneme_set_39.keys(): # map from 61 to 39 phonems using dict 107 | label = phoneme_set_61_39.get(label) 108 | 109 | classNumber = label # phoneme_set_39[label] - 1 # get class number 110 | phn_labels.append([start, stop, classNumber]) 111 | 112 | # print phn_labels 113 | # print phn_labels 114 | writeToTxt(phn_labels, dstPath) 115 | 116 | 117 | ########### High Level Functions ########## 118 | # just loop over all the found files 119 | def transformWavs(args): 120 | srcDir = args.srcDir 121 | dstDir = args.dstDir 122 | 123 | print("src: ", srcDir) 124 | print("dst: ", dstDir) 125 | srcWavs = loadWavs(srcDir) 126 | srcWavs.sort() 127 | 128 | if not os.path.exists(dstDir): 129 | os.makedirs(dstDir) 130 | 131 | resampled = [] 132 | # transform: fix headers and resample. Use 2 seperate loops to prevent having to wait for the fixed file to be written 133 | # therefore also don't wait until completion in transformWav (the Popen.wait(p) line) 134 | print("FIXING WAV HEADERS AND COPYING TO ", dstDir, "...") 135 | for srcPath in tqdm(srcWavs, total=len(srcWavs)): 136 | dstPath = getDestPath(srcPath, dstDir) 137 | resampled.append(dstPath) 138 | transformWav(srcPath, dstPath) 139 | if debug: print(srcPath, dstPath) 140 | 141 | print("RESAMPLING TO 16kHz...") 142 | for dstPath in tqdm(resampled, total=len(resampled)): 143 | resample.resampleWAV(dstPath, dstPath, out_fr=16000.0, q=1.0) # resample to 16 kHz from 48kHz 144 | 145 | ## TODO USING resampy library: about 4x faster, but sometimes weird crashes... 146 | # in_fr, in_data = wavfile.read(dstPath) 147 | # in_type = in_data.dtype 148 | # in_data = in_data.astype(float) 149 | # # x is now a 1-d numpy array, with `sr_orig` audio samples per second 150 | # # We can resample this to any sampling rate we like, say 16000 Hz 151 | # y_low = resampy.resample(in_data, in_fr, 16000) 152 | # y_low = y_low.astype(in_type) 153 | # wavfile.write(dstPath, 16000, y_low) 154 | 155 | 156 | def transformPhns(args): 157 | srcDir = args.srcDir 158 | dstDir = args.dstDir 159 | srcPhns = loadPhns(srcDir) 160 | 161 | print("Source Directory: ", srcDir) 162 | print("Destination Directory: ", dstDir) 163 | 164 | if not os.path.exists(dstDir): 165 | os.makedirs(dstDir) 166 | 167 | for srcPath in tqdm(srcPhns, total=len(srcPhns)): 168 | dstPath = getDestPath(srcPath, dstDir) 169 | # print("reading from: ", srcPath) 170 | # print("writing to: ", dstPath) 171 | transformPhn(srcPath, dstPath) 172 | 173 | 174 | ## help functions ### 175 | def readPhonemeDict(filePath): 176 | d = {} 177 | with open(filePath) as f: 178 | for line in f: 179 | (key, val) = line.split() 180 | d[int(key)] = val 181 | return d 182 | 183 | 184 | def checkDirs(args): 185 | if 'dstDir' in args and not os.path.exists(args.dstDir): 186 | os.makedirs(args.dstDir) 187 | if 'srcDir' in args and not os.path.exists(args.srcDir): 188 | raise Exception('Can not find source data path') 189 | 190 | 191 | ################# PARSER ##################### 192 | def prepare_parser(): 193 | parser = argparse.ArgumentParser() 194 | sub_parsers = parser.add_subparsers() 195 | 196 | ## TRANSFORM ## 197 | phn_parser = sub_parsers.add_parser('phonemes') 198 | phn_parser.set_defaults(func=transformPhns) 199 | phn_parser.add_argument('-i', '--srcDir', 200 | help="the directory storing source data", 201 | required=True) 202 | phn_parser.add_argument('-o', '--dstDir', 203 | help="the directory store output data", 204 | required=True) 205 | ## TRANSFORM ## 206 | wav_parser = sub_parsers.add_parser('wavs') 207 | wav_parser.set_defaults(func=transformWavs) 208 | wav_parser.add_argument('-i', '--srcDir', 209 | help="the directory storing source data", 210 | required=True) 211 | wav_parser.add_argument('-o', '--dstDir', 212 | help="the directory store output data", 213 | required=True) 214 | return parser 215 | 216 | 217 | if __name__ == '__main__': 218 | arg_parser = prepare_parser() 219 | args = arg_parser.parse_args(sys.argv[1:]) 220 | checkDirs(args) 221 | args.func(args) 222 | --------------------------------------------------------------------------------