├── LICENSE ├── README.md ├── arpabets.py ├── beamsearch.py ├── char_map.py ├── create_arpabet_json.py ├── create_desc_json.py ├── data_generator.py ├── download.sh ├── flac_to_wav.sh ├── flac_to_wav_ffmpeg.sh ├── model.py ├── model.tar.gz ├── model_wrp.py ├── models-evaluation.ipynb ├── plot.py ├── pre-trained ├── model_25_config.json └── model_45_config.json ├── test.py ├── train.py ├── trainer.py ├── utils.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deepspeech-playground 2 | 3 | This repo is a fork of [Baidu's DeepSpeech](https://github.com/baidu-research/ba-dls-deepspeech). Unlike Baidu's repo: 4 | 5 | - It works with both Tensorflow and Theano 6 | - It has helpers for better training by training against auto-generated phonograms 7 | - Training by Theano can be much faster, since CTC calculation may be done by GPU 8 | 9 | 10 | ## Training 11 | 12 | If you want train by Theano you'll need Theano>=0.10 since It has [bindings](http://deeplearning.net/software/theano_versions/dev/library/tensor/nnet/ctc.html) for Baidu's CTC. 13 | 14 | ### Using Phonogram 15 | 16 | `HalfPhonemeModelWrapper` class in `model_wrp` module implements training of a model with half of RNN layers trained for Phonorgrams and rest of them for actual output text. To generate Phonograms, [Logios](https://github.com/skerit/cmusphinx/tree/master/logios) tool of CMU Sphinx can be used. Sphinx Phonogram symbols are called [Arpabets](http://www.speech.cs.cmu.edu/cgi-bin/cmudict). To generate Arpabets from Baidu's DeepSpeech [description files](https://github.com/baidu-research/ba-dls-deepspeech#data) you can: 17 | ``` 18 | $ cat train_corpus.json | sed -e 's/.*"text": "\([^"]*\)".*/\1/' > train_corpus.txt 19 | # make_pronunciation.pl script is provided by logios 20 | # https://github.com/skerit/cmusphinx/tree/master/logios/Tools/MakeDict 21 | $ perl ./make_pronunciation.pl -tools ../ -dictdir . -words prons/train_corpus.txt -dict prons/train_corpus.dict 22 | $ python create_arpabet_json.py train_corpus.json train_corpus.dict train_corpus.arpadesc 23 | ``` 24 | 25 | ### Choose backend 26 | 27 | Select Keras backend by environment variable `KERAS_BACKEND` to `theano` or `tensorflow`. 28 | 29 | ### Train! 30 | Make a train routine, a function like this: 31 | 32 | ``` 33 | def train_sample_half_phoneme(datagen, save_dir, epochs, sortagrad, 34 | start_weights=False, mb_size=60): 35 | model_wrp = HalfPhonemeModelWrapper() 36 | model = model_wrp.compile(nodes=1000, conv_context=5, recur_layers=5) 37 | logger.info('model :\n%s' % (model.to_yaml(),)) 38 | 39 | if start_weights: 40 | model.load_weights(start_weights) 41 | 42 | train_fn, test_fn = (model_wrp.compile_train_fn(1e-4), 43 | model_wrp.compile_test_fn()) 44 | trainer = Trainer(model, train_fn, test_fn, on_text=True, on_phoneme=True) 45 | trainer.run(datagen, save_dir, epochs=epochs, do_sortagrad=sortagrad, 46 | mb_size=mb_size, stateful=False) 47 | return trainer, model_wrp 48 | ``` 49 | And call it in from `main()` of `train.py`. Training can be done by: 50 | ``` 51 | $ KERAS_BACKEND="tensorflow" python train.py descs/small.arpadesc descs/test-clean.arpadesc models/test --epochs 20 --use-arpabets --sortagrad 1 52 | ``` 53 | 54 | ## Evaluation 55 | 56 | `visualize.py` will give you a semi-shell for testing your model by giving it input files. There is also [models-evaluation notebook](models-evaluation.ipynb), though it may look too dirty.. 57 | 58 | ## Pre-trained models 59 | 60 | These models are trained for about three days by LibriSpeech corpus on a GTX 1080 Ti GPU: 61 | 62 | - A five layers unidirectional RNN model trained by LibriSpeech using Theano: [mega](https://mega.nz/#!ZTIjXQgA!HK1vCRxYC1VyzJ_8LCwwcTrNH9aF7l-H8TYf7eE1v6g), [drive](https://drive.google.com/open?id=0B-xCVC7fUa3MZ3B1UVpYWlY1LWs) 63 | - A five layers unidirectional RNN model trained by LibriSpeech using Tensorflow: [mega](https://mega.nz/#!APR1iRjT!pgJcnEWLTHzJ4m9dQXA_2gvrJxa_h9uwEHc6Sxwreow), [drive](https://drive.google.com/open?id=0B-xCVC7fUa3MdkdNc05zT2dyblk) 64 | 65 | Validation ~~WER~~ CER of these models on `test-clean` is about %5 an It's about %15 on `test-other`. 66 | -------------------------------------------------------------------------------- /arpabets.py: -------------------------------------------------------------------------------- 1 | from utils import for_tf_or_th 2 | 3 | sphinx_40_phones = ''' 4 | AA 5 | AE 6 | AH 7 | AO 8 | AW 9 | AY 10 | B 11 | CH 12 | D 13 | DH 14 | EH 15 | ER 16 | EY 17 | F 18 | G 19 | HH 20 | IH 21 | IY 22 | JH 23 | K 24 | L 25 | M 26 | N 27 | NG 28 | OW 29 | OY 30 | P 31 | R 32 | S 33 | SH 34 | T 35 | TH 36 | UH 37 | UW 38 | V 39 | W 40 | Y 41 | Z 42 | ZH 43 | IX 44 | ''' 45 | # SIL was not exists MakeDict output but there was IX translated for ' 46 | # sphinx_40_phones = sphinx_40_phones[:-4] + 'IX' 47 | 48 | phone_map, index_map = {}, {} 49 | start_index = for_tf_or_th(0, 1) 50 | for i, ph in enumerate(sphinx_40_phones.strip().split()): 51 | phone_map[ph] = i + start_index 52 | index_map[i+start_index] = ph 53 | 54 | 55 | def arpabet_to_int_sequence(phonograms): 56 | return [phone_map[ph] for ph in phonograms.split()] 57 | -------------------------------------------------------------------------------- /beamsearch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from char_map import index_map 3 | from utils import for_tf_or_th 4 | 5 | 6 | def ctc_decode(acc, next_i): 7 | next_char = '' if next_i == for_tf_or_th(28, 0) else index_map[next_i] 8 | if not acc: 9 | return [next_char] 10 | if acc[-1] == next_char: 11 | return acc 12 | if acc[-1] == '': 13 | return acc[:-1] + [next_char] 14 | return acc + [next_char] 15 | 16 | 17 | def beam_decode(probs, beam_width, eps=1e-5): 18 | preds = [] 19 | 20 | for t in range(probs.shape[0]): 21 | best_ys = probs[t].argsort()[-beam_width:] 22 | if t == 0: 23 | pred_probs = np.log(probs[t, best_ys].clip(eps, 100)) 24 | preds = [[i] for i in best_ys] 25 | else: 26 | level_probs = np.array([ 27 | [pp + np.log(p.clip(eps, 100)) for p in probs[t, best_ys]] 28 | for pp in pred_probs]).flatten() 29 | best_probs = level_probs.argsort()[-beam_width:] 30 | preds = [preds[prob_ix // beam_width] + 31 | [best_ys[prob_ix % beam_width]] for prob_ix in best_probs] 32 | pred_probs = level_probs[best_probs] 33 | 34 | # decode ctc 35 | predicts = [] 36 | for pred in preds: 37 | prev, plist = -1, [] 38 | for i in pred: 39 | if i == prev: 40 | continue 41 | elif i != for_tf_or_th(28, 0): 42 | plist.append(index_map[i]) 43 | prev = i 44 | predicts.append(''.join(plist)) 45 | 46 | return predicts, pred_probs 47 | 48 | 49 | def beam_decode_u(probs, beam_width, eps=1e-5, normalize=False): 50 | u_preds = [] # unique predictions 51 | # give more characters a chance becauase we remove duplicates in each step 52 | mid_beam = max(beam_width+2, probs.shape[1]) 53 | 54 | # loop over each time 55 | for t in range(probs.shape[0]): 56 | best_ys = probs[t].argsort()[-mid_beam:] 57 | if normalize: 58 | clipped_t_probs = probs[t].clip(eps, 100) 59 | norm_log_sum = np.log(np.exp(clipped_t_probs[best_ys]).sum()) 60 | if t == 0: 61 | if normalize: 62 | pred_probs = clipped_t_probs[best_ys] - norm_log_sum 63 | else: 64 | pred_probs = np.log(probs[t, best_ys].clip(eps, 100)) 65 | u_preds = [[] if i == for_tf_or_th(28, 0) else [index_map[i]] 66 | for i in best_ys] 67 | else: 68 | if normalize: 69 | level_probs = np.array([ 70 | [pp + p - norm_log_sum for p in clipped_t_probs[best_ys]] 71 | for pp in pred_probs]).flatten() 72 | else: 73 | level_probs = np.array([ 74 | [pp + np.log(p.clip(eps, 100)) for p in probs[t, best_ys]] 75 | for pp in pred_probs]).flatten() 76 | best_probs = level_probs.argsort()[-(beam_width*2):] 77 | level_preds = [(prob_ix, ctc_decode(u_preds[prob_ix // mid_beam], 78 | best_ys[prob_ix % mid_beam])) 79 | for prob_ix in best_probs] 80 | # delete duplicates 81 | new_preds, new_prob_ixs = [], [] 82 | for prob_ix, pred in level_preds[::-1]: 83 | if pred in new_preds: 84 | continue 85 | else: 86 | new_preds.append(pred) 87 | new_prob_ixs.append(prob_ix) 88 | u_preds = new_preds[:beam_width] 89 | pred_probs = level_probs[new_prob_ixs[:beam_width]] 90 | 91 | return [''.join(pred) for pred in u_preds], pred_probs 92 | 93 | 94 | def beam_decode_mul(probs, beam_width): 95 | nodes = [[]] * beam_width 96 | scores = None 97 | 98 | for t in range(probs.shape[0]): 99 | best_ys = probs[t].argsort()[-beam_width:] 100 | if t == 0: 101 | best_scores = probs[t, best_ys] / 10 102 | else: 103 | best_scores = (scores[:, None] * probs[t, best_ys]/10).flatten() 104 | best_is = best_scores.argsort()[-beam_width:] 105 | nodes = [nodes[si // beam_width] + [best_ys[si % beam_width]] 106 | for si in best_is] 107 | print (best_scores) 108 | scores = np.clip(best_scores[best_is], 1e-4, 1e4) 109 | preds = [] 110 | for strcode in nodes: 111 | preds.append([]) 112 | pred = -1 113 | for code in strcode: 114 | if code == pred: 115 | continue 116 | elif code != for_tf_or_th(28, 0): 117 | preds[-1].append(index_map[code]) 118 | pred = code 119 | 120 | return [''.join(p) for p in preds], scores 121 | -------------------------------------------------------------------------------- /char_map.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | 3 | char_map_str = """ 4 | ' 1 5 | 2 6 | a 3 7 | b 4 8 | c 5 9 | d 6 10 | e 7 11 | f 8 12 | g 9 13 | h 10 14 | i 11 15 | j 12 16 | k 13 17 | l 14 18 | m 15 19 | n 16 20 | o 17 21 | p 18 22 | q 19 23 | r 20 24 | s 21 25 | t 22 26 | u 23 27 | v 24 28 | w 25 29 | x 26 30 | y 27 31 | z 28 32 | """ 33 | 34 | 35 | def ctc_idx(i): 36 | if K.backend() == 'tensorflow': 37 | return i - 1 38 | elif K.backend() == 'theano': 39 | return i 40 | raise ValueError 41 | 42 | char_map = {} 43 | index_map = {} 44 | for line in char_map_str.strip().split('\n'): 45 | ch, index = line.split() 46 | i = ctc_idx(int(index)) 47 | char_map[ch] = i 48 | index_map[i] = ch 49 | index_map[ctc_idx(2)] = ' ' 50 | -------------------------------------------------------------------------------- /create_arpabet_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import sys 4 | 5 | 6 | def main(input_f, dict_f, out_f): 7 | desc_i = open(input_f) 8 | dic = open(dict_f) 9 | 10 | dic_read_line = None 11 | prev_desc = None 12 | dic_file_ended = False 13 | 14 | arpabet_descs = [] 15 | 16 | for desc_line in desc_i: 17 | desc_ld = json.loads(desc_line) 18 | text = desc_ld['text'].lower() 19 | if dic_read_line: 20 | if dic_read_line[0].lower() == text: 21 | desc_ld['arpabet'] = dic_read_line[1] 22 | dic_read_line = None 23 | 24 | if 'arpabet' not in desc_ld: 25 | pronun_found = False 26 | while not pronun_found: 27 | line = dic.readline() 28 | if line == '': 29 | sys.stderr.write('WARNING: dictionary file ended while' 30 | ' still looking for: {}\n'.format(text)) 31 | dic_file_ended = True 32 | break 33 | dic_read_line = line[:-1].split('\t') 34 | if dic_read_line[0].lower() == text: 35 | desc_ld['arpabet'] = dic_read_line[1] 36 | dic_read_line = None 37 | pronun_found = True 38 | elif prev_desc and (prev_desc['text'] == 39 | dic_read_line[0].lower().split('(')[0]): 40 | sys.stderr.write('INFO: found another pronunciation for: ' 41 | '{}\n'.format(prev_desc['text'])) 42 | prev_desc_new = prev_desc.copy() 43 | prev_desc_new['arpabet'] = dic_read_line[1] 44 | arpabet_descs.append(prev_desc_new) 45 | else: 46 | break 47 | 48 | if 'arpabet' in desc_ld: 49 | arpabet_descs.append(desc_ld) 50 | else: 51 | sys.stderr.write("WARNING: couldn't find pronunciation for: {}\n" 52 | .format(text)) 53 | prev_desc = desc_ld 54 | if dic_file_ended: 55 | sys.stderr.write('WARNING: dictionary find ended sooner\n') 56 | break 57 | 58 | with open(out_f, 'w') as out: 59 | for desc in arpabet_descs: 60 | out.write(json.dumps(desc) + '\n') 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('input_desc', type=str, help='Input json line file') 66 | parser.add_argument('dict', type=str, 67 | help='Arpabet translation file of input desc json') 68 | parser.add_argument('output_desc', type=str, 69 | help='Output json line file') 70 | args = parser.parse_args() 71 | 72 | main(args.input_desc, args.dict, args.output_desc) 73 | -------------------------------------------------------------------------------- /create_desc_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use this script to create JSON-Line description files that can be used to 3 | train deep-speech models through this library. 4 | This works with data directories that are organized like LibriSpeech: 5 | data_directory/group/speaker/[file_id1.wav, file_id2.wav, ..., 6 | speaker.trans.txt] 7 | 8 | Where speaker.trans.txt has in each line, file_id transcription 9 | """ 10 | 11 | from __future__ import absolute_import, division, print_function 12 | 13 | import argparse 14 | import json 15 | import os 16 | import wave 17 | 18 | 19 | def main(data_directory, output_file): 20 | labels = [] 21 | durations = [] 22 | keys = [] 23 | for group in os.listdir(data_directory): 24 | speaker_path = os.path.join(data_directory, group) 25 | for speaker in os.listdir(speaker_path): 26 | labels_file = os.path.join(speaker_path, speaker, 27 | '{}-{}.trans.txt' 28 | .format(group, speaker)) 29 | for line in open(labels_file): 30 | split = line.strip().split() 31 | file_id = split[0] 32 | label = ' '.join(split[1:]).lower() 33 | audio_file = os.path.join(speaker_path, speaker, 34 | file_id) + '.wav' 35 | audio = wave.open(audio_file) 36 | duration = float(audio.getnframes()) / audio.getframerate() 37 | audio.close() 38 | keys.append(audio_file) 39 | durations.append(duration) 40 | labels.append(label) 41 | with open(output_file, 'w') as out_file: 42 | for i in range(len(keys)): 43 | line = json.dumps({'key': keys[i], 'duration': durations[i], 44 | 'text': labels[i]}) 45 | out_file.write(line + '\n') 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('data_directory', type=str, 51 | help='Path to data directory') 52 | parser.add_argument('output_file', type=str, 53 | help='Path to output file') 54 | args = parser.parse_args() 55 | main(args.data_directory, args.output_file) 56 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a class that is used to featurize audio clips, and provide 3 | them to the network for training or testing. 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import json 9 | import logging 10 | import numpy as np 11 | import random 12 | import keras.backend as K 13 | from keras.preprocessing.sequence import pad_sequences 14 | 15 | from concurrent.futures import ThreadPoolExecutor, wait 16 | 17 | from utils import calc_feat_dim, spectrogram_from_file, text_to_int_sequence 18 | from arpabets import arpabet_to_int_sequence 19 | 20 | RNG_SEED = 123 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class DataGenerator(object): 25 | def __init__(self, step=10, window=20, max_freq=8000, desc_file=None, 26 | use_arpabets=False, use_durations=False): 27 | """ 28 | Params: 29 | step (int): Step size in milliseconds between windows 30 | window (int): FFT window size in milliseconds 31 | max_freq (int): Only FFT bins corresponding to frequencies between 32 | [0, max_freq] are returned 33 | desc_file (str, optional): Path to a JSON-line file that contains 34 | labels and paths to the audio files. If this is None, then 35 | load metadata right away 36 | """ 37 | self.feat_dim = calc_feat_dim(window, max_freq) 38 | self.feats_mean = np.zeros((self.feat_dim,)) 39 | self.feats_std = np.ones((self.feat_dim,)) 40 | self.rng = random.Random(RNG_SEED) 41 | if desc_file is not None: 42 | self.load_metadata_from_desc_file(desc_file) 43 | self.step = step 44 | self.window = window 45 | self.max_freq = max_freq 46 | self.use_arpabets = use_arpabets 47 | self.use_durations = use_durations 48 | 49 | def featurize(self, audio_clip): 50 | """ For a given audio clip, calculate the log of its Fourier Transform 51 | Params: 52 | audio_clip(str): Path to the audio clip 53 | """ 54 | return spectrogram_from_file( 55 | audio_clip, step=self.step, window=self.window, 56 | max_freq=self.max_freq) 57 | 58 | def load_metadata_from_desc_file(self, desc_file, partition='train', 59 | max_duration=10.0): 60 | """ Read metadata from the description file 61 | (possibly takes long, depending on the filesize) 62 | Params: 63 | desc_file (str): Path to a JSON-line file that contains labels and 64 | paths to the audio files 65 | partition (str): One of 'train', 'validation' or 'test' 66 | max_duration (float): In seconds, the maximum duration of 67 | utterances to train or test on 68 | """ 69 | logger.info('Reading description file: {} for partition: {}' 70 | .format(desc_file, partition)) 71 | audio_paths, durations, texts, arpabets = [], [], [], [] 72 | with open(desc_file) as json_line_file: 73 | for line_num, json_line in enumerate(json_line_file): 74 | try: 75 | spec = json.loads(json_line) 76 | if float(spec['duration']) > max_duration: 77 | continue 78 | audio_paths.append(spec['key']) 79 | durations.append(float(spec['duration'])) 80 | texts.append(spec['text']) 81 | if self.use_arpabets: 82 | arpabets.append(spec['arpabet']) 83 | except Exception as e: 84 | # Change to (KeyError, ValueError) or 85 | # (KeyError,json.decoder.JSONDecodeError), depending on 86 | # json module version 87 | logger.warn('Error reading line #{}: {}' 88 | .format(line_num, json_line)) 89 | logger.warn(str(e)) 90 | 91 | if not self.use_arpabets: 92 | arpabets = [''] * len(audio_paths) 93 | 94 | if partition == 'train': 95 | self.train_audio_paths = audio_paths 96 | self.train_durations = durations 97 | self.train_texts = texts 98 | self.train_arpabets = arpabets 99 | elif partition == 'validation': 100 | self.val_audio_paths = audio_paths 101 | self.val_durations = durations 102 | self.val_texts = texts 103 | self.val_arpabets = arpabets 104 | elif partition == 'test': 105 | self.test_audio_paths = audio_paths 106 | self.test_durations = durations 107 | self.test_texts = texts 108 | self.test_arpabets = arpabets 109 | else: 110 | raise Exception("Invalid partition to load metadata. " 111 | "Must be train/validation/test") 112 | 113 | def load_train_data(self, desc_file, max_duration=10.0): 114 | self.load_metadata_from_desc_file(desc_file, 'train', max_duration) 115 | 116 | def load_test_data(self, desc_file): 117 | self.load_metadata_from_desc_file(desc_file, 'test') 118 | 119 | def load_validation_data(self, desc_file): 120 | self.load_metadata_from_desc_file(desc_file, 'validation') 121 | 122 | @staticmethod 123 | def sort_by_duration(durations, audio_paths, texts, arpabets): 124 | x = sorted(zip(durations, audio_paths, texts, arpabets)) 125 | if K.backend() == 'theano': 126 | x.reverse() 127 | return zip(*x) 128 | 129 | def normalize(self, feature, eps=1e-14): 130 | return (feature - self.feats_mean) / (self.feats_std + eps) 131 | 132 | def prepare_minibatch(self, audio_paths, texts, durations, arpabets): 133 | """ Featurize a minibatch of audio, zero pad them and return a dictionary 134 | Params: 135 | audio_paths (list(str)): List of paths to audio files 136 | texts (list(str)): List of texts corresponding to the audio files 137 | Returns: 138 | dict: See below for contents 139 | """ 140 | assert len(audio_paths) == len(texts),\ 141 | "Inputs and outputs to the network must be of the same number" 142 | # Features is a list of (timesteps, feature_dim) arrays 143 | # Calculate the features for each audio clip, as the log of the 144 | # Fourier Transform of the audio 145 | features = [self.featurize(a) for a in audio_paths] 146 | input_lengths = [f.shape[0] for f in features] 147 | max_length = max(input_lengths) 148 | feature_dim = features[0].shape[1] 149 | mb_size = len(features) 150 | # Pad all the inputs so that they are all the same length 151 | x = np.zeros((mb_size, max_length, feature_dim)) 152 | y = [] 153 | label_lengths = [] 154 | for i in range(mb_size): 155 | feat = features[i] 156 | feat = self.normalize(feat) # Center using means and std 157 | x[i, :feat.shape[0], :] = feat 158 | label = text_to_int_sequence(texts[i]) 159 | y.append(label) 160 | label_lengths.append(len(label)) 161 | y = pad_sequences(y, maxlen=len(max(texts, key=len)), dtype='int32', 162 | padding='post', truncating='post', value=-1) 163 | res = { 164 | 'x': x, # (0-padded features of shape(mb_size,timesteps,feat_dim) 165 | 'y': y, # list(int) Flattened labels (integer sequences) 166 | 'texts': texts, # list(str) Original texts 167 | 'input_lengths': input_lengths, # list(int) Length of each input 168 | 'label_lengths': label_lengths # list(int) Length of each label 169 | # 'durations' [if use_durations] list(float) Duration of each sample 170 | # 'phonemes'[if use_arpabets] list(int) Flattened arpabet ints 171 | } 172 | if self.use_durations: 173 | res['durations'] = durations 174 | if self.use_arpabets: 175 | arpints, arpaint_lengths = [], [] 176 | for i in range(mb_size): 177 | arpaint_seq = arpabet_to_int_sequence(arpabets[i]) 178 | arpints.append(arpaint_seq) 179 | arpaint_lengths.append(len(arpaint_seq)) 180 | maxlen = len(max(arpints, key=len)) 181 | res['phonemes'] = pad_sequences(arpints, maxlen=maxlen, 182 | dtype='int32', padding='post', 183 | truncating='post', value=-1) 184 | res['phoneme_lengths'] = arpaint_lengths 185 | return res 186 | 187 | def iterate(self, audio_paths, texts, minibatch_size, durations=[], 188 | arpabets=[], max_iters=None, parallel=True): 189 | if max_iters is not None: 190 | k_iters = max_iters 191 | else: 192 | k_iters = int(np.ceil(len(audio_paths) / minibatch_size)) 193 | logger.info("Iters: {}".format(k_iters)) 194 | if parallel: 195 | pool = ThreadPoolExecutor(1) # Run a single I/O thread in parallel 196 | future = pool.submit(self.prepare_minibatch, 197 | audio_paths[:minibatch_size], 198 | texts[:minibatch_size], 199 | durations[:minibatch_size], 200 | arpabets[:minibatch_size]) 201 | else: 202 | minibatch = self.prepare_minibatch(audio_paths[:minibatch_size], 203 | texts[:minibatch_size], 204 | durations[:minibatch_size], 205 | arpabets[:minibatch_size]) 206 | 207 | start = minibatch_size 208 | for i in range(k_iters - 1): 209 | if parallel: 210 | wait([future]) 211 | minibatch = future.result() 212 | # While the current minibatch is being consumed, prepare the 213 | # next 214 | future = pool.submit(self.prepare_minibatch, 215 | audio_paths[start:start+minibatch_size], 216 | texts[start:start+minibatch_size], 217 | durations[start:start+minibatch_size], 218 | arpabets[start:start+minibatch_size]) 219 | yield minibatch 220 | if not parallel: 221 | minibatch = self.prepare_minibatch( 222 | audio_paths[start:start+minibatch_size], 223 | texts[start:start+minibatch_size], 224 | durations[start:start+minibatch_size], 225 | arpabets[start:start+minibatch_size] 226 | ) 227 | start += minibatch_size 228 | # Wait on the last minibatch 229 | if parallel: 230 | wait([future]) 231 | minibatch = future.result() 232 | yield minibatch 233 | else: 234 | yield minibatch 235 | 236 | def iterate_train(self, minibatch_size=16, sort_by_duration=False, 237 | shuffle=True): 238 | if sort_by_duration and shuffle: 239 | shuffle = False 240 | logger.warn("Both sort_by_duration and shuffle were set to True. " 241 | "Setting shuffle to False") 242 | durations, audio_paths, texts, arpabets = (self.train_durations, 243 | self.train_audio_paths, 244 | self.train_texts, 245 | self.train_arpabets) 246 | if shuffle: 247 | temp = zip(durations, audio_paths, texts, arpabets) 248 | self.rng.shuffle(temp) 249 | durations, audio_paths, texts, arpabets = zip(*temp) 250 | if sort_by_duration: 251 | logger.info('Sorting training samples by duration') 252 | if getattr(self, '_sorted_data', None) is None: 253 | self._sorted_data = DataGenerator.sort_by_duration( 254 | durations, audio_paths, texts, arpabets) 255 | durations, audio_paths, texts, arpabets = self._sorted_data 256 | return self.iterate(audio_paths, texts, minibatch_size, durations, 257 | arpabets) 258 | 259 | def iterate_test(self, minibatch_size=16): 260 | return self.iterate(self.test_audio_paths, self.test_texts, 261 | minibatch_size, self.test_durations) 262 | 263 | def iterate_validation(self, minibatch_size=16): 264 | return self.iterate(self.val_audio_paths, self.val_texts, 265 | minibatch_size, self.val_durations, 266 | self.val_arpabets) 267 | 268 | def fit_train(self, k_samples=100): 269 | """ Estimate the mean and std of the features from the training set 270 | Params: 271 | k_samples (int): Use this number of samples for estimation 272 | """ 273 | k_samples = min(k_samples, len(self.train_audio_paths)) 274 | samples = self.rng.sample(self.train_audio_paths, k_samples) 275 | feats = [self.featurize(s) for s in samples] 276 | feats = np.vstack(feats) 277 | self.feats_mean = np.mean(feats, axis=0) 278 | self.feats_std = np.std(feats, axis=0) 279 | 280 | def reload_norm(self, dataset): 281 | """ Set mean and std of features from previous calculations 282 | Params: 283 | dataset (str) 284 | """ 285 | if dataset == '860-1000': 286 | self.feats_std = np.array([ 287 | 4.25136062, 3.8713157, 4.27721627, 4.79254968, 5.047769, 288 | 5.00917253, 4.92034587, 4.95192179, 4.99958183, 4.98448796, 289 | 4.93224872, 4.85590985, 4.78577772, 4.70706027, 4.62677301, 290 | 4.54424163, 4.455477, 4.38643766, 4.32992825, 4.28711064, 291 | 4.24306676, 4.24044366, 4.23590435, 4.21825687, 4.19820567, 292 | 4.17238816, 4.12828632, 3.903265, 3.88530966, 4.10232629, 293 | 4.15094822, 4.14674498, 4.13922566, 4.13210467, 4.12067026, 294 | 4.10835004, 4.09651096, 4.08038286, 4.06577381, 4.04688416, 295 | 4.01817645, 4.02679759, 4.02986556, 4.03453092, 4.04160862, 296 | 4.04830856, 4.0602057, 4.0771961, 4.09297194, 4.11034371, 297 | 4.11758663, 4.12095657, 4.11906109, 4.1155184, 4.10200893, 298 | 4.08276379, 4.08628075, 4.08675451, 4.07840435, 4.06359915, 299 | 4.04148782, 4.06030573, 4.06159643, 4.0473447, 4.03310411, 300 | 4.02725498, 4.02498171, 4.02632823, 4.02484766, 4.02769822, 301 | 4.02489051, 4.02088211, 4.02309526, 4.01872619, 4.01964194, 302 | 4.02153504, 4.02851296, 4.02778547, 4.0279664, 4.02255787, 303 | 4.00012165, 4.01658932, 3.93528177, 3.89534593, 4.017947, 304 | 4.03439452, 4.03349856, 4.03254631, 4.03193693, 4.0297471, 305 | 4.02667958, 4.02249605, 4.01419366, 4.01364902, 4.01290134, 306 | 4.01051293, 4.0089972, 4.00612032, 4.00165361, 3.98616987, 307 | 3.96209925, 3.98299328, 3.99713713, 3.99335162, 3.99078871, 308 | 3.98656532, 3.98739388, 3.98590306, 3.99035434, 3.98769832, 309 | 3.96287722, 3.97156738, 3.9831056, 3.97919869, 3.9740908, 310 | 3.96782821, 3.96331332, 3.95866512, 3.94998412, 3.92881555, 311 | 3.90712036, 3.92175492, 3.92782247, 3.92540498, 3.92125062, 312 | 3.91851146, 3.91551745, 3.90756256, 3.90717957, 3.90416995, 313 | 3.8988367, 3.89860032, 3.89073579, 3.8857745, 3.88896128, 314 | 3.88531893, 3.87826128, 3.84122702, 3.80223124, 3.84022503, 315 | 3.83359076, 3.85137669, 3.85647565, 3.85543909, 3.85545835, 316 | 3.85710792, 3.85664199, 3.85589913, 3.85612751, 3.85606959, 317 | 3.84913531, 3.84287049, 3.83940561, 3.84250948, 3.83669538, 318 | 3.83200409, 3.83749335, 3.8358794, 3.83973577, 3.83577876, 319 | 3.88315269]) 320 | self.feats_mean = np.array([ 321 | -19.9190417, -17.87074816, -17.18417253, -16.60615722, 322 | -16.47524177, -16.78722456, -17.22669022, -17.37899149, 323 | -17.43706583, -17.56693628, -17.78871635, -18.10035868, 324 | -18.49648794, -18.90612757, -19.25623952, -19.51816016, 325 | -19.7352671, -19.91201681, -20.07744978, -20.23590349, 326 | -20.36163737, -20.49420555, -20.59985973, -20.68975368, 327 | -20.76943646, -20.81912058, -20.83019722, -20.71162806, 328 | -20.69917742, -20.80050996, -20.80537402, -20.81130023, 329 | -20.8206865, -20.84339359, -20.87898781, -20.9239178, 330 | -20.97779635, -21.03484084, -21.11069463, -21.18509355, 331 | -21.23770916, -21.31738037, -21.37270682, -21.41930211, 332 | -21.4509242, -21.47946454, -21.50444787, -21.51825131, 333 | -21.51594888, -21.51341618, -21.50983657, -21.52451875, 334 | -21.55498166, -21.59505743, -21.63528843, -21.67396515, 335 | -21.73275922, -21.79585578, -21.84556578, -21.87028101, 336 | -21.87358228, -21.92915795, -21.98821111, -22.04652423, 337 | -22.10257075, -22.14423208, -22.17747545, -22.21208271, 338 | -22.24781483, -22.28908871, -22.33593842, -22.37691381, 339 | -22.42626383, -22.46079106, -22.48787287, -22.50766501, 340 | -22.53586539, -22.56786281, -22.59582998, -22.62581144, 341 | -22.64886813, -22.71750843, -22.7279599, -22.7583583, 342 | -22.87805837, -22.9381045, -22.98543052, -23.03572058, 343 | -23.09462637, -23.15124978, -23.20831305, -23.26313514, 344 | -23.3142818, -23.3671435, -23.41988972, -23.47336085, 345 | -23.52725449, -23.57844801, -23.63283689, -23.67411968, 346 | -23.69833849, -23.77693613, -23.83653636, -23.87643816, 347 | -23.91009798, -23.94392507, -23.98188049, -24.0217485, 348 | -24.05824361, -24.09093447, -24.1127368, -24.15196192, 349 | -24.18803859, -24.21636283, -24.24457024, -24.27694599, 350 | -24.3077978, -24.33267521, -24.35511158, -24.3535704, 351 | -24.33095828, -24.37981144, -24.41567347, -24.42952345, 352 | -24.42856408, -24.43132484, -24.44428014, -24.46147466, 353 | -24.48292062, -24.50661327, -24.53453117, -24.56691744, 354 | -24.58867348, -24.61431382, -24.6420865, -24.66867143, 355 | -24.70376135, -24.71120825, -24.70337552, -24.74880836, 356 | -24.74934962, -24.8100975, -24.85395296, -24.88296942, 357 | -24.90949419, -24.93485977, -24.96167982, -24.98794161, 358 | -25.01261378, -25.04233288, -25.06206523, -25.0873588, 359 | -25.11614341, -25.1463879, -25.16707626, -25.18893841, 360 | -25.22892228, -25.26932148, -25.31942099, -25.34896047, 361 | -26.51045921]) 362 | else: 363 | raise ValueError 364 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # Use this script to download the entire LibriSpeech dataset 2 | 3 | #!/bin/bash 4 | 5 | base="http://www.openslr.org/resources/12/" 6 | 7 | 8 | 9 | for s in 'dev-clean' 'dev-other' 'test-clean' 'test-other' 'train-clean-100' 'train-clean-360' 'train-other-500' 10 | do 11 | linkname="${base}/${s}.tar.gz" 12 | echo $linkname 13 | wget -c $linkname 14 | done 15 | 16 | for s in 'dev-clean' 'dev-other' 'test-clean' 'test-other' 'train-clean-100' 'train-clean-360' 'train-other-500' 17 | do 18 | tar -xzvf $s.tar.gz 19 | done 20 | 21 | -------------------------------------------------------------------------------- /flac_to_wav.sh: -------------------------------------------------------------------------------- 1 | # Convert all .flac files within this folder to .wav files 2 | 3 | find . -iname "*.flac" | wc 4 | 5 | for flacfile in `find . -iname "*.flac"` 6 | do 7 | avconv -y -f flac -i $flacfile -ab 64k -ac 1 -ar 16000 -f wav "${flacfile%.*}.wav" 8 | done 9 | -------------------------------------------------------------------------------- /flac_to_wav_ffmpeg.sh: -------------------------------------------------------------------------------- 1 | # Convert all .flac files within this folder to .wav files 2 | 3 | find . -iname "*.flac" | wc 4 | 5 | for flacfile in `find . -iname "*.flac"` 6 | do 7 | [ -e "${flacfile%.*}.wav" ] || ffmpeg -i $flacfile -ab 64k -ac 1 -ar 16000 -loglevel 16 "${flacfile%.*}.wav" 8 | done 9 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define functions used to construct a multilayer GRU CTC model, and 3 | functions for training and testing it. 4 | """ 5 | 6 | import ctc 7 | import logging 8 | import keras.backend as K 9 | 10 | import keras 11 | from keras.layers import (BatchNormalization, Convolution1D, Dense, 12 | Input, GRU, TimeDistributed, Dropout) 13 | from keras.models import Model 14 | from keras.optimizers import SGD 15 | #import lasagne 16 | 17 | from utils import conv_output_length 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | k2 = keras.__version__[0] == '2' 22 | 23 | if k2: 24 | from keras.layers import Conv1D 25 | 26 | def batch_norm_compat(self, mode=0, **kwargs): 27 | if mode != 0: 28 | logger.warn('ignoring unsuported batchnorm mode of {} on keras 2' 29 | .format(mode)) 30 | self._compat_init(**kwargs) 31 | BatchNormalization._compat_init = BatchNormalization.__init__ 32 | BatchNormalization.__init__ = batch_norm_compat 33 | else: 34 | from keras.layers import Convolution1D 35 | 36 | def v1_compat(self, *args, **kwargs): 37 | v1_dict = {} 38 | for k, v in kwargs.items(): 39 | if k == 'padding': 40 | v1key = 'border_mode' 41 | elif k == 'strides': 42 | v1key = 'subsample_length' 43 | elif k == 'kernel_initializer': 44 | v1key = 'init' 45 | else: 46 | v1key = k 47 | v1_dict[v1key] = v 48 | self._v1_init(*args, **v1_dict) 49 | 50 | Conv1D = Convolution1D 51 | Conv1D._v1_init = Conv1D.__init__ 52 | Conv1D.__init__ = v1_compat 53 | 54 | Dense._v1_init = Dense.__init__ 55 | Dense.__init__ = v1_compat 56 | 57 | GRU._v1_init = GRU.__init__ 58 | GRU.__init__ = v1_compat 59 | 60 | 61 | def compile_train_fn(model, learning_rate=2e-4): 62 | """ Build the CTC training routine for speech models. 63 | Args: 64 | model: A keras model (built=True) instance 65 | Returns: 66 | train_fn (theano.function): Function that takes in acoustic inputs, 67 | and updates the model. Returns network outputs and ctc cost 68 | """ 69 | logger.info("Building train_fn") 70 | acoustic_input = model.inputs[0] 71 | network_output = model.outputs[0] 72 | output_lens = K.placeholder(ndim=1, dtype='int32') 73 | label = K.placeholder(ndim=1, dtype='int32') 74 | label_lens = K.placeholder(ndim=1, dtype='int32') 75 | network_output = network_output.dimshuffle((1, 0, 2)) 76 | 77 | ctc_cost = ctc.cpu_ctc_th(network_output, output_lens, 78 | label, label_lens).mean() 79 | trainable_vars = model.trainable_weights 80 | optimizer = SGD(nesterov=True, lr=learning_rate, momentum=0.9, 81 | clipnorm=100) 82 | updates = optimizer.get_updates(trainable_vars, [], ctc_cost) 83 | train_fn = K.function([acoustic_input, output_lens, label, label_lens, 84 | K.learning_phase()], 85 | [network_output, ctc_cost], 86 | updates=updates) 87 | return train_fn 88 | 89 | 90 | def compile_test_fn(model): 91 | """ Build a testing routine for speech models. 92 | Args: 93 | model: A keras model (built=True) instance 94 | Returns: 95 | val_fn (theano.function): Function that takes in acoustic inputs, 96 | and calculates the loss. Returns network outputs and ctc cost 97 | """ 98 | logger.info("Building val_fn") 99 | acoustic_input = model.inputs[0] 100 | network_output = model.outputs[0] 101 | output_lens = K.placeholder(ndim=1, dtype='int32') 102 | label = K.placeholder(ndim=1, dtype='int32') 103 | label_lens = K.placeholder(ndim=1, dtype='int32') 104 | network_output = network_output.dimshuffle((1, 0, 2)) 105 | 106 | ctc_cost = ctc.cpu_ctc_th(network_output, output_lens, 107 | label, label_lens).mean() 108 | val_fn = K.function([acoustic_input, output_lens, label, label_lens, 109 | K.learning_phase()], 110 | [network_output, ctc_cost]) 111 | return val_fn 112 | 113 | 114 | def compile_output_fn(model): 115 | """ Build a function that simply calculates the output of a model 116 | Args: 117 | model: A keras model (built=True) instance 118 | Returns: 119 | output_fn (theano.function): Function that takes in acoustic inputs, 120 | and returns network outputs 121 | """ 122 | logger.info("Building val_fn") 123 | acoustic_input = model.inputs[0] 124 | network_output = model.outputs[0] 125 | network_output = network_output.dimshuffle((1, 0, 2)) 126 | 127 | output_fn = K.function([acoustic_input, K.learning_phase()], 128 | [network_output]) 129 | return output_fn 130 | 131 | 132 | def compile_gru_model(input_dim=161, output_dim=29, recur_layers=3, nodes=1024, 133 | conv_context=11, conv_border_mode='valid', conv_stride=2, 134 | initialization='glorot_uniform', batch_norm=True): 135 | """ Build a recurrent network (CTC) for speech with GRU units """ 136 | logger.info("Building gru model") 137 | # Main acoustic input 138 | acoustic_input = Input(shape=(None, input_dim), name='acoustic_input') 139 | 140 | # Setup the network 141 | conv_1d = Convolution1D(nodes, conv_context, name='conv1d', 142 | border_mode=conv_border_mode, 143 | subsample_length=conv_stride, init=initialization, 144 | activation='relu')(acoustic_input) 145 | if batch_norm: 146 | output = BatchNormalization(name='bn_conv_1d', mode=2)(conv_1d) 147 | else: 148 | output = conv_1d 149 | output = Dropout(.2)(output) 150 | 151 | for r in range(recur_layers): 152 | output = GRU(nodes, activation='relu', 153 | name='rnn_{}'.format(r + 1), init=initialization, 154 | return_sequences=True)(output) 155 | if batch_norm: 156 | bn_layer = BatchNormalization(name='bn_rnn_{}'.format(r + 1), 157 | mode=2) 158 | output = bn_layer(output) 159 | 160 | # We don't softmax here because CTC does that 161 | network_output = TimeDistributed(Dense( 162 | output_dim, name='dense', activation='linear', init=initialization, 163 | ))(output) 164 | model = Model(input=acoustic_input, output=network_output) 165 | model.conv_output_length = lambda x: conv_output_length( 166 | x, conv_context, conv_border_mode, conv_stride) 167 | return model 168 | -------------------------------------------------------------------------------- /model.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reith/deepspeech-playground/55e57875733cc0e60fcb96e8953fcc105ba565e5/model.tar.gz -------------------------------------------------------------------------------- /model_wrp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model generator based on high level properties 3 | """ 4 | 5 | import logging 6 | from utils import for_tf_or_th 7 | 8 | import keras 9 | import keras.backend as K 10 | from keras.layers import (BatchNormalization, Dense, Input, GRU, concatenate, 11 | TimeDistributed, Dropout) 12 | from keras.layers.advanced_activations import LeakyReLU 13 | from keras.models import Model 14 | from keras.optimizers import Adam 15 | 16 | if K.backend() == 'theano': 17 | import theano.gpuarray.ctc as ctc_th 18 | logging.info('using theano.gpuarray.ctc') 19 | 20 | k2 = keras.__version__[0] == '2' 21 | 22 | if k2: 23 | from keras.layers import Conv1D 24 | 25 | def batch_norm_compat(self, mode=0, **kwargs): 26 | if mode != 0: 27 | logger.warn('ignoring unsuported batchnorm mode of {} on keras 2' 28 | .format(mode)) 29 | self._v2_init(**kwargs) 30 | BatchNormalization._v2_init = BatchNormalization.__init__ 31 | BatchNormalization.__init__ = batch_norm_compat 32 | else: 33 | from keras.layers import Convolution1D 34 | 35 | def v1_compat(self, *args, **kwargs): 36 | v1_dict = {} 37 | for k, v in kwargs.items(): 38 | if k == 'padding': 39 | v1key = 'border_mode' 40 | elif k == 'strides': 41 | v1key = 'subsample_length' 42 | elif k == 'kernel_initializer': 43 | v1key = 'init' 44 | else: 45 | v1key = k 46 | v1_dict[v1key] = v 47 | self._v1_init(*args, **v1_dict) 48 | 49 | Conv1D = Convolution1D 50 | Conv1D._v1_init = Conv1D.__init__ 51 | Conv1D.__init__ = v1_compat 52 | 53 | Dense._v1_init = Dense.__init__ 54 | Dense.__init__ = v1_compat 55 | 56 | GRU._v1_init = GRU.__init__ 57 | GRU.__init__ = v1_compat 58 | 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | 63 | def duration_cost(y, y_pred): 64 | """" A Loss function for duration costs """ 65 | return (y - y_pred)**2 66 | 67 | 68 | def model_output_dim(out_type): 69 | """ Return output dimention of model based on output type 70 | Args: 71 | out_type: string either 'text' or 'arpabet' 72 | """ 73 | if out_type == 'text': 74 | from char_map import index_map 75 | return len(index_map) + 1 76 | if out_type == 'arpabet': 77 | from arpabets import index_map 78 | return len(index_map) + 1 79 | raise ValueError 80 | 81 | 82 | class ModelWrapper(object): 83 | 84 | def __init__(self, outputs='text', stateful=False): 85 | if outputs == 'text': 86 | self.output_dim = model_output_dim('text') 87 | elif outputs == 'arpabet': 88 | self.output_dim = model_output_dim('arpabet') 89 | elif isinstance(outputs, list) and (sorted(outputs) == 90 | ['arpabet', 'text']): 91 | self.vocab_dim = model_output_dim('text') 92 | self.phono_dim = model_output_dim('arpabet') 93 | else: 94 | raise ValueError 95 | self.outputs = outputs 96 | self.stateful = stateful 97 | self.branch_vars = {} 98 | self.model = None 99 | 100 | @staticmethod 101 | def plug_model(old): 102 | if not isinstance(old, ModelWrapper): 103 | raise ValueError 104 | 105 | new = ModelWrapper(old.outputs, stateful=old.stateful) 106 | for attr in ('model', '_branch_labels', 'branch_vars', '_ctc_in_lens', 107 | 'branch_outputs', 'acoustic_input'): 108 | setattr(new, attr, getattr(old, attr)) 109 | return new 110 | 111 | @property 112 | def branch_labels(self): 113 | if getattr(self, '_branch_labels', None) is None: 114 | d = dict() 115 | for bname in self.branch_vars.keys(): 116 | d[bname] = ( 117 | K.placeholder(ndim=2, dtype='int32'), 118 | K.placeholder(ndim=1, dtype='int32') 119 | ) 120 | self._branch_labels = d 121 | return self._branch_labels 122 | 123 | @property 124 | def ctc_in_lens(self): 125 | if getattr(self, '_ctc_in_lens', None) is None: 126 | self._ctc_in_lens = K.placeholder(ndim=1, dtype='int32') 127 | return self._ctc_in_lens 128 | 129 | def compile_train_fn(self, learning_rate=2e-4): 130 | """ Build the CTC training routine for speech models. 131 | Args: 132 | learning_rate (float) 133 | Returns: 134 | train_fn (theano.function): Function that takes in acoustic inputs, 135 | and updates the model. Returns network outputs and ctc cost 136 | """ 137 | logger.info("Building train_fn") 138 | f_inputs = [self.acoustic_input, self.ctc_in_lens] 139 | f_outputs = [] 140 | f_updates = [] 141 | for branch in self.branch_outputs: 142 | labels, label_lens = self.branch_labels[branch.name] 143 | f_inputs.append(labels) 144 | f_inputs.append(label_lens) 145 | 146 | if K.backend() == 'tensorflow': 147 | network_output = branch.output 148 | ctc_cost = K.mean(K.ctc_batch_cost(labels, network_output, 149 | self.ctc_in_lens, 150 | label_lens)) 151 | else: 152 | network_output = branch.output.dimshuffle((1, 0, 2)) 153 | ctc_cost = ctc_th.gpu_ctc(network_output, labels, 154 | self.ctc_in_lens).mean() 155 | 156 | f_outputs.extend([network_output, ctc_cost]) 157 | trainable_vars = self.branch_vars[branch.name] 158 | optmz = Adam(lr=learning_rate, clipnorm=100) 159 | f_updates.extend(optmz.get_updates(trainable_vars, [], ctc_cost)) 160 | 161 | f_inputs.append(K.learning_phase()) 162 | self.train_fn = K.function(f_inputs, f_outputs, f_updates) 163 | return self.train_fn 164 | 165 | def compile_test_fn(self): 166 | """ Build a testing routine for speech models. 167 | Returns: 168 | val_fn (theano.function): Function that takes in acoustic inputs, 169 | and calculates the loss. Returns network outputs and ctc cost 170 | """ 171 | logger.info("Building val_fn") 172 | f_inputs = [self.acoustic_input, self.ctc_in_lens] 173 | f_outputs = [] 174 | for branch in self.branch_outputs: 175 | labels, label_lens = self.branch_labels[branch.name] 176 | 177 | if K.backend() == 'tensorflow': 178 | network_output = branch.output 179 | ctc_cost = K.mean(K.ctc_batch_cost(labels, network_output, 180 | self.ctc_in_lens, 181 | label_lens)) 182 | else: 183 | network_output = branch.output.dimshuffle((1, 0, 2)) 184 | ctc_cost = ctc_th.gpu_ctc(network_output, labels, 185 | self.ctc_in_lens).mean() 186 | 187 | f_inputs.extend([labels, label_lens]) 188 | f_outputs.extend([network_output, ctc_cost]) 189 | f_inputs.append(K.learning_phase()) 190 | 191 | self.val_fn = K.function(f_inputs, f_outputs) 192 | return self.val_fn 193 | 194 | def compile_output_fn(self): 195 | """ Build a function that simply calculates the output of a model 196 | Returns: 197 | output_fn (theano.function): Function that takes in acoustic inputs, 198 | and returns network outputs 199 | """ 200 | logger.info("Bulding output_fn") 201 | if self.outputs in ['text', 'arpabet']: 202 | output_idx = 0 203 | elif self.outputs == ['arpabet', 'text']: 204 | output_idx = 1 205 | else: 206 | raise ValueError 207 | 208 | output = self.model.outputs[output_idx] 209 | if K.backend() == 'theano': 210 | output = output.dimshuffle((1, 0, 2)) 211 | 212 | output_fn = K.function([self.acoustic_input, K.learning_phase()], 213 | [output]) 214 | return output_fn 215 | 216 | 217 | class GruModelWrapper(ModelWrapper): 218 | """ Recurrent network (CTC) for speech with GRU units """ 219 | 220 | def compile(self, input_dim=161, recur_layers=3, nodes=1024, 221 | conv_context=11, conv_border_mode='valid', conv_stride=2, 222 | activation='relu', lirelu_alpha=.3, dropout=False, 223 | initialization='glorot_uniform', batch_norm=True, 224 | stateful=False, mb_size=None): 225 | logger.info("Building gru model") 226 | assert self.model is None 227 | 228 | leaky_relu = False 229 | if activation == 'lirelu': 230 | activation = 'linear' 231 | leaky_relu = True 232 | 233 | if stateful: 234 | if mb_size is None: 235 | raise ValueError("Stateful GRU layer needs to know batch size") 236 | acoustic_input = Input(batch_shape=(mb_size, None, input_dim), 237 | name='acoustic_input') 238 | else: 239 | acoustic_input = Input(shape=(None, input_dim), 240 | name='acoustic_input') 241 | 242 | # Setup the network 243 | conv_1d = Conv1D(nodes, conv_context, name='conv_1d', 244 | padding=conv_border_mode, strides=conv_stride, 245 | kernel_initializer=initialization, 246 | activation=activation)(acoustic_input) 247 | 248 | if batch_norm: 249 | output = BatchNormalization(name='bn_conv_1d', mode=2)(conv_1d) 250 | else: 251 | output = conv_1d 252 | 253 | if leaky_relu: 254 | output = LeakyReLU(alpha=lirelu_alpha)(output) 255 | 256 | if dropout: 257 | output = Dropout(dropout)(output) 258 | 259 | for r in range(recur_layers): 260 | output = GRU(nodes, name='rnn_{}'.format(r + 1), 261 | kernel_initializer=initialization, stateful=stateful, 262 | return_sequences=True, activation=activation)(output) 263 | 264 | if batch_norm: 265 | bn_layer = BatchNormalization(name='bn_rnn_{}'.format(r + 1), 266 | mode=2) 267 | output = bn_layer(output) 268 | 269 | if leaky_relu: 270 | output = LeakyReLU(alpha=lirelu_alpha)(output) 271 | 272 | output_branch = TimeDistributed(Dense( 273 | self.output_dim, name='text_dense', init=initialization, 274 | activation=for_tf_or_th('softmax', 'linear') 275 | ), name=self.outputs) 276 | network_output = output_branch(output) 277 | 278 | self.model = Model(input=acoustic_input, output=[network_output]) 279 | self.branch_outputs = [output_branch] 280 | self.branch_vars[output_branch.name] = self.model.trainable_weights 281 | self.acoustic_input = self.model.inputs[0] 282 | return self.model 283 | 284 | 285 | class HalfPhonemeModelWrapper(ModelWrapper): 286 | 287 | def __init__(self, *args, **kwargs): 288 | super(HalfPhonemeModelWrapper, self).__init__(['arpabet', 'text'], 289 | *args, **kwargs) 290 | 291 | def compile(self, input_dim=161, recur_layers=3, nodes=1024, 292 | conv_context=11, conv_padding='valid', mb_size=16, 293 | activation='relu', lirelu_alpha=.3, conv_stride=2, 294 | initialization='glorot_uniform', fast_text=False, 295 | batch_norm=True, dropout=False, stateful=False): 296 | 297 | logger.info("Building half phoneme model") 298 | assert self.model is None 299 | 300 | leaky_relu = False 301 | if activation == 'lirelu': 302 | activation = 'linear' 303 | leaky_relu = True 304 | 305 | if stateful: 306 | if mb_size is None: 307 | raise ValueError("Stateful GRU layer needs to know batch size") 308 | acoustic_input = Input(batch_shape=(mb_size, None, input_dim), 309 | name='acoustic_input') 310 | else: 311 | acoustic_input = Input(shape=(None, input_dim), 312 | name='acoustic_input') 313 | 314 | branch = 'phoneme' 315 | self.branch_vars[branch] = [] 316 | conv_1dl = Conv1D(nodes, conv_context, name='conv_1d', 317 | padding=conv_padding, strides=conv_stride, 318 | kernel_initializer=initialization, 319 | activation=activation) 320 | output = conv_1dl(acoustic_input) 321 | self.branch_vars[branch].extend(conv_1dl.trainable_weights) 322 | 323 | if batch_norm: 324 | bn_l = BatchNormalization(name='bn_conv_1d') 325 | output = bn_l(output) 326 | self.branch_vars[branch].extend(bn_l.trainable_weights) 327 | 328 | if leaky_relu: 329 | output = LeakyReLU(alpha=lirelu_alpha)(output) 330 | 331 | if dropout: 332 | output = Dropout(dropout)(output) 333 | 334 | for r in range(recur_layers): 335 | gru_l = GRU(nodes, activation=activation, stateful=stateful, 336 | name='rnn_{}'.format(r + 1), 337 | kernel_initializer=initialization, 338 | return_sequences=True) 339 | output = gru_l(output) 340 | self.branch_vars[branch].extend(gru_l.trainable_weights) 341 | 342 | if batch_norm: 343 | bn_l = BatchNormalization(name='bn_rnn_{}'.format(r + 1), 344 | mode=2) 345 | output = bn_l(output) 346 | self.branch_vars[branch].extend(bn_l.trainable_weights) 347 | 348 | if leaky_relu: 349 | output = LeakyReLU(alpha=lirelu_alpha)(output) 350 | 351 | if r+1 == recur_layers // 2: 352 | phoneme_dense = Dense( 353 | self.phono_dim, name='phoneme_dense', 354 | activation=for_tf_or_th('softmax', 'linear'), 355 | kernel_initializer=initialization) 356 | phoneme_branch = TimeDistributed(phoneme_dense, name=branch) 357 | phoneme_out = phoneme_branch(output) 358 | 359 | branch = 'text' 360 | if fast_text: 361 | self.branch_vars[branch] = list(self.branch_vars['phoneme']) 362 | else: 363 | self.branch_vars[branch] = [] 364 | 365 | text_dense = Dense(self.vocab_dim, name='text_dense', 366 | activation=for_tf_or_th('softmax', 'linear'), 367 | kernel_initializer=initialization) 368 | text_branch = TimeDistributed(text_dense, name=branch) 369 | text_out = text_branch(output) 370 | 371 | self.branch_vars['phoneme'].extend(phoneme_branch.trainable_weights) 372 | self.branch_vars['text'].extend(text_branch.trainable_weights) 373 | 374 | self.model = Model(input=acoustic_input, output=[phoneme_out, text_out]) 375 | self.branch_outputs = [phoneme_branch, text_branch] 376 | self.acoustic_input = self.model.inputs[0] 377 | return self.model 378 | 379 | 380 | class TwoHornModelWrapper(ModelWrapper): 381 | 382 | def compile(self, input_dim=161, phoneme_recurs=2, nodes=1024, 383 | text_recurs=3, conv_context=11, conv_stride=2, 384 | conv_padding='valid', mb_size=16, 385 | initialization='glorot_uniform', stateful=False): 386 | assert self.model is None 387 | if stateful: 388 | if mb_size is None: 389 | raise ValueError("Stateful GRU layer needs to know batch size") 390 | acoustic_input = Input(batch_shape=(mb_size, None, input_dim), 391 | name='acoustic_input') 392 | else: 393 | acoustic_input = Input(shape=(None, input_dim), 394 | name='acoustic_input') 395 | 396 | branch = 'phoneme' 397 | self.branch_vars[branch] = [] 398 | ph_conv1 = Conv1D(nodes, conv_context, name='ph_conv1', 399 | padding=conv_padding, strides=conv_stride, 400 | kernel_initializer=initialization, activation='relu') 401 | ph_output = ph_conv1(acoustic_input) 402 | self.branch_vars[branch].extend(ph_conv1.trainable_weights) 403 | 404 | bn_l = BatchNormalization(name='bn_ph_conv1') 405 | ph_output = bn_l(ph_output) 406 | self.branch_vars[branch].extend(bn_l.trainable_weights) 407 | 408 | for r in range(phoneme_recurs): 409 | gru_l = GRU(nodes, activation='relu', name='ph_rnn_{}'.format(r+1), 410 | kernel_initializer=initialization, 411 | stateful=stateful, return_sequences=True) 412 | ph_output = gru_l(ph_output) 413 | self.branch_vars[branch].extend(gru_l.trainable_weights) 414 | 415 | bn_l = BatchNormalization(name='bn_ph_rnn_{}'.format(r+1)) 416 | ph_output = bn_l(ph_output) 417 | self.branch_vars[branch].extend(bn_l.trainable_weights) 418 | 419 | phoneme_dense = Dense(self.phono_dim, name='phoneme_dense', 420 | activation=for_tf_or_th('softmax', 'linear'), 421 | kernel_initializer=initialization) 422 | phoneme_branch = TimeDistributed(phoneme_dense, name=branch) 423 | phoneme_out = phoneme_branch(ph_output) 424 | 425 | branch = 'text' 426 | self.branch_vars[branch] = [] 427 | 428 | tx_conv1 = Conv1D(nodes, conv_context, name='tx_conv1', 429 | padding=conv_padding, strides=conv_stride, 430 | kernel_initializer=initialization, activation='relu') 431 | tx_output = tx_conv1(acoustic_input) 432 | self.branch_vars[branch].extend(tx_conv1.trainable_weights) 433 | 434 | bn_l = BatchNormalization(name='bn_tx_conv1') 435 | tx_output = bn_l(tx_output) 436 | self.branch_vars[branch].extend(bn_l.trainable_weights) 437 | 438 | for r in range(text_recurs-1): 439 | gru_l = GRU(nodes, activation='relu', name='tx_rnn_{}'.format(r+1), 440 | kernel_initializer=initialization, 441 | stateful=stateful, return_sequences=True) 442 | tx_output = gru_l(tx_output) 443 | self.branch_vars[branch].extend(gru_l.trainable_weights) 444 | 445 | bn_l = BatchNormalization(name='bn_tx_rnn_{}'.format(r+1)) 446 | tx_output = bn_l(tx_output) 447 | self.branch_vars[branch].extend(bn_l.trainable_weights) 448 | 449 | output = concatenate([ph_output, tx_output]) 450 | 451 | mix_l = Dense(nodes, name='mix_dense', activation='linear', 452 | kernel_initializer=initialization) 453 | output = mix_l(output) 454 | self.branch_vars[branch].extend(mix_l.trainable_weights) 455 | 456 | gru_l = GRU(nodes, activation='relu', 457 | name='tx_rnn_{}'.format(text_recurs), 458 | kernel_initializer=initialization, stateful=stateful, 459 | return_sequences=True) 460 | output = gru_l(output) 461 | self.branch_vars[branch].extend(gru_l.trainable_weights) 462 | 463 | bn_l = BatchNormalization(name='bn_tx_rnn_{}'.format(text_recurs)) 464 | output = bn_l(output) 465 | self.branch_vars[branch].extend(bn_l.trainable_weights) 466 | 467 | text_dense = Dense(self.vocab_dim, name='text_dense', 468 | activation=for_tf_or_th('softmax', 'linear'), 469 | kernel_initializer=initialization) 470 | text_branch = TimeDistributed(text_dense, name=branch) 471 | text_out = text_branch(output) 472 | 473 | self.branch_vars['phoneme'].extend(phoneme_branch.non_trainable_weights) 474 | self.branch_vars['text'].extend(text_branch.non_trainable_weights) 475 | 476 | self.model = Model(input=acoustic_input, output=[phoneme_out, text_out]) 477 | self.branch_outputs = [phoneme_branch, text_branch] 478 | self.acoustic_input = self.model.inputs[0] 479 | return self.model 480 | 481 | 482 | class ConvOverConvModelWrapper(ModelWrapper): 483 | """ Build a recurrent network (CTC) for speech with GRU units over 484 | multiple convolution layers""" 485 | 486 | def compile(self, conv_props, input_dim=161, recur_layers=3, 487 | nodes=1024, conv_border_mode='valid', 488 | initialization='glorot_uniform', batch_norm=True, 489 | stateful=False, mb_size=None): 490 | logger.info("Building gru model") 491 | assert self.model is None 492 | if stateful: 493 | if mb_size is None: 494 | raise ValueError("Stateful GRU layer needs to know batch size") 495 | acoustic_input = Input(batch_shape=(mb_size, None, input_dim), 496 | name='acoustic_input') 497 | else: 498 | acoustic_input = Input(shape=(None, input_dim), 499 | name='acoustic_input') 500 | 501 | # Setup the network 502 | output = acoustic_input 503 | for (c, (filters, size, stride)) in enumerate(conv_props): 504 | output = Conv1D(filters, size, name='conv_1d_{}'.format(c), 505 | padding=conv_border_mode, strides=stride, 506 | kernel_initializer=initialization, 507 | activation='relu')(output) 508 | 509 | if batch_norm: 510 | output = BatchNormalization(name='bn_conv_1d', mode=2)(output) 511 | 512 | for r in range(recur_layers): 513 | output = GRU(nodes, activation='relu', name='rnn_{}'.format(r + 1), 514 | kernel_initializer=initialization, 515 | stateful=stateful, return_sequences=True)(output) 516 | if batch_norm: 517 | bn_layer = BatchNormalization(name='bn_rnn_{}'.format(r + 1), 518 | mode=2) 519 | output = bn_layer(output) 520 | 521 | output_branch = TimeDistributed(Dense( 522 | self.output_dim, name='dense', init=initialization, 523 | activation=for_tf_or_th('softmax', 'linear') 524 | ), name=self.outputs) 525 | network_output = output_branch(output) 526 | 527 | self.model = Model(input=acoustic_input, output=[network_output]) 528 | self.branch_outputs = [output_branch] 529 | self.branch_vars[output_branch.name] = self.model.trainable_weights 530 | self.acoustic_input = self.model.inputs[0] 531 | return self.model 532 | -------------------------------------------------------------------------------- /models-evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "# use CPU or GPU\n", 13 | "os.environ['KERAS_BACKEND'] = 'theano'\n", 14 | "#os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", 15 | "os.environ['THEANO_FLAGS'] = 'device=cuda0'" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stderr", 25 | "output_type": "stream", 26 | "text": [ 27 | "Using Theano backend.\n", 28 | "Using cuDNN version 5110 on context None\n", 29 | "Mapped name None to device cuda0: GeForce GTX 1080 Ti (0000:02:00.0)\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import keras\n", 35 | "import numpy as np" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "### Data and weight loaders" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": { 49 | "collapsed": true 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "from glob import glob\n", 54 | "from random import Random\n", 55 | "import json\n", 56 | "\n", 57 | "rng = Random(42)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": { 64 | "collapsed": true 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "import kenlm\n", 69 | "import beamsearch\n", 70 | "reload(beamsearch)\n", 71 | "from utils import argmax_decode, word_error_rate, for_tf_or_th\n", 72 | "from beamsearch import beam_decode, beam_decode_u\n", 73 | "\n", 74 | "lm = kenlm.Model('data/lm/lm.binary')\n", 75 | "\n", 76 | "def iterate_weights(model_path):\n", 77 | " \"\"\"Iterate over saved model weights\"\"\"\n", 78 | " for model_weight in glob(os.path.join(model_path, '') + '*.h5'):\n", 79 | " yield model_weight\n", 80 | "\n", 81 | "def pick_sample_files(desc_file, count, min_duration, max_duration):\n", 82 | " metadata = []\n", 83 | " with open(desc_file) as f:\n", 84 | " for line in f:\n", 85 | " metadata.append(json.loads(line))\n", 86 | " legitimates = [ sample for sample in metadata if min_duration <= sample['duration'] <= max_duration ]\n", 87 | " rng.shuffle(legitimates)\n", 88 | " return legitimates[:count]\n", 89 | "\n", 90 | "def test_generator(datagen, test_samples, batch_size=64, normalize=True):\n", 91 | " global in_\n", 92 | " texts = [s['text'] for s in test_samples]\n", 93 | " durations = [s['duration'] for s in test_samples]\n", 94 | " paths = [s['key'] for s in test_samples]\n", 95 | " features = [datagen.featurize(p) for p in paths]\n", 96 | " if normalize:\n", 97 | " features = [datagen.normalize(f) for f in features]\n", 98 | "\n", 99 | " for i in range( np.ceil(len(features) / float(batch_size)).astype(int) ):\n", 100 | " batch_durations = durations[i*batch_size: (i+1)*batch_size]\n", 101 | " batch_features = features[i*batch_size: (i+1)*batch_size]\n", 102 | " batch_texts = texts[i*batch_size: (i+1)*batch_size]\n", 103 | " batch_paths = paths[i*batch_size: (i+1)*batch_size]\n", 104 | " max_length = max([f.shape[0] for f in batch_features])\n", 105 | " batch_array = np.zeros((len(batch_features), max_length, features[0].shape[1]), dtype='float32')\n", 106 | " for fi in range(len(batch_features)):\n", 107 | " batch_array[fi, :batch_features[fi].shape[0], :] = batch_features[fi]\n", 108 | " yield {'x': batch_array, 'y': batch_texts, 'path': batch_paths, 'duration': batch_durations}\n", 109 | "\n", 110 | "def best_lm_alternative(true_sentence, wer, predictions, verbose=False):\n", 111 | " \"\"\" predictions is a list of tuples which first denote sentence and next is It's probablity\n", 112 | " \"\"\"\n", 113 | " best, best_score = None, np.finfo('float32').min\n", 114 | " for s, p in predictions:\n", 115 | " lm_score = lm.score(s)\n", 116 | " if lm_score > best_score:\n", 117 | " best, best_score = s, lm_score\n", 118 | " if best == predictions[0][0]:\n", 119 | " if verbose:\n", 120 | " print \"language model didn't change prediction\"\n", 121 | " best_wer = wer\n", 122 | " else:\n", 123 | " best_wer = word_error_rate([true_sentence], [best], decoded=True)[0]\n", 124 | " if verbose:\n", 125 | " print \"language model changed prediction, WER changed from {old_wer} to {new_wer}\".format(\n", 126 | " old_wer = wer, new_wer = best_wer\n", 127 | " )\n", 128 | " return best, best_wer\n", 129 | "\n", 130 | "def evaluate(batch_generator, output_fn, learning_phase=False, use_lm=False, beam_width=12):\n", 131 | " all_nolm_wers, all_lm_wers = [], []\n", 132 | " for batch in batch_generator:\n", 133 | " net_out = output_fn([batch['x'], learning_phase])[0]\n", 134 | " mtp_net_out = for_tf_or_th(net_out, net_out.swapaxes(0, 1))\n", 135 | " pred_texts = [argmax_decode(o) for o in mtp_net_out]\n", 136 | " nolm_wers = word_error_rate(batch['y'], pred_texts, True)\n", 137 | " all_nolm_wers.append(nolm_wers)\n", 138 | " \n", 139 | " if use_lm:\n", 140 | " alt_beam_preds = lambda i: zip(*beam_decode_u(mtp_net_out[i, :, :], beam_width, normalize=True))\n", 141 | " pred_texts, lm_wers = zip(*[best_lm_alternative(batch['y'][i], nolm_wers[i], alt_beam_preds(i))\n", 142 | " for i in range(mtp_net_out.shape[0])])\n", 143 | " all_lm_wers.append(np.array(lm_wers))\n", 144 | " all_wers = all_lm_wers\n", 145 | " else:\n", 146 | " all_wers = all_nolm_wers\n", 147 | " \n", 148 | " for i, y in enumerate(batch['y']):\n", 149 | " print 'r:{}\\np:{}\\n{}: WER: {}, DURATION: {}, PATH: {}'.format(y, pred_texts[i], i, all_wers[-1][i], batch['duration'][i], batch['path'][i])\n", 150 | " print 'batch mean WER: {}'.format(all_wers[-1].mean())\n", 151 | " if use_lm:\n", 152 | " print 'LM WER: {} No LM WER: {}'.format(np.concatenate(all_lm_wers).mean(), np.concatenate(all_nolm_wers).mean())\n", 153 | " else:\n", 154 | " 'whole mean WER: {}'.format(np.concatenate(all_wers).mean())\n", 155 | " return mtp_net_out, pred_texts, all_wers, batch['y']" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "### Customize data generator" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 16, 168 | "metadata": { 169 | "collapsed": true 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "test_desc = '/home/reith/deepspeech/ba-dls-deepspeech/descs/test-clean.json'\n", 174 | "#test_desc = '/home/reith/deepspeech/ba-dls-deepspeech/descs/test-other.json'\n", 175 | "#test_desc = '/home/reith/deepspeech/ba-dls-deepspeech/descs/dev-clean.json'" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 6, 181 | "metadata": { 182 | "collapsed": true 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "from data_generator import DataGenerator\n", 187 | "datagen = DataGenerator()" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 18, 193 | "metadata": { 194 | "collapsed": true 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "test_samples = pick_sample_files(test_desc, 1024, 0, 30)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "Normalize by input data" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 8, 211 | "metadata": { 212 | "collapsed": true 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "train_desc = '/home/reith/deepspeech/ba-dls-deepspeech/descs/train-clean-360.json'\n", 217 | "datagen.load_train_data(train_desc, 15)\n", 218 | "datagen.fit_train(100)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "Or load them" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 17, 231 | "metadata": { 232 | "collapsed": true 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "datagen.reload_norm('860-1000')" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "### Load model" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "#### Theano mode" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "Load and test weights of a half-phoneme model" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 10, 263 | "metadata": { 264 | "collapsed": true 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "#model_dir = '/home/reith/deepspeech/ba-dls-deepspeech/models/22-cont-23-i9696-lr1e-4-train-360-dur15/'\n", 269 | "#model_dir = '/home/reith/deepspeech/ba-dls-deepspeech/models/23-cont-i2494-joingrus-dur15-nobn-lr5e-5/'\n", 270 | "model_dir = '/home/reith/deepspeech/ba-dls-deepspeech/models/24-cont-train-860'" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "A summary of training procedure:\n", 278 | "- 7 Epochs of dual phoneme-text on train-100 (20)\n", 279 | "- 3 Epochs on train-500 for phoenme fine-tuning (21)\n", 280 | "- 3 Epochs on train-500 for text fine-tuning (22)\n", 281 | "- 2 Epochs on train-360 (23)\n", 282 | "- 2 Epochs on train-360 dropping phoneme branch and and batch normalization (24)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "make half phoneme model " 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 10, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "name": "stderr", 299 | "output_type": "stream", 300 | "text": [ 301 | "WARNING:model_wrp:ignoring unsuported batchnorm mode of 2 on keras 2\n", 302 | "WARNING:model_wrp:ignoring unsuported batchnorm mode of 2 on keras 2\n", 303 | "WARNING:model_wrp:ignoring unsuported batchnorm mode of 2 on keras 2\n", 304 | "WARNING:model_wrp:ignoring unsuported batchnorm mode of 2 on keras 2\n", 305 | "WARNING:model_wrp:ignoring unsuported batchnorm mode of 2 on keras 2\n", 306 | "model_wrp.py:374: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=[Reshape{3..., inputs=/acoustic_...)`\n", 307 | " self.model = Model(input=acoustic_input, output=[phoneme_out, text_out])\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "from model_wrp import HalfPhonemeModelWrapper\n", 313 | "model_wrp = HalfPhonemeModelWrapper()\n", 314 | "model = model_wrp.compile(nodes=1000, conv_context=5, recur_layers=5)\n", 315 | "output_fn = model_wrp.compile_output_fn()" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "or gru model" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 11, 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "name": "stderr", 332 | "output_type": "stream", 333 | "text": [ 334 | "model_wrp.py:274: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(29, activation=\"linear\", kernel_initializer=\"glorot_uniform\", name=\"text_dense\")`\n", 335 | " activation=for_tf_or_th('softmax', 'linear')\n", 336 | "model_wrp.py:278: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=[Reshape{3..., inputs=/acoustic_...)`\n", 337 | " self.model = Model(input=acoustic_input, output=[network_output])\n" 338 | ] 339 | } 340 | ], 341 | "source": [ 342 | "from model_wrp import GruModelWrapper\n", 343 | "model_wrp = GruModelWrapper()\n", 344 | "model = model_wrp.compile(nodes=1000, conv_context=5, recur_layers=5, batch_norm=False)\n", 345 | "output_fn = model_wrp.compile_output_fn()" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 14, 351 | "metadata": { 352 | "collapsed": true 353 | }, 354 | "outputs": [], 355 | "source": [ 356 | "# model.load_weights(os.path.join(model_dir, 'best-val-weights.h5'))\n", 357 | "model.load_weights(os.path.join(model_dir, 'model_19336_weights.h5'))" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": {}, 363 | "source": [ 364 | "#### Tensorflow model" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "A summary of training procedure:\n", 372 | "- 3 Epochs of dual phoneme-text on train-100 by dropout of 0.3 and leaky relu factor of 0.05 (40)\n", 373 | "- 5 Epochs on train-100 for phoenme fine-tuning (41)\n", 374 | "- 5 Epochs on train-100 for text fine-tuning (42)\n", 375 | "- 5 Epochs on train-360 (43)\n", 376 | "- 5 Epochs on train-860 dropping phoneme branch and and batch normalization and reduced dropout to 0.1 (44)\n", 377 | "- 20 Epochs on train-860 reduced learning rate down to 5e-5 and for samples up to 20 seconds long (45)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 9, 383 | "metadata": { 384 | "collapsed": true 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "model_dir = '/home/reith/deepspeech/ba-dls-deepspeech/models/44-cont-45-i14490-dur20-lr5e-5'" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 10, 394 | "metadata": {}, 395 | "outputs": [ 396 | { 397 | "name": "stderr", 398 | "output_type": "stream", 399 | "text": [ 400 | "model_wrp.py:274: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(29, activation=\"softmax\", kernel_initializer=\"glorot_uniform\", name=\"text_dense\")`\n", 401 | " activation=for_tf_or_th('softmax', 'linear')\n", 402 | "model_wrp.py:278: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=[1]\n", 494 | " replaces = [l + c + r[1:] for c in letters for l, r in splits if r]\n", 495 | " inserts = [l + c + r for c in letters for l, r in splits if r]\n", 496 | " return set(deletes + transposes + replaces + inserts)\n", 497 | "\n", 498 | "def edits_n(word, n):\n", 499 | " es = set([word])\n", 500 | " for i in range(n):\n", 501 | " es = reduce(lambda a, b: a.union(b), (edits(w) for w in es))\n", 502 | " return es\n", 503 | "\n", 504 | "def words(text):\n", 505 | " return text.split()\n", 506 | "\n", 507 | "def known_words(words):\n", 508 | " return {word for word in words if word in WORDS}\n", 509 | "\n", 510 | "def candidate_words(word):\n", 511 | " return (known_words([word]) or known_words(edits_n(word, 1)) or known_words(edits_n(word, 2)) or [word])\n", 512 | "\n", 513 | "list(candidate_words(\"swam\"))" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 336, 519 | "metadata": { 520 | "collapsed": true 521 | }, 522 | "outputs": [], 523 | "source": [ 524 | "with open('./data/lm/words.txt') as f:\n", 525 | " WORDS = set(words(f.read()))" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": null, 531 | "metadata": { 532 | "collapsed": true 533 | }, 534 | "outputs": [], 535 | "source": [ 536 | "r:a ring of amethyst i could not wear here plainer to my sight than that first kiss\n", 537 | "p:a ring of amathyst i could not wear here plainer two my sight then that first kits" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 197, 543 | "metadata": {}, 544 | "outputs": [ 545 | { 546 | "name": "stdout", 547 | "output_type": "stream", 548 | "text": [ 549 | "she doesn't take up with anybody you know\n", 550 | "she doesn't take up with anybody you know\n", 551 | "langauge model changed prediction, WER changed from 0.0243902439024 to 0.0\n" 552 | ] 553 | }, 554 | { 555 | "data": { 556 | "text/plain": [ 557 | "\"she doesn't take up with anybody you know\"" 558 | ] 559 | }, 560 | "execution_count": 197, 561 | "metadata": {}, 562 | "output_type": "execute_result" 563 | } 564 | ], 565 | "source": [ 566 | "best_lm_alternative(res[3][3], res[2][3], zip(*beam_decode_u(res[0][:, 3, :], 12, normalize=True)))" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 187, 572 | "metadata": {}, 573 | "outputs": [ 574 | { 575 | "name": "stdout", 576 | "output_type": "stream", 577 | "text": [ 578 | "sir i have it in command to inform your excellency that you have been appointed governor of the crown colony which is called britannula\n", 579 | "sir i have in command to anform your excellency that you have been appointed governor of the crown colony which is called britain mula\n", 580 | "langauge model changed prediction, WER changed from 0.0334572490706 to 0.0334572490706\n", 581 | "sir i have in command to anform your excellency that you have been appointed governor of the crown colony which is called britain mula\n", 582 | "sir i have in command to anform your excellency that you have been appointed governor of the crown colony which is called britaan mula\n" 583 | ] 584 | } 585 | ], 586 | "source": [ 587 | "print best_lm_alternative(res[3][46], res[2][46], zip(*beam_decode_u(res[0][:, 46, :], 12, normalize=False)))\n", 588 | "print res[1][46]" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 165, 594 | "metadata": {}, 595 | "outputs": [ 596 | { 597 | "data": { 598 | "text/plain": [ 599 | "0.16216216216216217" 600 | ] 601 | }, 602 | "execution_count": 165, 603 | "metadata": {}, 604 | "output_type": "execute_result" 605 | } 606 | ], 607 | "source": [ 608 | "import edit_distance\n", 609 | "ref = 'there is no danger of the modern commentators on the timaeus falling into the absurdities of the neo platonists'\n", 610 | "pre = 'there is old danger of the madern commontychers un ther to meas falling into dubsurdities of the newo platinists'\n", 611 | "pre = 'there is old danger of the madern commontychers un ther to mes falling into dubsurdities of the newo platinists'\n", 612 | "#print edit_distance.SequenceMatcher(ref, pre).ratio()\n", 613 | "word_error_rate([ref], [pre], decoded=True)[0]" 614 | ] 615 | }, 616 | { 617 | "cell_type": "markdown", 618 | "metadata": {}, 619 | "source": [ 620 | "#### custom samples" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": null, 626 | "metadata": { 627 | "collapsed": true 628 | }, 629 | "outputs": [], 630 | "source": [ 631 | "samples = [\n", 632 | " {\"duration\": 4.905, \"text\": \"he began a confused complaint against the wizard who had vanished behind the curtain on the left\", \"key\": \"/mnt/ml-data/LibriSpeech/test-clean/61/70968/61-70968-0000.wav\"},\n", 633 | " {\"duration\": 3.61, \"text\": \"give not so earnest a mind to these mummeries child\", \"key\": \"/mnt/ml-data/LibriSpeech/test-clean/61/70968/61-70968-0001.wav\"} \n", 634 | "]" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": null, 640 | "metadata": { 641 | "collapsed": true 642 | }, 643 | "outputs": [], 644 | "source": [ 645 | "evaluate(test_generator(datagen, samples, normalize=True), output_fn)\n" 646 | ] 647 | } 648 | ], 649 | "metadata": { 650 | "kernelspec": { 651 | "display_name": "Python 2", 652 | "language": "python", 653 | "name": "python2" 654 | }, 655 | "language_info": { 656 | "codemirror_mode": { 657 | "name": "ipython", 658 | "version": 2 659 | }, 660 | "file_extension": ".py", 661 | "mimetype": "text/x-python", 662 | "name": "python", 663 | "nbconvert_exporter": "python", 664 | "pygments_lexer": "ipython2", 665 | "version": "2.7.13" 666 | } 667 | }, 668 | "nbformat": 4, 669 | "nbformat_minor": 2 670 | } 671 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot training/validation curves for multiple models. 3 | """ 4 | 5 | 6 | from __future__ import division 7 | from __future__ import print_function 8 | import argparse 9 | import matplotlib 10 | import numpy as np 11 | import os 12 | matplotlib.use('Agg') # This must be called before importing pyplot 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | COLORS_RGB = [ 17 | (228, 26, 28), (55, 126, 184), (77, 175, 74), 18 | (152, 78, 163), (255, 127, 0), (31, 75, 90), 19 | (166, 86, 40), (247, 129, 191), (153, 153, 153), 20 | (130, 22, 99), (18, 133, 114), (43, 202, 200), 21 | (141, 219, 221), (45, 10, 159), (7, 78, 47), 22 | (249, 15, 176), (114, 227, 216), (255, 138, 125) 23 | ] 24 | 25 | # Scale the RGB values to the [0, 1] range, which is the format 26 | # matplotlib accepts. 27 | colors = [(r / 255, g / 255, b / 255) for r, g, b in COLORS_RGB] 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('-d', '--dirs', nargs='+', required=True, 33 | help='Directories where the model and costs are saved') 34 | parser.add_argument('-s', '--save_file', type=str, required=True, 35 | help='Filename of the output plot') 36 | return parser.parse_args() 37 | 38 | 39 | def re_range(ys, step, factor=1): 40 | """ Compress ys so for each step we'll have mean of that step. 41 | Params: 42 | ys: Outputs 43 | step: Each slice of outputs of this size will be averaged 44 | factor: Scale inputs by this factor 45 | """ 46 | n = len(ys) 47 | rang = [step * (i+1) for i in range(n // step)] 48 | new_ys = ys[n % step:].reshape((-1, step)).mean(1) 49 | if n % step: 50 | rang.append(n) 51 | new_ys = np.insert(new_ys, [0], ys[:n % step].mean()) 52 | if factor != 1: 53 | rang = [r*factor for r in rang] 54 | return new_ys, rang 55 | 56 | 57 | def have_cost(name, npfile): 58 | return name in npfile and npfile[name].shape[0] > 0 59 | 60 | 61 | def graph(dirs, save_file, average_window=100): 62 | """ Plot the training and validation costs and if exist, word error rate 63 | over iterations 64 | Params: 65 | dirs (list(str)): Directories where the model and costs are saved 66 | save_file (str): Filename of the output plot 67 | average_window (int): Window size for smoothening the graphs 68 | """ 69 | fig, ax = plt.subplots() 70 | ax.set_xlabel('Iters') 71 | ax.set_ylabel('Loss') 72 | average_filter = np.ones(average_window) / float(average_window) 73 | 74 | for i, d in enumerate(dirs): 75 | name = os.path.basename(os.path.abspath(d)) 76 | color = colors[i % len(colors)] 77 | costs = np.load(os.path.join(d, 'costs.npz')) 78 | train_costs = costs['train'] if 'train' in costs.files else None 79 | if have_cost('train', costs): 80 | train_costs = costs['train'] 81 | iters = train_costs.shape[0] 82 | if train_costs.ndim == 1: 83 | train_costs = np.convolve(train_costs, average_filter, 84 | mode='valid') 85 | ax.plot(train_costs, color=color, label=name + '_train', lw=1.5) 86 | else: 87 | assert 'phoneme' in costs.files 88 | if have_cost('phoneme', costs): 89 | phoneme_costs = costs['phoneme'] 90 | iters = phoneme_costs.shape[0] 91 | if phoneme_costs.ndim == 1: 92 | phoneme_costs = np.convolve(phoneme_costs, average_filter, 93 | mode='valid') 94 | ax.plot(phoneme_costs, color=color, label=name + '_phoneme', 95 | linestyle='--', lw=1.5) 96 | if have_cost('validation', costs): 97 | valid_costs = costs['validation'] 98 | valid_ys, valid_xs = re_range(valid_costs, 1, 99 | iters / valid_costs.shape[0]) 100 | ax.plot(valid_xs, valid_ys, '.', color=color, 101 | label=name + '_valid') 102 | if have_cost('wer', costs): 103 | wers = costs['wer'] 104 | if wers.shape[0] == iters: 105 | y, x = re_range(wers * 100, average_window) 106 | else: 107 | y, x = re_range(wers * 100, 10, iters / wers.shape[0]) 108 | ax.plot(x, y, color=color, label=name + '_wer', marker='*') 109 | if have_cost('val_wer', costs): 110 | valid_wers = costs['val_wer'] 111 | y, x = re_range(valid_wers * 100, 1, iters / valid_wers.shape[0]) 112 | ax.plot(x, y, color=color, label=name + '_val_wer', marker='+') 113 | if have_cost('val_phoneme', costs): 114 | val_phoneme = costs['val_phoneme'] 115 | y, x = re_range(val_phoneme, 1, 116 | iters / val_phoneme.shape[0]) 117 | ax.plot(x, y, color=color, label=name + '_val_phoneme', marker='v') 118 | 119 | ax.grid(True) 120 | lgd = ax.legend(bbox_to_anchor=(1, 1)) 121 | plt.savefig(save_file, bbox_extra_artists=(lgd,), bbox_inches='tight') 122 | 123 | 124 | if __name__ == '__main__': 125 | args = parse_args() 126 | graph(args.dirs, args.save_file) 127 | -------------------------------------------------------------------------------- /pre-trained/model_25_config.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "keras_version": "2.0.6", "model_wrapper": {"class_name": "GruModelWrapper", "compile_args": {"recur_layers": 5, "conv_context": 5, "nodes": 1000, "batch_norm": false}}, "config": {"layers": [{"class_name": "InputLayer", "config": {"dtype": "float32", "batch_input_shape": [null, null, 161], "name": "acoustic_input", "sparse": false}, "inbound_nodes": [], "name": "acoustic_input"}, {"class_name": "Conv1D", "config": {"kernel_constraint": null, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "conv_1d", "bias_regularizer": null, "bias_constraint": null, "activation": "relu", "trainable": true, "padding": "valid", "strides": [2], "dilation_rate": [1], "kernel_regularizer": null, "filters": 1000, "bias_initializer": {"class_name": "Zeros", "config": {}}, "use_bias": true, "activity_regularizer": null, "kernel_size": [5]}, "inbound_nodes": [[["acoustic_input", 0, 0, {}]]], "name": "conv_1d"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_1", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["conv_1d", 0, 0, {}]]], "name": "rnn_1"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_2", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_1", 0, 0, {}]]], "name": "rnn_2"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_3", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_2", 0, 0, {}]]], "name": "rnn_3"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_4", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_3", 0, 0, {}]]], "name": "rnn_4"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_5", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_4", 0, 0, {}]]], "name": "rnn_5"}, {"class_name": "TimeDistributed", "config": {"layer": {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "text_dense", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "linear", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 29, "use_bias": true, "activity_regularizer": null}}, "trainable": true, "name": "text"}, "inbound_nodes": [[["rnn_5", 0, 0, {}]]], "name": "text"}], "input_layers": [["acoustic_input", 0, 0]], "output_layers": [["text", 0, 0]], "name": "model_1"}, "backend": "theano"} 2 | -------------------------------------------------------------------------------- /pre-trained/model_45_config.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "keras_version": "2.0.6", "model_wrapper": {"class_name": "GruModelWrapper", "compile_args": {"recur_layers": 5, "lirelu_alpha": 0.05, "conv_context": 5, "dropout": 0.1, "nodes": 1000, "batch_norm": false}}, "config": {"layers": [{"class_name": "InputLayer", "config": {"dtype": "float32", "batch_input_shape": [null, null, 161], "name": "acoustic_input", "sparse": false}, "inbound_nodes": [], "name": "acoustic_input"}, {"class_name": "Conv1D", "config": {"kernel_constraint": null, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "conv_1d", "bias_regularizer": null, "bias_constraint": null, "activation": "relu", "trainable": true, "padding": "valid", "strides": [2], "dilation_rate": [1], "kernel_regularizer": null, "filters": 1000, "bias_initializer": {"class_name": "Zeros", "config": {}}, "use_bias": true, "activity_regularizer": null, "kernel_size": [5]}, "inbound_nodes": [[["acoustic_input", 0, 0, {}]]], "name": "conv_1d"}, {"class_name": "Dropout", "config": {"rate": 0.1, "trainable": true, "name": "dropout_1"}, "inbound_nodes": [[["conv_1d", 0, 0, {}]]], "name": "dropout_1"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_1", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["dropout_1", 0, 0, {}]]], "name": "rnn_1"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_2", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_1", 0, 0, {}]]], "name": "rnn_2"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_3", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_2", 0, 0, {}]]], "name": "rnn_3"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_4", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_3", 0, 0, {}]]], "name": "rnn_4"}, {"class_name": "GRU", "config": {"recurrent_activation": "hard_sigmoid", "trainable": true, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"seed": null, "gain": 1.0}}, "use_bias": true, "bias_regularizer": null, "return_state": false, "unroll": false, "activation": "relu", "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 1000, "activity_regularizer": null, "recurrent_dropout": 0.0, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "kernel_constraint": null, "dropout": 0.0, "stateful": false, "recurrent_regularizer": null, "name": "rnn_5", "bias_constraint": null, "go_backwards": false, "implementation": 0, "kernel_regularizer": null, "return_sequences": true, "recurrent_constraint": null}, "inbound_nodes": [[["rnn_4", 0, 0, {}]]], "name": "rnn_5"}, {"class_name": "TimeDistributed", "config": {"layer": {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "text_dense", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "softmax", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 29, "use_bias": true, "activity_regularizer": null}}, "trainable": true, "name": "text"}, "inbound_nodes": [[["rnn_5", 0, 0, {}]]], "name": "text"}], "input_layers": [["acoustic_input", 0, 0]], "output_layers": [["text", 0, 0]], "name": "model_1"}, "backend": "tensorflow"} 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test a trained speech model over a dataset 3 | """ 4 | 5 | from __future__ import absolute_import, division, print_function 6 | import argparse 7 | 8 | import os 9 | import sys 10 | import json 11 | 12 | 13 | def load_model_wrapper(model_config_file, weights_file): 14 | """Loads a pre-trained model wrapper""" 15 | with open(model_config_file) as fp: 16 | model_config = json.load(fp) 17 | try: 18 | os.environ['KERAS_BACKEND'] = model_config['backend'] 19 | import model_wrp 20 | # pretrained_id = model_config['pre-trained-id'] 21 | wrapper_config = model_config['model_wrapper'] 22 | wrapper_class = getattr(sys.modules['model_wrp'], 23 | wrapper_config['class_name']) 24 | model_wrapper = wrapper_class(**wrapper_config.get('init_args', {})) 25 | model = model_wrapper.compile(**wrapper_config.get('compile_args', {})) 26 | model.load_weights(weights_file) 27 | except (KeyError, ): 28 | print ("Model is not known") 29 | sys.exit(1) 30 | return model_wrapper 31 | 32 | 33 | def main(test_desc_file, model_config_file, weights_file): 34 | # Load model 35 | model_wrapper = load_model_wrapper(model_config_file, weights_file) 36 | model, test_fn = model_wrapper.model, model_wrapper.compile_test_fn() 37 | 38 | # Prepare the data generator 39 | from data_generator import DataGenerator 40 | datagen = DataGenerator() 41 | # Load the JSON file that contains the dataset 42 | datagen.load_validation_data(test_desc_file) 43 | # Normalize input data by variance and mean of training input 44 | datagen.reload_norm('860-1000') 45 | 46 | from trainer import Trainer 47 | trainer = Trainer(model, None, test_fn) 48 | trainer.validate(datagen, 32, False, False, None) 49 | # Test the model 50 | print ("Test loss: {}".format(trainer.last_val_cost)) 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('test_desc_file', type=str, 56 | help='Path to a JSON-line file that contains ' 57 | 'test labels and paths to the audio files. ') 58 | parser.add_argument('model_config', type=str, help='Path to model config') 59 | parser.add_argument('weights_file', type=str, 60 | help='Load weights from this file') 61 | args = parser.parse_args() 62 | main(args.test_desc_file, args.model_config, args.weights_file) 63 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train an end-to-end speech recognition model using CTC. 3 | Use $python train.py --help for usage 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import argparse 9 | import logging 10 | import sys 11 | import os 12 | 13 | from data_generator import DataGenerator 14 | from utils import configure_logging 15 | 16 | from model_wrp import HalfPhonemeModelWrapper 17 | from trainer import Trainer 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def train_sample_half_phoneme(datagen, save_dir, epochs, sortagrad, 23 | start_weights=False, mb_size=60): 24 | model_wrp = HalfPhonemeModelWrapper() 25 | model = model_wrp.compile(nodes=1000, conv_context=5, recur_layers=5) 26 | logger.info('model :\n%s' % (model.to_yaml(),)) 27 | 28 | if start_weights: 29 | model.load_weights(start_weights) 30 | 31 | train_fn, test_fn = (model_wrp.compile_train_fn(1e-4), 32 | model_wrp.compile_test_fn()) 33 | trainer = Trainer(model, train_fn, test_fn, on_text=True, on_phoneme=True) 34 | trainer.run(datagen, save_dir, epochs=epochs, do_sortagrad=sortagrad, 35 | mb_size=mb_size, stateful=False) 36 | return trainer, model_wrp 37 | 38 | 39 | def main(train_desc_file, val_desc_file, epochs, save_dir, sortagrad, 40 | use_arpabets, start_weights=None): 41 | if not os.path.exists(save_dir): 42 | os.makedirs(save_dir) 43 | # Configure logging 44 | configure_logging(file_log_path=os.path.join(save_dir, 'train_log.txt')) 45 | logger.info(' '.join(sys.argv)) 46 | 47 | # Prepare the data generator 48 | datagen = DataGenerator(use_arpabets=use_arpabets) 49 | # Load the JSON file that contains the dataset 50 | datagen.load_train_data(train_desc_file, max_duration=20) 51 | datagen.load_validation_data(val_desc_file) 52 | # Use a few samples from the dataset, to calculate the means and variance 53 | # of the features, so that we can center our inputs to the network 54 | # datagen.fit_train(100) 55 | datagen.reload_norm('860-1000') 56 | train_sample_half_phoneme(datagen, save_dir, epochs, sortagrad, 57 | start_weights, mb_size=48) 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('train_desc_file', type=str, 63 | help='Path to a JSON-line file that contains ' 64 | 'training labels and paths to the audio files.') 65 | parser.add_argument('val_desc_file', type=str, 66 | help='Path to a JSON-line file that contains ' 67 | 'validation labels and paths to the audio files.') 68 | parser.add_argument('save_dir', type=str, 69 | help='Directory to store the model. This will be ' 70 | 'created if it doesn\'t already exist') 71 | parser.add_argument('--epochs', type=int, default=20, 72 | help='Number of epochs to train the model') 73 | parser.add_argument('--sortagrad', default=False, nargs='?', const=True, 74 | type=int, help='Sort utterances by duration for this ' 75 | 'number of epochs. Will sort all epochs if no value ' 76 | 'is given') 77 | parser.add_argument('--use-arpabets', default=False, 78 | help='Read arpabets', action='store_true') 79 | parser.add_argument('--start-weights', type=str, default=None, 80 | help='Load weights') 81 | args = parser.parse_args() 82 | 83 | main(args.train_desc_file, args.val_desc_file, args.epochs, args.save_dir, 84 | args.sortagrad, args.use_arpabets, args.start_weights) 85 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | 5 | from utils import conv_chain_output_length, word_error_rate, save_model 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def _last_of_list_or_none(l): 11 | return None if len(l) == 0 else l[-1] 12 | 13 | 14 | class Trainer(object): 15 | """ 16 | Training and validation routines 17 | 18 | Properties: 19 | best_cost (flaot) 20 | last_cost (float) 21 | best_val_cost (float) 22 | last_val_cost (float) 23 | wers (list(float)) 24 | val_wers (list(float)) 25 | """ 26 | def __init__(self, model, train_fn, val_fn, on_text=True, on_phoneme=False): 27 | self.model = model 28 | self.train_fn = train_fn 29 | self.val_fn = val_fn 30 | self.on_text = on_text 31 | self.on_phoneme = on_phoneme 32 | self.wers, self.text_costs, self.phoneme_costs = [], [], [] 33 | self.val_wers, self.val_text_costs, self.val_phoneme_costs = [], [], [] 34 | self.best_cost = np.iinfo(np.int32).max 35 | self.best_val_cost = np.iinfo(np.int32).max 36 | if not (on_text or on_phoneme): 37 | raise ValueError("Model should train against at least text or " 38 | "phoneme") 39 | 40 | @property 41 | def last_text_cost(self): 42 | return _last_of_list_or_none(self.text_costs) 43 | 44 | @property 45 | def last_phoneme_cost(self): 46 | return _last_of_list_or_none(self.phoneme_costs) 47 | 48 | @property 49 | def last_val_text_cost(self): 50 | return _last_of_list_or_none(self.val_text_costs) 51 | 52 | @property 53 | def last_val_phoneme_cost(self): 54 | return _last_of_list_or_none(self.val_phoneme_costs) 55 | 56 | @property 57 | def last_wer(self): 58 | return _last_of_list_or_none(self.wers) 59 | 60 | @property 61 | def last_val_wer(self): 62 | return _last_of_list_or_none(self.val_wers) 63 | 64 | @property 65 | def last_cost(self): 66 | """ Cost of last minibatch on train """ 67 | if self.on_text: 68 | return self.last_text_cost 69 | if self.on_phoneme: 70 | return self.last_phoneme_cost 71 | 72 | @property 73 | def last_val_cost(self): 74 | """ Last cost on whole validation set """ 75 | if self.on_text: 76 | return self.last_val_text_cost 77 | if self.on_phoneme: 78 | return self.last_val_phoneme_cost 79 | 80 | @property 81 | def best_cost(self): 82 | """ Best cost among minibatchs of training set """ 83 | if self.on_text: 84 | return self.best_text_cost 85 | if self.on_phoneme: 86 | return self.best_phoneme_cost 87 | 88 | @best_cost.setter 89 | def best_cost(self, val): 90 | if self.on_text: 91 | self.best_text_cost = val 92 | elif self.on_phoneme: 93 | self.best_phoneme_cost = val 94 | 95 | @property 96 | def best_val_cost(self): 97 | """ Best cost on whole validation set so far """ 98 | if self.on_text: 99 | return self.best_text_val_cost 100 | if self.on_phoneme: 101 | return self.best_phoneme_val_cost 102 | 103 | @best_val_cost.setter 104 | def best_val_cost(self, val): 105 | if self.on_text: 106 | self.best_text_val_cost = val 107 | elif self.on_phoneme: 108 | self.best_phoneme_val_cost = val 109 | 110 | def run(self, datagen, save_dir, epochs=10, mb_size=16, do_sortagrad=False, 111 | stateful=False, save_best_weights=False, save_best_val_weights=True, 112 | iters_to_valid=100, iters_to_checkout=500): 113 | """ Run trainig loop 114 | Args: 115 | datagen (DataGenerator) 116 | save_dir (str): directory path that will contain the model 117 | epochs (int): number of epochs 118 | mb_size (int): mini-batch size 119 | do_sortagrad (bool): sort dataset by duration on first epoch 120 | stateful (bool): is model stateful or not 121 | save_best_weights (bool): save weights whenever cost over 122 | training mini-batch reduced 123 | save_best_val_weights (bool): save weights whenever cost over 124 | validation set reduced 125 | iters_to_valid (int): after this amount of iterations validate 126 | model by whole validation set 127 | iters_to_checkout (int): after this amount of iterations save 128 | model 129 | """ 130 | logger.info("Training model..") 131 | iters = 0 132 | for e in range(epochs): 133 | if not isinstance(do_sortagrad, bool): 134 | sortagrad = e < do_sortagrad 135 | shuffle = not sortagrad 136 | elif do_sortagrad: 137 | shuffle = False 138 | sortagrad = True 139 | else: 140 | shuffle = True 141 | sortagrad = False 142 | 143 | train_iter = datagen.iterate_train(mb_size, shuffle=shuffle, 144 | sort_by_duration=sortagrad) 145 | for i, batch in enumerate(train_iter): 146 | if stateful and batch['x'].shape[0] != mb_size: 147 | break 148 | self.train_minibatch(batch, i % 10 == 0) 149 | if i % 10 == 0: 150 | logger.info("Epoch: {} Iteration: {}({}) TextLoss: {}" 151 | " PhonemeLoss: {} WER: {}" 152 | .format(e, i, iters, self.last_text_cost, 153 | self.last_phoneme_cost, 154 | self.last_wer)) 155 | iters += 1 156 | if save_best_weights and self.best_cost < self.last_cost: 157 | self.save_weights(save_dir, 'best-weights.h5') 158 | if iters_to_valid is not None and iters % iters_to_valid == 0: 159 | self.validate(datagen, mb_size, stateful, 160 | save_best_val_weights, save_dir) 161 | if i and i % iters_to_checkout == 0: 162 | self.save_model(save_dir, iters) 163 | if iters_to_valid is not None and iters % iters_to_valid != 0: 164 | self.validate(datagen, mb_size, stateful, save_best_val_weights, 165 | save_dir) 166 | if i % iters_to_checkout != 0: 167 | self.save_model(save_dir, iters) 168 | 169 | def train_minibatch(self, batch, compute_wer=False): 170 | inputs = batch['x'] 171 | input_lengths = batch['input_lengths'] 172 | ctc_input_lens = self.ctc_input_length(input_lengths) 173 | if self.on_text and self.on_phoneme: 174 | _, ctc_phoneme, pred_texts, ctc_text = self.train_fn([ 175 | inputs, ctc_input_lens, batch['phonemes'], 176 | batch['phoneme_lengths'], batch['y'], batch['label_lengths'], 177 | True]) 178 | elif self.on_text: 179 | pred_texts, ctc_text = self.train_fn([inputs, ctc_input_lens, 180 | batch['y'], 181 | batch['label_lengths'], True]) 182 | elif self.on_phoneme: 183 | _, ctc_phoneme = self.train_fn([inputs, ctc_input_lens, 184 | batch['phonemes'], 185 | batch['phoneme_lengths'], 186 | True]) 187 | if self.on_text: 188 | if compute_wer: 189 | wer = word_error_rate(batch['texts'], pred_texts).mean() 190 | self.wers.append(wer) 191 | self.text_costs.append(ctc_text) 192 | if self.on_phoneme: 193 | self.phoneme_costs.append(ctc_phoneme) 194 | 195 | def validate(self, datagen, mb_size, stateful, save_best_weights, save_dir): 196 | text_avg_cost, phoneme_avg_cost = 0.0, 0.0 197 | total_wers = [] 198 | i = 0 199 | for batch in datagen.iterate_validation(mb_size): 200 | if stateful and batch['x'].shape[0] != mb_size: 201 | break 202 | text_cost, phoneme_cost, wers = self.validate_minibatch(batch) 203 | if self.on_text: 204 | text_avg_cost += text_cost 205 | total_wers.append(wers) 206 | if self.on_phoneme: 207 | phoneme_avg_cost += phoneme_cost 208 | i += 1 209 | if i != 0: 210 | text_avg_cost /= i 211 | phoneme_avg_cost /= i 212 | if self.on_text: 213 | self.val_wers.append(np.concatenate(total_wers).mean()) 214 | self.val_text_costs.append(text_avg_cost) 215 | if self.on_phoneme: 216 | self.val_phoneme_costs.append(phoneme_avg_cost) 217 | logger.info("Validation TextLoss: {} Validation PhonemeLoss: {} " 218 | "Validation WER: {}".format(self.last_val_text_cost, 219 | self.last_val_phoneme_cost, 220 | self.last_val_wer)) 221 | if save_best_weights and self.last_val_cost < self.best_val_cost: 222 | self.best_val_cost = self.last_val_cost 223 | self.save_weights(save_dir, 'best-val-weights.h5') 224 | 225 | def validate_minibatch(self, batch): 226 | inputs = batch['x'] 227 | input_lengths = batch['input_lengths'] 228 | ctc_input_lens = self.ctc_input_length(input_lengths) 229 | text_ctc, phoneme_ctc, wers = None, None, None 230 | if self.on_text and self.on_phoneme: 231 | _, phoneme_ctc, pred_text, text_ctc = self.val_fn([ 232 | inputs, ctc_input_lens, batch['phonemes'], 233 | batch['phoneme_lengths'], batch['y'], batch['label_lengths'], 234 | True]) 235 | elif self.on_text: 236 | pred_text, text_ctc = self.val_fn([ 237 | inputs, ctc_input_lens, batch['y'], batch['label_lengths'], 238 | True]) 239 | elif self.on_phoneme: 240 | _, phoneme_ctc = self.val_fn([ 241 | inputs, ctc_input_lens, batch['phonemes'], 242 | batch['phoneme_lengths'], True 243 | ]) 244 | 245 | if self.on_text: 246 | wers = word_error_rate(batch['texts'], pred_text) 247 | 248 | return text_ctc, phoneme_ctc, wers 249 | 250 | def ctc_input_length(self, input_lengths): 251 | import keras.layers 252 | conv_class = (getattr(keras.layers, 'Conv1D', None) or 253 | keras.layers.Convolution1D) 254 | conv_lays = [l for l in self.model.layers if isinstance(l, conv_class)] 255 | return [conv_chain_output_length(l, conv_lays) for l in input_lengths] 256 | 257 | def save_weights(self, save_dir, filename): 258 | self.model.save_weights(os.path.join(save_dir, filename), 259 | overwrite=True) 260 | 261 | def save_model(self, save_dir, index): 262 | save_model(save_dir, self.model, self.text_costs, self.val_text_costs, 263 | wer=self.wers, val_wer=self.val_wers, 264 | phoneme=self.phoneme_costs, 265 | val_phoneme=self.val_phoneme_costs, index=index) 266 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import numpy as np 5 | import re 6 | import soundfile 7 | import keras 8 | from keras.models import model_from_json 9 | from numpy.lib.stride_tricks import as_strided 10 | 11 | from char_map import char_map, index_map 12 | from edit_distance import SequenceMatcher 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | k2 = keras.__version__[0] == '2' 17 | 18 | 19 | def calc_feat_dim(window, max_freq): 20 | return int(0.001 * window * max_freq) + 1 21 | 22 | 23 | def conv_output_length(input_length, filter_size, border_mode, stride, 24 | dilation=1): 25 | """ Compute the length of the output sequence after 1D convolution along 26 | time. Note that this function is in line with the function used in 27 | Convolution1D class from Keras. 28 | Params: 29 | input_length (int): Length of the input sequence. 30 | filter_size (int): Width of the convolution kernel. 31 | border_mode (str): Only support `same` or `valid`. 32 | stride (int): Stride size used in 1D convolution. 33 | dilation (int) 34 | """ 35 | if input_length is None: 36 | return None 37 | assert border_mode in {'same', 'valid'} 38 | dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) 39 | if border_mode == 'same': 40 | output_length = input_length 41 | elif border_mode == 'valid': 42 | output_length = input_length - dilated_filter_size + 1 43 | return (output_length + stride - 1) // stride 44 | 45 | 46 | def conv_chain_output_length(input_length, conv_layers): 47 | """ Compute output length after a sequence of 1D convolution layers 48 | Params: 49 | input_length (int): First layer input length 50 | conv_layers (list(Convolution1D)): List of keras Convolution1D layers 51 | """ 52 | length = input_length 53 | for layer in conv_layers: 54 | if k2: 55 | length = layer.compute_output_shape((None, length, None))[1] 56 | else: 57 | length = layer.get_output_shape_for((None, length))[1] 58 | return length 59 | 60 | 61 | def spectrogram(samples, fft_length=256, sample_rate=2, hop_length=128): 62 | """ 63 | Compute the spectrogram for a real signal. 64 | The parameters follow the naming convention of 65 | matplotlib.mlab.specgram 66 | 67 | Args: 68 | samples (1D array): input audio signal 69 | fft_length (int): number of elements in fft window 70 | sample_rate (scalar): sample rate 71 | hop_length (int): hop length (relative offset between neighboring 72 | fft windows). 73 | 74 | Returns: 75 | x (2D array): spectrogram [frequency x time] 76 | freq (1D array): frequency of each row in x 77 | 78 | Note: 79 | This is a truncating computation e.g. if fft_length=10, 80 | hop_length=5 and the signal has 23 elements, then the 81 | last 3 elements will be truncated. 82 | """ 83 | assert not np.iscomplexobj(samples), "Must not pass in complex numbers" 84 | 85 | window = np.hanning(fft_length)[:, None] 86 | window_norm = np.sum(window**2) 87 | 88 | # The scaling below follows the convention of 89 | # matplotlib.mlab.specgram which is the same as 90 | # matlabs specgram. 91 | scale = window_norm * sample_rate 92 | 93 | trunc = (len(samples) - fft_length) % hop_length 94 | x = samples[:len(samples) - trunc] 95 | 96 | # "stride trick" reshape to include overlap 97 | nshape = (fft_length, (len(x) - fft_length) // hop_length + 1) 98 | nstrides = (x.strides[0], x.strides[0] * hop_length) 99 | x = as_strided(x, shape=nshape, strides=nstrides) 100 | 101 | # window stride sanity check 102 | assert np.all(x[:, 1] == samples[hop_length:(hop_length + fft_length)]) 103 | 104 | # broadcast window, compute fft over columns and square mod 105 | x = np.fft.rfft(x * window, axis=0) 106 | x = np.absolute(x)**2 107 | 108 | # scale, 2.0 for everything except dc and fft_length/2 109 | x[1:-1, :] *= (2.0 / scale) 110 | x[(0, -1), :] /= scale 111 | 112 | freqs = float(sample_rate) / fft_length * np.arange(x.shape[0]) 113 | 114 | return x, freqs 115 | 116 | 117 | def spectrogram_from_file(filename, step=10, window=20, max_freq=None, 118 | eps=1e-14): 119 | """ Calculate the log of linear spectrogram from FFT energy 120 | Params: 121 | filename (str): Path to the audio file 122 | step (int): Step size in milliseconds between windows 123 | window (int): FFT window size in milliseconds 124 | max_freq (int): Only FFT bins corresponding to frequencies between 125 | [0, max_freq] are returned 126 | eps (float): Small value to ensure numerical stability (for ln(x)) 127 | """ 128 | with soundfile.SoundFile(filename) as sound_file: 129 | audio = sound_file.read(dtype='float32') 130 | sample_rate = sound_file.samplerate 131 | if audio.ndim >= 2: 132 | audio = np.mean(audio, 1) 133 | if max_freq is None: 134 | max_freq = sample_rate / 2 135 | if max_freq > sample_rate / 2: 136 | raise ValueError("max_freq must not be greater than half of " 137 | " sample rate") 138 | if step > window: 139 | raise ValueError("step size must not be greater than window size") 140 | hop_length = int(0.001 * step * sample_rate) 141 | fft_length = int(0.001 * window * sample_rate) 142 | pxx, freqs = spectrogram( 143 | audio, fft_length=fft_length, sample_rate=sample_rate, 144 | hop_length=hop_length) 145 | ind = np.where(freqs <= max_freq)[0][-1] + 1 146 | return np.transpose(np.log(pxx[:ind, :] + eps)) 147 | 148 | 149 | def save_model(save_dir, model, train=None, validation=None, wer=None, 150 | val_wer=None, phoneme=None, val_phoneme=None, index=None): 151 | """ Save the model and costs into a directory 152 | Params: 153 | save_dir (str): Directory used to store the model 154 | model (keras.models.Model) 155 | train (list(float)) 156 | validation (list(float)) 157 | index (int): If this is provided, add this index as a suffix to 158 | the weights (useful for checkpointing during training) 159 | """ 160 | logger.info("Checkpointing model to: {}".format(save_dir)) 161 | model_config_path = os.path.join(save_dir, 'model_config.json') 162 | with open(model_config_path, 'w') as model_config_file: 163 | model_json = model.to_json() 164 | model_config_file.write(model_json) 165 | if index is None: 166 | weights_format = 'model_weights.h5' 167 | else: 168 | weights_format = 'model_{}_weights.h5'.format(index) 169 | model_weights_file = os.path.join(save_dir, weights_format) 170 | model.save_weights(model_weights_file, overwrite=True) 171 | costs = {} 172 | for metric in ['train', 'validation', 'wer', 'val_wer', 'phoneme', 173 | 'val_phoneme']: 174 | metric_val = locals()[metric] 175 | if metric_val is not None: 176 | costs[metric] = metric_val 177 | np.savez(os.path.join(save_dir, 'costs.npz'), **costs) 178 | 179 | 180 | def load_model(load_dir, weights_file=None): 181 | """ Load a model and its weights from a directory 182 | Params: 183 | load_dir (str): Path the model directory 184 | weights_file (str): If this is not passed in, try to load the latest 185 | model_*weights.h5 file in the directory 186 | Returns: 187 | model (keras.models.Model) 188 | """ 189 | def atoi(text): 190 | return int(text) if text.isdigit() else text 191 | 192 | def natural_keys(text): 193 | # From http://stackoverflow.com/questions/5967500 194 | return [atoi(c) for c in re.split('(\d+)', text)] 195 | 196 | model_config_file = os.path.join(load_dir, 'model_config.json') 197 | model_config = open(model_config_file).read() 198 | model = model_from_json(model_config) 199 | 200 | if weights_file is None: 201 | # This will find all files of name model_*weights.h5 202 | # We try to use the latest one saved 203 | weights_files = glob.glob(os.path.join(load_dir, 'model_*weights.h5')) 204 | weights_files.sort(key=natural_keys) 205 | model_weights_file = weights_files[-1] # Use the latest model 206 | else: 207 | model_weights_file = weights_file 208 | model.load_weights(model_weights_file) 209 | return model 210 | 211 | 212 | def argmax_decode(prediction): 213 | """ Decode a prediction using the highest probable character at each 214 | timestep. Then, simply convert the integer sequence to text 215 | Params: 216 | prediction (np.array): timestep * num_characters 217 | """ 218 | int_sequence = [] 219 | for timestep in prediction: 220 | int_sequence.append(np.argmax(timestep)) 221 | tokens = [] 222 | c_prev = -1 223 | for c in int_sequence: 224 | if c == c_prev: 225 | continue 226 | if c != for_tf_or_th(28, 0): # Blank 227 | tokens.append(c) 228 | c_prev = c 229 | 230 | text = ''.join([index_map[i] for i in tokens]) 231 | return text 232 | 233 | 234 | def text_to_int_sequence(text): 235 | """ Use a character map and convert text to an integer sequence """ 236 | int_sequence = [] 237 | for c in text: 238 | if c == ' ': 239 | ch = char_map[''] 240 | else: 241 | ch = char_map[c] 242 | int_sequence.append(ch) 243 | return int_sequence 244 | 245 | 246 | def configure_logging(console_log_level=logging.INFO, 247 | console_log_format=None, 248 | file_log_path=None, 249 | file_log_level=logging.INFO, 250 | file_log_format=None, 251 | clear_handlers=False): 252 | """Setup logging. 253 | 254 | This configures either a console handler, a file handler, or both and 255 | adds them to the root logger. 256 | 257 | Args: 258 | console_log_level (logging level): logging level for console logger 259 | console_log_format (str): log format string for console logger 260 | file_log_path (str): full filepath for file logger output 261 | file_log_level (logging level): logging level for file logger 262 | file_log_format (str): log format string for file logger 263 | clear_handlers (bool): clear existing handlers from the root logger 264 | 265 | Note: 266 | A logging level of `None` will disable the handler. 267 | """ 268 | if file_log_format is None: 269 | file_log_format = \ 270 | '%(asctime)s %(levelname)-7s (%(name)s) %(message)s' 271 | 272 | if console_log_format is None: 273 | console_log_format = \ 274 | '%(asctime)s %(levelname)-7s (%(name)s) %(message)s' 275 | 276 | # configure root logger level 277 | root_logger = logging.getLogger() 278 | root_level = root_logger.level 279 | if console_log_level is not None: 280 | root_level = min(console_log_level, root_level) 281 | if file_log_level is not None: 282 | root_level = min(file_log_level, root_level) 283 | root_logger.setLevel(root_level) 284 | 285 | # clear existing handlers 286 | if clear_handlers and len(root_logger.handlers) > 0: 287 | print("Clearing {} handlers from root logger." 288 | .format(len(root_logger.handlers))) 289 | for handler in root_logger.handlers: 290 | root_logger.removeHandler(handler) 291 | 292 | # file logger 293 | if file_log_path is not None and file_log_level is not None: 294 | log_dir = os.path.dirname(os.path.abspath(file_log_path)) 295 | if not os.path.isdir(log_dir): 296 | os.makedirs(log_dir) 297 | file_handler = logging.FileHandler(file_log_path) 298 | file_handler.setLevel(file_log_level) 299 | file_handler.setFormatter(logging.Formatter(file_log_format)) 300 | root_logger.addHandler(file_handler) 301 | 302 | # console logger 303 | if console_log_level is not None: 304 | console_handler = logging.StreamHandler() 305 | console_handler.setLevel(console_log_level) 306 | console_handler.setFormatter(logging.Formatter(console_log_format)) 307 | root_logger.addHandler(console_handler) 308 | 309 | 310 | def char_error_rate(true_labels, pred_labels, decoded=False): 311 | cers = np.empty(len(true_labels)) 312 | 313 | if not decoded: 314 | pred_labels = for_tf_or_th(pred_labels, pred_labels.swapaxes(0, 1)) 315 | 316 | i = 0 317 | for true_label, pred_label in zip(true_labels, pred_labels): 318 | prediction = pred_label if decoded else argmax_decode(pred_label) 319 | ratio = SequenceMatcher(true_label, prediction).ratio() 320 | cers[i] = 1 - ratio 321 | i += 1 322 | return cers 323 | 324 | 325 | def word_error_rate(true_labels, pred_labels, decoded=False): 326 | wers = np.empty(len(true_labels)) 327 | 328 | if not decoded: 329 | pred_labels = for_tf_or_th(pred_labels, pred_labels.swapaxes(0, 1)) 330 | 331 | i = 0 332 | for true_label, pred_label in zip(true_labels, pred_labels): 333 | prediction = pred_label if decoded else argmax_decode(pred_label) 334 | # seq_matcher = SequenceMatcher(prediction, true_label) 335 | # errors = [c for c in seq_matcher.get_opcodes() if c[0] != 'equal'] 336 | # error_lens = [max(e[4] - e[3], e[2] - e[1]) for e in errors] 337 | # wers[i] = sum(error_lens) / float(len(true_label.split())) 338 | wers[i] = wer(true_label, prediction) 339 | i += 1 340 | return wers 341 | 342 | def wer(original, result): 343 | r""" 344 | The WER is defined as the editing/Levenshtein distance on word level 345 | divided by the amount of words in the original text. 346 | In case of the original having more words (N) than the result and both 347 | being totally different (all N words resulting in 1 edit operation each), 348 | the WER will always be 1 (N / N = 1). 349 | """ 350 | # The WER ist calculated on word (and NOT on character) level. 351 | # Therefore we split the strings into words first: 352 | original = original.split() 353 | result = result.split() 354 | return levenshtein(original, result) / float(len(original)) 355 | 356 | def levenshtein(a,b): 357 | "Calculates the Levenshtein distance between a and b." 358 | n, m = len(a), len(b) 359 | if n > m: 360 | # Make sure n <= m, to use O(min(n,m)) space 361 | a,b = b,a 362 | n,m = m,n 363 | 364 | current = list(range(n+1)) 365 | for i in range(1,m+1): 366 | previous, current = current, [i]+[0]*n 367 | for j in range(1,n+1): 368 | add, delete = previous[j]+1, current[j-1]+1 369 | change = previous[j-1] 370 | if a[j-1] != b[i-1]: 371 | change = change + 1 372 | current[j] = min(add, delete, change) 373 | 374 | return current[n] 375 | 376 | def for_tf_or_th(tf_val, th_val): 377 | backname = keras.backend.backend() 378 | if backname == 'tensorflow': 379 | return tf_val 380 | elif backname == 'theano': 381 | return th_val 382 | raise ValueError('Unsupported backend {}'.format(backname)) 383 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Use this script to visualize the output of a trained speech-model. 3 | Usage: python visualize.py /path/to/audio /path/to/training/json.json \ 4 | /path/to/model 5 | """ 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import sys 9 | import argparse 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | 15 | from test import load_model_wrapper 16 | 17 | 18 | def softmax(x): 19 | return np.exp(x) / np.sum(np.exp(x), axis=0) 20 | 21 | 22 | def prompt_loop(prompt_line, locs): 23 | """ Reads user codes and evaluates them then returns new locals scope """ 24 | 25 | while True: 26 | try: 27 | line = raw_input(prompt_line) 28 | except EOFError: 29 | break 30 | else: 31 | if line.strip() == '': 32 | break 33 | try: 34 | exec(line, globals(), locs) 35 | except Exception as exc: 36 | print(exc) 37 | continue 38 | 39 | return locs 40 | 41 | 42 | def visualize(model, test_file, train_desc_file): 43 | """ Get the prediction using the model, and visualize softmax outputs 44 | Params: 45 | model (keras.models.Model): Trained speech model 46 | test_file (str): Path to an audio clip 47 | train_desc_file(str): Path to the training file used to train this 48 | model 49 | """ 50 | from model import compile_output_fn 51 | from data_generator import DataGenerator 52 | from utils import argmax_decode 53 | 54 | datagen = DataGenerator() 55 | datagen.load_train_data(train_desc_file) 56 | datagen.fit_train(100) 57 | 58 | print ("Compiling test function...") 59 | test_fn = compile_output_fn(model) 60 | 61 | inputs = [datagen.featurize(test_file)] 62 | 63 | prediction = np.squeeze(test_fn([inputs, True])) 64 | # preds, probs = beam_decode(prediction, 8) 65 | # u_preds, u_probs = beam_decode_u(prediction, 8) 66 | 67 | softmax_file = "softmax.npy".format(test_file) 68 | softmax_img_file = "softmax.png".format(test_file) 69 | print ("Prediction: {}" 70 | .format(argmax_decode(prediction))) 71 | print ("Saving network output to: {}".format(softmax_file)) 72 | print ("As image: {}".format(softmax_img_file)) 73 | np.save(softmax_file, prediction) 74 | sm = softmax(prediction.T) 75 | sm = np.vstack((sm[0], sm[2], sm[3:][::-1])) 76 | fig, ax = plt.subplots() 77 | ax.pcolor(sm, cmap=plt.cm.Greys_r) 78 | column_labels = [chr(i) for i in range(97, 97 + 26)] + ['space', 'blank'] 79 | ax.set_yticks(np.arange(sm.shape[0]) + 0.5, minor=False) 80 | ax.set_yticklabels(column_labels[::-1], minor=False) 81 | plt.savefig(softmax_img_file) 82 | 83 | 84 | def interactive_vis(model_dir, model_config, train_desc_file, weights_file=None): 85 | """ Get the prediction using the model, and visualize softmax outputs, able 86 | to predict multiple inputs. 87 | Params: 88 | model_dir (str): Trained speech model or None. If None given will ask 89 | code to make model. 90 | model_config (str): Path too pre-trained model configuration 91 | train_desc_file(str): Path to the training file used to train this 92 | model 93 | weights_file(str): Path to stored weights file for model being made 94 | """ 95 | 96 | if model_dir is None: 97 | assert weights_file is not None 98 | if model_config is None: 99 | from model_wrp import HalfPhonemeModelWrapper, GruModelWrapper 100 | print ("""Make and store new model into model, e.g. 101 | >>> model_wrp = HalfPhonemeModelWrapper() 102 | >>> model = model_wrp.compile(nodes=1000, recur_layers=5, 103 | conv_context=5) 104 | """) 105 | 106 | model = prompt_loop('[model=]> ', locals())['model'] 107 | model.load_weights(weights_file) 108 | else: 109 | model_wrapper = load_model_wrapper(model_config, weights_file) 110 | test_fn = model_wrapper.compile_output_fn() 111 | else: 112 | from utils import load_model 113 | model = load_model(model_dir, weights_file) 114 | 115 | if model_config is None: 116 | print ("""Make and store test function to test_fn, e.g. 117 | >>> test_fn = model_wrp.compile_output_fn() 118 | """) 119 | test_fn = prompt_loop('[test_fn=]> ', locals())['test_fn'] 120 | 121 | from utils import argmax_decode 122 | from data_generator import DataGenerator 123 | datagen = DataGenerator() 124 | 125 | if train_desc_file is not None: 126 | datagen.load_train_data(train_desc_file) 127 | datagen.fit_train(100) 128 | else: 129 | datagen.reload_norm('860-1000') 130 | 131 | while True: 132 | try: 133 | test_file = raw_input('Input file: ') 134 | except EOFError: 135 | comm_mode = True 136 | while comm_mode: 137 | try: 138 | comm = raw_input("[w: load wieghts\t s: shell ] > ") 139 | if comm.strip() == 'w': 140 | w_path = raw_input("weights file path: ").strip() 141 | model.load_weights(w_path) 142 | if comm.strip() == 's': 143 | prompt_loop('> ', locals()) 144 | except EOFError: 145 | comm_mode = False 146 | except Exception as exc: 147 | print (exc) 148 | continue 149 | 150 | if test_file.strip() == '': 151 | break 152 | 153 | try: 154 | inputs = [datagen.normalize(datagen.featurize(test_file))] 155 | except Exception as exc: 156 | print (exc) 157 | continue 158 | 159 | prediction = np.squeeze(test_fn([inputs, False])) 160 | 161 | softmax_file = "softmax.npy".format(test_file) 162 | softmax_img_file = "softmax.png".format(test_file) 163 | print ("Prediction: {}".format(argmax_decode(prediction))) 164 | print ("Saving network output to: {}".format(softmax_file)) 165 | print ("As image: {}".format(softmax_img_file)) 166 | np.save(softmax_file, prediction) 167 | sm = softmax(prediction.T) 168 | sm = np.vstack((sm[0], sm[2], sm[3:][::-1])) 169 | fig, ax = plt.subplots() 170 | ax.pcolor(sm, cmap=plt.cm.Greys_r) 171 | column_labels = [chr(i) for i in range(97, 97+26)] + ['space', 'blank'] 172 | ax.set_yticks(np.arange(sm.shape[0]) + 0.5, minor=False) 173 | ax.set_yticklabels(column_labels[::-1], minor=False) 174 | plt.savefig(softmax_img_file) 175 | 176 | 177 | def main(): 178 | parser = argparse.ArgumentParser( 179 | description="Evaluate model on input file(s).", epilog=""" 180 | This script can give an interactive shell for evaluation on multiple 181 | input files. If you want plain prediction as originally came from 182 | Baidu's repo and model is trained without `model_wrapper` helpers, 183 | arguments --test-file, --train-desc-file, --load-dir and --weights-file 184 | are necessary. Otherwise set --interactive and If model is shipped 185 | by this repo give model config by --model-config. 186 | """) 187 | parser.add_argument('--test-file', type=str, help='Path to an audio file') 188 | parser.add_argument('--train-desc-file', type=str, 189 | help='Path to the training JSON-line file. This will ' 190 | 'be used to extract feature means/variance') 191 | parser.add_argument('--load-dir', type=str, 192 | help='Directory where a trained model is stored.') 193 | parser.add_argument('--model-config', type=str, 194 | help='Path to pre-trained model configuration') 195 | parser.add_argument('--weights-file', type=str, default=None, 196 | help='Path to a model weights file') 197 | parser.add_argument('--interactive', default=False, action='store_true', 198 | help='Interactive interface, necessary for pre-trained' 199 | ' models with this repo.') 200 | args = parser.parse_args() 201 | 202 | if args.interactive: 203 | assert args.test_file is None 204 | interactive_vis(args.load_dir, args.model_config, args.train_desc_file, 205 | args.weights_file) 206 | else: 207 | from utils import load_model 208 | if args.load_dir is None or args.test_file is None: 209 | parser.print_usage() 210 | sys.exit(1) 211 | 212 | print ("Loading model") 213 | model = load_model(args.load_dir, args.weights_file) 214 | visualize(model, args.test_file, args.train_desc_file) 215 | 216 | 217 | if __name__ == '__main__': 218 | main() 219 | --------------------------------------------------------------------------------