├── .idea
├── TIMITspeech.iml
├── misc.xml
├── modules.xml
├── preferred-vcs.xml
└── workspace.xml
├── README.md
├── RNN.py
├── RNN_implementation.py
├── background
├── __init__.py
├── htkbook-3.5.pdf
├── plot_data.py
├── timit_phones.txt
└── timit_words.txt
├── environment.yml
├── getResults.py
└── tools
├── __init__.py
├── __init__.pyc
├── createMLF.py
├── datasetToPkl.py
├── formatting.py
├── formatting.pyc
├── general_tools.py
├── general_tools.pyc
├── helpFunctions
├── __init__.py
├── copyFilesOfType.py
├── progress_bar.py
├── removeEmptyDirs.py
├── resample.py
├── resampleExperiment.py
├── sa1.wav
├── visualizeMFC.m
├── wavToPng.py
└── writeToTxt.py
├── mergeAudioFiles.py
├── phoneme_set.py
├── preprocessWavs.py
└── transform.py
/.idea/TIMITspeech.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/preferred-vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ApexVCS
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | 1506427818625
28 |
29 |
30 | 1506427818625
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | With this repo you can preprocess an audio dataset (modify phoneme classes, resample audio etc), and train LSTM networks for framewise phoneme classification.
2 | You can achieve 82% accuracy on the TIMIT dataset, similar to the results from [Graves et al (2013)](https://arxiv.org/abs/1303.5778), although CTC is not used here.
3 | Instead, the network generates predictions in the middle of each phoneme interval, as specified by the labels. This is to simplify things, but using CTC shouldn't be too much trouble.
4 |
5 | In order to create and train a network on a dataset, you need to:
6 |
7 | 0. install software. I recommend using [Anaconda](https//www.anaconda.com/download/) and a virtual environment.
8 | - create an environment from the provided file: `conda env create -f environment.yml`
9 |
10 | 1. Generate a binary file from the source dataset. It's easiest if you structure your data as in the TIMIT dataset, although that's not really needed. Just make sure that the wav and its corresponding phn file have the same path except for the extension. Otherwise they won't get matched and your labels will be off.
11 | - WAV files, 16kHz sampling rate, folder structure `dataset/TRAIN/speakerName/videoName/`.
12 | Each videoName/ directory contains a `videoName.wav` and `videoName.phn`.
13 | The phn contains the audio sample (@16kHz) numbers where each phoneme starts and ends.
14 | - If your files are in a different format, you can use functions from fixDataset/ to: (use transform.py, with the appropriate arguments, see bottom of file)
15 | - fix wav headers, resample wavs. Store them under `dataRoot/dataset/fixed(nbPhonemes)/`
16 | `transform.py phonemes -i dataRoot/TIMIT/original/ -o dataRoot/TIMIT/fixed`
17 | - fix labelfiles: replace phonemes (eg to use a reduced phoneme set; I used the 39 phonemes from Lee and Hon (1989)). Stored next to fixed wavs, under `root/dataset/fixed(nbPhonemes)/`
18 | - create a MLF file (like from HTK tool, and as used in the TCDTIMIT dataset)
19 | - the scripts should be case-agnostic, but you can convert lower to uppercase and vice versa by running `find . -depth -print0 | xargs -0 rename '$_ = lc $_'` in the root dataset directory (change 'lc' to 'uc to convert to upper case). Repeat until you get no more output.
20 | - Then set variables in datasetToPkl.py (source and target dir, nbMFCCs to use etc), and run the file
21 | - the result is stored as `root/dataset/binary(nbPhonemes)/dataset/dataset_nbPhonemes_ch.pkl`. eg root/TIMIT/binary39/TIMIT/TIMIT_39_ch.pkl
22 | - the mean and std_dev of the train data are stored as `root/dataset/binary_nbPhonemes/dataset_MeanStd.pkl`. It's useful for normalization when evaluating.
23 | 1. Use RNN.py to start training. Its functions are implemented in RNN_tools_lstm.py, but you can set the parameters from RNN.py.
24 | - set location of pkl generated by datasetToPkl.py
25 | - specify number of LSTM layers and number of units per layer
26 | - use bidirectional LSTM layers
27 | - add some dense layers (though it did not improve performance for me)
28 | - learning rate and decay (LR is updated at end of RNN_tools_lstm.py). It's decreased if the performance hasn't improved for some time.
29 |
30 | - it will automatically give the model a name based on the specified parameters. A log file, the model parameters and a pkl file containing training info (accuracy, error etc for each epoch) are stored as well.
31 | The storage location is`root/dataset/results`
32 |
33 | 1. to evaluate a dataset, change the test_dataset variable to whatever you want (TIMIT/TCDTIMIT/combined)
34 | 1. You can generate test datasets with noise (either white noise or simultaneous speakers) of a certain level using mergeAudioFiles.py to create the wavs and testdataToPkl.py to convert that to pkl files.
35 | 1. If this noisy audio is to be used for combinedSR, you need to generate the pkl files a bit differently, using audioToPkl_perVideo.py. That pkl file can then be combined with the images and labels generated by combinedSR/datasetToPkl. You can enable this by setting some parameters in `combinedSR/combinedNN.py`
36 |
37 | On TIMIT, you should get about 82% accuracy using a 2-layer, 256 units/layer bidirectional LSTM network.
38 | You should get about 67% on TCD-TIMIT.
39 |
40 | The TIMIT dataset is non-free and available from [https://catalog.ldc.upenn.edu/LDC93S19](https://catalog.ldc.upenn.edu/LDC93S1).
41 | The TCD-TIMIT dataset is free for research and available from [https://sigmedia.tcd.ie/TCDTIMIT/](https://sigmedia.tcd.ie/TCDTIMIT/).
42 | If you want to use TCD-TIMIT, I recommend to use my repo [TCDTIMITprocessing](https://github.com/matthijsvk/TCDTIMITprocessing) to download, and extract the database. It's quite a nasty job otherwise. You can use `extractTCDTIMITaudio.py` to get the phoneme and wav files.
43 |
44 | If you want to do lipreading or audio-visual speech recognition, check out my other repository [MultimodalSR](https://github.com/matthijsvk/multimodalSR)
45 |
--------------------------------------------------------------------------------
/RNN.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import warnings
4 | from time import gmtime, strftime
5 |
6 | warnings.simplefilter("ignore", UserWarning) # cuDNN warning
7 |
8 | import logging
9 | import tools.formatting as formatting
10 |
11 | import time
12 | program_start_time = time.time()
13 |
14 | print("\n * Importing libraries...")
15 | from RNN_implementation import *
16 | from pprint import pprint # printing properties of networkToRun objects
17 | from tqdm import tqdm # % done bar
18 |
19 |
20 | ##### SCRIPT META VARIABLES #####
21 | VERBOSE = True
22 | num_epochs = 20
23 | nbMFCCs = 39 # num of features to use -> see 'utils.py' in convertToPkl under processDatabase
24 | nbPhonemes = 39 # number output neurons
25 |
26 | # Root path: data is stored here, models and results will be stored here as well
27 | # see README.md for info about where you should store your data
28 | root = os.path.expanduser("~"+os.sep+"TCDTIMIT"+os.sep)
29 |
30 | # Choose which datasets to run
31 | datasets = ["TIMIT"] # "TCDTIMIT", combined"
32 |
33 | # for each dataset type as key this dictionary contains as value a list of all the network architectures that need to be trained for this dataset
34 | many_n_hidden_lists = {}
35 | many_n_hidden_lists['TIMIT'] = [[50]]#[8,8], [32], [64], [256], [512],[1024]]
36 | #[8, 8], [32, 32], [64, 64], [256, 256], [512, 512]]#, [1024,1024]]
37 | #[8,8,8],[32,32,32],[64,64,64],[256,256,256],[512,512,512],
38 | #[8,8,8,8],[32,32,32,32],[64,64,64,64],[256,256,256,256],[512,512,512,512]]
39 | #[512,512,512,512],
40 | #[1024,1024,1024],[1024,1024,1024,1024]]
41 | many_n_hidden_lists['TCDTIMIT'] = [[256, 256], [512, 512], [256, 256, 256, 256]]
42 | # combined is TIMIT and TCDTIMIT put together
43 | many_n_hidden_lists['combined'] = [[256, 256]] # [32, 32], [64, 64], [256, 256]]#, [512, 512]]
44 |
45 | #######################
46 |
47 | bidirectional = True
48 | add_dense_layers = False
49 |
50 | run_test = True # if network exists, just test, don't retrain but just evaluate on test set
51 | autoTrain = False # if network doesn't exist, train a new one automatically
52 | round_params = False
53 |
54 | # Decaying LR: each epoch LR = LR * LR_decay
55 | LR_start = 0.01 # will be divided by 10 if retraining existing model
56 | LR_fin = 0.0000001
57 | # LR_decay = (LR_fin / LR_start) ** (1. / num_epochs)
58 | LR_decay = 0.5
59 |
60 | print("LR_start = %s", str(LR_start))
61 | print("LR_fin = %s", str(LR_fin))
62 | print("LR_decay = %s", str(LR_decay))
63 |
64 |
65 | # quickly generate many networks
66 | def create_network_list(dataset, networkArchs):
67 | network_list = []
68 | for networkArch in networkArchs:
69 | network_list.append(NetworkToRun(run_type="audio", n_hidden_list=networkArch,
70 | dataset=dataset, test_dataset=dataset, run_test=True))
71 | return network_list
72 |
73 | def main():
74 | # Use this if you want only want to start training if the network doesn't exist
75 | network_list = create_network_list(dataset="TIMIT", networkArchs=many_n_hidden_lists['TIMIT'])
76 |
77 | net_runner = NetworkRunner(network_list)
78 | net_runner.get_network_results()
79 | print("\n got all results")
80 | net_runner.exportResultsToExcel()
81 |
82 |
83 | class NetworkToRun:
84 | def __init__(self,
85 | n_hidden_list=(256, 256,), nbMFCCs=39, audio_bidirectional=True,
86 | LR_start=0.001, round_params=False, run_test=False, force_train=False,
87 | run_type='audio', dataset="TCDTIMIT", test_dataset=None,
88 | with_noise=False, noise_types=('white',), ratio_dBs=('0, -3, -5, -10',)):
89 | # Audio
90 | self.n_hidden_list = n_hidden_list # LSTM architecture for audio part
91 | self.nbMFCCs = nbMFCCs
92 | self.audio_bidirectional = audio_bidirectional
93 |
94 | # Others
95 | self.run_type = run_type
96 | self.LR_start = LR_start
97 | self.run_test = run_test
98 | self.force_train = force_train # If False, just test the network outputs when the network already exists.
99 | # If force_train == True, train it anyway before testing
100 | # If True, set the LR_start low enough so you don't move too far out of the objective minimum
101 |
102 | self.dataset = dataset
103 | if test_dataset == None:
104 | self.test_dataset = self.dataset
105 | else:
106 | self.test_dataset = test_dataset
107 |
108 | self.round_params = round_params
109 | self.with_noise = with_noise
110 | self.noise_types = noise_types
111 | self.ratio_dBs = ratio_dBs
112 |
113 | self.model_name, self.nice_name = self.get_model_name()
114 | self.model_path = self.get_model_path()
115 | self.model_path_noNPZ = self.model_path.replace('.npz','')
116 |
117 | # this generates the correct path based on the chosen parameters, and gets the train/val/test data
118 | def load_data(self, noise_type='white', ratio_dB='0'):
119 | dataset = self.dataset
120 | test_dataset = self.test_dataset
121 | with_noise = self.with_noise
122 |
123 | data_dir = os.path.join(root, self.run_type+"SR", dataset, "binary" + str(nbPhonemes),
124 | dataset) # output dir from datasetToPkl.py
125 | data_path = os.path.join(data_dir, dataset + '_' + str(nbMFCCs) + '_ch.pkl')
126 | if run_test:
127 | test_data_dir = os.path.join(root, self.run_type+"SR", test_dataset, "binary" + str(nbPhonemes) + \
128 | ('_'.join([noise_type, os.sep, "ratio", str(ratio_dB)]) if with_noise else ""),test_dataset)
129 | test_data_path = os.path.join(test_data_dir, test_dataset + '_' + str(nbMFCCs) + '_ch.pkl')
130 |
131 | self.logger.info(' data source: %s', data_path)
132 |
133 | dataset = unpickle(data_path)
134 | x_train, y_train, valid_frames_train, x_val, y_val, valid_frames_val, x_test, y_test, valid_frames_test = dataset
135 |
136 | # if run_test, you can use another dataset than the one used for training for evaluation
137 | if run_test:
138 | self.logger.info(" test data source: %s", test_data_path)
139 | if with_noise:
140 | x_test, y_test, valid_frames_test = unpickle(test_data_path)
141 | else:
142 | _, _, _, _, _, _, x_test, y_test, valid_frames_test = unpickle(test_data_path)
143 |
144 | datasetFiles = [x_train, y_train, valid_frames_train, x_val, y_val, valid_frames_val, x_test, y_test,
145 | valid_frames_test]
146 | # Print some information
147 | debug = False
148 | if debug:
149 | self.logger.info("\n* Data information")
150 | self.logger.info('X train')
151 | self.logger.info(' %s %s', type(x_train), len(x_train))
152 | self.logger.info(' %s %s', type(x_train[0]), x_train[0].shape)
153 | self.logger.info(' %s %s', type(x_train[0][0]), x_train[0][0].shape)
154 | self.logger.info(' %s', type(x_train[0][0][0]))
155 |
156 | self.logger.info('y train')
157 | self.logger.info(' %s %s', type(y_train), len(y_train))
158 | self.logger.info(' %s %s', type(y_train[0]), y_train[0].shape)
159 | self.logger.info(' %s %s', type(y_train[0][0]), y_train[0][0].shape)
160 |
161 | self.logger.info('valid_frames train')
162 | self.logger.info(' %s %s', type(valid_frames_train), len(valid_frames_train))
163 | self.logger.info(' %s %s', type(valid_frames_train[0]), valid_frames_train[0].shape)
164 | self.logger.info(' %s %s', type(valid_frames_train[0][0]), valid_frames_train[0][0].shape)
165 |
166 | return datasetFiles
167 |
168 | # this builds the chosen network architecture, loads network weights and compiles the functions
169 | def setup_network(self, batch_size):
170 | dataset = self.dataset
171 | test_dataset = self.test_dataset
172 | n_hidden_list = self.n_hidden_list
173 | round_params = self.round_params
174 |
175 | store_dir = root + dataset + "/results"
176 | if not os.path.exists(store_dir): os.makedirs(store_dir)
177 |
178 | # log file
179 | fh = self.setupLogging(store_dir)
180 | #############################################################
181 |
182 |
183 | self.logger.info("\n\n\n\n STARTING NEW TRAINING SESSION AT " + strftime("%Y-%m-%d %H:%M:%S", gmtime()))
184 |
185 | ##### IMPORTING DATA #####
186 |
187 | self.logger.info(' model target: %s', self.model_name)
188 |
189 | self.logger.info('\n* Building network using batch size: %s...', batch_size)
190 | RNN_network = NeuralNetwork('RNN', None, batch_size=batch_size,
191 | num_features=nbMFCCs, n_hidden_list=n_hidden_list,
192 | num_output_units=nbPhonemes,
193 | bidirectional=bidirectional, addDenseLayers=add_dense_layers,
194 | debug=False,
195 | dataset=dataset, test_dataset=test_dataset, logger=self.logger)
196 |
197 | # print number of parameters
198 | nb_params = lasagne.layers.count_params(RNN_network.network_lout_batch)
199 | self.logger.info(" Number of parameters of this network: %s", nb_params)
200 |
201 | # Try to load stored model
202 | self.logger.info(' Network built. \nTrying to load stored model: %s', self.model_name + '.npz')
203 | success = RNN_network.load_model(self.model_path_noNPZ, round_params=round_params)
204 |
205 | RNN_network.loadPreviousResults(self.model_path_noNPZ)
206 |
207 | ##### COMPILING FUNCTIONS #####
208 | self.logger.info("\n* Compiling functions ...")
209 | RNN_network.build_functions(train=True, debug=False)
210 |
211 | return RNN_network, success, fh
212 |
213 | def setupLogging(self, store_dir):
214 | self.logger = logging.getLogger(self.model_name)
215 | self.logger.setLevel(logging.DEBUG)
216 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s '
217 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False))
218 | # formatter2 = logging.Formatter(
219 | # '%(asctime)s - %(name)-5s - %(levelname)-10s - (%(filename)s:%(lineno)d): %(message)s')
220 |
221 | # create console handler with a higher log level
222 | ch = logging.StreamHandler()
223 | ch.setLevel(logging.DEBUG)
224 | ch.setFormatter(formatter)
225 | self.logger.addHandler(ch)
226 |
227 | logFile = store_dir + os.sep + self.model_name + '.log'
228 | if os.path.exists(logFile):
229 | self.fh = logging.FileHandler(logFile) # append to existing log
230 | else:
231 | self.fh = logging.FileHandler(logFile, 'w') # create new logFile
232 | self.fh.setLevel(logging.DEBUG)
233 | self.fh.setFormatter(formatter)
234 | self.logger.addHandler(self.fh)
235 |
236 | def get_model_name(self):
237 | n_hidden_list = self.n_hidden_list
238 | dataset = self.dataset
239 | model_name = str(len(n_hidden_list)) + "_LSTMLayer" + '_'.join([str(layer) for layer in n_hidden_list]) \
240 | + "_nbMFCC" + str(nbMFCCs) + ("_bidirectional" if bidirectional else "_unidirectional") + \
241 | ("_withDenseLayers" if add_dense_layers else "") + "_" + dataset
242 |
243 | nice_name = "Audio:" + ' '.join(
244 | ["LSTM", str(n_hidden_list[0]), "/", str(len(n_hidden_list))])
245 |
246 | return model_name, nice_name
247 |
248 | def get_model_path(self):
249 | model_name, nice_name = self.get_model_name()
250 | model_path = os.path.join(root, self.run_type + 'SR', self.dataset, 'results', model_name + '.npz')
251 | return model_path
252 |
253 | # this takes the prepared data, built network and some parameters, and trains/evaluates the network
254 | def executeNetwork(self, RNN_network, load_params_success, batch_size, datasetFiles,
255 | noiseType='white', ratio_dB=0, fh=None):
256 | with_noise = self.with_noise
257 | run_test = self.run_test
258 |
259 | LR = LR_start
260 | if load_params_success == 0: LR = LR_start / 10.0
261 |
262 | ##### TRAINING #####
263 | self.logger.info("\n* Training ...")
264 | results = RNN_network.train(datasetFiles, self.model_path_noNPZ, num_epochs=num_epochs,
265 | batch_size=batch_size, LR_start=LR, LR_decay=LR_decay,
266 | compute_confusion=True, justTest=run_test, debug=False,
267 | withNoise=with_noise, noiseType=noiseType, ratio_dB=ratio_dB)
268 |
269 | self.fh.close()
270 | self.logger.removeHandler(self.fh)
271 |
272 | return results
273 |
274 | # estimate a good batchsize based on the size of the network
275 | def getBatchSizes(self):
276 | n_hidden_list = self.n_hidden_list
277 | if n_hidden_list[0] > 128:
278 | batch_sizes = [64, 32, 16, 8, 4]
279 | elif n_hidden_list[0] > 64:
280 | batch_sizes = [128, 64, 32, 16, 8, 4]
281 | else:
282 | batch_sizes = [256, 128, 64, 32, 16, 8, 4]
283 | return batch_sizes
284 |
285 | # run the network at the maximum batch size
286 | def runNetwork(self):
287 |
288 | batch_sizes = self.getBatchSizes()
289 | results = [0, 0, 0]
290 | for batch_size in batch_sizes:
291 | try:
292 | RNN_network, load_params_success, fh = self.setup_network(batch_size)
293 | datasetFiles = self.load_data()
294 | results = self.executeNetwork(RNN_network, load_params_success, batch_size=batch_size,
295 | datasetFiles=datasetFiles, fh=fh)
296 | break
297 | except:
298 | print('caught this error: ' + traceback.format_exc());
299 | self.logger.info("batch size too large; trying again with lower batch size")
300 | pass # just try again with the next batch_size
301 |
302 | return results
303 |
304 | # run a network for all noise types
305 | def testAudio(self, batch_size, fh, load_params_success, RNN_network,
306 | with_noise=False, noiseTypes=('white',), ratio_dBs=(0, -3, -5, -10,)):
307 | for noiseType in noiseTypes:
308 | for ratio_dB in ratio_dBs:
309 | datasetFiles = self.load_data()
310 | self.executeNetwork(RNN_network=RNN_network, load_params_success=load_params_success,
311 | batch_size=batch_size, datasetFiles=datasetFiles,
312 | noiseType=noiseType, ratio_dB=ratio_dB, fh=fh)
313 | if not with_noise: # only need to run once, not for all noise types as we're not testing on noisy audio anyway
314 | return 0
315 | return 0
316 |
317 | # try different batch sizes for testing a network
318 | def testNetwork(self, with_noise, noiseTypes, ratio_dBs):
319 |
320 | batch_sizes = self.getBatchSizes()
321 | for batch_size in batch_sizes:
322 | try:
323 | RNN_network, load_params_success, fh = self.setup_network(batch_size)
324 | # evaluate on test dataset for all noise types
325 | self.testAudio(batch_size, fh, load_params_success, RNN_network,
326 | with_noise, noiseTypes, ratio_dBs)
327 | except:
328 | print('caught this error: ' + traceback.format_exc());
329 | self.logger.info("batch size too large; trying again with lower batch size")
330 | pass # just try again with the next batch_size
331 |
332 | def get_clean_results(self, network_train_info, nice_name, noise_type='white', ratio_dB='0'):
333 | with_noise = self.with_noise
334 | results_type = ("round_params" if self.round_params else "") + (
335 | "_Noise" + noise_type + "_" + str(ratio_dB) if with_noise else "")
336 |
337 | this_results = {'results_type': results_type}
338 | this_results['values'] = []
339 | this_results['dataset'] = self.dataset
340 | this_results['test_dataset'] = self.test_dataset
341 | this_results['audio_dataset'] = self.dataset
342 |
343 | # audio networks can be run on TIMIT or combined as well
344 | if self.run_type != 'audio' and self.test_dataset != self.dataset:
345 | test_type = "_" + self.test_dataset
346 | else:
347 | test_type = ""
348 | if self.round_params:
349 | test_type = "_round_params" + test_type
350 | if self.run_type != 'lipreading' and with_noise:
351 | this_results['values'] = [
352 | network_train_info['final_test_cost_' + noise_type + "_" + "ratio" + str(ratio_dB) + test_type],
353 | network_train_info['final_test_acc_' + noise_type + "_" + "ratio" + str(ratio_dB) + test_type],
354 | network_train_info['final_test_top3_acc_' + noise_type + "_" + "ratio" + str(ratio_dB) + test_type]]
355 | else:
356 | try:
357 | val_acc = max(network_train_info['val_acc'])
358 | except:
359 | try:
360 | val_acc = max(network_train_info['test_acc'])
361 | except:
362 | val_acc = network_train_info['final_test_acc']
363 | this_results['values'] = [network_train_info['final_test_cost' + test_type],
364 | network_train_info['final_test_acc' + test_type],
365 | network_train_info['final_test_top3_acc' + test_type], val_acc]
366 | this_results['nb_params'] = network_train_info['nb_params']
367 | this_results['niceName'] = nice_name
368 |
369 | return this_results
370 |
371 | def get_network_train_info(self, save_path,):
372 | save_name = save_path.replace('.npz','')
373 | if os.path.exists(save_path) and os.path.exists(save_name + "_trainInfo.pkl"):
374 | network_train_info = unpickle(save_name + '_trainInfo.pkl')
375 |
376 | if not 'final_test_cost' in network_train_info.keys():
377 | network_train_info['final_test_cost'] = min(network_train_info['test_cost'])
378 | if not 'final_test_acc' in network_train_info.keys():
379 | network_train_info['final_test_acc'] = max(network_train_info['test_acc'])
380 | if not 'final_test_top3_acc' in network_train_info.keys():
381 | network_train_info['final_test_top3_acc'] = max(network_train_info['test_topk_acc'])
382 | return network_train_info
383 | else:
384 | return -1
385 |
386 |
387 | # This class runs many networks, storing the weight files, training data and log in the appropriate location
388 | # It can also just evaluate a network on a test set, or simply load the results from previous runs
389 | # All results are stored in an Excel file.
390 | # Audio networks can also be evaluated on audio data which has been polluted with noise (set the appropriate parameters in the declaration of the networks you want to test)
391 | # It receives a network_list as input, containing NetworkToRun objects that specify all the relevant parameters to define a network and run it.
392 | class NetworkRunner:
393 | def __init__(self, network_list):
394 | self.network_list = network_list
395 | self.results = {}
396 | self.setup_logging(root)
397 |
398 | def setup_logging(self,store_dir):
399 | self.logger = logging.getLogger('audioNetworkRunner')
400 | self.logger.setLevel(logging.DEBUG)
401 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s '
402 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False))
403 | # formatter2 = logging.Formatter(
404 | # '%(asctime)s - %(name)-5s - %(levelname)-10s - (%(filename)s:%(lineno)d): %(message)s')
405 |
406 | # create console handler with a higher log level
407 | ch = logging.StreamHandler()
408 | ch.setLevel(logging.DEBUG)
409 | ch.setFormatter(formatter)
410 | self.logger.addHandler(ch)
411 |
412 | logFile = store_dir + os.sep + 'audioNetworkRunner' + '.log'
413 | if os.path.exists(logFile):
414 | self.fh = logging.FileHandler(logFile) # append to existing log
415 | else:
416 | self.fh = logging.FileHandler(logFile, 'w') # create new logFile
417 | self.fh.setLevel(logging.DEBUG)
418 | self.fh.setFormatter(formatter)
419 | self.logger.addHandler(self.fh)
420 |
421 | # this loads the specified results from networks in network_list
422 | # more efficient if we don't have to reload each network file; just retrieve the data you found earlier
423 | def get_network_results(self):
424 | self.results_path = root + 'storedResults' + ".pkl"
425 | try:
426 | prev_results = unpickle(self.results_path)
427 | except:
428 | prev_results = {}
429 |
430 | # get results for networks that were trained before. If a network has run_test=True, run it on the test set
431 | results, to_retrain, failures = self.get_trained_network_results(self.network_list)
432 |
433 | # failures mean that the network does not exist yet. We have to generate and train it.
434 | if len(failures) > 0:
435 | results2 = self.train_networks(failures)
436 | results.update(results2)
437 |
438 | # to_retrain are networks that had forceTrain==True. We have to train them.
439 | if len(to_retrain) > 0:
440 | results3 = self.train_networks(to_retrain)
441 | results.update(results3)
442 |
443 | # update and store the results
444 | prev_results.update(results)
445 | saveToPkl(self.results_path, prev_results)
446 | self.results = prev_results
447 |
448 | # train the networks
449 | def train_networks(self, networks):
450 | results = []
451 | self.logger.info("Couldn't get results from %s networks...", len(networks))
452 | for network in networks:
453 | pprint(vars(network))
454 | if autoTrain or query_yes_no("\nWould you like to train the networks now?\n\n"):
455 | self.logger.info("Running networks...")
456 |
457 | failures = []
458 | for network in tqdm(networks, total=len(networks)):
459 | print("\n\n\n\n ################################")
460 | print("Training new network...")
461 | print("Network properties: ")
462 | pprint(vars(network))
463 | try:
464 | network.runNetwork()
465 | except:
466 | print('caught this error: ' + traceback.format_exc());
467 | pprint(vars(network))
468 | failures.append(network)
469 |
470 | if len(failures) > 0:
471 | print("Some networks failed to train...")
472 | import pdb; pdb.set_trace()
473 |
474 | # now the networks are trained, we can load their results
475 | results, _, failures = self.get_trained_network_results(networks)
476 |
477 | return results
478 |
479 | # get the stored results from networks that were trained previously
480 | # for networks that have run_test=True, run the network on the test set before getting the results
481 | # for networks that have forceTrain=True, add to 'to_retrain' list
482 | # for networks that fail to give results (most probably because they haven't been trained yet), add to 'failures' list
483 | def get_trained_network_results(self, networks):
484 | results = {}
485 | results['audio'] = {}
486 | results['lipreading'] = {}
487 | results['combined'] = {}
488 |
489 | failures = []
490 | to_retrain = []
491 |
492 | for network in tqdm(networks, total=len(networks)):
493 | self.logger.info("\n\n\n\n ################################")
494 | # pprint(vars(network_params))
495 | try:
496 | # if forced test evaluation, test the network before trying to load the results
497 | if network.run_test == True:
498 | network.testNetwork()
499 |
500 | # if forced training, we don't need to get the results. That will only be done after retraining
501 | if network.force_train == True:
502 | to_retrain.append(network)
503 | # by leaving here we make sure the network is not appended to the failures list
504 | # (that happens when we try to get its stored results while it doesn't exist, and we haven't done that yet here.
505 | continue
506 |
507 | model_name, nice_name = network.get_model_name()
508 | model_path = network.get_model_path()
509 |
510 | self.logger.info("Getting results for %s", model_path)
511 | network_train_info = network.get_network_train_info(model_path)
512 | if network_train_info == -1:
513 | raise IOError("this model doesn't have any stored results")
514 |
515 | if network.with_noise:
516 | for noise_type in network.noise_types:
517 | for ratio_dB in network.ratio_dBs:
518 | this_results = network.get_clean_results(network_train_info=network_train_info,
519 | nice_name=nice_name,
520 | noise_type=noise_type,
521 | ratio_dB=ratio_dB)
522 |
523 | results[network.run_type][model_name] = this_results
524 | else:
525 | this_results = network.get_clean_results(network_train_info=network_train_info,
526 | nice_name=nice_name)
527 |
528 | # eg results['audio']['2Layer_256_256_TIMIT']['values'] = [0.8, 79.5, 92,6] #test cost, acc, top3 acc
529 | results[network.run_type][model_name] = this_results
530 |
531 | except:
532 | self.logger.info('caught this error: ' + traceback.format_exc());
533 | failures.append(network)
534 |
535 | self.logger.info("\n\nDONE getting stored results from networks")
536 | self.logger.info("####################################################")
537 |
538 | return results, to_retrain, failures
539 |
540 |
541 | def exportResultsToExcel(self):
542 | path =self.results_path
543 | results = self.results
544 |
545 | storePath = path.replace(".pkl", ".xlsx")
546 | import xlsxwriter
547 | workbook = xlsxwriter.Workbook(storePath)
548 |
549 | for run_type in results.keys()[1:]: # audio, lipreading, combined:
550 | worksheet = workbook.add_worksheet(run_type) # one worksheet per run_type, but then everything is spread out...
551 | row = 0
552 |
553 | allNets = results[run_type]
554 |
555 | # get and write the column titles
556 | # get the number of parameters. #for audio, only 1 value. For combined/lipreadin: lots of values in a dictionary
557 | try:
558 | nb_paramNames = allNets.items()[0][1][
559 | 'nb_params'].keys() # first key-value pair, get the value ([1]), then get names of nbParams (=the keys)
560 | except:
561 | nb_paramNames = ['nb_params']
562 | startVals = 4 + len(nb_paramNames) # column number of first value
563 |
564 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset'] + nb_paramNames + ['Test Cost',
565 | 'Test Accuracy',
566 | 'Test Top 3 Accuracy',
567 | 'Validation accuracy']
568 | for i in range(len(colNames)):
569 | worksheet.write(0, i, colNames[i])
570 |
571 | # write the data for each network
572 | for netName in allNets.keys():
573 | row += 1
574 |
575 | thisNet = allNets[netName]
576 | # write the path and name
577 | worksheet.write(row, 0, os.path.basename(netName)) # netName)
578 | worksheet.write(row, 1, thisNet['niceName'])
579 | if run_type == 'audio':
580 | worksheet.write(row, 2, thisNet['audio_dataset'])
581 | worksheet.write(row, 3, thisNet['test_dataset'])
582 | else:
583 | worksheet.write(row, 2, thisNet['dataset'])
584 | worksheet.write(row, 3, thisNet['test_dataset'])
585 |
586 | # now write the params
587 | try:
588 | vals = thisNet['nb_params'].values() # vals is list of [test_cost, test_acc, test_top3_acc]
589 | except:
590 | vals = [thisNet['nb_params']]
591 | for i in range(len(vals)):
592 | worksheet.write(row, 4 + i, vals[i])
593 |
594 | # now write the values
595 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc]
596 | for i in range(len(vals)):
597 | worksheet.write(row, startVals + i, vals[i])
598 |
599 | workbook.close()
600 |
601 | self.logger.info("Excel file stored in %s", storePath)
602 | self.fh.close()
603 | self.logger.removeHandler(self.fh)
604 |
605 | def exportResultsToExcelManyNoise(self, resultsList, path):
606 | storePath = path.replace(".pkl", ".xlsx")
607 | import xlsxwriter
608 | workbook = xlsxwriter.Workbook(storePath)
609 |
610 | storePath = path.replace(".pkl", ".xlsx")
611 | import xlsxwriter
612 | workbook = xlsxwriter.Workbook(storePath)
613 |
614 | row = 0
615 |
616 | if len(resultsList[0]['audio'].keys()) > 0: thisrun_type = 'audio'
617 | if len(resultsList[0]['lipreading'].keys()) > 0: thisrun_type = 'lipreading'
618 | if len(resultsList[0]['combined'].keys()) > 0: thisrun_type = 'combined'
619 | worksheetAudio = workbook.add_worksheet('audio');
620 | audioRow = 0
621 | worksheetLipreading = workbook.add_worksheet('lipreading');
622 | lipreadingRow = 0
623 | worksheetCombined = workbook.add_worksheet('combined');
624 | combinedRow = 0
625 |
626 | for r in range(len(resultsList)):
627 | results = resultsList[r]
628 |
629 | for run_type in results.keys()[1:]:
630 | if len(results[run_type]) == 0: continue
631 | if run_type == 'audio': worksheet = worksheetAudio; row = audioRow
632 | if run_type == 'lipreading': worksheet = worksheetLipreading; row = lipreadingRow
633 | if run_type == 'combined': worksheet = worksheetCombined; row = combinedRow
634 |
635 | allNets = results[run_type]
636 |
637 | # write the column titles
638 | startVals = 5
639 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset', 'Noise Type', 'Test Cost',
640 | 'Test Accuracy', 'Test Top 3 Accuracy']
641 | for i in range(len(colNames)):
642 | worksheet.write(0, i, colNames[i])
643 |
644 | # write the data for each network
645 | for netName in allNets.keys():
646 | row += 1
647 |
648 | thisNet = allNets[netName]
649 | # write the path and name
650 | worksheet.write(row, 0, os.path.basename(netName)) # netName)
651 | worksheet.write(row, 1, thisNet['niceName'])
652 | worksheet.write(row, 2, thisNet['dataset'])
653 | worksheet.write(row, 3, thisNet['test_dataset'])
654 | worksheet.write(row, 4, thisNet['results_type'])
655 |
656 | # now write the values
657 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc]
658 | for i in range(len(vals)):
659 | worksheet.write(row, startVals + i, vals[i])
660 |
661 | if run_type == 'audio': audioRow = row
662 | if run_type == 'lipreading': lipreadingRow = row
663 | if run_type == 'combined': combinedRow = row
664 |
665 | row += 1
666 |
667 | workbook.close()
668 |
669 | self.logger.info("Excel file stored in %s", storePath)
670 |
671 | self.fh.close()
672 | self.logger.removeHandler(self.fh)
673 |
674 |
675 | if __name__ == "__main__":
676 | main()
677 |
--------------------------------------------------------------------------------
/RNN_implementation.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import logging # debug < info < warn < error < critical # from https://docs.python.org/3/howto/logging-cookbook.html
4 | import time
5 | import traceback
6 |
7 | import lasagne
8 | import lasagne.layers as L
9 | import theano
10 | import theano.tensor as T
11 | from tqdm import tqdm
12 |
13 | logger_RNNtools = logging.getLogger('audioSR.tools')
14 | logger_RNNtools.setLevel(logging.DEBUG)
15 |
16 | from tools.general_tools import *
17 |
18 |
19 | class NeuralNetwork:
20 | network = None
21 | training_fn = None
22 | best_param = None
23 | best_error = 100
24 | curr_epoch, best_epoch = 0, 0
25 | X = None
26 | Y = None
27 |
28 | def __init__(self, architecture, data=None, batch_size=1, max_seq_length=1000, num_features=26,
29 | n_hidden_list=(100,), num_output_units=61,
30 | bidirectional=False, addDenseLayers=False, seed=int(time.time()), debug=False, logger=logger_RNNtools,
31 | dataset="", test_dataset=""):
32 | self.num_output_units = num_output_units
33 | self.num_features = num_features
34 | self.batch_size = batch_size
35 | self.max_seq_length = max_seq_length # currently unused
36 | self.epochsNotImproved = 0 # keep track, to know when to stop training
37 | self.updates = {}
38 | # self.network_train_info = [[], [], [], [], []] # train cost, val cost, val acc, test cost, test acc
39 | self.network_train_info = {
40 | 'train_cost': [],
41 | 'val_cost': [], 'val_acc': [], 'val_topk_acc': [],
42 | 'test_cost': [], 'test_acc': [], 'test_topk_acc': []
43 | }
44 | self.dataset = dataset
45 | self.test_dataset = test_dataset
46 | self.logger= logger
47 |
48 | if architecture == 'RNN':
49 | if data != None:
50 | X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test = data
51 |
52 | X = X_train[:batch_size]
53 | y = y_train[:batch_size]
54 | self.valid_frames = valid_frames_train[:batch_size]
55 | self.masks = generate_masks(X, valid_frames=self.valid_frames, batch_size=len(X))
56 |
57 | self.X = pad_sequences_X(X)
58 | self.Y = pad_sequences_y(y)
59 | # self.valid_frames = pad_sequences_y(self.valid_frames)
60 |
61 | self.logger.debug('X.shape: %s', self.X.shape)
62 | self.logger.debug('X[0].shape: %s', self.X[0].shape)
63 | self.logger.debug('X[0][0][0].type: %s', type(self.X[0][0][0]))
64 | self.logger.debug('y.shape: %s', self.Y.shape)
65 | self.logger.debug('y[0].shape: %s', self.Y[0].shape)
66 | self.logger.debug('y[0][0].type: %s', type(self.Y[0][0]))
67 | self.logger.debug('masks.shape: %s', self.masks.shape)
68 | self.logger.debug('masks[0].shape: %s', self.masks[0].shape)
69 | self.logger.debug('masks[0][0].type: %s', type(self.masks[0][0]))
70 |
71 | self.logger.info("NUM FEATURES: %s", num_features)
72 |
73 | self.audio_inputs_var = T.tensor3('audio_inputs')
74 | self.audio_masks_var = T.matrix(
75 | 'audio_masks') # set MATRIX, not iMatrix!! Otherwise all mask calculations are done by CPU, and everything will be ~2x slowed down!! Also in general_tools.generate_masks()
76 | self.audio_valid_indices_var = T.imatrix('audio_valid_indices')
77 | self.audio_targets_var = T.imatrix('audio_targets')
78 |
79 | self.build_RNN(n_hidden_list=n_hidden_list, bidirectional=bidirectional, addDenseLayers=addDenseLayers,
80 | seed=seed, debug=debug)
81 | else:
82 | print("ERROR: Invalid argument: The valid architecture arguments are: 'RNN'")
83 |
84 | def build_RNN(self, n_hidden_list=(100,), bidirectional=False, addDenseLayers=False,
85 | seed=int(time.time()), debug=False):
86 | # some inspiration from http://colinraffel.com/talks/hammer2015recurrent.pdf
87 |
88 | # if debug:
89 | # self.self.logger.debug('\nInputs:');
90 | # self.self.logger.debug(' X.shape: %s', self.X[0].shape)
91 | # self.self.logger.debug(' X[0].shape: %s %s %s \n%s', self.X[0][0].shape, type(self.X[0][0]),
92 | # type(self.X[0][0][0]), self.X[0][0][:5])
93 | #
94 | # self.self.logger.debug('Targets: ');
95 | # self.self.logger.debug(' Y.shape: %s', self.Y.shape)
96 | # self.self.logger.debug(' Y[0].shape: %s %s %s \n%s', self.Y[0].shape, type(self.Y[0]), type(self.Y[0][0]),
97 | # self.Y[0][:5])
98 | # self.self.logger.debug('Layers: ')
99 |
100 | # fix these at initialization because it allows for compiler opimizations
101 | num_output_units = self.num_output_units
102 | num_features = self.num_features
103 | batch_size = self.batch_size
104 |
105 | audio_inputs = self.audio_inputs_var
106 | audio_masks = self.audio_masks_var # set MATRIX, not iMatrix!! Otherwise all mask calculations are done by CPU, and everything will be ~2x slowed down!! Also in general_tools.generate_masks()
107 | valid_indices = self.audio_valid_indices_var
108 |
109 | net = {}
110 | # net['l1_in_valid'] = L.InputLayer(shape=(batch_size, None), input_var=valid_indices)
111 |
112 | # shape = (batch_size, batch_max_seq_length, num_features)
113 | net['l1_in'] = L.InputLayer(shape=(batch_size, None, num_features), input_var=audio_inputs)
114 | # We could do this and set all input_vars to None, but that is slower -> fix batch_size and num_features at initialization
115 | # batch_size, n_time_steps, n_features = net['l1_in'].input_var.shape
116 |
117 | # This input will be used to provide the network with masks.
118 | # Masks are matrices of shape (batch_size, n_time_steps);
119 | net['l1_mask'] = L.InputLayer(shape=(batch_size, None), input_var=audio_masks)
120 |
121 | if debug:
122 | get_l_in = L.get_output(net['l1_in'])
123 | l_in_val = get_l_in.eval({net['l1_in'].input_var: self.X})
124 | # self.self.logger.debug(l_in_val)
125 | self.self.logger.debug(' l_in size: %s', l_in_val.shape);
126 |
127 | get_l_mask = L.get_output(net['l1_mask'])
128 | l_mask_val = get_l_mask.eval({net['l1_mask'].input_var: self.masks})
129 | # self.self.logger.debug(l_in_val)
130 | self.self.logger.debug(' l_mask size: %s', l_mask_val.shape);
131 |
132 | n_batch, n_time_steps, n_features = net['l1_in'].input_var.shape
133 | self.self.logger.debug(" n_batch: %s | n_time_steps: %s | n_features: %s", n_batch, n_time_steps,
134 | n_features)
135 |
136 | ## LSTM parameters
137 | # All gates have initializers for the input-to-gate and hidden state-to-gate
138 | # weight matrices, the cell-to-gate weight vector, the bias vector, and the nonlinearity.
139 | # The convention is that gates use the standard sigmoid nonlinearity,
140 | # which is the default for the Gate class.
141 | gate_parameters = L.recurrent.Gate(
142 | W_in=lasagne.init.Orthogonal(), W_hid=lasagne.init.Orthogonal(),
143 | b=lasagne.init.Constant(0.))
144 | cell_parameters = L.recurrent.Gate(
145 | W_in=lasagne.init.Orthogonal(), W_hid=lasagne.init.Orthogonal(),
146 | # Setting W_cell to None denotes that no cell connection will be used.
147 | W_cell=None, b=lasagne.init.Constant(0.),
148 | # By convention, the cell nonlinearity is tanh in an LSTM.
149 | nonlinearity=lasagne.nonlinearities.tanh)
150 |
151 | # generate layers of stacked LSTMs, possibly bidirectional
152 | net['l2_lstm'] = []
153 |
154 | for i in range(len(n_hidden_list)):
155 | n_hidden = n_hidden_list[i]
156 |
157 | if i == 0: input = net['l1_in']
158 | else: input = net['l2_lstm'][-1]
159 |
160 | nextForwardLSTMLayer = L.recurrent.LSTMLayer(
161 | input, n_hidden,
162 | # We need to specify a separate input for masks
163 | mask_input=net['l1_mask'],
164 | # Here, we supply the gate parameters for each gate
165 | ingate=gate_parameters, forgetgate=gate_parameters,
166 | cell=cell_parameters, outgate=gate_parameters,
167 | # We'll learn the initialization and use gradient clipping
168 | learn_init=True, grad_clipping=100.)
169 | net['l2_lstm'].append(nextForwardLSTMLayer)
170 |
171 | if bidirectional:
172 | if i == 0: input = net['l1_in']
173 | else: input = net['l2_lstm'][-2]
174 | # Use backward LSTM
175 | # The "backwards" layer is the same as the first,
176 | # except that the backwards argument is set to True.
177 | nextBackwardLSTMLayer = L.recurrent.LSTMLayer(
178 | input, n_hidden, ingate=gate_parameters,
179 | mask_input=net['l1_mask'], forgetgate=gate_parameters,
180 | cell=cell_parameters, outgate=gate_parameters,
181 | learn_init=True, grad_clipping=100., backwards=True)
182 | net['l2_lstm'].append(nextBackwardLSTMLayer)
183 |
184 | # if debug:
185 | # # Backwards LSTM
186 | # get_l_lstm_back = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var],
187 | # L.get_output(net['l2_lstm'][-1]))
188 | # l_lstmBack_val = get_l_lstm_back(self.X, self.masks)
189 | # self.self.logger.debug(' l_lstm_back size: %s', l_lstmBack_val.shape)
190 |
191 | # We'll combine the forward and backward layer output by summing.
192 | # Merge layers take in lists of layers to merge as input.
193 | # The output of l_sum will be of shape (n_batch, max_n_time_steps, n_features)
194 | net['l2_lstm'].append(L.ElemwiseSumLayer([net['l2_lstm'][-2], net['l2_lstm'][-1]]))
195 |
196 | # we need to convert (batch_size, seq_length, num_features) to (batch_size * seq_length, num_features) because Dense networks can't deal with 2 unknown sizes
197 | net['l3_reshape'] = L.ReshapeLayer(net['l2_lstm'][-1], (-1, n_hidden_list[-1]))
198 |
199 | if debug:
200 | get_l_reshape = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var],
201 | L.get_output(net['l3_reshape']))
202 | l_reshape_val = get_l_reshape(self.X, self.masks)
203 | self.logger.debug(' l_reshape size: %s', l_reshape_val.shape)
204 |
205 | # Forwards LSTM
206 | get_l_lstm = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var],
207 | L.get_output(net['l2_lstm'][-1]))
208 | l_lstm_val = get_l_lstm(self.X, self.masks)
209 | self.self.logger.debug(' l2_lstm size: %s', l_lstm_val.shape);
210 |
211 | if addDenseLayers:
212 | net['l4_dense'] = L.DenseLayer(net['l3_reshape'], nonlinearity=lasagne.nonlinearities.rectify,
213 | num_units=256)
214 | dropoutLayer = L.DropoutLayer(net['l4_dense'], p=0.3)
215 | net['l5_dense'] = L.DenseLayer(dropoutLayer, nonlinearity=lasagne.nonlinearities.rectify, num_units=64)
216 | # Now we can apply feed-forward layers as usual for classification
217 | net['l6_dense'] = L.DenseLayer(net['l5_dense'], num_units=num_output_units,
218 | nonlinearity=lasagne.nonlinearities.softmax)
219 | else:
220 | # Now we can apply feed-forward layers as usual for classification
221 | net['l6_dense'] = L.DenseLayer(net['l3_reshape'], num_units=num_output_units,
222 | nonlinearity=lasagne.nonlinearities.softmax)
223 |
224 | # # Now, the shape will be (n_batch * n_timesteps, num_output_units). We can then reshape to
225 | # # n_batch to get num_output_units values for each timestep from each sequence
226 | net['l7_out_flattened'] = L.ReshapeLayer(net['l6_dense'], (-1, num_output_units))
227 | net['l7_out'] = L.ReshapeLayer(net['l6_dense'], (batch_size, -1, num_output_units))
228 |
229 | # we only want the predictions at for the valid frames, not for every frame. We can do this with theano mask functions for audio-only nets,
230 | # but that's not possible if there are subsequent networks as lasagne needs lasagne layers, not theano functions.
231 | # -> use lasagne slice layer to extract valid predictions
232 | net['l7_out_valid_basic'] = L.SliceLayer(net['l7_out'], indices=valid_indices, axis=1)
233 | net['l7_out_valid'] = L.ReshapeLayer(net['l7_out_valid_basic'], (batch_size, -1, num_output_units))
234 | net['l7_out_valid_flattened'] = L.ReshapeLayer(net['l7_out_valid_basic'], (-1, num_output_units))
235 |
236 | if debug:
237 | get_l_out = theano.function([net['l1_in'].input_var, net['l1_mask'].input_var], L.get_output(net['l7_out']))
238 | l_out = get_l_out(self.X, self.masks)
239 |
240 | # this only works for batch_size == 1
241 | get_l_out_valid = theano.function([audio_inputs, audio_masks, valid_indices],
242 | L.get_output(net['l7_out_valid']))
243 | try:
244 | l_out_valid = get_l_out_valid(self.X, self.masks, self.valid_frames)
245 | self.self.logger.debug('\n\n\n l_out: %s | l_out_valid: %s', l_out.shape, l_out_valid.shape);
246 | except:
247 | self.self.logger.warning("batchsize not 1, get_valid not working")
248 |
249 | if debug: self.print_network_structure(net)
250 | self.network_lout = net['l7_out_flattened']
251 | self.network_lout_batch = net['l7_out']
252 | self.network_lout_valid = net['l7_out_valid']
253 | self.network_lout_valid_flattened = net['l7_out_valid_flattened']
254 |
255 | self.network = net
256 |
257 | def print_network_structure(self, net=None):
258 | if net == None: net = self.network
259 |
260 | self.logger.debug("\n PRINTING Network structure: \n %s ", sorted(net.keys()))
261 | for key in sorted(net.keys()):
262 | if 'lstm' in key:
263 | for layer in net['l2_lstm']:
264 | try:
265 | self.logger.debug('Layer: %12s | in: %s | out: %s', key, layer.input_shape, layer.output_shape)
266 | except:
267 | self.logger.debug('Layer: %12s | out: %s', key, layer.output_shape)
268 | else:
269 | try:
270 | self.logger.debug('Layer: %12s | in: %s | out: %s', key, net[key].input_shape, net[key].output_shape)
271 | except:
272 | self.logger.debug('Layer: %12s | out: %s', key, net[key].output_shape)
273 | return 0
274 |
275 | def use_best_param(self):
276 | L.set_all_param_values(self.network, self.best_param)
277 | self.curr_epoch = self.best_epoch
278 | # Remove the network_train_info entries newer than self.best_epoch
279 | del self.network_train_info[0][self.best_epoch:]
280 | del self.network_train_info[1][self.best_epoch:]
281 | del self.network_train_info[2][self.best_epoch:]
282 |
283 | def load_model(self, model_name, round_params=False):
284 | model_path = model_name + '.npz'
285 | if self.network is not None:
286 | try:
287 | #self.logger.info("Loading stored model %s...", model_path)
288 |
289 | # restore network weights
290 | with np.load(model_path) as f:
291 | param_values = [f['arr_%d' % i] for i in range(len(f.files))][0]
292 | if round_params:
293 | self.logger.info("ROUND PARAMS")
294 | param_values = self.roundParams(param_values)
295 | L.set_all_param_values(self.network_lout, param_values)
296 |
297 | self.logger.info("Loading parameters successful.")
298 | return 0
299 |
300 | except IOError as e:
301 | print(os.strerror(e.errno))
302 | #self.logger.warning('Model: {} not found. No weights loaded'.format(model_path))
303 | return -1
304 | else:
305 | raise IOError('You must build the network before loading the weights.')
306 | return -1
307 |
308 | def roundParams(self, param_values):
309 | # round by converting to float16 and back to flort32
310 | # up to 12 bit should be possible without accuracy loss, but I don't know how to do this
311 | for i in range(len(param_values)):
312 | param_values[i] = param_values[i].astype(np.float16)
313 | param_values[i] = param_values[i].astype(np.float32)
314 |
315 | return param_values
316 |
317 | def save_model(self, model_name):
318 | if not os.path.exists(os.path.dirname(model_name)):
319 | os.makedirs(os.path.dirname(model_name))
320 | np.savez(model_name + '.npz', self.best_param)
321 |
322 | # also restore the updates variables to continue training. LR should also be saved and restored...
323 | # updates_vals = [p.get_value() for p in self.best_updates.keys()]
324 | # np.savez(model_name + '_updates.npz', updates_vals)
325 |
326 | # TODO use combinedSR.py for a working version
327 | def create_confusion(self, X, y, debug=False):
328 | argmax_fn = self.training_fn[1]
329 |
330 | y_pred = []
331 | for X_obs in X:
332 | for x in argmax_fn(X_obs):
333 | for j in x:
334 | y_pred.append(j)
335 |
336 | y_actu = []
337 | for Y in y:
338 | for y in Y:
339 | y_actu.append(y)
340 |
341 | conf_img = np.zeros([61, 61])
342 | assert (len(y_pred) == len(y_actu))
343 |
344 | for i in range(len(y_pred)):
345 | row_idx = y_actu[i]
346 | col_idx = y_pred[i]
347 | conf_img[row_idx, col_idx] += 1
348 |
349 | return conf_img, y_pred, y_actu
350 |
351 | def build_functions(self, train=False, debug=False):
352 |
353 | # LSTM in lasagne: see https://github.com/craffel/Lasagne-tutorial/blob/master/examples/recurrent.py
354 | # and also http://colinraffel.com/talks/hammer2015recurrent.pdf
355 | target_var = self.audio_targets_var # T.imatrix('audio_targets')
356 |
357 | if debug: self.print_network_structure()
358 |
359 | network_output = L.get_output(self.network_lout_batch)
360 | network_output_flattened = L.get_output(self.network_lout) # (batch_size * batch_max_seq_length, nb_phonemes)
361 |
362 | # compare targets with highest output probability. Take maximum of all probs (3rd axis (index 2) of output:
363 | # 1=batch_size (input files), 2 = time_seq (frames), 3 = n_features (phonemes)
364 | # network_output.shape = (len(X), 39) -> (nb_inputs, nb_classes)
365 | predictions = (T.argmax(network_output, axis=2))
366 |
367 | if debug:
368 | self.predictions_fn = theano.function([self.audio_inputs_var, self.audio_masks_var], predictions,
369 | name='predictions_fn')
370 |
371 | predicted = self.predictions_fn(self.X, self.masks)
372 | self.logger.debug('predictions_fn(X).shape: %s', predicted.shape)
373 | # self.logger.debug('predictions_fn(X)[0], value: %s', predicted[0])
374 |
375 | self.output_fn = theano.function([self.audio_inputs_var, self.audio_masks_var], network_output,
376 | name='output_fn')
377 | n_out = self.output_fn(self.X, self.masks)
378 | self.logger.debug('network_output.shape: \t%s', n_out.shape);
379 | # self.logger.debug('network_output[0]: \n%s', n_out[0]);
380 |
381 | # # Function to determine the number of correct classifications
382 | # which video, and which frames in the video
383 | valid_indices_example, valid_indices_seqNr = self.audio_masks_var.nonzero()
384 | valid_indices_fn = theano.function([self.audio_masks_var], [valid_indices_example, valid_indices_seqNr],
385 | name='valid_indices_fn')
386 |
387 | # this gets a FLATTENED array of all the valid predictions of all examples of this batch (so not one row per example)
388 | # if you want to get the valid predictions per example, you need to use the valid_frames list (it tells you the number of valid frames per wav, so where to split this valid_predictions array)
389 | # of course this is trivial for batch_size_audio = 1, as all valid_predictions will belong to the one input wav
390 | valid_predictions = predictions[valid_indices_example, valid_indices_seqNr]
391 | valid_targets = target_var[valid_indices_example, valid_indices_seqNr]
392 | self.valid_targets_fn = theano.function([self.audio_masks_var, target_var], valid_targets,
393 | name='valid_targets_fn')
394 | self.valid_predictions_fn = theano.function([self.audio_inputs_var, self.audio_masks_var], valid_predictions,
395 | name='valid_predictions_fn')
396 |
397 | # get valid network output
398 | valid_network_output = network_output[valid_indices_example, valid_indices_seqNr]
399 | if debug:
400 | self.valid_network_output_fn = theano.function([self.audio_inputs_var, self.audio_masks_var],
401 | valid_network_output)
402 |
403 | # Functions for computing cost and training
404 | top1_acc = T.mean(lasagne.objectives.categorical_accuracy(valid_network_output, valid_targets, top_k=1))
405 | self.top1_acc_fn = theano.function(
406 | [self.audio_inputs_var, self.audio_masks_var, self.audio_targets_var], top1_acc)
407 | top3_acc = T.mean(lasagne.objectives.categorical_accuracy(valid_network_output, valid_targets, top_k=3))
408 | self.top3_acc_fn = theano.function(
409 | [self.audio_inputs_var, self.audio_masks_var, self.audio_targets_var], top3_acc)
410 |
411 |
412 | if debug:
413 | try:
414 | # only works with batch_size == 1
415 | # valid_preds2 = self.valid_predictions2_fn(self.X, self.masks, self.valid_frames)
416 | # self.logger.debug("all valid predictions of this batch: ")
417 | # self.logger.debug('valid_preds2.shape: %s', valid_preds2.shape)
418 | # self.logger.debug('valid_preds2, value: \n%s', valid_preds2)
419 |
420 | # valid_out = self.valid_network_fn(self.X, self.masks, self.valid_frames)
421 | # self.logger.debug('valid_out.shape: %s', valid_out.shape)
422 | # # self.logger.debug('valid_out, value: \n%s', valid_out)
423 |
424 | valid_example, valid_seqNr = valid_indices_fn(self.masks)
425 | self.logger.debug('valid_inds(masks).shape: %s', valid_example.shape)
426 |
427 | valid_output = self.valid_network_output_fn(self.X, self.masks)
428 | self.logger.debug("all valid outputs of this batch: ")
429 | self.logger.debug('valid_output.shape: %s', valid_output.shape)
430 |
431 | valid_preds = self.valid_predictions_fn(self.X, self.masks)
432 | self.logger.debug("all valid predictions of this batch: ")
433 | self.logger.debug('valid_preds.shape: %s', valid_preds.shape)
434 | self.logger.debug('valid_preds, value: \n%s', valid_preds)
435 |
436 | valid_targs = self.valid_targets_fn(self.masks, self.Y)
437 | self.logger.debug('valid_targets.shape: %s', valid_targs.shape)
438 | self.logger.debug('valid_targets, value: \n%s', valid_targs)
439 |
440 | top1 = self.top1_acc_fn(self.X, self.masks, self.Y)
441 | self.logger.debug("top 1 accuracy: %s", top1 * 100.0)
442 |
443 | top3 = self.top3_acc_fn(self.X, self.masks, self.Y)
444 | self.logger.debug("top 3 accuracy: %s", top3 * 100.0)
445 |
446 | except Exception as error:
447 | print('caught this error: ' + traceback.format_exc());
448 | import pdb;
449 | pdb.set_trace()
450 | # pdb.set_trace()
451 |
452 | ## from https://groups.google.com/forum/#!topic/lasagne-users/os0j3f_Th5Q
453 | # Pad your vector of labels and then mask the cost:
454 | # It's important to pad the label vectors with something valid such as zeros,
455 | # since they will still have to give valid costs that can be multiplied by the mask.
456 | # The shape of predictions, targets and mask should match:
457 | # (predictions as (batch_size*max_seq_len, n_features), the other two as (batch_size*max_seq_len,)) -> we need to get the flattened output of the network for this
458 |
459 |
460 | # this works, using theano masks
461 | cost_pointwise = lasagne.objectives.categorical_crossentropy(network_output_flattened, target_var.flatten())
462 | cost = lasagne.objectives.aggregate(cost_pointwise, self.audio_masks_var.flatten())
463 | weight_decay = 1e-5
464 | weightsl2 = lasagne.regularization.regularize_network_params(self.network_lout, lasagne.regularization.l2)
465 | cost += weight_decay * weightsl2
466 |
467 | self.validate_fn = theano.function([self.audio_inputs_var, self.audio_masks_var,
468 | self.audio_targets_var],
469 | [cost, top1_acc, top3_acc], name='validate_fn')
470 | self.cost_pointwise_fn = theano.function([self.audio_inputs_var, self.audio_masks_var, target_var],
471 | cost_pointwise, name='cost_pointwise_fn')
472 |
473 | if debug:
474 | self.logger.debug('cost pointwise: %s', self.cost_pointwise_fn(self.X, self.masks, self.Y))
475 |
476 | try:
477 | evaluate_cost = self.validate_fn(self.X, self.masks, self.Y)
478 | except:
479 | print('caught this error: ' + traceback.format_exc());
480 | pdb.set_trace()
481 | self.logger.debug('cost: {:.3f}'.format(float(evaluate_cost[0])))
482 | self.logger.debug('accuracy: {:.3f}'.format(float(evaluate_cost[1] * 100.0)))
483 | self.logger.debug('top 3 accuracy: {:.3f}'.format(float(evaluate_cost[2] * 100.0)))
484 |
485 | # pdb.set_trace()
486 |
487 | if train:
488 | LR = T.scalar('LR', dtype=theano.config.floatX)
489 | # Retrieve all trainable parameters from the network
490 | all_params = L.get_all_params(self.network_lout, trainable=True)
491 | self.updates = lasagne.updates.adam(loss_or_grads=cost, params=all_params, learning_rate=LR)
492 | self.train_fn = theano.function([self.audio_inputs_var, self.audio_masks_var,
493 | target_var, LR],
494 | [cost, top1_acc, top3_acc], updates=self.updates, name='train_fn')
495 |
496 | def shuffle(X, y, valid_frames):
497 |
498 | chunk_size = len(X)
499 | shuffled_range = range(chunk_size)
500 |
501 | X_buffer = np.copy(X[0:chunk_size])
502 | y_buffer = np.copy(y[0:chunk_size])
503 | valid_frames_buffer = np.copy(valid_frames[0:chunk_size])
504 |
505 | np.random.shuffle(shuffled_range)
506 |
507 | for i in range(chunk_size):
508 | X_buffer[i] = X[shuffled_range[i]]
509 | y_buffer[i] = y[shuffled_range[i]]
510 | valid_frames_buffer[i] = valid_frames[shuffled_range[i]]
511 |
512 | X[0: chunk_size] = X_buffer
513 | y[0: chunk_size] = y_buffer
514 | valid_frames[0: chunk_size] = valid_frames_buffer
515 |
516 | return X, y, valid_frames
517 |
518 | # This function trains the model a full epoch (on the whole dataset)
519 | def run_epoch(self, X, y, valid_frames, get_predictions=False, LR=None, batch_size=-1):
520 | if batch_size == -1: batch_size = self.batch_size
521 |
522 | cost = 0;
523 | accuracy = 0;
524 | top3_accuracy = 0
525 | nb_batches = len(X) / batch_size
526 |
527 | predictions = [] # only used if get_predictions = True
528 | for i in tqdm(range(nb_batches), total=nb_batches):
529 | batch_X = X[i * batch_size:(i + 1) * batch_size]
530 | batch_y = y[i * batch_size:(i + 1) * batch_size]
531 | batch_valid_frames = valid_frames[i * batch_size:(i + 1) * batch_size]
532 | batch_masks = generate_masks(batch_X, valid_frames=batch_valid_frames, batch_size=batch_size)
533 | # now pad inputs and target to maxLen
534 | batch_X = pad_sequences_X(batch_X)
535 | batch_y = pad_sequences_y(batch_y)
536 | # batch_valid_frames = pad_sequences_y(batch_valid_frames)
537 | # print("batch_X.shape: ", batch_X.shape)
538 | # print("batch_y.shape: ", batch_y.shape)
539 | # import pdb;pdb.set_trace()
540 | if LR != None:
541 | cst, acc, top3_acc = self.train_fn(batch_X, batch_masks, batch_y, LR) # training
542 | else:
543 | cst, acc, top3_acc = self.validate_fn(batch_X, batch_masks, batch_y) # validation
544 | cost += cst;
545 | accuracy += acc;
546 | top3_accuracy += top3_acc
547 |
548 | if get_predictions:
549 | prediction = self.predictions_fn(batch_X, batch_masks)
550 | # prediction = np.reshape(prediction, (nb_inputs, -1)) #only needed if predictions_fn is the flattened and not the batched version (see RNN_implementation.py)
551 | prediction = list(prediction)
552 | predictions = predictions + prediction
553 | # # some tests of valid predictions functions (this works :) )
554 | valid_predictions = self.valid_predictions_fn(batch_X, batch_masks)
555 | self.self.logger.debug("valid predictions: ", valid_predictions.shape)
556 |
557 | # # get valid predictions for video 0
558 | # self.get_validPredictions_video(valid_predictions, valid_frames, 0)
559 | # # and the targets for video 0
560 | # targets[0][valid_frames[0]]
561 | #
562 | cost /= nb_batches;
563 | accuracy /= nb_batches;
564 | top3_accuracy /= nb_batches
565 | if get_predictions:
566 | return cost, accuracy * 100.0, top3_accuracy * 100.0, predictions
567 | return cost, accuracy * 100.0, top3_accuracy * 100.0
568 |
569 | def train(self, dataset, save_name='Best_model', num_epochs=100, batch_size=1, LR_start=1e-4, LR_decay=1,
570 | compute_confusion=False, justTest=False, debug=False, roundParams=False,
571 | withNoise=False, noiseType='white', ratio_dB=0):
572 |
573 | X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test = dataset
574 |
575 | confusion_matrices = []
576 |
577 | # try to load performance metrics of stored model
578 | best_val_acc, test_acc, old_train_info = self.loadPreviousResults(
579 | save_name) # stores old_train_info into self.network_train_info
580 |
581 | self.logger.info("Initial best Val acc: %s", best_val_acc)
582 | self.logger.info("Initial best test acc: %s\n", test_acc)
583 | self.best_val_acc = best_val_acc
584 |
585 | self.logger.info("Pass over Test Set")
586 | test_cost, test_acc, test_topk_acc = self.run_epoch(X=X_test, y=y_test,
587 | valid_frames=valid_frames_test)
588 | self.logger.info("Test cost:\t\t{:.6f} ".format(test_cost))
589 | self.logger.info("Test accuracy:\t\t{:.6f} %".format(test_acc))
590 | self.logger.info("Test Top 3 accuracy:\t{:.6f} %".format(test_topk_acc))
591 |
592 | self.network_train_info['nb_params'] = lasagne.layers.count_params(self.network_lout_batch)
593 | if justTest:
594 | if os.path.exists(save_name + ".npz"):
595 | self.saveFinalResults(noiseType, ratio_dB, roundParams, save_name, test_acc, test_cost,
596 | test_topk_acc, withNoise)
597 | return 0
598 | # else do nothing and train anyway
599 | else:
600 | self.network_train_info['test_cost'].append(test_cost)
601 | self.network_train_info['test_acc'].append(test_acc)
602 | self.network_train_info['test_topk_acc'].append(test_topk_acc)
603 |
604 | self.logger.info("\n* Starting training...")
605 | LR = LR_start
606 | self.best_cost = 100
607 | for epoch in range(num_epochs):
608 | self.curr_epoch += 1
609 | epoch_time = time.time()
610 | self.logger.info("\n\nCURRENT EPOCH: %s", self.curr_epoch)
611 |
612 | self.logger.info("Pass over Training Set")
613 | train_cost, train_acc, train_topk_acc = self.run_epoch(X=X_train, y=y_train,
614 | valid_frames=valid_frames_train, LR=LR)
615 |
616 | self.logger.info("Pass over Validation Set")
617 | val_cost, val_acc, val_topk_acc = self.run_epoch(X=X_val, y=y_val, valid_frames=valid_frames_val)
618 |
619 | # Print epoch summary
620 | self.logger.info("Epoch {} of {} took {:.3f}s.".format(
621 | epoch + 1, num_epochs, time.time() - epoch_time))
622 | self.logger.info("Learning Rate:\t\t{:.6f} %".format(LR))
623 | self.logger.info("Training cost:\t{:.6f}".format(train_cost))
624 | self.logger.info("Validation Top 3 accuracy:\t{:.6f} %".format(val_topk_acc))
625 |
626 | self.logger.info("Validation cost:\t{:.6f} ".format(val_cost))
627 | self.logger.info("Validation accuracy:\t\t{:.6f} %".format(val_acc))
628 | self.logger.info("Validation Top 3 accuracy:\t{:.6f} %".format(val_topk_acc))
629 |
630 | # better model, so save parameters
631 | if val_acc > self.best_val_acc:
632 | # only reset if significant improvement
633 | if val_acc - self.best_val_acc > 0.2:
634 | self.epochsNotImproved = 0
635 | # store new parameters
636 | self.best_cost = val_cost
637 | self.best_val_acc = val_acc
638 | self.best_epoch = self.curr_epoch
639 | self.best_param = L.get_all_param_values(self.network_lout)
640 | self.best_updates = [p.get_value() for p in self.updates.keys()]
641 | self.logger.info("New best model found!")
642 | if save_name is not None:
643 | self.logger.info("Model saved as " + save_name)
644 | self.save_model(save_name)
645 |
646 | self.logger.info("Pass over Test Set")
647 | test_cost, test_acc, test_topk_acc = self.run_epoch(X=X_test, y=y_test,
648 | valid_frames=valid_frames_test)
649 | self.logger.info("Test cost:\t\t{:.6f} ".format(test_cost))
650 | self.logger.info("Test accuracy:\t\t{:.6f} %".format(test_acc))
651 | self.logger.info("Test Top 3 accuracy:\t{:.6f} %".format(test_topk_acc))
652 |
653 | # save the training info
654 | self.network_train_info['train_cost'].append(train_cost)
655 | self.network_train_info['val_cost'].append(val_cost)
656 | self.network_train_info['val_acc'].append(val_acc)
657 | self.network_train_info['val_topk_acc'].append(val_topk_acc)
658 | self.network_train_info['test_cost'].append(test_cost)
659 | self.network_train_info['test_acc'].append(test_acc)
660 | self.network_train_info['test_topk_acc'].append(test_topk_acc)
661 |
662 | saveToPkl(save_name + '_trainInfo.pkl', self.network_train_info)
663 | self.logger.info("Train info written to:\t %s", save_name + '_trainInfo.pkl')
664 |
665 | if compute_confusion:
666 | print("does not work here")
667 |
668 | # update LR, see if we can stop training
669 | LR = self.updateLR(LR, LR_decay)
670 |
671 | if self.epochsNotImproved >= 3:
672 | logging.warning("\n\nNo more improvements, stopping training...")
673 | self.logger.info("Pass over Test Set")
674 | test_cost, test_acc, test_topk_acc = self.run_epoch(X=X_test, y=y_test,
675 | valid_frames=valid_frames_test)
676 | self.logger.info("Test cost:\t\t{:.6f} ".format(test_cost))
677 | self.logger.info("Test accuracy:\t\t{:.6f} %".format(test_acc))
678 | self.logger.info("Test Top 3 accuracy:\t{:.6f} %".format(test_topk_acc))
679 |
680 | self.network_train_info['test_cost'][-1] = test_cost
681 | self.network_train_info['test_acc'][-1] = test_acc
682 | self.network_train_info['test_topk_acc'][-1] = test_topk_acc
683 |
684 | self.saveFinalResults(noiseType, ratio_dB, roundParams, save_name, test_acc, test_cost,
685 | test_topk_acc, withNoise)
686 |
687 | return test_cost, test_acc, test_topk_acc
688 | return test_cost, test_acc, test_topk_acc
689 |
690 |
691 | def saveFinalResults(self, noiseType, ratio_dB, roundParams, save_name, test_acc, test_cost, test_topk_acc,
692 | withNoise):
693 | if self.test_dataset != self.dataset:
694 | testType = "_" + self.test_dataset
695 | else:
696 | testType = ""
697 | if roundParams:
698 | testType = "_roundParams" + testType
699 |
700 | if withNoise:
701 | self.network_train_info[
702 | 'final_test_cost_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType] = test_cost
703 | self.network_train_info['final_test_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType] = test_acc
704 | self.network_train_info[
705 | 'final_test_top3_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType] = test_topk_acc
706 | else:
707 | self.network_train_info['final_test_cost' + testType] = test_cost
708 | self.network_train_info['final_test_acc' + testType] = test_acc
709 | self.network_train_info['final_test_top3_acc' + testType] = test_topk_acc
710 |
711 | saveToPkl(save_name + '_trainInfo.pkl', self.network_train_info)
712 | self.logger.info("Train info written to:\t %s", save_name + '_trainInfo.pkl')
713 |
714 | def get_validPredictions_video(self, valid_predictions, valid_frames, videoIndexInBatch):
715 | # get indices of the valid frames for each video, using the valid_frames
716 | nbValidsPerVideo = [len(el) for el in valid_frames]
717 |
718 | # each el is the sum of the els before. -> for example video 3, you need valid_predictions from indices[2] (inclusive) till indices[3] (not inclusive)
719 | indices = [0] + [np.sum(nbValidsPerVideo[:i + 1]) for i in range(len(nbValidsPerVideo))]
720 |
721 | # make a 2D list. Each el of the list is a list with the valid frames per video.
722 | videoPreds = [range(indices[videoIndex], indices[videoIndex + 1]) for videoIndex in range(
723 | len(valid_frames))]
724 | # assert len(videoPreds) == len(inputs) == len(valid_frames)
725 |
726 | # now you can get the frames for a specific video:
727 | return valid_predictions[videoPreds[videoIndexInBatch]]
728 |
729 | def loadPreviousResults(self, save_name):
730 | # try to load performance metrics of stored model
731 | best_val_acc = 0
732 | test_topk_acc = 0
733 | test_cost = 0
734 | test_acc = 0
735 | old_train_info = {}
736 | try:
737 | if os.path.exists(save_name + "_trainInfo.pkl"):
738 | old_train_info = unpickle(save_name + '_trainInfo.pkl')
739 | best_val_acc = max(old_train_info['val_acc'])
740 | test_cost = min(old_train_info['test_cost'])
741 | test_acc = max(old_train_info['test_acc'])
742 | self.network_train_info = old_train_info
743 | try:
744 | test_topk_acc = max(old_train_info['test_topk_acc'])
745 | except:
746 | pass
747 | except:
748 | pass
749 | return best_val_acc, test_acc, old_train_info
750 |
751 | def updateLR(self, LR, LR_decay):
752 | this_acc = self.network_train_info['val_acc'][-1]
753 | this_cost = self.network_train_info['val_cost'][-1]
754 | try:
755 | last_acc = self.network_train_info['val_acc'][-2]
756 | last_cost = self.network_train_info['val_cost'][-2]
757 | except:
758 | last_acc = -10
759 | last_cost = 10 * this_cost # first time it will fail because there is only 1 result stored
760 |
761 | # only reduce LR if not much improvment anymore
762 | if this_cost / float(last_cost) >= 0.98 or this_acc - last_acc < 0.2:
763 | self.logger.info(" Error not much reduced: %s vs %s. Reducing LR: %s", this_cost, last_cost, LR * LR_decay)
764 | self.epochsNotImproved += 1
765 | return LR * LR_decay
766 | else:
767 | self.epochsNotImproved = max(self.epochsNotImproved - 1, 0) # reduce by 1, minimum 0
768 | return LR
769 |
--------------------------------------------------------------------------------
/background/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/background/__init__.py
--------------------------------------------------------------------------------
/background/htkbook-3.5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/background/htkbook-3.5.pdf
--------------------------------------------------------------------------------
/background/plot_data.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from general_tools import *
4 |
5 | store_dir = os.path.expanduser('~/TCDTIMIT/audioSR/TIMIT/binary39/TIMIT')
6 | dataset_path = os.path.join(store_dir, 'TIMIT_39_ch.pkl')
7 |
8 | X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test = load_dataset(
9 | dataset_path)
10 |
11 | plt.figure(1)
12 | plt.title('Preprocessed data visualization')
13 | for i in range(1, 5):
14 | plt.subplot(2, 2, i)
15 | plt.axis('off')
16 | plt.imshow(X_train[i].T)
17 | plt.imshow(np.log(X_train[i].T))
18 | print(X_train[i].shape)
19 |
20 | plt.tight_layout()
21 | plt.show()
22 |
--------------------------------------------------------------------------------
/background/timit_phones.txt:
--------------------------------------------------------------------------------
1 | iy
2 | ih
3 | eh
4 | ae
5 | ah
6 | uw
7 | uh
8 | aa
9 | ey
10 | ay
11 | oy
12 | aw
13 | ow
14 | l
15 | r
16 | y
17 | w
18 | er
19 | m
20 | n
21 | ng
22 | ch
23 | jh
24 | dh
25 | b
26 | d
27 | dx
28 | g
29 | p
30 | t
31 | k
32 | z
33 | v
34 | f
35 | th
36 | s
37 | sh
38 | hh
39 | sil
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: DeepLearning
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - libgpuarray=0.7.1=vc9_0
7 | - pygpu=0.7.1=np113py27_0
8 | - _license=1.1=py27_1
9 | - alabaster=0.7.10=py27_0
10 | - anaconda-client=1.6.3=py27_0
11 | - anaconda=custom=py27_0
12 | - anaconda-navigator=1.6.2=py27_0
13 | - anaconda-project=0.6.0=py27_0
14 | - asn1crypto=0.22.0=py27_0
15 | - astroid=1.4.9=py27_0
16 | - astropy=2.0.1=np113py27_0
17 | - babel=2.4.0=py27_0
18 | - backports=1.0=py27_0
19 | - backports_abc=0.5=py27_0
20 | - beautifulsoup4=4.6.0=py27_0
21 | - bitarray=0.8.1=py27_1
22 | - blaze=0.10.1=py27_0
23 | - bleach=1.5.0=py27_0
24 | - bokeh=0.12.5=py27_1
25 | - boto=2.46.1=py27_0
26 | - bottleneck=1.2.1=np113py27_0
27 | - bzip2=1.0.6=vc9_3
28 | - cdecimal=2.3=py27_2
29 | - certifi=2016.2.28=py27_0
30 | - cffi=1.10.0=py27_0
31 | - chardet=3.0.3=py27_0
32 | - click=6.7=py27_0
33 | - cloudpickle=0.2.2=py27_0
34 | - clyent=1.2.2=py27_0
35 | - colorama=0.3.9=py27_0
36 | - comtypes=1.1.2=py27_0
37 | - configparser=3.5.0=py27_0
38 | - console_shortcut=0.1.1=py27_1
39 | - contextlib2=0.5.5=py27_0
40 | - cryptography=1.8.1=py27_0
41 | - cudatoolkit=8.0=1
42 | - curl=7.52.1=vc9_0
43 | - cycler=0.10.0=py27_0
44 | - cython=0.25.2=py27_0
45 | - cytoolz=0.8.2=py27_0
46 | - dask=0.14.3=py27_1
47 | - datashape=0.5.4=py27_0
48 | - decorator=4.0.11=py27_0
49 | - distributed=1.16.3=py27_0
50 | - docutils=0.13.1=py27_0
51 | - entrypoints=0.2.2=py27_1
52 | - enum34=1.1.6=py27_0
53 | - et_xmlfile=1.0.1=py27_0
54 | - fastcache=1.0.2=py27_1
55 | - flask=0.12.2=py27_0
56 | - flask-cors=3.0.2=py27_0
57 | - freetype=2.5.5=vc9_2
58 | - funcsigs=1.0.2=py27_0
59 | - functools32=3.2.3.2=py27_0
60 | - futures=3.1.1=py27_0
61 | - get_terminal_size=1.0.0=py27_0
62 | - gevent=1.2.1=py27_0
63 | - greenlet=0.4.12=py27_0
64 | - grin=1.2.1=py27_3
65 | - h5py=2.7.0=np113py27_0
66 | - hdf5=1.8.15.1=vc9_4
67 | - heapdict=1.0.0=py27_1
68 | - html5lib=0.999=py27_0
69 | - icu=57.1=vc9_0
70 | - idna=2.5=py27_0
71 | - imagesize=0.7.1=py27_0
72 | - ipaddress=1.0.18=py27_0
73 | - ipykernel=4.6.1=py27_0
74 | - ipython=5.3.0=py27_0
75 | - ipython_genutils=0.2.0=py27_0
76 | - ipywidgets=6.0.0=py27_0
77 | - isort=4.2.5=py27_0
78 | - itsdangerous=0.24=py27_0
79 | - jdcal=1.3=py27_0
80 | - jedi=0.10.2=py27_2
81 | - jinja2=2.9.6=py27_0
82 | - jpeg=9b=vc9_0
83 | - jsonschema=2.6.0=py27_0
84 | - jupyter=1.0.0=py27_3
85 | - jupyter_client=5.0.1=py27_0
86 | - jupyter_console=5.1.0=py27_0
87 | - jupyter_core=4.3.0=py27_0
88 | - lasagne=0.1=py27_0
89 | - lazy-object-proxy=1.2.2=py27_0
90 | - libpng=1.6.27=vc9_0
91 | - libpython=2.0=py27_0
92 | - libtiff=4.0.6=vc9_3
93 | - llvmlite=0.20.0=py27_0
94 | - locket=0.2.0=py27_1
95 | - lxml=3.7.3=py27_0
96 | - m2w64-binutils=2.25.1=5
97 | - m2w64-bzip2=1.0.6=6
98 | - m2w64-crt-git=5.0.0.4636.2595836=2
99 | - m2w64-gcc=5.3.0=6
100 | - m2w64-gcc-ada=5.3.0=6
101 | - m2w64-gcc-fortran=5.3.0=6
102 | - m2w64-gcc-libgfortran=5.3.0=6
103 | - m2w64-gcc-libs=5.3.0=7
104 | - m2w64-gcc-libs-core=5.3.0=7
105 | - m2w64-gcc-objc=5.3.0=6
106 | - m2w64-gmp=6.1.0=2
107 | - m2w64-headers-git=5.0.0.4636.c0ad18a=2
108 | - m2w64-isl=0.16.1=2
109 | - m2w64-libiconv=1.14=6
110 | - m2w64-libmangle-git=5.0.0.4509.2e5a9a2=2
111 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
112 | - m2w64-make=4.1.2351.a80a8b8=2
113 | - m2w64-mpc=1.0.3=3
114 | - m2w64-mpfr=3.1.4=4
115 | - m2w64-pkg-config=0.29.1=2
116 | - m2w64-toolchain=5.3.0=7
117 | - m2w64-tools-git=5.0.0.4592.90b8472=2
118 | - m2w64-windows-default-manifest=6.4=3
119 | - m2w64-winpthreads-git=5.0.0.4634.697f757=2
120 | - m2w64-zlib=1.2.8=10
121 | - mako=1.0.6=py27_0
122 | - markupsafe=0.23=py27_2
123 | - matplotlib=2.0.2=np113py27_0
124 | - menuinst=1.4.7=py27_0
125 | - mistune=0.7.4=py27_0
126 | - mkl=2017.0.1=0
127 | - mkl-service=1.1.2=py27_3
128 | - mpmath=0.19=py27_1
129 | - msgpack-python=0.4.8=py27_0
130 | - msys2-conda-epoch=20160418=1
131 | - multipledispatch=0.4.9=py27_0
132 | - navigator-updater=0.1.0=py27_0
133 | - nbconvert=5.1.1=py27_0
134 | - nbformat=4.3.0=py27_0
135 | - networkx=1.11=py27_0
136 | - nltk=3.2.3=py27_0
137 | - nose=1.3.7=py27_1
138 | - nose-parameterized=0.6.0=py27_0
139 | - notebook=5.0.0=py27_0
140 | - numba=0.35.0=np113py27_0
141 | - numexpr=2.6.2=np113py27_0
142 | - numpy=1.13.1=py27_0
143 | - numpydoc=0.6.0=py27_0
144 | - odo=0.5.0=py27_1
145 | - olefile=0.44=py27_0
146 | - openpyxl=2.4.7=py27_0
147 | - openssl=1.0.2l=vc9_0
148 | - packaging=16.8=py27_0
149 | - pandas=0.20.3=py27_0
150 | - pandocfilters=1.4.1=py27_0
151 | - partd=0.3.8=py27_0
152 | - path.py=10.3.1=py27_0
153 | - pathlib2=2.2.1=py27_0
154 | - patsy=0.4.1=py27_0
155 | - pep8=1.7.0=py27_0
156 | - pickleshare=0.7.4=py27_0
157 | - pillow=4.1.1=py27_0
158 | - pip=9.0.1=py27_1
159 | - ply=3.10=py27_0
160 | - progressbar=2.3=py27_0
161 | - prompt_toolkit=1.0.14=py27_0
162 | - psutil=5.2.2=py27_0
163 | - py=1.4.33=py27_0
164 | - pycosat=0.6.2=py27_0
165 | - pycparser=2.17=py27_0
166 | - pycrypto=2.6.1=py27_6
167 | - pyculib=1.0.2=np113py27_2
168 | - pyculib_sorting=1.0.0=8
169 | - pycurl=7.43.0=py27_2
170 | - pydot-ng=1.0.0.15=py27_0
171 | - pyflakes=1.5.0=py27_0
172 | - pygments=2.2.0=py27_0
173 | - pylint=1.6.4=py27_1
174 | - pyodbc=4.0.16=py27_0
175 | - pyopenssl=17.0.0=py27_0
176 | - pyparsing=2.1.4=py27_0
177 | - pyqt=5.6.0=py27_2
178 | - pytables=3.2.2=np113py27_4
179 | - pytest=3.0.7=py27_0
180 | - python=2.7.13=1
181 | - python-dateutil=2.6.0=py27_0
182 | - pytz=2017.2=py27_0
183 | - pywavelets=0.5.2=np113py27_0
184 | - pywin32=220=py27_2
185 | - pyyaml=3.12=py27_0
186 | - pyzmq=16.0.2=py27_0
187 | - qt=5.6.2=vc9_4
188 | - qtawesome=0.4.4=py27_0
189 | - qtconsole=4.3.0=py27_0
190 | - qtpy=1.2.1=py27_0
191 | - requests=2.14.2=py27_0
192 | - rope=0.9.4=py27_1
193 | - ruamel_yaml=0.11.14=py27_1
194 | - scandir=1.5=py27_0
195 | - scikit-image=0.13.0=np113py27_0
196 | - scikit-learn=0.19.0=np113py27_0
197 | - scipy=0.19.1=np113py27_0
198 | - seaborn=0.7.1=py27_0
199 | - setuptools=36.4.0=py27_1
200 | - simplegeneric=0.8.1=py27_1
201 | - singledispatch=3.4.0.3=py27_0
202 | - sip=4.18=py27_0
203 | - six=1.10.0=py27_0
204 | - snowballstemmer=1.2.1=py27_0
205 | - sortedcollections=0.5.3=py27_0
206 | - sortedcontainers=1.5.7=py27_0
207 | - sphinx=1.6.3=py27_0
208 | - sphinxcontrib=1.0=py27_0
209 | - sphinxcontrib-websupport=1.0.1=py27_0
210 | - spyder=3.1.4=py27_0
211 | - sqlalchemy=1.1.9=py27_0
212 | - ssl_match_hostname=3.4.0.2=py27_1
213 | - statsmodels=0.8.0=np113py27_0
214 | - subprocess32=3.2.7=py27_0
215 | - sympy=1.0=py27_0
216 | - tblib=1.3.2=py27_0
217 | - testpath=0.3=py27_0
218 | - theano=0.9.0=py27_0
219 | - tk=8.5.18=vc9_0
220 | - toolz=0.8.2=py27_0
221 | - tornado=4.5.1=py27_0
222 | - traitlets=4.3.2=py27_0
223 | - typing=3.6.2=py27_0
224 | - unicodecsv=0.14.1=py27_0
225 | - vs2008_runtime=9.00.30729.5054=0
226 | - vs2015_runtime=14.0.25123=0
227 | - wcwidth=0.1.7=py27_0
228 | - werkzeug=0.12.2=py27_0
229 | - wheel=0.29.0=py27_0
230 | - widgetsnbextension=2.0.0=py27_0
231 | - win_unicode_console=0.5=py27_0
232 | - wincertstore=0.2=py27_0
233 | - wrapt=1.10.10=py27_0
234 | - xlrd=1.0.0=py27_0
235 | - xlsxwriter=0.9.6=py27_0
236 | - xlwings=0.10.4=py27_0
237 | - xlwt=1.2.0=py27_0
238 | - zict=0.1.2=py27_0
239 | - zlib=1.2.8=vc9_3
240 | - pip:
241 | - backports-abc==0.5
242 | - backports.shutil-get-terminal-size==1.0.0
243 | - backports.ssl-match-hostname==3.4.0.2
244 | - et-xmlfile==1.0.1
245 | - gpustat==0.3.2
246 | - ipython-genutils==0.2.0
247 | - jupyter-client==5.0.1
248 | - jupyter-console==5.1.0
249 | - jupyter-core==4.3.0
250 | - keras==2.0.8
251 | - lmdb==0.93
252 | - msgpack-numpy==0.4.1
253 | - prompt-toolkit==1.0.14
254 | - pypet==0.3.0
255 | - python-speech-features==0.6
256 | - tables==3.2.2
257 | - tabulate==0.7.7
258 | - tensorpack==0.5.0
259 | - termcolor==1.1.0
260 | - tqdm==4.17.1
261 | - win-unicode-console==0.5
262 | prefix: C:\Users\matthijsv\AppData\Local\Continuum\Anaconda2\envs\DeepLearning
263 |
264 |
--------------------------------------------------------------------------------
/getResults.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import warnings
4 | from pprint import pprint # printing properties of networkToRun objects
5 | from time import gmtime, strftime
6 |
7 | # pprint(vars(a))
8 |
9 | warnings.simplefilter("ignore", UserWarning) # cuDNN warning
10 |
11 | import logging
12 | import formatting
13 | from tqdm import tqdm
14 |
15 | logger_combined = logging.getLogger('combined')
16 | logger_combined.setLevel(logging.DEBUG)
17 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s '
18 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False))
19 |
20 | # create console handler with a higher log level
21 | ch = logging.StreamHandler()
22 | ch.setLevel(logging.DEBUG)
23 | ch.setFormatter(formatter)
24 | logger_combined.addHandler(ch)
25 |
26 | # File logger: see below META VARIABLES
27 |
28 | import time
29 |
30 | program_start_time = time.time()
31 |
32 | print("\n * Importing libraries...")
33 | from combinedNN_tools import *
34 | from general_tools import *
35 | import traceback
36 |
37 | ###################### Script settings #######################################
38 | root = os.path.expanduser('~/TCDTIMIT/')
39 | resultsPath = root + 'combinedSR/TCDTIMIT/results/allEvalResults.pkl'
40 |
41 | logToFile = True;
42 | overwriteResults = False
43 |
44 | # if you wish to force retrain of networks, set justTest to False, forceTrain in main() to True, and overwriteSubnets to True.
45 | # if True, and justTest=False, even if a network exists it will continue training. If False, it will just be evaluated
46 | forceTrain = False
47 |
48 | # JustTest: If True, mainGetResults just runs over the trained networks. If a network doesn't exist, it's skipped
49 | # If False, ask user to train networks, then start training networks that don't exist.
50 | justTest = False
51 |
52 | getConfusionMatrix = True # if True, stores confusionMatrix where the .npz and train_info.pkl are stored
53 | # use this for testing with reduced precision. It converts the network weights to float16, then back to float32 for execution.
54 | # This amounts to rounding. Performance should hardly be impacted.
55 | ROUND_PARAMS = False
56 |
57 | # use this to TEST trained networks on the test dataset with noise added.
58 | # This data is generated using audioSR/fixDataset/mergeAudiofiles.py + audioToPkl_perVideo.py and combinedSR/dataToPkl_lipspeakers.py
59 | # It is loaded in in combinedNN_tools/finalEvaluation (also just before training in 'train'. You could also generate noisy training data and train on that, but I haven't tried that
60 | withNoise = False
61 | noiseTypes = ['white', 'voices']
62 | ratio_dBs = [0, -3, -5, -10]
63 |
64 |
65 | ###################### Script code #######################################
66 |
67 | # quickly generate many networks
68 | def createNetworkList(dataset, networkArchs):
69 | networkList = []
70 | for networkArchi in networkArchs:
71 | networkList.append(networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=networkArch,
72 | audio_dataset=dataset, test_dataset=dataset))
73 | return networkList
74 |
75 | networkList = [
76 | # # # # # ### AUDIO ### -> see audioSR/RNN.py, there it can run in batch mode which is much faster
77 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64],
78 | # audio_dataset="TCDTIMIT", test_dataset="TCDTIMIT"),#,forceTrain=True),
79 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256],
80 | # audio_dataset="combined", test_dataset="TCDTIMIT"),#, forceTrain=True)
81 | networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256],
82 | audio_dataset="TCDTIMIT", test_dataset="TCDTIMIT"), # , forceTrain=True)
83 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512,512],
84 | # audio_dataset="TCDTIMIT", test_dataset="TCDTIMIT"),#,forceTrain=True),
85 |
86 | # run TCDTIMIT-trained network on lipspeakers, and on the real TCDTIMIT test set (volunteers)
87 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256],
88 | # audio_dataset="TCDTIMIT", test_dataset="TIMIT"), # , forceTrain=True) % 66.75 / 89.19
89 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256],
90 | # audio_dataset="TCDTIMIT", test_dataset="TCDTIMITvolunteers"), # ,forceTrain=True),
91 |
92 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64],
93 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True),
94 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256],
95 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True),
96 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512, 512],
97 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True),
98 |
99 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256],
100 | # audio_dataset="combined", test_dataset="TCDTIMIT"), # ,forceTrain=True),
101 |
102 |
103 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8], audio_dataset="TIMIT", test_dataset="TIMIT"),
104 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32], audio_dataset="TIMIT", test_dataset="TIMIT"),
105 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64], audio_dataset="TIMIT", test_dataset="TIMIT"),
106 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256], dataset="TIMIT", audio_dataset="TIMIT", test_dataset="TIMIT"),
107 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512], audio_dataset="TIMIT", test_dataset="TIMIT"),
108 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024], audio_dataset="TIMIT", test_dataset="TIMIT"),
109 | #
110 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8, 8], audio_dataset="TIMIT", test_dataset="TIMIT"),
111 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32, 32], audio_dataset="TIMIT", test_dataset="TIMIT"),
112 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT"),
113 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256, 256], audio_dataset="TIMIT", test_dataset="TIMIT"),
114 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512, 512], audio_dataset="TIMIT", test_dataset="TIMIT"),
115 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024,1024], audio_dataset="TIMIT", test_dataset="TIMIT"),
116 | #
117 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8,8,8], audio_dataset="TIMIT", test_dataset="TIMIT"),
118 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32,32,32], audio_dataset="TIMIT", test_dataset="TIMIT"),
119 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64,64], audio_dataset="TIMIT", test_dataset="TIMIT"),
120 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256,256,256], audio_dataset="TIMIT", test_dataset="TIMIT"),
121 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512,512,512], audio_dataset="TIMIT", test_dataset="TIMIT"),
122 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024, 1024, 1024], audio_dataset="TIMIT", test_dataset="TIMIT"),
123 | #
124 | #
125 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[8,8,8,8], audio_dataset="TIMIT", test_dataset="TIMIT"),
126 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[32,32,32,32], audio_dataset="TIMIT", test_dataset="TIMIT"),
127 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64,64,64,64], audio_dataset="TIMIT", test_dataset="TIMIT"),
128 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[256,256,256,256], audio_dataset="TIMIT", test_dataset="TIMIT"),
129 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[512,512,512,512], audio_dataset="TIMIT", test_dataset="TIMIT"),
130 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[1024,1024,1024,1024], audio_dataset="TIMIT", test_dataset="TIMIT"),
131 |
132 | # #get the MFCC results
133 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT",nbMFCCs=13),
134 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT",nbMFCCs=26),
135 | # networkToRun(runType="audio", AUDIO_LSTM_HIDDEN_LIST=[64, 64], audio_dataset="TIMIT", test_dataset="TIMIT",nbMFCCs=39),
136 | ]
137 |
138 | def main():
139 |
140 | # Use this if you want only want to start training if the network doesn't exist
141 | if withNoise:
142 | allResults = []
143 | for noiseType in noiseTypes:
144 | for ratio_dB in ratio_dBs:
145 | results, resultsPath = mainGetResults(networkList, withNoise, noiseType, ratio_dB)
146 | allResults.append(results)
147 | # print(allResults)
148 | # import pdb;pdb.set_trace()
149 |
150 | allNoisePath = root + 'resultsNoisy.pkl'
151 | exportResultsToExcelManyNoise(allResults, allNoisePath)
152 | else:
153 | results, resultsPath = mainGetResults(networkList)
154 | print("\n got all results")
155 | exportResultsToExcel(results, resultsPath)
156 |
157 | # Use this if you want to force run the network on train sets. If justTest==True, it will only evaluate performance on the test set
158 | runManyNetworks(networkList, withNoise=withNoise)
159 |
160 | class networkToRun:
161 | def __init__(self,
162 | AUDIO_LSTM_HIDDEN_LIST=[256, 256], audio_dataset="TCDTIMIT", nbMFCCs=39, audio_bidirectional=True,
163 | LR_start=0.001,
164 | forceTrain=False, runType='audio',
165 | dataset="TCDTIMIT", test_dataset=None):
166 | # Audio
167 | self.AUDIO_LSTM_HIDDEN_LIST = AUDIO_LSTM_HIDDEN_LIST # LSTM architecture for audio part
168 | self.audio_dataset = audio_dataset # training here only works for TCDTIMIT at the moment; for that go to audioSR/RNN.py. This variable is used to get the stored results from that python script
169 | self.nbMFCCs = nbMFCCs
170 | self.audio_bidirectional = audio_bidirectional
171 |
172 | # Others
173 | self.runType = runType
174 | self.LR_start = LR_start
175 | self.forceTrain = forceTrain # If False, just test the network outputs when the network already exists.
176 | # If forceTrain == True, train it anyway before testing
177 | # If True, set the LR_start low enough so you don't move too far out of the objective minimum
178 |
179 | self.dataset = dataset
180 | if test_dataset == None: self.test_dataset = self.dataset
181 | else: self.test_dataset = test_dataset
182 |
183 |
184 |
185 |
186 | def exportResultsToExcel(results, path):
187 | storePath = path.replace(".pkl", ".xlsx")
188 | import xlsxwriter
189 | workbook = xlsxwriter.Workbook(storePath)
190 |
191 | for runType in results.keys()[1:]: # audio, lipreading, combined:
192 | worksheet = workbook.add_worksheet(runType) # one worksheet per runType, but then everything is spread out...
193 | row = 0
194 |
195 | allNets = results[runType]
196 |
197 | # get and write the column titles
198 |
199 | # get the number of parameters. #for audio, only 1 value. For combined/lipreadin: lots of values in a dictionary
200 | try:
201 | nb_paramNames = allNets.items()[0][1][
202 | 'nb_params'].keys() # first key-value pair, get the value ([1]), then get names of nbParams (=the keys)
203 | except:
204 | nb_paramNames = ['nb_params']
205 | startVals = 4 + len(nb_paramNames) # column number of first value
206 |
207 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset'] + nb_paramNames + ['Test Cost',
208 | 'Test Accuracy',
209 | 'Test Top 3 Accuracy',
210 | 'Validation accuracy']
211 | for i in range(len(colNames)):
212 | worksheet.write(0, i, colNames[i])
213 |
214 | # write the data for each network
215 | for netName in allNets.keys():
216 | row += 1
217 |
218 | thisNet = allNets[netName]
219 | # write the path and name
220 | worksheet.write(row, 0, os.path.basename(netName)) # netName)
221 | worksheet.write(row, 1, thisNet['niceName'])
222 | if runType == 'audio':
223 | worksheet.write(row, 2, thisNet['audio_dataset'])
224 | worksheet.write(row, 3, thisNet['test_dataset'])
225 | else:
226 | worksheet.write(row, 2, thisNet['dataset'])
227 | worksheet.write(row, 3, thisNet['test_dataset'])
228 |
229 | # now write the params
230 | try:
231 | vals = thisNet['nb_params'].values() # vals is list of [test_cost, test_acc, test_top3_acc]
232 | except:
233 | vals = [thisNet['nb_params']]
234 | for i in range(len(vals)):
235 | worksheet.write(row, 4 + i, vals[i])
236 |
237 | # now write the values
238 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc]
239 | for i in range(len(vals)):
240 | worksheet.write(row, startVals + i, vals[i])
241 |
242 | workbook.close()
243 |
244 | logger_combined.info("Excel file stored in %s", storePath)
245 |
246 | def exportResultsToExcelManyNoise(resultsList, path):
247 | storePath = path.replace(".pkl", ".xlsx")
248 | import xlsxwriter
249 | workbook = xlsxwriter.Workbook(storePath)
250 |
251 | storePath = path.replace(".pkl", ".xlsx")
252 | import xlsxwriter
253 | workbook = xlsxwriter.Workbook(storePath)
254 |
255 | row = 0
256 |
257 | if len(resultsList[0]['audio'].keys()) > 0: thisRunType = 'audio'
258 | if len(resultsList[0]['lipreading'].keys()) > 0: thisRunType = 'lipreading'
259 | if len(resultsList[0]['combined'].keys()) > 0: thisRunType = 'combined'
260 | worksheetAudio = workbook.add_worksheet('audio');
261 | audioRow = 0
262 | worksheetLipreading = workbook.add_worksheet('lipreading');
263 | lipreadingRow = 0
264 | worksheetCombined = workbook.add_worksheet('combined');
265 | combinedRow = 0
266 |
267 | for r in range(len(resultsList)):
268 | results = resultsList[r]
269 | noiseType = results['resultsType']
270 |
271 | for runType in results.keys()[1:]:
272 | if len(results[runType]) == 0: continue
273 | if runType == 'audio': worksheet = worksheetAudio; row = audioRow
274 | if runType == 'lipreading': worksheet = worksheetLipreading; row = lipreadingRow
275 | if runType == 'combined': worksheet = worksheetCombined; row = combinedRow
276 |
277 | allNets = results[runType]
278 |
279 | # write the column titles
280 | startVals = 5
281 | colNames = ['Network Full Name', 'Network Name', 'Dataset', 'Test Dataset', 'Noise Type', 'Test Cost',
282 | 'Test Accuracy', 'Test Top 3 Accuracy']
283 | for i in range(len(colNames)):
284 | worksheet.write(0, i, colNames[i])
285 |
286 | # write the data for each network
287 | for netName in allNets.keys():
288 | row += 1
289 |
290 | thisNet = allNets[netName]
291 | # write the path and name
292 | worksheet.write(row, 0, os.path.basename(netName)) # netName)
293 | worksheet.write(row, 1, thisNet['niceName'])
294 | worksheet.write(row, 2, thisNet['dataset'])
295 | worksheet.write(row, 3, thisNet['test_dataset'])
296 | worksheet.write(row, 4, noiseType)
297 |
298 | # now write the values
299 | vals = thisNet['values'] # vals is list of [test_cost, test_acc, test_top3_acc]
300 | for i in range(len(vals)):
301 | worksheet.write(row, startVals + i, vals[i])
302 |
303 | if runType == 'audio': audioRow = row
304 | if runType == 'lipreading': lipreadingRow = row
305 | if runType == 'combined': combinedRow = row
306 |
307 | row += 1
308 |
309 | workbook.close()
310 |
311 | logger_combined.info("Excel file stored in %s", storePath)
312 |
313 | def getManyNetworkResults(networks, resultsType="unknownResults", roundParams=False, withNoise=False,
314 | noiseType='white', ratio_dB=0):
315 | results = {'resultsType': resultsType}
316 | results['audio'] = {}
317 | results['lipreading'] = {}
318 | results['combined'] = {}
319 |
320 | failures = []
321 |
322 | for networkParams in tqdm(networks, total=len(networks)):
323 | logger_combined.info("\n\n\n\n ################################")
324 | logger_combined.info("Getting results from network...")
325 | logger_combined.info("Network properties: ")
326 | # pprint(vars(networkParams))
327 | try:
328 | if networkParams.forceTrain == True:
329 | runManyNetworks([networkParams])
330 | thisResults = {}
331 | thisResults['values'] = []
332 | thisResults['dataset'] = networkParams.dataset
333 | thisResults['test_dataset'] = networkParams.test_dataset
334 | thisResults['audio_dataset'] = networkParams.audio_dataset
335 |
336 | model_name, nice_name = getModelName(networkParams.AUDIO_LSTM_HIDDEN_LIST, networkParams.dataset)
337 |
338 | logger_combined.info("Getting results for %s", model_name + '.npz')
339 | network_train_info = getNetworkResults(model_name)
340 | if network_train_info == -1:
341 | raise IOError("this model doesn't have any stored results")
342 | # import pdb;pdb.set_trace()
343 |
344 | # audio networks can be run on TIMIT or combined as well
345 | if networkParams.runType != 'audio' and networkParams.test_dataset != networkParams.dataset:
346 | testType = "_" + networkParams.test_dataset
347 | else:
348 | testType = ""
349 |
350 | if roundParams:
351 | testType = "_roundParams" + testType
352 |
353 | if networkParams.runType != 'lipreading' and withNoise:
354 | thisResults['values'] = [
355 | network_train_info['final_test_cost_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType],
356 | network_train_info['final_test_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType],
357 | network_train_info['final_test_top3_acc_' + noiseType + "_" + "ratio" + str(ratio_dB) + testType]]
358 | else:
359 | try:
360 | val_acc = max(network_train_info['val_acc'])
361 | except:
362 | try:
363 | val_acc = max(network_train_info['test_acc'])
364 | except:
365 | val_acc = network_train_info['final_test_acc']
366 | thisResults['values'] = [network_train_info['final_test_cost' + testType],
367 | network_train_info['final_test_acc' + testType],
368 | network_train_info['final_test_top3_acc' + testType], val_acc]
369 |
370 | thisResults['nb_params'] = network_train_info['nb_params']
371 | thisResults['niceName'] = nice_name
372 |
373 | # eg results['audio']['2Layer_256_256_TIMIT'] = [0.8, 79.5, 92,6] #test cost, test acc, test top3 acc
374 | results[networkParams.runType][model_name] = thisResults
375 |
376 | except:
377 | logger_combined.info('caught this error: ' + traceback.format_exc());
378 | # import pdb;pdb.set_trace()
379 | failures.append(networkParams)
380 |
381 | logger_combined.info("\n\nDONE getting stored results from networks")
382 | logger_combined.info("####################################################")
383 |
384 | if len(failures) > 0:
385 | logger_combined.info("Couldn't get %s results from %s networks...", resultsType, len(failures))
386 | for failure in failures:
387 | pprint(vars(failure))
388 | if autoTrain or query_yes_no("\nWould you like to evalute the networks now?\n\n"):
389 | logger_combined.info("Running networks...")
390 | runManyNetworks(failures, withNoise=withNoise, noiseType=noiseType, ratio_dB=ratio_dB)
391 | mainGetResults(failures, withNoise=withNoise, noiseType=noiseType, ratio_dB=ratio_dB)
392 |
393 | logger_combined.info("Done training.\n\n")
394 | # import pdb; pdb.set_trace()
395 | return results
396 |
397 |
398 | def getNetworkResults(save_name, logger=logger_combined): # copy-pasted from loadPreviousResults
399 | if os.path.exists(save_name + ".npz") and os.path.exists(save_name + "_trainInfo.pkl"):
400 | old_train_info = unpickle(save_name + '_trainInfo.pkl')
401 | # import pdb;pdb.set_trace()
402 | if type(old_train_info) == dict: # normal case
403 | network_train_info = old_train_info # load old train info so it won't get lost on retrain
404 |
405 | if not 'final_test_cost' in network_train_info.keys():
406 | network_train_info['final_test_cost'] = min(network_train_info['test_cost'])
407 | if not 'final_test_acc' in network_train_info.keys():
408 | network_train_info['final_test_acc'] = max(network_train_info['test_acc'])
409 | if not 'final_test_top3_acc' in network_train_info.keys():
410 | network_train_info['final_test_top3_acc'] = max(network_train_info['test_topk_acc'])
411 | else:
412 | logger.warning("old trainInfo found, but wrong format: %s", save_name + "_trainInfo.pkl")
413 | # do nothing
414 | else:
415 | return -1
416 | return network_train_info
417 |
418 |
419 | # networks is a list of dictionaries, where each dictionary contains the needed parameters for training
420 | def runManyNetworks(networks, withNoise=False, noiseType='white', ratio_dB=0):
421 | results = {}
422 | failures = []
423 | if justTest:
424 | logger_combined.warning("\n\n!!!!!!!!! WARNING !!!!!!!!!! \n justTest = True")
425 | if not query_yes_no("\nAre you sure you want to continue?\n\n"):
426 | return -1
427 | for network in tqdm(networks, total=len(networks)):
428 | print("\n\n\n\n ################################")
429 | print("Training new network...")
430 | print("Network properties: ")
431 | pprint(vars(network))
432 | try:
433 | model_save, test_results = runNetwork(AUDIO_LSTM_HIDDEN_LIST=network.AUDIO_LSTM_HIDDEN_LIST,
434 | audio_features=network.audio_features,
435 | audio_bidirectional=network.audio_bidirectional,
436 | CNN_NETWORK=network.CNN_NETWORK,
437 | cnn_features=network.cnn_features,
438 | LIP_RNN_HIDDEN_LIST=network.LIP_RNN_HIDDEN_LIST,
439 | lipRNN_bidirectional=network.lipRNN_bidirectional,
440 | lipRNN_features=network.lipRNN_features,
441 | DENSE_HIDDEN_LIST=network.DENSE_HIDDEN_LIST,
442 | combinationType=network.combinationType,
443 | dataset=network.dataset, datasetType=network.datasetType,
444 | test_dataset=network.test_dataset,
445 | addNoisyAudio=network.addNoisyAudio,
446 | runType=network.runType,
447 | LR_start=network.LR_start,
448 | allowSubnetTraining=network.allowSubnetTraining,
449 | forceTrain=network.forceTrain,
450 | overwriteSubnets=network.overwriteSubnets,
451 | audio_dataset=network.audio_dataset,
452 | withNoise=withNoise, noiseType=noiseType, ratio_dB=ratio_dB)
453 | print(model_save)
454 | name = model_save + ("_Noise" + noiseType + "_" + str(ratio_dB) if withNoise else "")
455 | results[name] = test_results # should be test_cost, test_acc, test_topk_acc
456 |
457 | except:
458 | print('caught this error: ' + traceback.format_exc());
459 | # import pdb; pdb.set_trace()
460 |
461 | failures.append(network)
462 | print("#########################################################")
463 | print("\n\n\n DONE running all networks")
464 |
465 | if len(failures) > 0:
466 | print("Some networks failed to run...")
467 | # import pdb;pdb.set_trace()
468 | return results
469 |
470 |
471 |
472 | if __name__ == "__main__":
473 | main()
474 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/__init__.py
--------------------------------------------------------------------------------
/tools/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/__init__.pyc
--------------------------------------------------------------------------------
/tools/createMLF.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import helpFunctions.writeToTxt as wrTxt
5 |
6 | """
7 | License: WTFPL http://www.wtfpl.net
8 | Copyright: Gabriel Synnaeve 2013
9 | """
10 |
11 | doc = """
12 | Usage:
13 | python createMLF.py [$folder_path]
14 |
15 | Will create "${folder}.mlf" and "labels.txt" files in $folder_path.
16 |
17 | If you run it only on the training folder, all the phones that you will
18 | encounter in the test should be present in training so that the "labels"
19 | corresponds.
20 | """
21 |
22 |
23 | def process(folder):
24 | folder = folder.rstrip('/')
25 | countPhonemes = {}
26 | master_label_fname = folder + '/' + folder.split('/')[-1] + '.mlf'
27 | labels_fpath = folder + '/labels.txt'
28 | master_label_file = open(master_label_fname, 'w')
29 | master_label_file.write("#!MLF!#\n")
30 |
31 | for d, ds, fs in os.walk(folder):
32 | for fname in fs:
33 | fullname = d.rstrip('/') + '/' + fname
34 | print("Processing: ", fullname)
35 | extension = fname[-4:]
36 |
37 | phones = []
38 | if extension.lower() == '.phn':
39 | master_label_file.write('"' + fullname + '"\n')
40 | for line in open(fullname):
41 | master_label_file.write(line)
42 | phones.append(line.split()[2])
43 | for tmp_phn in phones:
44 | countPhonemes[tmp_phn] = countPhonemes.get(tmp_phn, 0) + 1
45 | master_label_file.write('\n.\n')
46 |
47 | master_label_file.close()
48 | print("written MLF file", master_label_fname)
49 |
50 | wrTxt.writeToTxt(sorted(countPhonemes.items()), labels_fpath)
51 | print("written labels", labels_fpath)
52 |
53 | print("phones counts:", countPhonemes)
54 | print("number of phones:", len(countPhonemes))
55 |
56 |
57 | if __name__ == '__main__':
58 | if len(sys.argv) > 1:
59 | if '--help' in sys.argv:
60 | print(doc)
61 | sys.exit(0)
62 | l = filter(lambda x: not '--' in x[0:2], sys.argv)
63 | print(l)
64 | foldername = '.'
65 | if len(l) > 1:
66 | foldername = l[1]
67 | print(foldername)
68 | process(foldername)
69 | else:
70 | process('.') # default
71 |
--------------------------------------------------------------------------------
/tools/datasetToPkl.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import timeit;
4 |
5 | program_start_time = timeit.default_timer()
6 | import random
7 |
8 | random.seed(int(timeit.default_timer()))
9 |
10 | from phoneme_set import phoneme_set_39_list
11 | import formatting, preprocessWavs
12 | import general_tools
13 |
14 | import logging
15 |
16 | # prepare logging (File logger: see below META VARIABLES)
17 | logger = logging.getLogger('PrepTCDTIMIT')
18 | logger.setLevel(logging.DEBUG)
19 | FORMAT = '[$BOLD%(filename)s$RESET:%(lineno)d][%(levelname)-5s]: %(message)s '
20 | formatter = logging.Formatter(formatting.formatter_message(FORMAT, False))
21 |
22 | # create console handler with a higher log level
23 | ch = logging.StreamHandler()
24 | ch.setLevel(logging.DEBUG)
25 | ch.setFormatter(formatter)
26 | logger.addHandler(ch)
27 |
28 |
29 |
30 | ##### META VARIABLES #####
31 | DEBUG = False
32 | debug_size = 50
33 |
34 | # TODO: MODIFY THESE PARAMETERS for other nbPhonemes of mfccTypes. Save location is updated automatically.
35 | nbMFCCs = 39 # 13 => just mfcc (13 features). 26 => also derivative (26 features). 39 => also 2nd derivative (39 features)
36 |
37 | # set phoneme type and get dictionary with phoneme-number mappings
38 | nbPhonemes = 39
39 | phoneme_set_list = phoneme_set_39_list # import list of phonemes,
40 | values = [i for i in range(0, len(phoneme_set_list))]
41 | phoneme_classes = dict(zip(phoneme_set_list, values))
42 |
43 | ##### DATA Settings #####
44 |
45 | FRAC_VAL = 0.1 # fraction of training data to be used for validation
46 |
47 | root = os.path.expanduser("~/TCDTIMIT/audioSR/") # (keep the trailing slash)
48 | dataPreSplit = True # some datasets have a pre-defined TEST set (eg TIMIT). Set the dataset name below
49 |
50 | if dataPreSplit:
51 | dataset = "TIMIT" # eg TIMIT. You can also manually split up TCDTIMIT according to train/test split in Harte, N.; Gillen, E., "TCD-TIMIT: An Audio-Visual Corpus of Continuous Speech," doi: 10.1109/TMM.2015.2407694
52 | ## eg TIMIT ##
53 | dataRootDir = root + dataset + "/fixed" + str(nbPhonemes) + os.sep + dataset
54 | train_source_path = os.path.join(dataRootDir, 'TRAIN')
55 | test_source_path = os.path.join(dataRootDir, 'TEST')
56 | outputDir = root + dataset + "/binary2" + str(nbPhonemes) + os.sep + dataset
57 |
58 | else:
59 | ## just a bunch of wav and phn files, not split up in train and test -> need to create the split yourself.
60 | dataset = "TCDTIMIT"
61 | dataRootDir = root + dataset + "/fixed" + str(nbPhonemes) + "_nonSplit" + os.sep + dataset
62 | outputDir = root + dataset + "/binary" + str(nbPhonemes) + os.sep + os.path.basename(dataRootDir)
63 | # TOTAL = TRAINING + TEST = TRAIN + VALIDATION + TEST
64 | FRAC_TEST = 0.1
65 | FRAC_TRAINING = 1 - FRAC_TEST # val set will be FRAC_TRAINING * FRAC_VAL = 9% of the data. FRAC_TRAIN is 90 - 9 = 81%, test = 10
66 |
67 |
68 | ##### Everything below is calculated automatically ##########
69 | #############################################################
70 |
71 | # store path
72 | target = os.path.join(outputDir, os.path.basename(dataRootDir) + '_' + str(nbMFCCs) + '_ch');
73 | target_path = target + '.pkl'
74 | if not os.path.exists(outputDir):
75 | os.makedirs(outputDir)
76 |
77 | # Already exists, ask if overwrite
78 | if (os.path.exists(target_path)):
79 | if (not general_tools.query_yes_no(target_path + " exists. Overwrite?", "no")):
80 | raise Exception("Not Overwriting")
81 |
82 | # set log file
83 | logFile = outputDir + os.sep + os.path.basename(target) + '.log'
84 | fh = logging.FileHandler(logFile, 'w') # create new logFile
85 | fh.setLevel(logging.DEBUG)
86 | fh.setFormatter(formatter)
87 | logger.addHandler(fh)
88 |
89 | if DEBUG:
90 | logger.info('DEBUG mode: \tACTIVE, only a small dataset will be preprocessed')
91 | target_path = target + '_DEBUG.pkl'
92 | else:
93 | logger.info('DEBUG mode: \tDEACTIVE')
94 | debug_size = None
95 |
96 | ##### The PREPROCESSING itself #####
97 | logger.info('Preprocessing data ...')
98 |
99 |
100 | # FIRST, gather the WAV and PHN files, generate MFCCs, extract labels to make inputs and targets for the network
101 | # for a dataset containing no TRAIN/TEST subdivision, just a bunch of wavs -> choose training set yourself
102 | def processDataset(FRAC_TRAINING, data_source_path, logger=None):
103 | logger.info(' Data: %s ', data_source_path)
104 | X_all, y_all, valid_frames_all = preprocessWavs.preprocess_dataset(source_path=data_source_path, nbMFCCs=nbMFCCs,
105 | logger=logger, debug=debug_size)
106 | assert len(X_all) == len(y_all) == len(valid_frames_all)
107 |
108 | logger.info(' Loading data complete.')
109 | logger.debug('Type and shape/len of X_all')
110 | logger.debug('type(X_all): {}'.format(type(X_all)))
111 | logger.debug('type(X_all[0]): {}'.format(type(X_all[0])))
112 | logger.debug('type(X_all[0][0]): {}'.format(type(X_all[0][0])))
113 | logger.debug('type(X_all[0][0][0]): {}'.format(type(X_all[0][0][0])))
114 | logger.info('Creating Validation index ...')
115 |
116 | total_size = len(X_all) # TOTAL = TRAINING + TEST = TRAIN + VAL + TEST
117 | total_training_size = int(math.ceil(FRAC_TRAINING * total_size)) # TRAINING = TRAIN + VAL
118 | test_size = total_size - total_training_size
119 |
120 | # split off a 'test' dataset
121 | test_idx = random.sample(range(0, total_training_size), test_size)
122 | test_idx = [int(i) for i in test_idx]
123 | # ensure that the testidation set isn't empty
124 | if DEBUG:
125 | test_idx[0] = 0
126 | test_idx[1] = 1
127 | logger.info('Separating test and training set ...')
128 | X_training = []
129 | y_training = []
130 | valid_frames_training = []
131 | X_test = []
132 | y_test = []
133 | valid_frames_test = []
134 | for i in range(len(X_all)):
135 | if i in test_idx:
136 | X_test.append(X_all[i])
137 | y_test.append(y_all[i])
138 | valid_frames_test.append(valid_frames_all[i])
139 | else:
140 | X_training.append(X_all[i])
141 | y_training.append(y_all[i])
142 | valid_frames_training.append(valid_frames_all[i])
143 |
144 | assert len(X_test) == test_size
145 | assert len(X_training) == total_training_size
146 |
147 | return X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test
148 |
149 |
150 | def processDatasetSplit(train_source_path, test_source_path, logger=None):
151 | logger.info(' Training data: %s ', train_source_path)
152 | X_training, y_training, valid_frames_training = preprocessWavs.preprocess_dataset(source_path=train_source_path,
153 | logger=logger,
154 | nbMFCCs=nbMFCCs, debug=debug_size)
155 | logger.info(' Test data: %s', test_source_path)
156 | X_test, y_test, valid_frames_test = preprocessWavs.preprocess_dataset(source_path=test_source_path, logger=logger,
157 | nbMFCCs=nbMFCCs, debug=debug_size)
158 | return X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test
159 |
160 |
161 | if dataPreSplit:
162 | X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test = \
163 | processDatasetSplit(train_source_path, test_source_path, logger)
164 | else:
165 | X_training, y_training, valid_frames_training, X_test, y_test, valid_frames_test = \
166 | processDataset(FRAC_TRAINING, dataRootDir, logger)
167 |
168 | # SECOND, split off a 'validation' set from the training set. The remainder is the 'train' set
169 | total_training_size = len(X_training)
170 | val_size = int(math.ceil(total_training_size * FRAC_VAL))
171 | train_size = total_training_size - val_size
172 | val_idx = random.sample(range(0, total_training_size), val_size) # choose random indices to be validation data
173 | val_idx = [int(i) for i in val_idx]
174 |
175 | logger.info('Length of training')
176 | logger.info(" train X: %s", len(X_training))
177 |
178 | # ensure that the validation set isn't empty
179 | if DEBUG:
180 | val_idx[0] = 0
181 | val_idx[1] = 1
182 |
183 | logger.info('Separating training set into validation and train ...')
184 | X_train = []
185 | y_train = []
186 | valid_frames_train = []
187 | X_val = []
188 | y_val = []
189 | valid_frames_val = []
190 | for i in range(len(X_training)):
191 | if i in val_idx:
192 | X_val.append(X_training[i])
193 | y_val.append(y_training[i])
194 | valid_frames_val.append(valid_frames_training[i])
195 | else:
196 | X_train.append(X_training[i])
197 | y_train.append(y_training[i])
198 | valid_frames_train.append(valid_frames_training[i])
199 | assert len(X_val) == val_size
200 |
201 | # Print some information
202 | logger.info('Length of train, val, test')
203 | logger.info(" train X: %s", len(X_train))
204 | logger.info(" train y: %s", len(y_train))
205 | logger.info(" train valid_frames: %s", len(valid_frames_train))
206 |
207 | logger.info(" val X: %s", len(X_val))
208 | logger.info(" val y: %s", len(y_val))
209 | logger.info(" val valid_frames: %s", len(valid_frames_val))
210 |
211 | logger.info(" test X: %s", len(X_test))
212 | logger.info(" test y: %s", len(y_test))
213 | logger.info(" test valid_frames: %s", len(valid_frames_test))
214 |
215 | ### NORMALIZE data ###
216 | logger.info('Normalizing data ...')
217 | logger.info(' Each channel mean=0, sd=1 ...')
218 |
219 | mean_val, std_val, _ = preprocessWavs.calc_norm_param(X_train)
220 |
221 | X_train = preprocessWavs.normalize(X_train, mean_val, std_val)
222 | X_val = preprocessWavs.normalize(X_val, mean_val, std_val)
223 | X_test = preprocessWavs.normalize(X_test, mean_val, std_val)
224 |
225 | logger.debug('X train')
226 | logger.debug(' %s %s', type(X_train), len(X_train))
227 | logger.debug(' %s %s', type(X_train[0]), X_train[0].shape)
228 | logger.debug(' %s %s', type(X_train[0][0]), X_train[0][0].shape)
229 | logger.debug(' %s %s', type(X_train[0][0][0]), X_train[0][0].shape)
230 | logger.debug('y train')
231 | logger.debug(' %s %s', type(y_train), len(y_train))
232 | logger.debug(' %s %s', type(y_train[0]), y_train[0].shape)
233 | logger.debug(' %s %s', type(y_train[0][0]), y_train[0][0].shape)
234 |
235 | # make sure we're working with float32
236 | X_data_type = 'float32'
237 | X_train = preprocessWavs.set_type(X_train, X_data_type)
238 | X_val = preprocessWavs.set_type(X_val, X_data_type)
239 | X_test = preprocessWavs.set_type(X_test, X_data_type)
240 |
241 | y_data_type = 'int32'
242 | y_train = preprocessWavs.set_type(y_train, y_data_type)
243 | y_val = preprocessWavs.set_type(y_val, y_data_type)
244 | y_test = preprocessWavs.set_type(y_test, y_data_type)
245 |
246 | valid_frames_data_type = 'int32'
247 | valid_frames_train = preprocessWavs.set_type(valid_frames_train, valid_frames_data_type)
248 | valid_frames_val = preprocessWavs.set_type(valid_frames_val, valid_frames_data_type)
249 | valid_frames_test = preprocessWavs.set_type(valid_frames_test, valid_frames_data_type)
250 |
251 | # print some more to check that cast succeeded
252 | logger.debug('X train')
253 | logger.debug(' %s %s', type(X_train), len(X_train))
254 | logger.debug(' %s %s', type(X_train[0]), X_train[0].shape)
255 | logger.debug(' %s %s', type(X_train[0][0]), X_train[0][0].shape)
256 | logger.debug(' %s %s', type(X_train[0][0][0]), X_train[0][0].shape)
257 | logger.debug('y train')
258 | logger.debug(' %s %s', type(y_train), len(y_train))
259 | logger.debug(' %s %s', type(y_train[0]), y_train[0].shape)
260 | logger.debug(' %s %s', type(y_train[0][0]), y_train[0][0].shape)
261 |
262 | ### STORE DATA ###
263 | logger.info('Saving data to %s', target_path)
264 | dataList = [X_train, y_train, valid_frames_train, X_val, y_val, valid_frames_val, X_test, y_test, valid_frames_test]
265 | general_tools.saveToPkl(target_path, dataList)
266 |
267 | # these can be used to evaluate new data, so you don't have to load the whole dataset just to normalize
268 | meanStd_path = os.path.dirname(outputDir) + os.sep + os.path.basename(dataRootDir) + "MeanStd.pkl"
269 | logger.info('Saving Mean and Std_val to %s', meanStd_path)
270 | dataList = [mean_val, std_val]
271 | general_tools.saveToPkl(meanStd_path, dataList)
272 |
273 | logger.info('Preprocessing complete!')
274 | logger.info('Total time: {:.3f}'.format(timeit.default_timer() - program_start_time))
275 |
--------------------------------------------------------------------------------
/tools/formatting.py:
--------------------------------------------------------------------------------
1 | # from http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output?rq=1
2 | import logging
3 | import os
4 |
5 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
6 |
7 | # The background is set with 40 plus the number of the color, and the foreground with 30
8 |
9 | # These are the sequences need to get colored ouput
10 | RESET_SEQ = "\033[0m"
11 | COLOR_SEQ = "\033[1;%dm"
12 | BOLD_SEQ = "\033[1m"
13 |
14 |
15 | def formatter_message(message, use_color=True):
16 | if use_color:
17 | message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ)
18 | else:
19 | message = message.replace("$RESET", "").replace("$BOLD", "")
20 | return message
21 |
22 |
23 | COLORS = {
24 | 'WARNING': YELLOW,
25 | 'INFO': WHITE,
26 | 'DEBUG': BLUE,
27 | 'CRITICAL': YELLOW,
28 | 'ERROR': RED
29 | }
30 |
31 |
32 | class ColoredFormatter(logging.Formatter):
33 | def __init__(self, msg, use_color=True):
34 | logging.Formatter.__init__(self, msg)
35 | self.use_color = use_color
36 |
37 | def format(self, record):
38 | levelname = record.levelname
39 | if self.use_color and levelname in COLORS:
40 | levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ
41 | record.levelname = levelname_color
42 | return logging.Formatter.format(self, record)
43 |
44 |
45 | # Custom logger class with multiple destinations
46 | class ColoredLogger(logging.Logger):
47 | FORMAT = '%(asctime)s - (%(filename)s:%(lineno)d) | %(message)s' # "[$BOLD%(name)-5s$RESET][%(levelname)-10s]($BOLD%(filename)s$RESET:%(lineno)d) %(message)s "
48 | COLOR_FORMAT = formatter_message(FORMAT, True)
49 |
50 | def __init__(self, name):
51 | logging.Logger.__init__(self, name, logging.WARNING)
52 |
53 | color_formatter = ColoredFormatter(self.COLOR_FORMAT)
54 |
55 | console = logging.StreamHandler()
56 | console.setFormatter(color_formatter)
57 | console.setLevel(logging.INFO)
58 |
59 | self.addHandler(console)
60 | return
61 |
62 | def addFileHandler(self, output_dir='.', log_name="logger.log"):
63 | fileHandler = logging.FileHandler(os.path.join(output_dir, log_name), "w", encoding=None, delay=True)
64 | fileHandler.setFormatter(ColoredFormatter(self.COLOR_FORMAT))
65 | fileHandler.setLevel(logging.DEBUG)
66 | self.addHandler(fileHandler)
67 |
--------------------------------------------------------------------------------
/tools/formatting.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/formatting.pyc
--------------------------------------------------------------------------------
/tools/general_tools.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | import numpy as np
6 | from six.moves import cPickle
7 |
8 | logger_GeneralTools = logging.getLogger('audioSR.generalTools')
9 | logger_GeneralTools.setLevel(logging.ERROR)
10 |
11 |
12 | def path_reader(filename):
13 | with open(filename) as f:
14 | path_list = f.read().splitlines()
15 | return path_list
16 |
17 |
18 | def unpickle(file_path):
19 | with open(file_path, 'rb') as cPickle_file:
20 | a = cPickle.load(cPickle_file)
21 | return a
22 |
23 |
24 | # find all files of a type under a directory, recursive
25 | def load_wavPhn(rootDir):
26 | wavs = loadWavs(rootDir)
27 | phns = loadPhns(rootDir)
28 | return wavs, phns
29 |
30 |
31 | def loadWavs(rootDir):
32 | wav_files = []
33 | for dirpath, dirs, files in os.walk(rootDir):
34 | for f in files:
35 | if (f.lower().endswith(".wav")):
36 | wav_files.append(os.path.join(dirpath, f))
37 | return sorted(wav_files)
38 |
39 |
40 | def loadPhns(rootDir):
41 | phn_files = []
42 | for dirpath, dirs, files in os.walk(rootDir):
43 | for f in files:
44 | if (f.lower().endswith(".phn")):
45 | phn_files.append(os.path.join(dirpath, f))
46 | return sorted(phn_files)
47 |
48 |
49 | def pad_sequences_X(sequences, maxlen=None, padding='post', truncating='post', value=0.):
50 | """
51 | Pad each sequence to the same length:
52 | the length of the longuest sequence.
53 | If maxlen is provided, any sequence longer
54 | than maxlen is truncated to maxlen. Truncation happens off either the beginning (default) or
55 | the end of the sequence.
56 | Supports post-padding and pre-padding (default).
57 | """
58 | lengths = [len(s) for s in sequences]
59 |
60 | nb_samples = len(sequences)
61 | if maxlen is None:
62 | maxlen = np.max(lengths)
63 |
64 | # try-except to distinguish between X and y
65 | datatype = type(sequences[0][0][0]);
66 | logger_GeneralTools.debug('X data: %s, %s, %s', type(sequences[0][0]), sequences[0][0].shape, sequences[0][0])
67 |
68 | xSize = nb_samples;
69 | ySize = maxlen;
70 | zSize = sequences[0].shape[1];
71 | # sequences = [[np.reshape(subsequence, (subsequence.shape[0], 1)) for subsequence in sequence] for sequence in sequences]
72 |
73 | logger_GeneralTools.debug('new dimensions: %s, %s, %s', xSize, ySize, zSize)
74 | logger_GeneralTools.debug('intermediate matrix, estimated_size: %s',
75 | xSize * ySize * zSize * np.dtype(datatype).itemsize)
76 |
77 | x = (np.ones((xSize, ySize, zSize)) * value).astype(datatype)
78 |
79 | for idx, s in enumerate(sequences):
80 | if truncating == 'pre':
81 | trunc = s[-maxlen:]
82 | elif truncating == 'post':
83 | trunc = s[:maxlen]
84 | else:
85 | raise ValueError("Truncating type '%s' not understood" % padding)
86 |
87 | if padding == 'post':
88 | x[idx, :len(trunc), :] = trunc
89 | elif padding == 'pre':
90 | x[idx, -len(trunc):, :] = np.array(trunc, dtype='float32')
91 | else:
92 | raise ValueError("Padding type '%s' not understood" % padding)
93 |
94 | return x
95 |
96 |
97 | def pad_sequences_y(sequences, maxlen=None, padding='post', truncating='post', value=0.):
98 | """
99 | Pad each sequence to the same length:
100 | the length of the longuest sequence.
101 | If maxlen is provided, any sequence longer
102 | than maxlen is truncated to maxlen. Truncation happens off either the beginning (default) or
103 | the end of the sequence.
104 | Supports post-padding and pre-padding (default).
105 | """
106 | lengths = [len(s) for s in sequences]
107 |
108 | nb_samples = len(sequences)
109 | if maxlen is None: maxlen = np.max(lengths)
110 |
111 | datatype = type(sequences[0][0]);
112 | logger_GeneralTools.debug('Y data: %s, %s, %s', type(sequences[0]), sequences[0].shape, sequences[0])
113 |
114 | xSize = nb_samples;
115 | ySize = maxlen;
116 | # sequences = [np.reshape(sequence, (sequence.shape[0], 1)) for sequence in sequences]
117 |
118 | logger_GeneralTools.debug('new dimensions: %s, %s', xSize, ySize)
119 | logger_GeneralTools.debug('intermediate matrix, estimated_size: %s', xSize * ySize * np.dtype(datatype).itemsize)
120 |
121 | y = (np.ones((xSize, ySize)) * value).astype(datatype)
122 |
123 | for idx, s in enumerate(sequences):
124 | if truncating == 'pre':
125 | trunc = s[-maxlen:]
126 | elif truncating == 'post':
127 | trunc = s[:maxlen]
128 | else:
129 | raise ValueError("Truncating type '%s' not understood" % padding)
130 |
131 | if padding == 'post':
132 | y[idx, :len(trunc)] = trunc
133 | elif padding == 'pre':
134 | y[idx, -len(trunc):] = np.array(trunc, dtype='float32')
135 | else:
136 | raise ValueError("Padding type '%s' not understood" % padding)
137 |
138 | return y
139 |
140 |
141 | def generate_masks(inputs, valid_frames=None, batch_size=1, max_length=1000,
142 | logger=logger_GeneralTools): # inputs = X. valid_frames = list of frames when we need to extract the phoneme
143 | ## all recurrent layers in lasagne accept a separate mask input which has shape
144 | # (batch_size, n_time_steps), which is populated such that mask[i, j] = 1 when j <= (length of sequence i) and mask[i, j] = 0 when j > (length
145 | # of sequence i). When no mask is provided, it is assumed that all sequences in the minibatch are of length n_time_steps.
146 | logger.debug("* Data information")
147 | logger.debug('%s %s', type(inputs), len(inputs))
148 | logger.debug('%s %s', type(inputs[0]), inputs[0].shape)
149 | logger.debug('%s %s', type(inputs[0][0]), inputs[0][0].shape)
150 | logger.debug('%s', type(inputs[0][0][0]))
151 |
152 | # max_input_length = max([len(inputs[i]) for i in range(len(inputs))])
153 | max_length = max([len(inputs[i]) for i in range(len(inputs))])
154 | input_dim = len(inputs[0][0])
155 |
156 | logger.debug("max_seq_len: %d", max_length)
157 | logger.debug("input_dim: %d", input_dim)
158 |
159 | # X = np.zeros((batch_size, max_input_length, input_dim))
160 | input_mask = np.zeros((batch_size, max_length), dtype='float32')
161 |
162 | for example_id in range(len(inputs)):
163 | try:
164 | if valid_frames != None:
165 | # Sometimes phonemes are so close to each other that all are mapped to last frame -> gives error
166 | if valid_frames[example_id][-1] >= max_length: valid_frames[example_id][-1] = max_length - 1
167 | if valid_frames[example_id][-2] >= max_length: valid_frames[example_id][-1] = max_length - 1
168 | if valid_frames[example_id][-3] >= max_length: valid_frames[example_id][-1] = max_length - 1
169 |
170 | input_mask[example_id, valid_frames[example_id]] = 1
171 | else:
172 | logger.warning("NO VALID FRAMES SPECIFIED!!!")
173 | # raise Exception("NO VALID FRAMES SPECIFIED!!!")
174 |
175 | logger.debug('%d', example_id)
176 | curr_seq_len = len(inputs[example_id])
177 | logger.debug('%d', curr_seq_len)
178 | input_mask[example_id, :curr_seq_len] = 1
179 | except Exception as e:
180 | print("Couldn't do it: %s" % e)
181 | import pdb;
182 | pdb.set_trace()
183 |
184 | return input_mask
185 |
186 |
187 | def query_yes_no(question, default="yes"):
188 | """Ask a yes/no question via raw_input() and return their answer.
189 |
190 | "question" is a string that is presented to the user.
191 | "default" is the presumed answer if the user just hits .
192 | It must be "yes" (the default), "no" or None (meaning
193 | an answer is required of the user).
194 |
195 | The "answer" return value is True for "yes" or False for "no".
196 | """
197 | valid = {
198 | "yes": True, "y": True, "ye": True,
199 | "no": False, "n": False
200 | }
201 | if default is None:
202 | prompt = " [y/n] "
203 | elif default == "yes":
204 | prompt = " [Y/n] "
205 | elif default == "no":
206 | prompt = " [y/N] "
207 | else:
208 | raise ValueError("invalid default answer: '%s'" % default)
209 |
210 | while True:
211 | sys.stdout.write(question + prompt)
212 | choice = raw_input().lower()
213 | if default is not None and choice == '':
214 | return valid[default]
215 | elif choice in valid:
216 | return valid[choice]
217 | else:
218 | sys.stdout.write("Please respond with 'yes' or 'no' "
219 | "(or 'y' or 'n').\n")
220 |
221 |
222 | def saveToPkl(target_path, data): # data can be list or dictionary
223 | if not os.path.exists(os.path.dirname(target_path)):
224 | os.makedirs(os.path.dirname(target_path))
225 | with open(target_path, 'wb') as cPickle_file:
226 | cPickle.dump(
227 | data,
228 | cPickle_file,
229 | protocol=cPickle.HIGHEST_PROTOCOL)
230 | return 0
231 |
232 |
233 | def depth(path):
234 | return path.count(os.sep)
235 |
236 |
237 | # stuff for getting relative paths between two directories
238 | def pathsplit(p, rest=[]):
239 | (h, t) = os.path.split(p)
240 | if len(h) < 1: return [t] + rest
241 | if len(t) < 1: return [h] + rest
242 | return pathsplit(h, [t] + rest)
243 |
244 |
245 | def commonpath(l1, l2, common=[]):
246 | if len(l1) < 1: return (common, l1, l2)
247 | if len(l2) < 1: return (common, l1, l2)
248 | if l1[0] != l2[0]: return (common, l1, l2)
249 | return commonpath(l1[1:], l2[1:], common + [l1[0]])
250 |
251 |
252 | # p1 = main path, p2= the one you want to get the relative path of
253 | def relpath(p1, p2):
254 | (common, l1, l2) = commonpath(pathsplit(p1), pathsplit(p2))
255 | p = []
256 | if len(l1) > 0:
257 | p = ['../' * len(l1)]
258 | p = p + l2
259 | return os.path.join(*p)
260 |
261 |
262 | # need this to traverse directories, find depth
263 | def directories(root):
264 | dirList = []
265 | for path, folders, files in os.walk(root):
266 | for name in folders:
267 | dirList.append(os.path.join(path, name))
268 | return sorted(dirList)
269 |
270 |
271 | def tryint(s):
272 | try:
273 | return int(s)
274 | except ValueError:
275 | return s
276 |
277 |
278 | def alphanum_key(s):
279 | return [tryint(c) for c in re.split('([0-9]+)', s)]
280 |
281 |
282 | def sort_nicely(l):
283 | return sorted(l, key=alphanum_key)
284 |
285 |
286 | def set_type(X, type):
287 | for i in range(len(X)):
288 | X[i] = X[i].astype(type)
289 | return X
290 |
--------------------------------------------------------------------------------
/tools/general_tools.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/general_tools.pyc
--------------------------------------------------------------------------------
/tools/helpFunctions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/helpFunctions/__init__.py
--------------------------------------------------------------------------------
/tools/helpFunctions/copyFilesOfType.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import os.path
5 | import re
6 | import shutil
7 | import sys
8 |
9 | protocolPattern = re.compile('r&apos^\w+://&apos')
10 |
11 |
12 | def pathsplit(path):
13 | """ This version, in contrast to the original version, permits trailing
14 | slashes in the pathname (in the event that it is a directory).
15 | It also uses no recursion """
16 | return path.split(os.path.sep)
17 |
18 |
19 | def commonpath(l1, l2, common=[]):
20 | if len(l1) < 1: return (common, l1, l2)
21 | if len(l2) < 1: return (common, l1, l2)
22 | if l1[0] != l2[0]: return (common, l1, l2)
23 | return commonpath(l1[1:], l2[1:], common + [l1[0]])
24 |
25 |
26 | def relpath(p1, p2):
27 | (common, l1, l2) = commonpath(pathsplit(p1), pathsplit(p2))
28 | p = []
29 | if len(l1) > 0:
30 | p = ['../' * len(l1)]
31 | p = p + l2
32 | return os.path.join(*p)
33 |
34 |
35 | def isabs(string):
36 | if protocolPattern.match(string): return 1
37 | return os.path.isabs(string)
38 |
39 |
40 | def rel2abs(path, base=os.curdir):
41 | if isabs(path): return path
42 | retval = os.path.join(base, path)
43 | return os.path.abspath(retval)
44 |
45 |
46 | def abs2rel(path, base=os.curdir): # return a relative path from base to path.
47 | if protocolPattern.match(path): return path
48 | base = rel2abs(base)
49 | path = rel2abs(path) # redundant - should already be absolute
50 | return relpath(base, path)
51 |
52 |
53 | def test(p1, p2):
54 | print("from", p1, "to", p2, " -> ",
55 | relpath(p1, p2)) # this is what I need. p1 = AbsDirPath; p2 = AbsfilePath; out = filepathRelToDir
56 | print("from", p1, "to", p2, " -> ", rel2abs(p1, p2))
57 | print("from", p1, "to", p2, " -> ", abs2rel(p1, p2))
58 |
59 |
60 | def query_yes_no(question, default="yes"):
61 | """Ask a yes/no question via raw_input() and return their answer.
62 |
63 | "question" is a string that is presented to the user.
64 | "default" is the presumed answer if the user just hits .
65 | It must be "yes" (the default), "no" or None (meaning
66 | an answer is required of the user).
67 |
68 | The "answer" return value is True for "yes" or False for "no".
69 | """
70 | valid = {
71 | "yes": True, "y": True, "ye": True,
72 | "no": False, "n": False
73 | }
74 | if default is None:
75 | prompt = " [y/n] "
76 | elif default == "yes":
77 | prompt = " [Y/n] "
78 | elif default == "no":
79 | prompt = " [y/N] "
80 | else:
81 | raise ValueError("invalid default answer: '%s'" % default)
82 |
83 | while True:
84 | sys.stdout.write(question + prompt)
85 | choice = raw_input().lower()
86 | if default is not None and choice == '':
87 | return valid[default]
88 | elif choice in valid:
89 | return valid[choice]
90 | else:
91 | sys.stdout.write("Please respond with 'yes' or 'no' "
92 | "(or 'y' or 'n').\n")
93 |
94 |
95 | def copyFilesOfType(srcDir, dstDir, extension, interactive=False):
96 | print("Source Dir: %s, Destination Dir: %s, Extension: %s" % (srcDir, dstDir, extension))
97 |
98 | src = []
99 | dest = []
100 | for root, dirs, files in os.walk(srcDir):
101 | for file_ in files:
102 | # print(file_)
103 | if file_.lower().endswith(extension):
104 | srcPath = os.path.join(root, file_)
105 | relSrcPath = relpath(srcDir, srcPath).lstrip("../")
106 | # print(relSrcPath)
107 | destPath = os.path.join(dstDir, relSrcPath)
108 | # print("copying from : %s to \t\t %s" % (srcPath, destPath))
109 | src.append(srcPath)
110 | dest.append(destPath)
111 |
112 | print("Example: copying ", src[0], "to:", dest[0])
113 | print(len(src), " files will be copied in total")
114 |
115 | if (interactive and (not query_yes_no("Are you sure you want to peform these operations?", "yes"))):
116 | print("Not doing a thing.")
117 | else:
118 | for i in range(len(src)):
119 | if (not os.path.exists(os.path.dirname(dest[i]))):
120 | os.makedirs(os.path.dirname(dest[i]))
121 | shutil.copy(src[i], dest[i])
122 | print("Done.")
123 |
124 | return 0
125 |
126 |
127 | if __name__ == '__main__':
128 | if __name__ == '__main__':
129 | srcDir = sys.argv[1]
130 | dstDir = sys.argv[2]
131 | type = sys.argv[3]
132 | copyFilesOfType(srcDir, dstDir, type, interactive=True)
133 |
134 | # example usage: python copyFilesOfType.py ~/Documents/Dropbox/_MyDocs/_ku_leuven/Master_2/Thesis/convNets/code /media/matthijs/TOSHIBA_EXT/TCDTIMIT/zzzNPZmodels ".npz"
135 |
--------------------------------------------------------------------------------
/tools/helpFunctions/progress_bar.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | # def update_progress(amtDone):
5 | # sys.stdout.write("\rProgress: [{0:50s}] {1:.1f}%".format('#' * int(amtDone * 50), amtDone * 100))
6 | # sys.stdout.flush()
7 |
8 | def show_progress(frac_done, bar_length=20):
9 | # for i in range(end_val):
10 | hashes = '#' * int(round(frac_done * bar_length))
11 | spaces = ' ' * (bar_length - len(hashes))
12 | sys.stdout.write("\rProgress: [{0}] {1}% ".format(hashes + spaces, int(round(frac_done * 100))))
13 | sys.stdout.flush()
14 |
15 |
16 | if __name__ == '__main__':
17 | show_progress(0.8)
18 |
--------------------------------------------------------------------------------
/tools/helpFunctions/removeEmptyDirs.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | '''
3 | Module to remove empty folders recursively. Can be used as standalone script or be imported into existing script.
4 | '''
5 |
6 | import os
7 | import sys
8 |
9 |
10 | def removeEmptyFolders(path, removeRoot=True):
11 | 'Function to remove empty folders'
12 | if not os.path.isdir(path):
13 | return
14 |
15 | # remove empty subfolders
16 | files = os.listdir(path)
17 | if len(files):
18 | for f in files:
19 | fullpath = os.path.join(path, f)
20 | if os.path.isdir(fullpath):
21 | removeEmptyFolders(fullpath)
22 |
23 | # if folder empty, delete it
24 | files = os.listdir(path)
25 | if len(files) == 0 and removeRoot:
26 | print "Removing empty folder:", path
27 | os.rmdir(path)
28 |
29 |
30 | def usageString():
31 | 'Return usage string to be output in error cases'
32 | return 'Usage: %s directory [removeRoot]' % sys.argv[0]
33 |
34 |
35 | if __name__ == "__main__":
36 | removeRoot = False
37 |
38 | if len(sys.argv) < 1:
39 | print "Not enough arguments"
40 | sys.exit(usageString())
41 |
42 | if not os.path.isdir(sys.argv[1]):
43 | print "No such directory %s" % sys.argv[1]
44 | sys.exit(usageString())
45 |
46 | if len(sys.argv) == 2 and sys.argv[2] == "True":
47 | removeRoot = True
48 |
49 | removeEmptyFolders(sys.argv[1], removeRoot)
50 |
--------------------------------------------------------------------------------
/tools/helpFunctions/resample.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Apr 7, 2011
3 | by Uri Nieto
4 | uri@urinieto.com
5 |
6 | Sample Rate converter from 48kHz to 44.1kHz.
7 |
8 | USAGE:
9 | $>python resample.py -i input.wav [-o output.wav -q [0.0-1.0]]
10 |
11 | EXAMPLES:
12 | $>python resample.py -i onades.wav
13 | $>python resample.py -i onades.wav -o out.wav
14 | $>python resample.py -i onades-mono.wav -q 0.8
15 | $>python resample.py -i onades.wav -o out3.wav -q 0.5
16 |
17 | DESCRIPTION:
18 | The input has to be a WAV file sampled at 48 kHz/sec
19 | with a resolution of 16 bits/sample. It can have n>0
20 | number of channels (i.e. 1=mono, 2=stereo, ...).
21 |
22 | The output will be a WAV file sampled at 44.1 kHz/sec
23 | with a resolution of 16 bits/sample, and the same
24 | number of channels as the input.
25 |
26 | A quality parameter q can be provided (>0.0 to 1.0), and
27 | it will modify the number of zero crossings of the filter,
28 | making the output quality best when q = 1.0 and very bad
29 | as q tends to 0.0
30 |
31 | The sample rate factor is:
32 |
33 | 44100 147
34 | ------- = -----
35 | 48000 160
36 |
37 | To do the conversion, we upsample by 147, low pass filter,
38 | and downsample by 160 (in this order). This is done by
39 | using an efficient polyphase filter bank with resampling
40 | algorithm proposed by Vaidyanathan in [2].
41 |
42 | The low pass filter is an impulse response windowed
43 | by a Kaiser window to have a better filtering
44 | (around -60dB in the rejection band) [1].
45 |
46 | As a comparison between the Kaiser Window and the
47 | Rectangular Window, this algorithm plotted the following
48 | images included with this package:
49 |
50 | KaiserIR.png
51 | KaiserFR.png
52 | RectIR.png
53 | RectFR.png
54 |
55 | The images show the Impulse Responses and the Frequency
56 | Responses of the Kaiser and Rectangular Windows. As it can be
57 | clearly seen, the Kaiser window has a gain of around -60dB in the
58 | rejection band, whereas the Rect window has a gain of around -20dB
59 | and smoothly decreasing to -40dB. Thus, the Kaiser window
60 | method is rejecting the aliasing much better than the Rect window.
61 |
62 | The Filter Design is performed in the function:
63 | designFIR()
64 |
65 | The Upsampling, Filtering, and Downsampling is
66 | performed in the function:
67 | upSampleFilterDownSample()
68 |
69 | Also included in the package are two wav files sampled at 48kHz
70 | with 16bits/sample resolution. One is stereo and the other mono:
71 |
72 | onades.wav
73 | onades-mono.wav
74 |
75 | NOTES:
76 | You need numpy and scipy installed to run this script.
77 | You can find them here:
78 | http://numpy.scipy.org/
79 |
80 | You may want to have matplotlib as well if you want to
81 | print the plots by yourself (commented right now)
82 |
83 | This code would be much faster on C or C++, but my decision
84 | on using Python was to make the code more readable (yet useful)
85 | rather than focusing on performance.
86 |
87 | @author: uri
88 |
89 | REFERENCES:
90 | [1]: Smith, J.O., "Spectral Audio Signal Processing",
91 | W3K Publishing, March 2010
92 |
93 | [2] Vaidyanathan, P.P., "Multirate Systems and Filter Banks",
94 | Prentice Hall, 1993.
95 |
96 | COPYRIGHT NOTES:
97 |
98 | Copyright (C) 2011, Uri Nieto
99 |
100 | This program is free software: you can redistribute it and/or modify
101 | it under the terms of the GNU General Public License as published by
102 | the Free Software Foundation, either version 3 of the License, or
103 | any later version.
104 |
105 | This program is distributed in the hope that it will be useful,
106 | but WITHOUT ANY WARRANTY; without even the implied warranty of
107 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
108 | GNU General Public License for more details.
109 |
110 | You should have received a copy of the GNU General Public License
111 | along with this program. If not, see .
112 |
113 | '''
114 |
115 | import sys
116 | import time as tm
117 | from decimal import Decimal
118 | from fractions import Fraction
119 | from shutil import copyfile
120 |
121 | import numpy as np
122 | from scipy.io import wavfile
123 |
124 | # import matplotlib.pyplot as plt #Uncomment to plot
125 |
126 |
127 | '''
128 | h = flipTranspose(h, L, cpp)
129 | ...
130 | Desc: Flips and Transposes the impulse response h, dividing
131 | it into L different phases with cpp coefficients per phase.
132 | Following as described in [2].
133 | ...
134 | h: Impulse response to flip and Transpose
135 | L: Upsampling Factor
136 | cpp: Coeficients per Phase
137 | return hh: h flipped and transposed following the descritpion
138 | '''
139 |
140 |
141 | def flipTranspose(h, L, cpp):
142 | # Get the impulse response size
143 | N = len(h)
144 |
145 | # Init the output to 0
146 | hh = np.zeros(N)
147 |
148 | # Flip and Transpose:
149 | for i in range(L):
150 | hh[cpp - 1 + i * cpp:-N - 1 + i * cpp:-1] = h[i:cpp * L:L]
151 |
152 | return hh
153 |
154 |
155 | '''
156 | h = upSampleFilterDownSample(x, h, L, M)
157 | ...
158 | Desc: Upsamples the input x by L, filters it out using h, and
159 | downsamples it by M.
160 |
161 | The algorithm is based on the "efficient polyphase filter bank
162 | with resampling" found on page 129 of the book [2] (Figure 4.3-8d).
163 |
164 | ...
165 | x: input signal
166 | h: impulse response (assumes it has the correct cut-off freq)
167 | L: Upsampling Factor
168 | M: Downsampling Factor
169 | returns y: output signal (x upsampled, filtered, and downsampled)
170 | '''
171 |
172 |
173 | def upSampleFilterDownSample(x, h, L, M, printing=False):
174 | # Number of samples to convert
175 | N = len(x)
176 |
177 | # Compute the number of coefficients per phase
178 | cpp = len(h) / L
179 |
180 | # Flip and Transpose the impulse response
181 | h = flipTranspose(h, L, cpp)
182 |
183 | # Check number of channels
184 | if (np.shape(np.shape(x)) == (2,)):
185 | nchan = np.shape(x)[1]
186 | y = np.zeros(int((np.ceil(N * L / float(M)), nchan)))
187 | else:
188 | nchan = 1
189 | y = np.zeros(int(np.ceil(N * L / float(M))))
190 |
191 | # Init the output index
192 | y_i = 0
193 |
194 | # Init the phase index
195 | phase_i = 0
196 |
197 | # Init the main loop index
198 | i = 0
199 |
200 | # Main Loop
201 | while i < N:
202 |
203 | # Print % completed
204 | if (printing and (i % 30000 == 0)):
205 | print("%.2f %% completed" % float(100 * i / float(len(x))))
206 |
207 | # Compute the filter index
208 | h_i = phase_i * cpp
209 |
210 | # Compute the input index
211 | x_i = i - cpp + 1;
212 |
213 | # Update impulse index if needed (offset)
214 | if x_i < 0:
215 | h_i -= x_i
216 | x_i = 0
217 |
218 | # Compute the current output sample
219 | rang = i - x_i + 1
220 | if nchan == 1:
221 | y[y_i] = np.sum(x[x_i:x_i + rang] * h[h_i:h_i + rang])
222 | else:
223 | for c in range(nchan):
224 | y[y_i, c] = np.sum(x[x_i:x_i + rang, c] * h[h_i:h_i + rang])
225 |
226 | # Add the downsampling factor to the phase index
227 | phase_i += M
228 |
229 | # Compute the increment for the index of x with the new phase
230 | x_incr = phase_i / int(L)
231 |
232 | # Update phase index
233 | phase_i %= L
234 |
235 | # Update the main loop index
236 | i += x_incr
237 |
238 | # Update the output index
239 | y_i += 1
240 |
241 | return y
242 |
243 |
244 | '''
245 | h = impulse(M, L)
246 | ...
247 | M: Impulse Response Size
248 | T: Sampling Period
249 | returns h: The impulse response
250 | '''
251 |
252 |
253 | def impulse(M, T):
254 | # Create time array
255 | n = np.arange(-(M - 1) / 2, (M - 1) / 2 + 1)
256 |
257 | # Compute the impulse response using the sinc function
258 | h = (1 / T) * np.sinc((1 / T) * n)
259 |
260 | return h
261 |
262 |
263 | '''
264 | b = bessel(x)
265 | ...
266 | Desc: Zero-order modified Bessel function of the first kind, with
267 | approximation using the Maclaurin series, as described in [1]
268 | ...
269 | x: Input sample
270 | b: Zero-order modified Bessel function of the first kind
271 | '''
272 |
273 |
274 | def bessel(x):
275 | return np.power(np.exp(x), 2);
276 |
277 |
278 | '''
279 | k = kaiser(M, beta)
280 | ...
281 | Desc: Generates an M length Kaiser window with the
282 | specified beta parameter. Following instructions in [1]
283 | ...
284 | M: Number of samples of the window
285 | beta: Beta parameter of the Kaiser Window
286 | k: array(M,1) containing the Kaiser window with the specified beta
287 | '''
288 |
289 |
290 | def kaiser(M, beta):
291 | # Init Kaiser Window
292 | k = np.zeros(M)
293 |
294 | # Compute each sample of the Kaiser Window
295 | i = 0
296 | for n in np.arange(-(M - 1) / 2, (M - 1) / 2 + 1):
297 | samp = beta * np.sqrt(1 - np.power((n / (M / 2.0)), 2))
298 | samp = bessel(samp) / float(bessel(beta))
299 | k[i] = samp
300 | i = i + 1
301 |
302 | return k
303 |
304 |
305 | '''
306 | h = designFIR(N, L, M)
307 | ...
308 | Desc: Designs a low pass filter to perform the conversion of
309 | sampling frequencies given the upsampling and downsampling factors.
310 | It uses the Kaiser window to better filter out aliasing.
311 | ...
312 | N: Maximum size of the Impulse Response of the FIR
313 | L: Upsampling Factor
314 | M: Downsampling Factor
315 | returns h: Impulse Response of the FIR
316 | '''
317 |
318 |
319 | def designFIR(N, L, M):
320 | # Get the impulse response with the right Sampling Period
321 | h0 = impulse(N, float(M))
322 |
323 | # Compute a Kaiser Window
324 | alpha = 2.5 # Alpha factor for the Kaiser Window
325 | k = kaiser(N, alpha * np.pi)
326 |
327 | # Window the impulse response with the Kaiser window
328 | h = h0 * k
329 |
330 | # Filter Gain
331 | h = h * L
332 |
333 | # Reduce window by removing almost 0 values to improve filtering
334 | for i in range(len(h)):
335 | if abs(h[i]) > 1e-3:
336 | for j in range(i, 0, -1):
337 | if abs(h[j]) < 1e-7:
338 | h = h[j:len(h) - j]
339 | break
340 | break
341 |
342 | '''
343 | # For plotting purposes:
344 | N = len(h)
345 | Hww = fft(h, N)
346 | Hww = Hww[0:N/2.0]
347 | Hwwdb = 20.0*np.log10(np.abs(Hww))
348 | Hw = fft(h0, N)
349 | Hw = Hw[0:N/2.0]
350 | Hwdb = 20.0*np.log10(np.abs(Hw))
351 |
352 | plt.figure(1)
353 | plt.plot(h)
354 | plt.title('Kaiser Windowed Impulse Response of the low pass filter')
355 | plt.show()
356 |
357 | plt.figure(2)
358 | plt.plot(h0)
359 | plt.title('Rect Windowd Impulse Response of the low pass filter')
360 | plt.show()
361 |
362 | #print np.shape(np.arange(0, N/2.0-1)/float(N)), np.shape(Hwwdb)
363 |
364 | plt.figure(3)
365 | plt.plot(np.arange(0, N/2.0-1)/float(N),Hwwdb)
366 | plt.xlabel('Normalized Frequency');
367 | plt.ylabel('Magnitude (dB)');
368 | plt.title('Amplitude Response, Kaiser window with beta = ' + str(2.5*np.pi));
369 | plt.show()
370 |
371 | plt.figure(4)
372 | plt.plot(np.arange(0, N/2.0-1)/float(N),Hwdb)
373 | plt.xlabel('Normalized Frequency');
374 | plt.ylabel('Magnitude (dB)');
375 | plt.title('Amplitude Response using a Rect Window');
376 | plt.show()
377 | '''
378 |
379 | return h
380 |
381 |
382 | '''
383 | dieWithUsage()
384 | ...
385 | Desc: Stops program and prints usage
386 | '''
387 |
388 |
389 | def dieWithUsage():
390 | usage = """
391 | USAGE: $>python resample.py -i input.wav [-o output.wav -q (0.0-1.0]]
392 |
393 | input.wav has to be sampled at 48kHz with 16bits/sample
394 | If no output file is provided, file will be written in "./output.wav"
395 | If no quality param is provided, 1.0 (max) will be used
396 |
397 | Description: Converts sampling frequency of input.wav from 48kHz to 16kHz and writes it into output.wav
398 | """
399 | print
400 | usage
401 | sys.exit(1)
402 |
403 |
404 | '''
405 | Main
406 | ...
407 | Desc: Reads the input wave file, designs the filter,
408 | upsamples, filters, and downsamples the input, and
409 | finally writes it to the output wave file.
410 | '''
411 |
412 |
413 | def resampleWAV(inFile, outFile="output.wav", out_fr=16000.0, q=0.0):
414 | # Parse arguments
415 | inPath, outPath, out_fr, q = inFile, outFile, out_fr, q
416 |
417 | # Read input wave
418 | in_fr, in_data = wavfile.read(inPath)
419 | in_nbits = in_data.dtype
420 |
421 | # Time it
422 | start_time = tm.time()
423 |
424 | # Set output wave parameters
425 | out_nbits = in_nbits
426 |
427 | frac = Fraction(Decimal(str(float(in_fr) / out_fr))).limit_denominator(1000)
428 |
429 | if (float(frac) < 1.0):
430 | print("input file smaller sampling rate than output..")
431 | print("input: ", in_fr, "Hz. Output: ", out_fr, "Hz")
432 |
433 | elif (float(frac) == 1.0):
434 | try:
435 | copyfile(inPath, outPath)
436 | print("Input file", inPath, "already at correct sampling rate; just copying")
437 | except Exception, e:
438 | print(e.args)
439 | else:
440 | print("some weird error", inPath, outPath)
441 | return -1
442 | else:
443 | L = frac.denominator # Upsampling Factor
444 | M = frac.numerator # Downsampling Factor
445 | Nz = int(25 * q) # Max Number of Zero Crossings (depending on quality)
446 | h = designFIR(Nz * L, L, M)
447 |
448 | # Upsample, Filter, and Downsample
449 | out_data = upSampleFilterDownSample(in_data, h, L, M, printing=False) # control progression output
450 |
451 | # Make sure the output is 16 bits
452 | out_data = out_data.astype(out_nbits)
453 |
454 | # Write the output wave
455 | wavfile.write(outPath, out_fr, out_data)
456 |
457 | # Print Results
458 | duration = float(tm.time() - start_time)
459 | print(
460 | "File", outPath, " was successfully written. Resampled from ", in_fr, "to", out_fr, "in", duration,
461 | " seconds")
462 | return 0
463 |
--------------------------------------------------------------------------------
/tools/helpFunctions/resampleExperiment.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | ########## RESAMPLING ###############
3 | ######################################
4 | from resample import *
5 | from wavToPng import *
6 |
7 | inFile = "sa1.wav"
8 | outFile44 = "sa1_out44.wav"
9 | outFile20 = "sa1_out20.wav"
10 | outFile16 = "sa1_out16.wav"
11 | outFile8 = "sa1_out8.wav"
12 | outFile4 = "sa1_out4.wav"
13 | outFile16_from44 = "sa1_out16_from44.wav"
14 |
15 | resampleWAV(inFile, outFile44, out_fr=44100, q=1.0)
16 | resampleWAV(inFile, outFile20, out_fr=20000, q=1.0)
17 | resampleWAV(inFile, outFile16, out_fr=16000, q=1.0)
18 | resampleWAV(inFile, outFile16, out_fr=16000, q=1.0)
19 | resampleWAV(inFile, outFile8, out_fr=8000, q=1.0)
20 | resampleWAV(inFile, outFile4, out_fr=4000, q=1.0)
21 | resampleWAV(outFile44, outFile16_from44, out_fr=16000, q=1.0)
22 |
23 | wavToPng("sa1.wav")
24 |
25 | ######################################
26 | ############## Fractions #############
27 | ######################################
28 | # from fractions import Fraction
29 | # from decimal import Decimal
30 | # print Fraction(Decimal('1.4'))
31 | #
32 | # a= Fraction(Decimal(str(48000/44100.0))).limit_denominator(1000)
33 | # print a.numerator
34 | # print a.denominator
35 | # print(type(a.numerator))
36 |
--------------------------------------------------------------------------------
/tools/helpFunctions/sa1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthijsvk/TIMITspeech/4294fe4af760d19dc807c4e01d01d07662ff7bde/tools/helpFunctions/sa1.wav
--------------------------------------------------------------------------------
/tools/helpFunctions/visualizeMFC.m:
--------------------------------------------------------------------------------
1 | % EXAMPLE Simple demo of the MFCC function usage.
2 | %
3 | % This script is a step by step walk-through of computation of the
4 | % mel frequency cepstral coefficients (MFCCs) from a speech signal
5 | % using the MFCC routine.
6 | %
7 | % See also MFCC, COMPARE.
8 |
9 | % Author: Kamil Wojcicki, September 2011
10 | % see https://nl.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab
11 | % installation of toolbox:
12 | % Download zip, extract to Matlab bin/MATLAB/R2016b/toolbox
13 | % Then click 'Home', 'Set Path'; Add the 'toolbox/mfcc' folder;
14 | % Click 'Save'.
15 |
16 | format short g
17 |
18 | % Clean-up MATLAB's environment
19 | clear all; close all; clc;
20 |
21 | % Define variables
22 | Tw = 25; % analysis frame duration (ms)
23 | Ts = 10; % analysis frame shift (ms)
24 | alpha = 0.97; % preemphasis coefficient
25 | M = 20; % number of filterbank channels
26 | C = 12; % number of cepstral coefficients
27 | L = 22; % cepstral sine lifter parameter
28 | LF = 300; % lower frequency limit (Hz)
29 | HF = 3700; % upper frequency limit (Hz)
30 | wav_file = 'si650.wav'; % input audio filename
31 | mfc_file = 'si650.mfc';
32 |
33 | % Read MFC file for comparison (see bottom)
34 | htkmfc = readhtk(mfc_file);
35 |
36 | % Read speech samples, sampling rate and precision from file
37 | [ speech, fs] = audioread( wav_file );
38 | info = audioinfo(wav_file);
39 | nbits = info.BitsPerSample
40 |
41 | % Feature extraction (feature vectors as columns)
42 | [ MFCCs, FBEs, frames ] = ...
43 | mfcc( speech, fs, Tw, Ts, alpha, @hamming, [LF HF], M, C+1, L );
44 |
45 | disp(MFCCs)
46 | disp(frames)
47 |
48 | % Generate data needed for plotting
49 | [ Nw, NF ] = size( frames ); % frame length and number of frames
50 | time_frames = [0:NF-1]*Ts*0.001+0.5*Nw/fs; % time vector (s) for frames
51 | time = [ 0:length(speech)-1 ]/fs; % time vector (s) for signal samples
52 | logFBEs = 20*log10( FBEs ); % compute log FBEs for plotting
53 | logFBEs_floor = max(logFBEs(:))-50; % get logFBE floor 50 dB below max
54 | logFBEs( logFBEs /20
53 | while sound2.rms > targetRMS + min_acc:
54 | sound2 -= min_acc / 20.0
55 |
56 | # print(sound1.rms, targetRMS, sound2.rms)
57 |
58 | combined = sound1.overlay(sound2, loop=True)
59 |
60 | combined.export(out_path, format='wav')
61 |
62 |
63 | def generateBadAudio(outType, srcDir, dstDir, ratio_dB):
64 | # copy phoneme files
65 | copyFilesOfType(srcDir, dstDir, ".phn")
66 |
67 | # copy merged wav files
68 | noiseFile = createNoiseFile(ratio_dB)
69 | src_wavs = loadWavs(srcDir)
70 | for i in tqdm(range(len(src_wavs))):
71 | relSrcPath = relpath(srcDir, src_wavs[i]).lstrip("../")
72 | # print(relSrcPath)
73 | destPath = os.path.join(dstDir, relSrcPath)
74 | if outType == 'voices':
75 | # index of voice to merge
76 | j = random.randint(0, len(src_wavs) - 1)
77 | mergeAudioFiles(src_wavs[i], src_wavs[j], destPath, ratio_dB)
78 | else:
79 | mergeAudioFiles(src_wavs[i], noiseFile, destPath, ratio_dB)
80 |
81 |
82 | import random
83 |
84 |
85 | def createNoiseFile(ratio_dB, noise_path='noise.wav'):
86 | rate = 16000
87 | noise = np.random.normal(0, 1, rate * 3) # generate 3 seconds of white noise
88 | wav.write(noise_path, rate, noise)
89 |
90 | # change the volume to be ~ the TIMIT volume
91 | sound = AudioSegment.from_file(noise_path);
92 | loud2 = sound.dBFS
93 | while sound.dBFS > -30 + ratio_dB:
94 | sound -= 1
95 |
96 | sound.export(noise_path, format='wav')
97 | return os.path.abspath(noise_path)
98 |
99 |
100 | if __name__ == "__main__":
101 | main()
102 |
--------------------------------------------------------------------------------
/tools/phoneme_set.py:
--------------------------------------------------------------------------------
1 | # using the 39 phone set proposed in (Lee & Hon, 1989)
2 | # Table 3. Mapping from 61 classes to 39 classes, as proposed by Lee and Hon, (Lee & Hon,
3 | # 1989). The phones in the left column are folded into the labels of the right column. The
4 | # remaining phones are left intact.
5 | import logging
6 |
7 | logger_phonemeSet = logging.getLogger('phonemeSet')
8 | logger_phonemeSet.setLevel(logging.ERROR)
9 |
10 | phoneme_set_61_39 = {
11 | 'ao': 'aa', # 1
12 | 'ax': 'ah', # 2
13 | 'ax-h': 'ah',
14 | 'axr': 'er', # 3
15 | 'hv': 'hh', # 4
16 | 'ix': 'ih', # 5
17 | 'el': 'l', # 6
18 | 'em': 'm', # 6
19 | 'en': 'n', # 7
20 | 'nx': 'n',
21 | 'eng': 'ng', # 8
22 | 'zh': 'sh', # 9
23 | "ux": "uw", # 10
24 | "pcl": "sil", # 11
25 | "tcl": "sil",
26 | "kcl": "sil",
27 | "qcl": "sil",
28 | "bcl": "sil",
29 | "dcl": "sil",
30 | "gcl": "sil",
31 | "h#": "sil",
32 | "#h": "sil",
33 | "pau": "sil",
34 | "epi": "sil",
35 | "q": "sil",
36 | }
37 |
38 | # from https://www.researchgate.net/publication/275055833_TCD-TIMIT_An_audio-visual_corpus_of_continuous_speech
39 | phoneme_set_39_list = [
40 | 'iy', 'ih', 'eh', 'ae', 'ah', 'uw', 'uh', 'aa', 'ey', 'ay', 'oy', 'aw', 'ow', # 13 phns
41 | 'l', 'r', 'y', 'w', 'er', 'm', 'n', 'ng', 'ch', 'jh', 'dh', 'b', 'd', 'dx', # 14 phns
42 | 'g', 'p', 't', 'k', 'z', 'v', 'f', 'th', 's', 'sh', 'hh', 'sil' # 12 pns
43 | ]
44 | values = [i for i in range(0, len(phoneme_set_39_list))]
45 | phoneme_set_39 = dict(zip(phoneme_set_39_list, values))
46 | classToPhoneme39 = dict((v, k) for k, v in phoneme_set_39.iteritems())
47 |
48 | # from http://www.intechopen.com/books/speech-technologies/phoneme-recognition-on-the-timit-database, page 5
49 | phoneme_set_61_list = [
50 | 'iy', 'ih', 'eh', 'ey', 'ae', 'aa', 'aw', 'ay', 'ah', 'ao', 'oy', 'ow', 'uh', 'uw', 'ux', 'er', 'ax', 'ix', 'axr',
51 | 'ax-h', 'jh',
52 | 'ch', 'b', 'd', 'g', 'p', 't', 'k', 'dx', 's', 'sh', 'z', 'zh', 'f', 'th', 'v', 'dh', 'm', 'n', 'ng', 'em', 'nx',
53 | 'en', 'eng', 'l', 'r', 'w', 'y', 'hh', 'hv', 'el', 'bcl', 'dcl', 'gcl', 'pcl', 'tcl', 'kcl', 'q', 'pau', 'epi',
54 | 'h#',
55 | ]
56 | values = [i for i in range(0, len(phoneme_set_61_list))]
57 | phoneme_set_61 = dict(zip(phoneme_set_61_list, values))
58 |
59 |
60 | def convertPredictions(predictions, phoneme_list=classToPhoneme39, valid_frames=None, outputType="phonemes"):
61 | # b is straight conversion to phoneme chars
62 | predictedPhonemes = [phoneme_list[predictedClass] for predictedClass in predictions]
63 |
64 | # c is reduced set of b: duplicates following each other are removed until only 1 is left
65 | reducedPhonemes = []
66 | for j in range(len(predictedPhonemes) - 1):
67 | if predictedPhonemes[j] != predictedPhonemes[j + 1]:
68 | reducedPhonemes.append(predictedPhonemes[j])
69 |
70 | # get only the outputs for valid phrames
71 | validPredictions = [predictedPhonemes[frame] for frame in valid_frames]
72 |
73 | # return class outputs
74 | if outputType != "phonemes":
75 | predictedPhonemes = [phoneme_set_39[phoneme] for phoneme in predictedPhonemes]
76 | reducedPhonemes = [phoneme_set_39[phoneme] for phoneme in reducedPhonemes]
77 | validPredictions = [phoneme_set_39[phoneme] for phoneme in validPredictions]
78 |
79 | return predictedPhonemes, reducedPhonemes, validPredictions
80 |
--------------------------------------------------------------------------------
/tools/preprocessWavs.py:
--------------------------------------------------------------------------------
1 | import timeit;
2 |
3 | import numpy as np
4 | import scipy.io.wavfile as wav
5 | from tqdm import tqdm
6 |
7 | program_start_time = timeit.default_timer()
8 | import pdb
9 | import python_speech_features
10 |
11 | from phoneme_set import phoneme_set_39_list
12 | import transform
13 |
14 | nbPhonemes = 39
15 | phoneme_set_list = phoneme_set_39_list # import list of phonemes,
16 | # convert to dictionary with number mappings (see phoneme_set.py)
17 | values = [i for i in range(0, len(phoneme_set_list))]
18 | phoneme_classes = dict(zip(phoneme_set_list, values))
19 |
20 |
21 | ## Functions ##
22 | def get_total_duration(file):
23 | """Get the length of the phoneme file, i.e. the 'time stamp' of the last phoneme"""
24 | for line in reversed(list(open(file))):
25 | [_, val, _] = line.split()
26 | return int(val)
27 |
28 |
29 | def create_mfcc(method, filename, type=2):
30 | """Perform standard preprocessing, as described by Alex Graves (2012)
31 | http://www.cs.toronto.edu/~graves/preprint.pdf
32 | Output consists of 12 MFCC and 1 energy, as well as the first derivative of these.
33 | [1 energy, 12 MFCC, 1 diff(energy), 12 diff(MFCC)
34 |
35 | method is a dummy input!!"""
36 |
37 | (rate, sample) = wav.read(filename)
38 |
39 | mfcc = python_speech_features.mfcc(sample, rate, winlen=0.025, winstep=0.01, numcep=13, nfilt=26,
40 | preemph=0.97, appendEnergy=True)
41 | out = mfcc
42 | if type > 13:
43 | derivative = np.zeros(mfcc.shape)
44 | for i in range(1, mfcc.shape[0] - 1):
45 | derivative[i, :] = mfcc[i + 1, :] - mfcc[i - 1, :]
46 |
47 | mfcc_derivative = np.concatenate((mfcc, derivative), axis=1)
48 | out = mfcc_derivative
49 | if type > 26:
50 | derivative2 = np.zeros(derivative.shape)
51 | for i in range(1, derivative.shape[0] - 1):
52 | derivative2[i, :] = derivative[i + 1, :] - derivative[i - 1, :]
53 |
54 | out = np.concatenate((mfcc, derivative, derivative2), axis=1)
55 | if type > 39:
56 | derivative3 = np.zeros(derivative2.shape)
57 | for i in range(1, derivative2.shape[0] - 1):
58 | derivative3[i, :] = derivative2[i + 1, :] - derivative2[i - 1, :]
59 |
60 | out = np.concatenate((mfcc, derivative, derivative2, derivative3), axis=1)
61 |
62 | return out, out.shape[0]
63 |
64 |
65 | def calc_norm_param(X):
66 | """Assumes X to be a list of arrays (of differing sizes)"""
67 | total_len = 0
68 | mean_val = np.zeros(X[0].shape[1])
69 | std_val = np.zeros(X[0].shape[1])
70 | for obs in X:
71 | obs_len = obs.shape[0]
72 | mean_val += np.mean(obs, axis=0) * obs_len
73 | std_val += np.std(obs, axis=0) * obs_len
74 | total_len += obs_len
75 |
76 | mean_val /= total_len
77 | std_val /= total_len
78 |
79 | return mean_val, std_val, total_len
80 |
81 |
82 | def normalize(X, mean_val, std_val):
83 | for i in range(len(X)):
84 | X[i] = (X[i] - mean_val) / std_val
85 | return X
86 |
87 |
88 | def set_type(X, type):
89 | for i in range(len(X)):
90 | X[i] = X[i].astype(type)
91 | return X
92 |
93 |
94 | def preprocess_dataset(source_path, nbMFCCs=39, logger=None, debug=None, verbose=False):
95 | """Preprocess data, ignoring compressed files and files starting with 'SA'"""
96 | X = []
97 | y = []
98 | valid_frames = []
99 |
100 | # source_path is the root dir of all the wav/phn files
101 | wav_files = transform.loadWavs(source_path)
102 | label_files = transform.loadPhns(source_path)
103 |
104 | logger.debug("Found %d WAV files" % len(wav_files))
105 | logger.debug("Found %d PHN files" % len(label_files))
106 | assert len(wav_files) == len(label_files)
107 | assert len(wav_files) != 0
108 |
109 | processed = 0
110 | for i in tqdm(range(len(wav_files))):
111 | phn_name = str(label_files[i])
112 | wav_name = str(wav_files[i])
113 |
114 | if (wav_name.startswith("SA")): # specific for TIMIT: these files contain strong dialects; don't use them
115 | continue
116 |
117 | # Get MFCC of the WAV
118 | X_val, total_frames = create_mfcc('DUMMY', wav_name,
119 | nbMFCCs) # get 3 levels: 0th, 1st and 2nd derivative (=> 3*13 = 39 coefficients)
120 | total_frames = int(total_frames)
121 |
122 | X.append(X_val)
123 |
124 | # Get phonemes and valid frame numbers out of .phn files
125 | total_duration = get_total_duration(phn_name)
126 | fr = open(phn_name)
127 |
128 | # some .PHN files don't start at 0. Set default phoneme to silence (expected at the end of phoneme_set_list)
129 | y_vals = np.zeros(total_frames) + phoneme_classes[phoneme_set_list[-1]]
130 | valid_frames_vals = []
131 |
132 | for line in fr:
133 | [start_time, end_time, phoneme] = line.rstrip('\n').split()
134 | start_time = int(start_time)
135 | end_time = int(end_time)
136 | start_ind = int(np.round(start_time / (total_duration / total_frames)))
137 | end_ind = int(np.round(end_time / (total_duration / total_frames)))
138 |
139 | valid_ind = int((start_ind + end_ind) / 2)
140 | valid_frames_vals.append(valid_ind)
141 |
142 | phoneme_num = phoneme_classes[phoneme]
143 | # check that phoneme is found in dict
144 | if (phoneme_num == -1):
145 | logger.error("In file: %s, phoneme not found: %s", phn_name, phoneme)
146 | pdb.set_trace()
147 | y_vals[start_ind:end_ind] = phoneme_num
148 |
149 | if verbose:
150 | logger.debug('%s', (total_frames / float(total_duration)))
151 | logger.debug('TIME start: %s end: %s, phoneme: %s, class: %s', start_time, end_time, phoneme,
152 | phoneme_num)
153 | logger.debug('FRAME start: %s end: %s, phoneme: %s, class: %s', start_ind, end_ind, phoneme,
154 | phoneme_num)
155 | fr.close()
156 |
157 | # append the target array to our y
158 | y.append(y_vals.astype('int32'))
159 |
160 | # append the valid_frames array to our valid_frames
161 | valid_frames_vals = np.array(valid_frames_vals)
162 | valid_frames.append(valid_frames_vals.astype('int32'))
163 |
164 | if verbose:
165 | logger.debug('(%s) create_target_vector: %s', i, phn_name[:-4])
166 | logger.debug('type(X_val): \t\t %s', type(X_val))
167 | logger.debug('X_val.shape: \t\t %s', X_val.shape)
168 | logger.debug('type(X_val[0][0]):\t %s', type(X_val[0][0]))
169 |
170 | logger.debug('type(y_val): \t\t %s', type(y_vals))
171 | logger.debug('y_val.shape: \t\t %s', y_vals.shape)
172 | logger.debug('type(y_val[0]):\t %s', type(y_vals[0]))
173 | logger.debug('y_val: \t\t %s', (y_vals))
174 |
175 | processed += 1
176 | if debug != None and processed >= debug:
177 | break
178 |
179 | return X, y, valid_frames
180 |
181 |
182 | def preprocess_unlabeled_dataset(source_path, nbMFCCs=39, verbose=False, logger=None): # TODO
183 | wav_files = transform.loadWavs(source_path)
184 | logger.debug("Found %d WAV files" % len(wav_files))
185 | assert len(wav_files) != 0
186 |
187 | X = []
188 | for i in tqdm(range(len(wav_files))):
189 | wav_name = str(wav_files[i])
190 | X_val, total_frames = create_mfcc('DUMMY', wav_name, nbMFCCs)
191 | X.append(X_val)
192 |
193 | if verbose:
194 | logger.debug('type(X_val): \t\t %s', type(X_val))
195 | logger.debug('X_val.shape: \t\t %s', X_val.shape)
196 | logger.debug('type(X_val[0][0]):\t %s', type(X_val[0][0]))
197 | return X
198 |
--------------------------------------------------------------------------------
/tools/transform.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import os
4 | import subprocess
5 | import sys
6 |
7 | from tqdm import tqdm
8 |
9 | from helpFunctions import resample
10 | from helpFunctions.writeToTxt import writeToTxt
11 | from phoneme_set import phoneme_set_39, phoneme_set_61_39
12 |
13 | debug = False
14 |
15 |
16 | # input dir should be the dir just above 'TRAIN' and 'TEST'. It expects
17 | # example usage: python transform.py phonemes -i /home/matthijs/TCDTIMIT/TIMIT/original/TIMIT/ -o /home/matthijs/TCDTIMIT/TIMIT/processed
18 |
19 |
20 | ##### Load Data ######
21 | # find all files of a type under a directory, recursive
22 | def load_wavPhn(rootDir):
23 | wavs = loadWavs(rootDir)
24 | phns = loadPhns(rootDir)
25 | return wavs, phns
26 |
27 |
28 | def loadWavs(rootDir):
29 | wav_files = []
30 | for dirpath, dirs, files in os.walk(rootDir):
31 | for f in files:
32 | if (f.lower().endswith(".wav")):
33 | wav_files.append(os.path.join(dirpath, f))
34 | return sorted(wav_files)
35 |
36 |
37 | def loadPhns(rootDir):
38 | phn_files = []
39 | for dirpath, dirs, files in os.walk(rootDir):
40 | for f in files:
41 | if (f.lower().endswith(".phn")):
42 | phn_files.append(os.path.join(dirpath, f))
43 | return sorted(phn_files)
44 |
45 |
46 | # generates for example: dstDir/TIMIT/TRAIN/DR2/MTAT1/SX59.PHN'
47 | # from srcPath = someDir/TIMIT/TRAIN/DR2/MTAT1/SX239.PHN')
48 | def getDestPath(srcPath, dstDir):
49 | filename = os.path.basename(srcPath)
50 |
51 | speakerPath = os.path.dirname(srcPath)
52 | speaker = os.path.basename(speakerPath)
53 |
54 | regionPath = os.path.dirname(speakerPath)
55 | region = os.path.basename(regionPath)
56 |
57 | setPath = os.path.dirname(regionPath)
58 | set = os.path.basename(setPath)
59 |
60 | timitPath = os.path.dirname(setPath)
61 | timit = os.path.basename(timitPath)
62 |
63 | dstPath = os.path.join(dstDir, timit, set, region, speaker, filename)
64 | return dstPath
65 |
66 |
67 | ### TRANSFORM FUNCTIONS ###
68 | # create a wav file with NIST headers
69 | def transformWav(wav_file, dstPath):
70 | output_dir = os.path.dirname(dstPath)
71 | if not os.path.exists(output_dir):
72 | os.makedirs(output_dir)
73 |
74 | if not os.path.exists(dstPath):
75 | command = ['mplayer',
76 | '-quiet',
77 | '-vo', 'null',
78 | '-vc', 'dummy',
79 | '-ao', 'pcm:waveheader:file=' + dstPath,
80 | wav_file]
81 |
82 | # actually run the command, only show stderror on terminal, close the processes (don't wait for user input)
83 | FNULL = open(os.devnull, 'w')
84 | p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True) # stdout=subprocess.PIPE
85 |
86 | # TODO this line is commented out to enable parallel file processing; uncomment if you need to access the file directly after creation
87 | # subprocess.Popen.wait(p) # wait for completion
88 | return 1
89 | else:
90 | return 0
91 |
92 | # generate new .phn file with mapped phonemes (from 61, to 39 -> see dictionary in phoneme_set.py)
93 | def transformPhn(phn_file, dstPath):
94 | output_dir = os.path.dirname(dstPath)
95 | if not os.path.exists(output_dir):
96 | os.makedirs(output_dir)
97 |
98 | if not os.path.exists(dstPath):
99 | # extract label from phn
100 | phn_labels = []
101 | with open(phn_file, 'rb') as csvfile:
102 | phn_reader = csv.reader(csvfile, delimiter=' ')
103 | for row in phn_reader:
104 | start, stop, label = row[0], row[1], row[2]
105 |
106 | if label not in phoneme_set_39.keys(): # map from 61 to 39 phonems using dict
107 | label = phoneme_set_61_39.get(label)
108 |
109 | classNumber = label # phoneme_set_39[label] - 1 # get class number
110 | phn_labels.append([start, stop, classNumber])
111 |
112 | # print phn_labels
113 | # print phn_labels
114 | writeToTxt(phn_labels, dstPath)
115 |
116 |
117 | ########### High Level Functions ##########
118 | # just loop over all the found files
119 | def transformWavs(args):
120 | srcDir = args.srcDir
121 | dstDir = args.dstDir
122 |
123 | print("src: ", srcDir)
124 | print("dst: ", dstDir)
125 | srcWavs = loadWavs(srcDir)
126 | srcWavs.sort()
127 |
128 | if not os.path.exists(dstDir):
129 | os.makedirs(dstDir)
130 |
131 | resampled = []
132 | # transform: fix headers and resample. Use 2 seperate loops to prevent having to wait for the fixed file to be written
133 | # therefore also don't wait until completion in transformWav (the Popen.wait(p) line)
134 | print("FIXING WAV HEADERS AND COPYING TO ", dstDir, "...")
135 | for srcPath in tqdm(srcWavs, total=len(srcWavs)):
136 | dstPath = getDestPath(srcPath, dstDir)
137 | resampled.append(dstPath)
138 | transformWav(srcPath, dstPath)
139 | if debug: print(srcPath, dstPath)
140 |
141 | print("RESAMPLING TO 16kHz...")
142 | for dstPath in tqdm(resampled, total=len(resampled)):
143 | resample.resampleWAV(dstPath, dstPath, out_fr=16000.0, q=1.0) # resample to 16 kHz from 48kHz
144 |
145 | ## TODO USING resampy library: about 4x faster, but sometimes weird crashes...
146 | # in_fr, in_data = wavfile.read(dstPath)
147 | # in_type = in_data.dtype
148 | # in_data = in_data.astype(float)
149 | # # x is now a 1-d numpy array, with `sr_orig` audio samples per second
150 | # # We can resample this to any sampling rate we like, say 16000 Hz
151 | # y_low = resampy.resample(in_data, in_fr, 16000)
152 | # y_low = y_low.astype(in_type)
153 | # wavfile.write(dstPath, 16000, y_low)
154 |
155 |
156 | def transformPhns(args):
157 | srcDir = args.srcDir
158 | dstDir = args.dstDir
159 | srcPhns = loadPhns(srcDir)
160 |
161 | print("Source Directory: ", srcDir)
162 | print("Destination Directory: ", dstDir)
163 |
164 | if not os.path.exists(dstDir):
165 | os.makedirs(dstDir)
166 |
167 | for srcPath in tqdm(srcPhns, total=len(srcPhns)):
168 | dstPath = getDestPath(srcPath, dstDir)
169 | # print("reading from: ", srcPath)
170 | # print("writing to: ", dstPath)
171 | transformPhn(srcPath, dstPath)
172 |
173 |
174 | ## help functions ###
175 | def readPhonemeDict(filePath):
176 | d = {}
177 | with open(filePath) as f:
178 | for line in f:
179 | (key, val) = line.split()
180 | d[int(key)] = val
181 | return d
182 |
183 |
184 | def checkDirs(args):
185 | if 'dstDir' in args and not os.path.exists(args.dstDir):
186 | os.makedirs(args.dstDir)
187 | if 'srcDir' in args and not os.path.exists(args.srcDir):
188 | raise Exception('Can not find source data path')
189 |
190 |
191 | ################# PARSER #####################
192 | def prepare_parser():
193 | parser = argparse.ArgumentParser()
194 | sub_parsers = parser.add_subparsers()
195 |
196 | ## TRANSFORM ##
197 | phn_parser = sub_parsers.add_parser('phonemes')
198 | phn_parser.set_defaults(func=transformPhns)
199 | phn_parser.add_argument('-i', '--srcDir',
200 | help="the directory storing source data",
201 | required=True)
202 | phn_parser.add_argument('-o', '--dstDir',
203 | help="the directory store output data",
204 | required=True)
205 | ## TRANSFORM ##
206 | wav_parser = sub_parsers.add_parser('wavs')
207 | wav_parser.set_defaults(func=transformWavs)
208 | wav_parser.add_argument('-i', '--srcDir',
209 | help="the directory storing source data",
210 | required=True)
211 | wav_parser.add_argument('-o', '--dstDir',
212 | help="the directory store output data",
213 | required=True)
214 | return parser
215 |
216 |
217 | if __name__ == '__main__':
218 | arg_parser = prepare_parser()
219 | args = arg_parser.parse_args(sys.argv[1:])
220 | checkDirs(args)
221 | args.func(args)
222 |
--------------------------------------------------------------------------------