├── .gitignore ├── LICENSE ├── README.md ├── configs ├── neural_network.ini └── testing │ └── neural_network.ini ├── data └── raw │ └── librivox │ └── LibriSpeech │ ├── dev-clean-wav │ ├── 3752-4944-0041.txt │ ├── 3752-4944-0041.wav │ ├── 777-126732-0068.txt │ └── 777-126732-0068.wav │ ├── test-clean-wav │ ├── 4507-16021-0019.txt │ ├── 4507-16021-0019.wav │ ├── 7176-92135-0009.txt │ └── 7176-92135-0009.wav │ └── train-clean-100-wav │ ├── 1970-28415-0023.txt │ ├── 1970-28415-0023.wav │ ├── 211-122425-0059.txt │ ├── 211-122425-0059.wav │ ├── 2843-152918-0008.txt │ ├── 2843-152918-0008.wav │ ├── 3259-158083-0026.txt │ ├── 3259-158083-0026.wav │ ├── 3879-174923-0005.txt │ └── 3879-174923-0005.wav ├── requirements.txt └── src ├── __init__.py ├── data_manipulation ├── __init__.py └── datasets.py ├── features ├── __init__.py └── utils │ ├── __init__.py │ ├── load_audio_to_mem.py │ └── text.py ├── models ├── RNN │ ├── __init__.py │ ├── rnn.py │ └── utils.py └── __init__.py ├── tests ├── __init__.py └── train_framework │ ├── __init__.py │ └── tf_train_ctc_test.py ├── train_framework ├── __init__.py └── tf_train_ctc.py └── utils ├── __init__.py ├── gpu.py └── set_dirs.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | # Jupyter Notebook 5 | .ipynb_checkpoints -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent Neural Networks - A Short TensorFlow Tutorial 2 | 3 | ### Setup 4 | Clone this repo to your local machine, and add the RNN-Tutorial directory as a system variable to your `~/.profile`. Instructions given for bash shell: 5 | 6 | ```bash 7 | git clone https://github.com/silicon-valley-data-science/RNN-Tutorial 8 | cd RNN-Tutorial 9 | echo "export RNN_TUTORIAL=${PWD}" >> ~/.profile 10 | echo "export PYTHONPATH=$RNN_TUTORIAL/src:${PYTHONPATH}" >> ~/.profile 11 | source ~/.profile 12 | ``` 13 | 14 | Create a Conda environment (You will need to [Install Conda](https://conda.io/docs/install/quick.html) first) 15 | 16 | ```bash 17 | conda create --name tf-rnn python=3 18 | source activate tf-rnn 19 | cd $RNN_TUTORIAL 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ### Install TensorFlow 24 | 25 | If you have a NVIDIA GPU with [CUDA](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/#package-manager-installation) already installed 26 | 27 | ```bash 28 | pip install tensorflow-gpu==1.0.1 29 | ``` 30 | 31 | If you will be running TensorFlow on CPU only (i.e. a MacBook Pro), use the following command (if you get an error the first time you run this command read below): 32 | 33 | ```bash 34 | pip install --upgrade --ignore-installed\ 35 | https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.1-py3-none-any.whl 36 | ``` 37 | 38 | **Error note** (if you did not get an error skip this paragraph): Depending on how you installed pip and/or conda, we've seen different outcomes. If you get an error the first time, rerunning it may incorrectly show that it installs without error. Try running with `pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.1-py3-none-any.whl --ignore-installed`. The `--ignore-installed` flag tells it to reinstall the package. If that still doesn't work, please open an [issue](https://github.com/silicon-valley-data-science/RNN-Tutorial/issues), or you can try to follow the advice [here](https://www.tensorflow.org/install/install_mac). 39 | 40 | 41 | ### Run unittests 42 | We have included example unittests for the `tf_train_ctc.py` script 43 | 44 | ```bash 45 | python $RNN_TUTORIAL/src/tests/train_framework/tf_train_ctc_test.py 46 | ``` 47 | 48 | 49 | ### Run RNN training 50 | All configurations for the RNN training script can be found in `$RNN_TUTORIAL/configs/neural_network.ini` 51 | 52 | ```bash 53 | python $RNN_TUTORIAL/src/train_framework/tf_train_ctc.py 54 | ``` 55 | 56 | _NOTE: If you have a GPU available, the code will run faster if you set `tf_device = /gpu:0` in `configs/neural_network.ini`_ 57 | 58 | 59 | ### TensorBoard configuration 60 | To visualize your results via tensorboard: 61 | 62 | ```bash 63 | tensorboard --logdir=$RNN_TUTORIAL/models/nn/debug_models/summary/ 64 | ``` 65 | 66 | - TensorBoard can be found in your browser at [http://localhost:6006](http://localhost:6006). 67 | - `tf.name_scope` is used to define parts of the network for visualization in TensorBoard. TensorBoard automatically finds any similarly structured network parts, such as identical fully connected layers and groups them in the graph visualization. 68 | - Related to this are the `tf.summary.* methods` that log values of network parts, such as distributions of layer activations or error rate across epochs. These summaries are grouped within the `tf.name_scope`. 69 | - See the official TensorFlow documentation for more details. 70 | 71 | 72 | ### Add data 73 | We have included example data from the [LibriVox corpus](https://librivox.org) in `data/raw/librivox/LibriSpeech/`. The data is separated into folders: 74 | 75 | - Train: train-clean-100-wav (5 examples) 76 | - Test: test-clean-wav (2 examples) 77 | - Dev: dev-clean-wav (2 examples) 78 | 79 | If you would like to train a performant model, you can add additional wave and txt files to these folders, or create a new folder and update `configs/neural_network.ini` with the folder locations 80 | 81 | 82 | ### Remove additions 83 | 84 | We made a few additions to your `.profile` -- remove those additions if you want, or if you want to keep the system variables, add it to your `.bash_profile` by running: 85 | 86 | ```bash 87 | echo "source ~/.profile" >> .bash_profile 88 | ``` 89 | -------------------------------------------------------------------------------- /configs/neural_network.ini: -------------------------------------------------------------------------------- 1 | [nn] 2 | epochs = 100 3 | network_type = BiRNN 4 | decode_train = True 5 | n_input = 26 6 | n_context = 9 7 | model_dir = nn/debug_models 8 | SAVE_MODEL_EPOCH_NUM = 2 9 | VALIDATION_EPOCH_NUM = 1 10 | CURR_VALIDATION_LER_DIFF = 0.005 11 | AVG_VALIDATION_LER_EPOCHS = 2 12 | beam_search_decoder = default 13 | shuffle_data_after_epoch = True 14 | min_dev_ler = 100.0 15 | tf_device = /gpu:0 16 | simultaneous_users_count = 4 17 | 18 | [data] 19 | #If data_dir does not start with '/', then home_dir is prepended in set_dirs.py 20 | data_dir = data/raw/librivox/LibriSpeech/ 21 | dir_pattern_train = train-clean-100-wav 22 | dir_pattern_dev = dev-clean-wav 23 | dir_pattern_test = test-clean-wav 24 | n_train_limit = 5 25 | n_dev_limit = 2 26 | n_test_limit = 2 27 | batch_size_train = 2 28 | batch_size_dev = 2 29 | batch_size_test = 2 30 | start_idx_init_train = 0 31 | start_idx_init_dev = 0 32 | start_idx_init_test = 0 33 | sort_train = filesize_low_high 34 | sort_dev = random 35 | sort_test = random 36 | 37 | [optimizer] 38 | # AdamOptimizer (http://arxiv.org/abs/1412.6980) parameters 39 | beta1 = 0.9 40 | beta2 = 0.999 41 | epsilon = 1e-8 42 | learning_rate = 0.001 43 | 44 | [simplelstm] 45 | n_character = 29 46 | default_stddev = 0.046875 47 | b1_stddev = %(default_stddev)s 48 | h1_stddev = %(default_stddev)s 49 | n_layers = 2 50 | n_hidden_units = 512 51 | 52 | [birnn] 53 | n_character = 29 54 | use_warpctc = False 55 | dropout_rate = 0.05 56 | dropout_rate2 = %(dropout_rate)s 57 | dropout_rate3 = %(dropout_rate)s 58 | dropout_rate4 = 0.0 59 | dropout_rate5 = 0.0 60 | dropout_rate6 = %(dropout_rate)s 61 | dropout_rates = %(dropout_rate)s,%(dropout_rate2)s,%(dropout_rate3)s,%(dropout_rate4)s,%(dropout_rate5)s,%(dropout_rate6)s 62 | relu_clip = 20 63 | default_stddev = 0.046875 64 | b1_stddev = %(default_stddev)s 65 | h1_stddev = %(default_stddev)s 66 | b2_stddev = %(default_stddev)s 67 | h2_stddev = %(default_stddev)s 68 | b3_stddev = %(default_stddev)s 69 | h3_stddev = %(default_stddev)s 70 | b5_stddev = %(default_stddev)s 71 | h5_stddev = %(default_stddev)s 72 | b6_stddev = %(default_stddev)s 73 | h6_stddev = %(default_stddev)s 74 | n_hidden = 1024 75 | n_hidden_1 = %(n_hidden)s 76 | n_hidden_2 = %(n_hidden)s 77 | n_hidden_5 = %(n_hidden)s 78 | n_cell_dim = %(n_hidden)s 79 | n_hidden_3 = 2 * %(n_cell_dim)s 80 | n_hidden_6 = %(n_character)s 81 | -------------------------------------------------------------------------------- /configs/testing/neural_network.ini: -------------------------------------------------------------------------------- 1 | [nn] 2 | epochs = 1 3 | network_type = BiRNN 4 | decode_train = True 5 | n_input = 26 6 | n_context = 9 7 | model_dir = nn/debug_models 8 | SAVE_MODEL_EPOCH_NUM = 2 9 | VALIDATION_EPOCH_NUM = 1 10 | CURR_VALIDATION_LER_DIFF = 0.005 11 | AVG_VALIDATION_LER_EPOCHS = 2 12 | beam_search_decoder = default 13 | shuffle_data_after_epoch = True 14 | min_dev_ler = 100.0 15 | tf_device = /gpu:0 16 | simultaneous_users_count = 4 17 | 18 | [data] 19 | #If data_dir does not start with '/', then home_dir is prepended in set_dirs.py 20 | data_dir = data/raw/librivox/LibriSpeech/ 21 | dir_pattern_train = train-clean-100-wav 22 | dir_pattern_dev = dev-clean-wav 23 | dir_pattern_test = test-clean-wav 24 | n_train_limit = 2 25 | n_dev_limit = 2 26 | n_test_limit = 2 27 | batch_size_train = 2 28 | batch_size_dev = 2 29 | batch_size_test = 2 30 | start_idx_init_train = 0 31 | start_idx_init_dev = 0 32 | start_idx_init_test = 0 33 | sort_train = filesize_low_high 34 | sort_dev = random 35 | sort_test = random 36 | 37 | [optimizer] 38 | # AdamOptimizer (http://arxiv.org/abs/1412.6980) parameters 39 | beta1 = 0.9 40 | beta2 = 0.999 41 | epsilon = 1e-8 42 | learning_rate = 0.001 43 | 44 | [simplelstm] 45 | n_character = 29 46 | default_stddev = 0.046875 47 | b1_stddev = %(default_stddev)s 48 | h1_stddev = %(default_stddev)s 49 | n_layers = 2 50 | n_hidden_units = 512 51 | 52 | [birnn] 53 | n_character = 29 54 | use_warpctc = False 55 | dropout_rate = 0.05 56 | dropout_rate2 = %(dropout_rate)s 57 | dropout_rate3 = %(dropout_rate)s 58 | dropout_rate4 = 0.0 59 | dropout_rate5 = 0.0 60 | dropout_rate6 = %(dropout_rate)s 61 | dropout_rates = %(dropout_rate)s,%(dropout_rate2)s,%(dropout_rate3)s,%(dropout_rate4)s,%(dropout_rate5)s,%(dropout_rate6)s 62 | relu_clip = 20 63 | default_stddev = 0.046875 64 | b1_stddev = %(default_stddev)s 65 | h1_stddev = %(default_stddev)s 66 | b2_stddev = %(default_stddev)s 67 | h2_stddev = %(default_stddev)s 68 | b3_stddev = %(default_stddev)s 69 | h3_stddev = %(default_stddev)s 70 | b5_stddev = %(default_stddev)s 71 | h5_stddev = %(default_stddev)s 72 | b6_stddev = %(default_stddev)s 73 | h6_stddev = %(default_stddev)s 74 | n_hidden = 1024 75 | n_hidden_1 = %(n_hidden)s 76 | n_hidden_2 = %(n_hidden)s 77 | n_hidden_5 = %(n_hidden)s 78 | n_cell_dim = %(n_hidden)s 79 | n_hidden_3 = 2 * %(n_cell_dim)s 80 | n_hidden_6 = %(n_character)s 81 | -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/dev-clean-wav/3752-4944-0041.txt: -------------------------------------------------------------------------------- 1 | how delightful the grass smells -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/dev-clean-wav/3752-4944-0041.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/dev-clean-wav/3752-4944-0041.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/dev-clean-wav/777-126732-0068.txt: -------------------------------------------------------------------------------- 1 | that boy hears too much of what is talked about here -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/dev-clean-wav/777-126732-0068.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/dev-clean-wav/777-126732-0068.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/test-clean-wav/4507-16021-0019.txt: -------------------------------------------------------------------------------- 1 | it is the language of wretchedness -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/test-clean-wav/4507-16021-0019.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/test-clean-wav/4507-16021-0019.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/test-clean-wav/7176-92135-0009.txt: -------------------------------------------------------------------------------- 1 | and i should begin with a short homily on soliloquy -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/test-clean-wav/7176-92135-0009.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/test-clean-wav/7176-92135-0009.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/1970-28415-0023.txt: -------------------------------------------------------------------------------- 1 | where people were making their gifts to god -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/1970-28415-0023.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/1970-28415-0023.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/211-122425-0059.txt: -------------------------------------------------------------------------------- 1 | and the two will pass off together -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/211-122425-0059.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/211-122425-0059.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/2843-152918-0008.txt: -------------------------------------------------------------------------------- 1 | one day may be pleasant enough but two three four -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/2843-152918-0008.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/2843-152918-0008.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/3259-158083-0026.txt: -------------------------------------------------------------------------------- 1 | i have a nephew fighting for democracy in france -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/3259-158083-0026.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/3259-158083-0026.wav -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/3879-174923-0005.txt: -------------------------------------------------------------------------------- 1 | he must vanish out of the world -------------------------------------------------------------------------------- /data/raw/librivox/LibriSpeech/train-clean-100-wav/3879-174923-0005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/3879-174923-0005.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mako==1.0.6 2 | matplotlib==2.0.0 3 | numpy==1.12.0 4 | protobuf==3.2.0 5 | python-speech-features==0.5 6 | pyyaml==5.4 7 | pyxdg==0.26 8 | requests==2.20.0 9 | scipy==0.18.1 10 | sox==1.2.6 11 | tree==0.1.0 12 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/__init__.py -------------------------------------------------------------------------------- /src/data_manipulation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/data_manipulation/__init__.py -------------------------------------------------------------------------------- /src/data_manipulation/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | from random import random 4 | from glob import glob 5 | from configparser import ConfigParser 6 | import logging 7 | from collections import namedtuple 8 | 9 | from features.utils.load_audio_to_mem import get_audio_and_transcript, pad_sequences 10 | from features.utils.text import sparse_tuple_from 11 | from utils.set_dirs import get_data_dir 12 | 13 | 14 | DataSets = namedtuple('DataSets', 'train dev test') 15 | 16 | 17 | def read_datasets(conf_path, sets, numcep, numcontext, 18 | thread_count=8): 19 | '''Main function to create DataSet objects. 20 | 21 | This function calls an internal function _get_data_set_dict that 22 | reads the configuration file. Then it calls the internal function _read_data_set 23 | which collects the text files in the data directories, returning a DataSet object. 24 | This function returns a DataSets object containing the requested datasets. 25 | 26 | Args: 27 | sets (list): List of datasets to create. Options are: 'train', 'dev', 'test' 28 | numcep (int): Number of mel-frequency cepstral coefficients to compute. 29 | numcontext (int): For each time point, number of contextual samples to include. 30 | thread_count (int): Number of threads 31 | 32 | Returns: 33 | DataSets: A single `DataSets` instance containing each of the requested datasets 34 | 35 | E.g., when sets=['train'], datasets.train exists, with methods to retrieve examples. 36 | 37 | This function has been modified from Mozilla DeepSpeech: 38 | https://github.com/mozilla/DeepSpeech/blob/master/util/importers/librivox.py 39 | 40 | # This Source Code Form is subject to the terms of the Mozilla Public 41 | # License, v. 2.0. If a copy of the MPL was not distributed with this 42 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 43 | ''' 44 | data_dir, dataset_config = _get_data_set_dict(conf_path, sets) 45 | 46 | def _read_data_set(config): 47 | path = os.path.join(data_dir, config['dir_pattern']) 48 | return DataSet.from_directory(path, 49 | thread_count=thread_count, 50 | batch_size=config['batch_size'], 51 | numcep=numcep, 52 | numcontext=numcontext, 53 | start_idx=config['start_idx'], 54 | limit=config['limit'], 55 | sort=config['sort'] 56 | ) 57 | datasets = {name: _read_data_set(dataset_config[name]) 58 | if name in sets else None 59 | for name in ('train', 'dev', 'test')} 60 | return DataSets(**datasets) 61 | 62 | 63 | class DataSet: 64 | ''' 65 | Train/test/dev dataset API for loading via threads and delivering batches. 66 | 67 | This class has been modified from Mozilla DeepSpeech: 68 | https://github.com/mozilla/DeepSpeech/blob/master/util/importers/librivox.py 69 | 70 | # This Source Code Form is subject to the terms of the Mozilla Public 71 | # License, v. 2.0. If a copy of the MPL was not distributed with this 72 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 73 | ''' 74 | 75 | def __init__(self, txt_files, thread_count, batch_size, numcep, numcontext): 76 | self._coord = None 77 | self._numcep = numcep 78 | self._txt_files = txt_files 79 | self._batch_size = batch_size 80 | self._numcontext = numcontext 81 | self._start_idx = 0 82 | 83 | @classmethod 84 | def from_directory(cls, dirpath, thread_count, batch_size, numcep, numcontext, start_idx=0, limit=0, sort=None): 85 | if not os.path.exists(dirpath): 86 | raise IOError("'%s' does not exist" % dirpath) 87 | txt_files = txt_filenames(dirpath, start_idx=start_idx, limit=limit, sort=sort) 88 | if len(txt_files) == 0: 89 | raise RuntimeError('start_idx=%d and limit=%d arguments result in zero files' % (start_idx, limit)) 90 | return cls(txt_files, thread_count, batch_size, numcep, numcontext) 91 | 92 | def next_batch(self, batch_size=None): 93 | if batch_size is None: 94 | batch_size = self._batch_size 95 | 96 | end_idx = min(len(self._txt_files), self._start_idx + batch_size) 97 | idx_list = range(self._start_idx, end_idx) 98 | txt_files = [self._txt_files[i] for i in idx_list] 99 | wav_files = [x.replace('.txt', '.wav') for x in txt_files] 100 | (source, _, target, _) = get_audio_and_transcript(txt_files, 101 | wav_files, 102 | self._numcep, 103 | self._numcontext) 104 | self._start_idx += batch_size 105 | # Verify that the start_idx is not larger than total available sample size 106 | if self._start_idx >= self.size: 107 | self._start_idx = 0 108 | 109 | # Pad input to max_time_step of this batch 110 | source, source_lengths = pad_sequences(source) 111 | sparse_labels = sparse_tuple_from(target) 112 | return source, source_lengths, sparse_labels 113 | 114 | @property 115 | def files(self): 116 | return self._txt_files 117 | 118 | @property 119 | def size(self): 120 | return len(self.files) 121 | 122 | @property 123 | def total_batches(self): 124 | # Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files 125 | return int(ceil(float(len(self._txt_files)) / float(self._batch_size))) 126 | # END DataSet 127 | 128 | SORTS = ['filesize_low_high', 'filesize_high_low', 'alpha', 'random'] 129 | 130 | def txt_filenames(dataset_path, start_idx=0, limit=None, sort='alpha'): 131 | # Obtain list of txt files 132 | txt_files = glob(os.path.join(dataset_path, "*.txt")) 133 | limit = limit or len(txt_files) 134 | 135 | # Optional: sort files to improve padding performance 136 | if sort not in SORTS: 137 | raise ValueError('sort must be one of [%s]', SORTS) 138 | reverse = False 139 | key = None 140 | if 'filesize' in sort: 141 | key = os.path.getsize 142 | if sort == 'filesize_high_low': 143 | reverse = True 144 | elif sort == 'random': 145 | key = lambda *args: random() 146 | txt_files = sorted(txt_files, key=key, reverse=reverse) 147 | 148 | return txt_files[start_idx:limit + start_idx] 149 | 150 | 151 | def _get_data_set_dict(conf_path, sets): 152 | parser = ConfigParser(os.environ) 153 | parser.read(conf_path) 154 | config_header = 'data' 155 | data_dir = get_data_dir(parser.get(config_header, 'data_dir')) 156 | data_dict = {} 157 | 158 | if 'train' in sets: 159 | d = {} 160 | d['dir_pattern'] = parser.get(config_header, 'dir_pattern_train') 161 | d['limit'] = parser.getint(config_header, 'n_train_limit') 162 | d['sort'] = parser.get(config_header, 'sort_train') 163 | d['batch_size'] = parser.getint(config_header, 'batch_size_train') 164 | d['start_idx'] = parser.getint(config_header, 'start_idx_init_train') 165 | data_dict['train'] = d 166 | logging.debug('Training configuration: %s', str(d)) 167 | 168 | if 'dev' in sets: 169 | d = {} 170 | d['dir_pattern'] = parser.get(config_header, 'dir_pattern_dev') 171 | d['limit'] = parser.getint(config_header, 'n_dev_limit') 172 | d['sort'] = parser.get(config_header, 'sort_dev') 173 | d['batch_size'] = parser.getint(config_header, 'batch_size_dev') 174 | d['start_idx'] = parser.getint(config_header, 'start_idx_init_dev') 175 | data_dict['dev'] = d 176 | logging.debug('Dev configuration: %s', str(d)) 177 | 178 | if 'test' in sets: 179 | d = {} 180 | d['dir_pattern'] = parser.get(config_header, 'dir_pattern_test') 181 | d['limit'] = parser.getint(config_header, 'n_test_limit') 182 | d['sort'] = parser.get(config_header, 'sort_test') 183 | d['batch_size'] = parser.getint(config_header, 'batch_size_test') 184 | d['start_idx'] = parser.getint(config_header, 'start_idx_init_test') 185 | data_dict['test'] = d 186 | logging.debug('Test configuration: %s', str(d)) 187 | 188 | return data_dir, data_dict 189 | -------------------------------------------------------------------------------- /src/features/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/features/__init__.py -------------------------------------------------------------------------------- /src/features/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/features/utils/__init__.py -------------------------------------------------------------------------------- /src/features/utils/load_audio_to_mem.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import scipy.io.wavfile as wav 5 | 6 | import numpy as np 7 | from python_speech_features import mfcc 8 | from features.utils.text import text_to_char_array, normalize_txt_file 9 | 10 | 11 | def load_wavfile(wavfile): 12 | """ 13 | Read a wav file using scipy.io.wavfile 14 | """ 15 | rate, sig = wav.read(wavfile) 16 | data_name = os.path.splitext(os.path.basename(wavfile))[0] 17 | return rate, sig, data_name 18 | 19 | 20 | def get_audio_and_transcript(txt_files, wav_files, n_input, n_context): 21 | ''' 22 | Loads audio files and text transcriptions from ordered lists of filenames. 23 | Converts to audio to MFCC arrays and text to numerical arrays. 24 | Returns list of arrays. Returned audio array list can be padded with 25 | pad_sequences function in this same module. 26 | ''' 27 | audio = [] 28 | audio_len = [] 29 | transcript = [] 30 | transcript_len = [] 31 | 32 | for txt_file, wav_file in zip(txt_files, wav_files): 33 | # load audio and convert to features 34 | audio_data = audiofile_to_input_vector(wav_file, n_input, n_context) 35 | audio_data = audio_data.astype('float32') 36 | 37 | audio.append(audio_data) 38 | audio_len.append(np.int32(len(audio_data))) 39 | 40 | # load text transcription and convert to numerical array 41 | target = normalize_txt_file(txt_file) 42 | target = text_to_char_array(target) 43 | transcript.append(target) 44 | transcript_len.append(len(target)) 45 | 46 | audio = np.asarray(audio) 47 | audio_len = np.asarray(audio_len) 48 | transcript = np.asarray(transcript) 49 | transcript_len = np.asarray(transcript_len) 50 | return audio, audio_len, transcript, transcript_len 51 | 52 | 53 | def audiofile_to_input_vector(audio_filename, numcep, numcontext): 54 | ''' 55 | Turn an audio file into feature representation. 56 | 57 | This function has been modified from Mozilla DeepSpeech: 58 | https://github.com/mozilla/DeepSpeech/blob/master/util/audio.py 59 | 60 | # This Source Code Form is subject to the terms of the Mozilla Public 61 | # License, v. 2.0. If a copy of the MPL was not distributed with this 62 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 63 | ''' 64 | 65 | # Load wav files 66 | fs, audio = wav.read(audio_filename) 67 | 68 | # Get mfcc coefficients 69 | orig_inputs = mfcc(audio, samplerate=fs, numcep=numcep) 70 | 71 | # We only keep every second feature (BiRNN stride = 2) 72 | orig_inputs = orig_inputs[::2] 73 | 74 | # For each time slice of the training set, we need to copy the context this makes 75 | # the numcep dimensions vector into a numcep + 2*numcep*numcontext dimensions 76 | # because of: 77 | # - numcep dimensions for the current mfcc feature set 78 | # - numcontext*numcep dimensions for each of the past and future (x2) mfcc feature set 79 | # => so numcep + 2*numcontext*numcep 80 | train_inputs = np.array([], np.float32) 81 | train_inputs.resize((orig_inputs.shape[0], numcep + 2 * numcep * numcontext)) 82 | 83 | # Prepare pre-fix post fix context 84 | empty_mfcc = np.array([]) 85 | empty_mfcc.resize((numcep)) 86 | 87 | # Prepare train_inputs with past and future contexts 88 | time_slices = range(train_inputs.shape[0]) 89 | context_past_min = time_slices[0] + numcontext 90 | context_future_max = time_slices[-1] - numcontext 91 | for time_slice in time_slices: 92 | # Reminder: array[start:stop:step] 93 | # slices from indice |start| up to |stop| (not included), every |step| 94 | 95 | # Add empty context data of the correct size to the start and end 96 | # of the MFCC feature matrix 97 | 98 | # Pick up to numcontext time slices in the past, and complete with empty 99 | # mfcc features 100 | need_empty_past = max(0, (context_past_min - time_slice)) 101 | empty_source_past = list(empty_mfcc for empty_slots in range(need_empty_past)) 102 | data_source_past = orig_inputs[max(0, time_slice - numcontext):time_slice] 103 | assert(len(empty_source_past) + len(data_source_past) == numcontext) 104 | 105 | # Pick up to numcontext time slices in the future, and complete with empty 106 | # mfcc features 107 | need_empty_future = max(0, (time_slice - context_future_max)) 108 | empty_source_future = list(empty_mfcc for empty_slots in range(need_empty_future)) 109 | data_source_future = orig_inputs[time_slice + 1:time_slice + numcontext + 1] 110 | assert(len(empty_source_future) + len(data_source_future) == numcontext) 111 | 112 | if need_empty_past: 113 | past = np.concatenate((empty_source_past, data_source_past)) 114 | else: 115 | past = data_source_past 116 | 117 | if need_empty_future: 118 | future = np.concatenate((data_source_future, empty_source_future)) 119 | else: 120 | future = data_source_future 121 | 122 | past = np.reshape(past, numcontext * numcep) 123 | now = orig_inputs[time_slice] 124 | future = np.reshape(future, numcontext * numcep) 125 | 126 | train_inputs[time_slice] = np.concatenate((past, now, future)) 127 | assert(len(train_inputs[time_slice]) == numcep + 2 * numcep * numcontext) 128 | 129 | # Scale/standardize the inputs 130 | # This can be done more efficiently in the TensorFlow graph 131 | train_inputs = (train_inputs - np.mean(train_inputs)) / np.std(train_inputs) 132 | return train_inputs 133 | 134 | 135 | def pad_sequences(sequences, maxlen=None, dtype=np.float32, 136 | padding='post', truncating='post', value=0.): 137 | 138 | ''' 139 | # From TensorLayer: 140 | # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/prepro.html 141 | 142 | Pads each sequence to the same length of the longest sequence. 143 | 144 | If maxlen is provided, any sequence longer than maxlen is truncated to 145 | maxlen. Truncation happens off either the beginning or the end 146 | (default) of the sequence. Supports post-padding (default) and 147 | pre-padding. 148 | 149 | Args: 150 | sequences: list of lists where each element is a sequence 151 | maxlen: int, maximum length 152 | dtype: type to cast the resulting sequence. 153 | padding: 'pre' or 'post', pad either before or after each sequence. 154 | truncating: 'pre' or 'post', remove values from sequences larger 155 | than maxlen either in the beginning or in the end of the sequence 156 | value: float, value to pad the sequences to the desired value. 157 | 158 | Returns: 159 | numpy.ndarray: Padded sequences shape = (number_of_sequences, maxlen) 160 | numpy.ndarray: original sequence lengths 161 | ''' 162 | lengths = np.asarray([len(s) for s in sequences], dtype=np.int64) 163 | 164 | nb_samples = len(sequences) 165 | if maxlen is None: 166 | maxlen = np.max(lengths) 167 | 168 | # take the sample shape from the first non empty sequence 169 | # checking for consistency in the main loop below. 170 | sample_shape = tuple() 171 | for s in sequences: 172 | if len(s) > 0: 173 | sample_shape = np.asarray(s).shape[1:] 174 | break 175 | 176 | x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype) 177 | for idx, s in enumerate(sequences): 178 | if len(s) == 0: 179 | continue # empty list was found 180 | if truncating == 'pre': 181 | trunc = s[-maxlen:] 182 | elif truncating == 'post': 183 | trunc = s[:maxlen] 184 | else: 185 | raise ValueError('Truncating type "%s" not understood' % truncating) 186 | 187 | # check `trunc` has expected shape 188 | trunc = np.asarray(trunc, dtype=dtype) 189 | if trunc.shape[1:] != sample_shape: 190 | raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' % 191 | (trunc.shape[1:], idx, sample_shape)) 192 | 193 | if padding == 'post': 194 | x[idx, :len(trunc)] = trunc 195 | elif padding == 'pre': 196 | x[idx, -len(trunc):] = trunc 197 | else: 198 | raise ValueError('Padding type "%s" not understood' % padding) 199 | return x, lengths 200 | -------------------------------------------------------------------------------- /src/features/utils/text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import unicodedata 3 | import codecs 4 | import re 5 | 6 | import tensorflow as tf 7 | 8 | # Constants 9 | SPACE_TOKEN = '' 10 | SPACE_INDEX = 0 11 | FIRST_INDEX = ord('a') - 1 # 0 is reserved to space 12 | 13 | 14 | def normalize_txt_file(txt_file, remove_apostrophe=True): 15 | """ 16 | Given a path to a text file, return contents with unsupported characters removed. 17 | """ 18 | with codecs.open(txt_file, encoding="utf-8") as open_txt_file: 19 | return normalize_text(open_txt_file.read(), remove_apostrophe=remove_apostrophe) 20 | 21 | 22 | def normalize_text(original, remove_apostrophe=True): 23 | """ 24 | Given a Python string ``original``, remove unsupported characters. 25 | 26 | The only supported characters are letters and apostrophes. 27 | """ 28 | # convert any unicode characters to ASCII equivalent 29 | # then ignore anything else and decode to a string 30 | result = unicodedata.normalize("NFKD", original).encode("ascii", "ignore").decode() 31 | if remove_apostrophe: 32 | # remove apostrophes to keep contractions together 33 | result = result.replace("'", "") 34 | # return lowercase alphabetic characters and apostrophes (if still present) 35 | return re.sub("[^a-zA-Z']+", ' ', result).strip().lower() 36 | 37 | 38 | def text_to_char_array(original): 39 | """ 40 | Given a Python string ``original``, map characters 41 | to integers and return a numpy array representing the processed string. 42 | 43 | This function has been modified from Mozilla DeepSpeech: 44 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py 45 | 46 | # This Source Code Form is subject to the terms of the Mozilla Public 47 | # License, v. 2.0. If a copy of the MPL was not distributed with this 48 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 49 | """ 50 | 51 | # Create list of sentence's words w/spaces replaced by '' 52 | result = original.replace(' ', ' ') 53 | result = result.split(' ') 54 | 55 | # Tokenize words into letters adding in SPACE_TOKEN where required 56 | result = np.hstack([SPACE_TOKEN if xt == '' else list(xt) for xt in result]) 57 | 58 | # Return characters mapped into indicies 59 | return np.asarray([SPACE_INDEX if xt == SPACE_TOKEN else ord(xt) - FIRST_INDEX for xt in result]) 60 | 61 | 62 | def sparse_tuple_from(sequences, dtype=np.int32): 63 | """ 64 | Create a sparse representention of ``sequences``. 65 | 66 | Args: 67 | sequences: a list of lists of type dtype where each element is a sequence 68 | Returns: 69 | A tuple with (indices, values, shape) 70 | 71 | This function has been modified from Mozilla DeepSpeech: 72 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py 73 | 74 | # This Source Code Form is subject to the terms of the Mozilla Public 75 | # License, v. 2.0. If a copy of the MPL was not distributed with this 76 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 77 | """ 78 | 79 | indices = [] 80 | values = [] 81 | 82 | for n, seq in enumerate(sequences): 83 | indices.extend(zip([n] * len(seq), range(len(seq)))) 84 | values.extend(seq) 85 | 86 | indices = np.asarray(indices, dtype=np.int64) 87 | values = np.asarray(values, dtype=dtype) 88 | shape = np.asarray([len(sequences), indices.max(0)[1] + 1], dtype=np.int64) 89 | 90 | # return tf.SparseTensor(indices=indices, values=values, shape=shape) 91 | return indices, values, shape 92 | 93 | 94 | def sparse_tensor_value_to_texts(value): 95 | """ 96 | Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings 97 | representing its values. 98 | 99 | This function has been modified from Mozilla DeepSpeech: 100 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py 101 | 102 | # This Source Code Form is subject to the terms of the Mozilla Public 103 | # License, v. 2.0. If a copy of the MPL was not distributed with this 104 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 105 | """ 106 | return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape)) 107 | 108 | 109 | def sparse_tuple_to_texts(tuple): 110 | ''' 111 | This function has been modified from Mozilla DeepSpeech: 112 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py 113 | 114 | # This Source Code Form is subject to the terms of the Mozilla Public 115 | # License, v. 2.0. If a copy of the MPL was not distributed with this 116 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 117 | ''' 118 | indices = tuple[0] 119 | values = tuple[1] 120 | results = [''] * tuple[2][0] 121 | for i in range(len(indices)): 122 | index = indices[i][0] 123 | c = values[i] 124 | c = ' ' if c == SPACE_INDEX else chr(c + FIRST_INDEX) 125 | results[index] = results[index] + c 126 | # List of strings 127 | return results 128 | 129 | 130 | def ndarray_to_text(value): 131 | ''' 132 | This function has been modified from Mozilla DeepSpeech: 133 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py 134 | 135 | # This Source Code Form is subject to the terms of the Mozilla Public 136 | # License, v. 2.0. If a copy of the MPL was not distributed with this 137 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 138 | ''' 139 | results = '' 140 | for i in range(len(value)): 141 | results += chr(value[i] + FIRST_INDEX) 142 | return results.replace('`', ' ') 143 | 144 | 145 | def gather_nd(params, indices, shape): 146 | ''' 147 | # Function aken from https://github.com/tensorflow/tensorflow/issues/206#issuecomment-229678962 148 | 149 | ''' 150 | rank = len(shape) 151 | flat_params = tf.reshape(params, [-1]) 152 | multipliers = [reduce(lambda x, y: x * y, shape[i + 1:], 1) for i in range(0, rank)] 153 | indices_unpacked = tf.unstack(tf.transpose(indices, [rank - 1] + range(0, rank - 1))) 154 | flat_indices = sum([a * b for a, b in zip(multipliers, indices_unpacked)]) 155 | return tf.gather(flat_params, flat_indices) 156 | 157 | 158 | def ctc_label_dense_to_sparse(labels, label_lengths, batch_size): 159 | ''' 160 | The CTC implementation in TensorFlow needs labels in a sparse representation, 161 | but sparse data and queues don't mix well, so we store padded tensors in the 162 | queue and convert to a sparse representation after dequeuing a batch. 163 | 164 | Taken from https://github.com/tensorflow/tensorflow/issues/1742#issuecomment-205291527 165 | ''' 166 | 167 | # The second dimension of labels must be equal to the longest label length in the batch 168 | correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths)) 169 | with tf.control_dependencies([correct_shape_assert]): 170 | labels = tf.identity(labels) 171 | 172 | label_shape = tf.shape(labels) 173 | num_batches_tns = tf.stack([label_shape[0]]) 174 | max_num_labels_tns = tf.stack([label_shape[1]]) 175 | 176 | def range_less_than(previous_state, current_input): 177 | return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input 178 | 179 | init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool) 180 | init = tf.expand_dims(init, 0) 181 | dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1) 182 | dense_mask = dense_mask[:, 0, :] 183 | 184 | label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape) 185 | 186 | label_ind = tf.boolean_mask(label_array, dense_mask) 187 | 188 | batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), 189 | tf.reverse(label_shape, [0])) 190 | ) 191 | batch_ind = tf.boolean_mask(batch_array, dense_mask) 192 | batch_label = tf.concat([batch_ind, label_ind], 0) 193 | indices = tf.transpose(tf.reshape(batch_label, [2, -1])) 194 | shape = [batch_size, tf.reduce_max(label_lengths)] 195 | vals_sparse = gather_nd(labels, indices, shape) 196 | 197 | return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape)) 198 | -------------------------------------------------------------------------------- /src/models/RNN/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/models/RNN/rnn.py: -------------------------------------------------------------------------------- 1 | # Note: All calls to tf.name_scope or tf.summary.* support TensorBoard visualization. 2 | 3 | import os 4 | import tensorflow as tf 5 | from configparser import ConfigParser 6 | 7 | from models.RNN.utils import variable_on_cpu 8 | 9 | 10 | def SimpleLSTM(conf_path, input_tensor, seq_length): 11 | ''' 12 | This function was initially based on open source code from Mozilla DeepSpeech: 13 | https://github.com/mozilla/DeepSpeech/blob/master/DeepSpeech.py 14 | 15 | # This Source Code Form is subject to the terms of the Mozilla Public 16 | # License, v. 2.0. If a copy of the MPL was not distributed with this 17 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 18 | ''' 19 | parser = ConfigParser(os.environ) 20 | parser.read(conf_path) 21 | 22 | # SimpleLSTM 23 | n_character = parser.getint('simplelstm', 'n_character') 24 | b1_stddev = parser.getfloat('simplelstm', 'b1_stddev') 25 | h1_stddev = parser.getfloat('simplelstm', 'h1_stddev') 26 | n_layers = parser.getint('simplelstm', 'n_layers') 27 | n_hidden_units = parser.getint('simplelstm', 'n_hidden_units') 28 | 29 | # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] 30 | # batch_x_shape = tf.shape(batch_x) 31 | 32 | input_tensor_shape = tf.shape(input_tensor) 33 | n_items = input_tensor_shape[0] 34 | 35 | with tf.name_scope("lstm"): 36 | # Initialize weights 37 | # with tf.device('/cpu:0'): 38 | W = tf.get_variable('W', shape=[n_hidden_units, n_character], 39 | # initializer=tf.truncated_normal_initializer(stddev=h1_stddev), 40 | initializer=tf.random_normal_initializer(stddev=h1_stddev), 41 | ) 42 | # Initialize bias 43 | # with tf.device('/cpu:0'): 44 | # b = tf.get_variable('b', initializer=tf.zeros_initializer([n_character])) 45 | b = tf.get_variable('b', shape=[n_character], 46 | # initializer=tf.constant_initializer(value=0), 47 | initializer=tf.random_normal_initializer(stddev=b1_stddev), 48 | ) 49 | 50 | # Define the cell 51 | # Can be: 52 | # tf.contrib.rnn.BasicRNNCell 53 | # tf.contrib.rnn.GRUCell 54 | cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, state_is_tuple=True) 55 | 56 | # Stacking rnn cells 57 | stack = tf.contrib.rnn.MultiRNNCell([cell] * n_layers, state_is_tuple=True) 58 | 59 | # Get layer activations (second output is the final state of the layer, do not need) 60 | outputs, _ = tf.nn.dynamic_rnn(stack, input_tensor, seq_length, 61 | time_major=False, dtype=tf.float32) 62 | 63 | # Reshape to apply the same weights over the timesteps 64 | outputs = tf.reshape(outputs, [-1, n_hidden_units]) 65 | 66 | # Perform affine transformation to layer output: 67 | # multiply by weights (linear transformation), add bias (translation) 68 | logits = tf.add(tf.matmul(outputs, W), b) 69 | 70 | tf.summary.histogram("weights", W) 71 | tf.summary.histogram("biases", b) 72 | tf.summary.histogram("activations", logits) 73 | 74 | # Reshaping back to the original shape 75 | logits = tf.reshape(logits, [n_items, -1, n_character]) 76 | 77 | # Put time as the major axis 78 | logits = tf.transpose(logits, (1, 0, 2)) 79 | 80 | summary_op = tf.summary.merge_all() 81 | 82 | return logits, summary_op 83 | 84 | def BiRNN(conf_path, batch_x, seq_length, n_input, n_context): 85 | """ 86 | This function was initially based on open source code from Mozilla DeepSpeech: 87 | https://github.com/mozilla/DeepSpeech/blob/master/DeepSpeech.py 88 | 89 | # This Source Code Form is subject to the terms of the Mozilla Public 90 | # License, v. 2.0. If a copy of the MPL was not distributed with this 91 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 92 | """ 93 | parser = ConfigParser(os.environ) 94 | parser.read(conf_path) 95 | 96 | dropout = [float(x) for x in parser.get('birnn', 'dropout_rates').split(',')] 97 | relu_clip = parser.getint('birnn', 'relu_clip') 98 | 99 | b1_stddev = parser.getfloat('birnn', 'b1_stddev') 100 | h1_stddev = parser.getfloat('birnn', 'h1_stddev') 101 | b2_stddev = parser.getfloat('birnn', 'b2_stddev') 102 | h2_stddev = parser.getfloat('birnn', 'h2_stddev') 103 | b3_stddev = parser.getfloat('birnn', 'b3_stddev') 104 | h3_stddev = parser.getfloat('birnn', 'h3_stddev') 105 | b5_stddev = parser.getfloat('birnn', 'b5_stddev') 106 | h5_stddev = parser.getfloat('birnn', 'h5_stddev') 107 | b6_stddev = parser.getfloat('birnn', 'b6_stddev') 108 | h6_stddev = parser.getfloat('birnn', 'h6_stddev') 109 | 110 | n_hidden_1 = parser.getint('birnn', 'n_hidden_1') 111 | n_hidden_2 = parser.getint('birnn', 'n_hidden_2') 112 | n_hidden_5 = parser.getint('birnn', 'n_hidden_5') 113 | n_cell_dim = parser.getint('birnn', 'n_cell_dim') 114 | 115 | n_hidden_3 = int(eval(parser.get('birnn', 'n_hidden_3'))) 116 | n_hidden_6 = parser.getint('birnn', 'n_hidden_6') 117 | 118 | # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] 119 | batch_x_shape = tf.shape(batch_x) 120 | 121 | # Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`. 122 | # This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`. 123 | 124 | # Permute n_steps and batch_size 125 | batch_x = tf.transpose(batch_x, [1, 0, 2]) 126 | # Reshape to prepare input for first layer 127 | batch_x = tf.reshape(batch_x, 128 | [-1, n_input + 2 * n_input * n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) 129 | 130 | # The next three blocks will pass `batch_x` through three hidden layers with 131 | # clipped RELU activation and dropout. 132 | 133 | # 1st layer 134 | with tf.name_scope('fc1'): 135 | b1 = variable_on_cpu('b1', [n_hidden_1], tf.random_normal_initializer(stddev=b1_stddev)) 136 | h1 = variable_on_cpu('h1', [n_input + 2 * n_input * n_context, n_hidden_1], 137 | tf.random_normal_initializer(stddev=h1_stddev)) 138 | layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), relu_clip) 139 | layer_1 = tf.nn.dropout(layer_1, (1.0 - dropout[0])) 140 | 141 | tf.summary.histogram("weights", h1) 142 | tf.summary.histogram("biases", b1) 143 | tf.summary.histogram("activations", layer_1) 144 | 145 | # 2nd layer 146 | with tf.name_scope('fc2'): 147 | b2 = variable_on_cpu('b2', [n_hidden_2], tf.random_normal_initializer(stddev=b2_stddev)) 148 | h2 = variable_on_cpu('h2', [n_hidden_1, n_hidden_2], tf.random_normal_initializer(stddev=h2_stddev)) 149 | layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), relu_clip) 150 | layer_2 = tf.nn.dropout(layer_2, (1.0 - dropout[1])) 151 | 152 | tf.summary.histogram("weights", h2) 153 | tf.summary.histogram("biases", b2) 154 | tf.summary.histogram("activations", layer_2) 155 | 156 | # 3rd layer 157 | with tf.name_scope('fc3'): 158 | b3 = variable_on_cpu('b3', [n_hidden_3], tf.random_normal_initializer(stddev=b3_stddev)) 159 | h3 = variable_on_cpu('h3', [n_hidden_2, n_hidden_3], tf.random_normal_initializer(stddev=h3_stddev)) 160 | layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), relu_clip) 161 | layer_3 = tf.nn.dropout(layer_3, (1.0 - dropout[2])) 162 | 163 | tf.summary.histogram("weights", h3) 164 | tf.summary.histogram("biases", b3) 165 | tf.summary.histogram("activations", layer_3) 166 | 167 | # Create the forward and backward LSTM units. Inputs have length `n_cell_dim`. 168 | # LSTM forget gate bias initialized at `1.0` (default), meaning less forgetting 169 | # at the beginning of training (remembers more previous info) 170 | with tf.name_scope('lstm'): 171 | # Forward direction cell: 172 | lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) 173 | lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, 174 | input_keep_prob=1.0 - dropout[3], 175 | output_keep_prob=1.0 - dropout[3], 176 | # seed=random_seed, 177 | ) 178 | # Backward direction cell: 179 | lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) 180 | lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, 181 | input_keep_prob=1.0 - dropout[4], 182 | output_keep_prob=1.0 - dropout[4], 183 | # seed=random_seed, 184 | ) 185 | 186 | # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, 187 | # as the LSTM BRNN expects its input to be of shape `[max_time, batch_size, input_size]`. 188 | layer_3 = tf.reshape(layer_3, [-1, batch_x_shape[0], n_hidden_3]) 189 | 190 | # Now we feed `layer_3` into the LSTM BRNN cell and obtain the LSTM BRNN output. 191 | outputs, output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_fw_cell, 192 | cell_bw=lstm_bw_cell, 193 | inputs=layer_3, 194 | dtype=tf.float32, 195 | time_major=True, 196 | sequence_length=seq_length) 197 | 198 | tf.summary.histogram("activations", outputs) 199 | 200 | # Reshape outputs from two tensors each of shape [n_steps, batch_size, n_cell_dim] 201 | # to a single tensor of shape [n_steps*batch_size, 2*n_cell_dim] 202 | outputs = tf.concat(outputs, 2) 203 | outputs = tf.reshape(outputs, [-1, 2 * n_cell_dim]) 204 | 205 | with tf.name_scope('fc5'): 206 | # Now we feed `outputs` to the fifth hidden layer with clipped RELU activation and dropout 207 | b5 = variable_on_cpu('b5', [n_hidden_5], tf.random_normal_initializer(stddev=b5_stddev)) 208 | h5 = variable_on_cpu('h5', [(2 * n_cell_dim), n_hidden_5], tf.random_normal_initializer(stddev=h5_stddev)) 209 | layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(outputs, h5), b5)), relu_clip) 210 | layer_5 = tf.nn.dropout(layer_5, (1.0 - dropout[5])) 211 | 212 | tf.summary.histogram("weights", h5) 213 | tf.summary.histogram("biases", b5) 214 | tf.summary.histogram("activations", layer_5) 215 | 216 | with tf.name_scope('fc6'): 217 | # Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5` 218 | # creating `n_classes` dimensional vectors, the logits. 219 | b6 = variable_on_cpu('b6', [n_hidden_6], tf.random_normal_initializer(stddev=b6_stddev)) 220 | h6 = variable_on_cpu('h6', [n_hidden_5, n_hidden_6], tf.random_normal_initializer(stddev=h6_stddev)) 221 | layer_6 = tf.add(tf.matmul(layer_5, h6), b6) 222 | 223 | tf.summary.histogram("weights", h6) 224 | tf.summary.histogram("biases", b6) 225 | tf.summary.histogram("activations", layer_6) 226 | 227 | # Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6] 228 | # to the slightly more useful shape [n_steps, batch_size, n_hidden_6]. 229 | # Note, that this differs from the input in that it is time-major. 230 | layer_6 = tf.reshape(layer_6, [-1, batch_x_shape[0], n_hidden_6]) 231 | 232 | summary_op = tf.summary.merge_all() 233 | 234 | # Output shape: [n_steps, batch_size, n_hidden_6] 235 | return layer_6, summary_op 236 | -------------------------------------------------------------------------------- /src/models/RNN/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | from configparser import ConfigParser 5 | from utils.set_dirs import get_conf_dir 6 | 7 | conf_dir = get_conf_dir(debug=False) 8 | parser = ConfigParser(os.environ) 9 | parser.read(os.path.join(conf_dir, 'neural_network.ini')) 10 | 11 | # AdamOptimizer 12 | beta1 = parser.getfloat('optimizer', 'beta1') 13 | beta2 = parser.getfloat('optimizer', 'beta2') 14 | epsilon = parser.getfloat('optimizer', 'epsilon') 15 | learning_rate = parser.getfloat('optimizer', 'learning_rate') 16 | 17 | 18 | def variable_on_cpu(name, shape, initializer): 19 | """ 20 | Next we concern ourselves with graph creation. 21 | However, before we do so we must introduce a utility function ``variable_on_cpu()`` 22 | used to create a variable in CPU memory. 23 | """ 24 | # Use the /cpu:0 device for scoped operations 25 | with tf.device('/cpu:0'): 26 | # Create or get apropos variable 27 | var = tf.get_variable(name=name, shape=shape, initializer=initializer) 28 | return var 29 | 30 | 31 | def create_optimizer(): 32 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 33 | beta1=beta1, 34 | beta2=beta2, 35 | epsilon=epsilon) 36 | return optimizer 37 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/models/__init__.py -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/train_framework/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/tests/train_framework/__init__.py -------------------------------------------------------------------------------- /src/tests/train_framework/tf_train_ctc_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import unittest 3 | import os 4 | import logging 5 | import tensorflow as tf 6 | 7 | # custom modules 8 | import train_framework.tf_train_ctc as tf_train 9 | 10 | 11 | class TestTrain_ctc(unittest.TestCase): 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | def setUp(self): 15 | ''' 16 | Create the Tf_train_ctc instance 17 | ''' 18 | self.tf_train_ctc = tf_train.Tf_train_ctc( 19 | config_file='neural_network.ini', 20 | debug=True, 21 | model_name=None) 22 | 23 | def tearDown(self): 24 | ''' 25 | Close TF session if available 26 | ''' 27 | if hasattr(self.tf_train_ctc, 'sess'): 28 | self.tf_train_ctc.sess.close() 29 | self.tf_train_ctc.writer.flush() 30 | 31 | def test_setup_tf_train_framework(self): 32 | ''' 33 | Does the instance have the expected fields (i.e. type casting) 34 | ''' 35 | tf_train_ctc = self.tf_train_ctc 36 | 37 | # make sure everything is loaded correctly 38 | self.assertEqual(isinstance(tf_train_ctc.epochs, int), True) 39 | self.assertEqual(isinstance(tf_train_ctc.network_type, str), True) 40 | self.assertEqual(isinstance(tf_train_ctc.n_input, int), True) 41 | self.assertEqual(isinstance(tf_train_ctc.n_context, int), True) 42 | self.assertEqual(isinstance(tf_train_ctc.model_dir, str), True) 43 | self.assertEqual(isinstance(tf_train_ctc.session_name, str), True) 44 | self.assertEqual(isinstance(tf_train_ctc.SAVE_MODEL_EPOCH_NUM, int), True) 45 | self.assertEqual(isinstance(tf_train_ctc.VALIDATION_EPOCH_NUM, int), True) 46 | self.assertEqual(isinstance(tf_train_ctc.CURR_VALIDATION_LER_DIFF, float), True) 47 | self.assertNotEqual(tf_train_ctc.beam_search_decoder, None) 48 | self.assertEqual(isinstance(tf_train_ctc.AVG_VALIDATION_LERS, list), True) 49 | self.assertEqual(isinstance(tf_train_ctc.shuffle_data_after_epoch, bool), True) 50 | self.assertEqual(isinstance(tf_train_ctc.min_dev_ler, float), True) 51 | 52 | # tests associated with data_sets object 53 | self.assertNotEqual(tf_train_ctc.data_sets, None) 54 | self.assertTrue(tf_train_ctc.sets == ['train', 'dev', 'test']) # this will vary if changed in future 55 | self.assertEqual(isinstance(tf_train_ctc.n_examples_train, int), True) 56 | self.assertEqual(isinstance(tf_train_ctc.n_examples_dev, int), True) 57 | self.assertEqual(isinstance(tf_train_ctc.n_examples_test, int), True) 58 | self.assertEqual(isinstance(tf_train_ctc.batch_size, int), True) 59 | self.assertEqual(isinstance(tf_train_ctc.n_batches_per_epoch, int), True) 60 | 61 | # make sure folders are made 62 | self.assertTrue(os.path.exists(tf_train_ctc.SESSION_DIR)) 63 | self.assertTrue(os.path.exists(tf_train_ctc.SUMMARY_DIR)) 64 | 65 | # since model_name set as None, model_path should be None 66 | self.assertEqual(tf_train_ctc.model_path, None) 67 | 68 | def test_run_tf_train_gpu(self): 69 | ''' 70 | Can a small model run on the GPU? 71 | ''' 72 | tf_train_ctc = self.tf_train_ctc 73 | tf_train_ctc.run_model() 74 | 75 | # Verify objects of the train framework were created in the test training run 76 | self.assertNotEqual(tf_train_ctc.graph, None) 77 | self.assertNotEqual(tf_train_ctc.sess, None) 78 | self.assertNotEqual(tf_train_ctc.writer, None) 79 | self.assertNotEqual(tf_train_ctc.saver, None) 80 | self.assertNotEqual(tf_train_ctc.logits, None) 81 | 82 | # Verify some learning has been done 83 | self.assertTrue(tf_train_ctc.train_ler > 0) 84 | 85 | # make sure the input targets (i.e. .txt files) are not empty string 86 | self.assertTrue(tf_train_ctc.dense_labels is not '') 87 | 88 | # shutdown the running model 89 | self.tearDown() 90 | 91 | def test_verify_if_cpu_can_be_used(self): 92 | ''' 93 | Can a small model run on the CPU? 94 | ''' 95 | tf_train_ctc = self.tf_train_ctc 96 | tf_train_ctc.tf_device = '/cpu:0' 97 | 98 | tf_train_ctc.graph = tf.Graph() 99 | with tf_train_ctc.graph.as_default(), tf.device(tf_train_ctc.tf_device): 100 | with tf.device(tf_train_ctc.tf_device): 101 | tf_train_ctc.setup_network_and_graph() 102 | tf_train_ctc.load_placeholder_into_network() 103 | tf_train_ctc.setup_loss_function() 104 | tf_train_ctc.setup_optimizer() 105 | tf_train_ctc.setup_decoder() 106 | tf_train_ctc.setup_summary_statistics() 107 | tf_train_ctc.sess = tf.Session( 108 | config=tf.ConfigProto(allow_soft_placement=True), graph=tf_train_ctc.graph) 109 | 110 | # initialize the summary writer 111 | tf_train_ctc.writer = tf.summary.FileWriter( 112 | tf_train_ctc.SUMMARY_DIR, graph=tf_train_ctc.sess.graph) 113 | 114 | # Add ops to save and restore all the variables 115 | tf_train_ctc.saver = tf.train.Saver() 116 | 117 | # Verify objects of the train framework were created in the test training run 118 | self.assertNotEqual(tf_train_ctc.graph, None) 119 | self.assertNotEqual(tf_train_ctc.sess, None) 120 | self.assertNotEqual(tf_train_ctc.writer, None) 121 | self.assertNotEqual(tf_train_ctc.saver, None) 122 | 123 | # shutdown the running model 124 | self.tearDown() 125 | 126 | 127 | if __name__ == '__main__': 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /src/train_framework/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/train_framework/__init__.py -------------------------------------------------------------------------------- /src/train_framework/tf_train_ctc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import numpy as np 5 | import time 6 | import warnings 7 | from configparser import ConfigParser 8 | import logging 9 | 10 | import tensorflow as tf 11 | from tensorflow.python.ops import ctc_ops 12 | 13 | # Custom modules 14 | from features.utils.text import ndarray_to_text, sparse_tuple_to_texts 15 | 16 | # in future different than utils class 17 | from models.RNN.utils import create_optimizer 18 | from data_manipulation.datasets import read_datasets 19 | from utils.set_dirs import get_conf_dir, get_model_dir 20 | import utils.gpu as gpu_tool 21 | 22 | # Import the setup scripts for different types of model 23 | from models.RNN.rnn import BiRNN as BiRNN_model 24 | from models.RNN.rnn import SimpleLSTM as SimpleLSTM_model 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class Tf_train_ctc(object): 30 | ''' 31 | Class to train a speech recognition model with TensorFlow. 32 | 33 | Requirements: 34 | - TensorFlow 1.0.1 35 | - Python 3.5 36 | - Configuration: $RNN_TUTORIAL/configs/neural_network.ini 37 | 38 | Features: 39 | - Batch loading of input data 40 | - Checkpoints model 41 | - Label error rate is the edit (Levenshtein) distance of the top path vs true sentence 42 | - Logs summary stats for TensorBoard 43 | - Epoch 1: Train starting with shortest transcriptions, then shuffle 44 | 45 | # Note: All calls to tf.name_scope or tf.summary.* support TensorBoard visualization. 46 | 47 | This class was initially based on open source code from Mozilla DeepSpeech: 48 | https://github.com/mozilla/DeepSpeech/blob/master/DeepSpeech.py 49 | 50 | # This Source Code Form is subject to the terms of the Mozilla Public 51 | # License, v. 2.0. If a copy of the MPL was not distributed with this 52 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 53 | ''' 54 | 55 | def __init__(self, 56 | config_file='neural_network.ini', 57 | model_name=None, 58 | debug=False): 59 | # set TF logging verbosity 60 | tf.logging.set_verbosity(tf.logging.INFO) 61 | 62 | # Load the configuration file depending on debug True/False 63 | self.debug = debug 64 | self.conf_path = get_conf_dir(debug=self.debug) 65 | self.conf_path = os.path.join(self.conf_path, config_file) 66 | self.load_configs() 67 | 68 | # Verify that the GPU is operational, if not use CPU 69 | if not gpu_tool.check_if_gpu_available(self.tf_device): 70 | self.tf_device = '/cpu:0' 71 | logging.info('Using this device for main computations: %s', self.tf_device) 72 | 73 | # set the directories 74 | self.set_up_directories(model_name) 75 | 76 | # set up the model 77 | self.set_up_model() 78 | 79 | def load_configs(self): 80 | parser = ConfigParser(os.environ) 81 | if not os.path.exists(self.conf_path): 82 | raise IOError("Configuration file '%s' does not exist" % self.conf_path) 83 | logging.info('Loading config from %s', self.conf_path) 84 | parser.read(self.conf_path) 85 | 86 | # set which set of configs to import 87 | config_header = 'nn' 88 | 89 | logger.info('config header: %s', config_header) 90 | 91 | self.epochs = parser.getint(config_header, 'epochs') 92 | logger.debug('self.epochs = %d', self.epochs) 93 | 94 | self.network_type = parser.get(config_header, 'network_type') 95 | 96 | # Number of mfcc features, 13 or 26 97 | self.n_input = parser.getint(config_header, 'n_input') 98 | 99 | # Number of contextual samples to include 100 | self.n_context = parser.getint(config_header, 'n_context') 101 | 102 | # self.decode_train = parser.getboolean(config_header, 'decode_train') 103 | # self.random_seed = parser.getint(config_header, 'random_seed') 104 | self.model_dir = parser.get(config_header, 'model_dir') 105 | 106 | # set the session name 107 | self.session_name = '{}_{}'.format( 108 | self.network_type, time.strftime("%Y%m%d-%H%M%S")) 109 | sess_prefix_str = 'develop' 110 | if len(sess_prefix_str) > 0: 111 | self.session_name = '{}_{}'.format( 112 | sess_prefix_str, self.session_name) 113 | 114 | # How often to save the model 115 | self.SAVE_MODEL_EPOCH_NUM = parser.getint( 116 | config_header, 'SAVE_MODEL_EPOCH_NUM') 117 | 118 | # decode dev set after N epochs 119 | self.VALIDATION_EPOCH_NUM = parser.getint( 120 | config_header, 'VALIDATION_EPOCH_NUM') 121 | 122 | # decide when to stop training prematurely 123 | self.CURR_VALIDATION_LER_DIFF = parser.getfloat( 124 | config_header, 'CURR_VALIDATION_LER_DIFF') 125 | 126 | self.AVG_VALIDATION_LER_EPOCHS = parser.getint( 127 | config_header, 'AVG_VALIDATION_LER_EPOCHS') 128 | # initialize list to hold average validation at end of each epoch 129 | self.AVG_VALIDATION_LERS = [ 130 | 1.0 for _ in range(self.AVG_VALIDATION_LER_EPOCHS)] 131 | 132 | # setup type of decoder 133 | self.beam_search_decoder = parser.get( 134 | config_header, 'beam_search_decoder') 135 | 136 | # determine if the data input order should be shuffled after every epic 137 | self.shuffle_data_after_epoch = parser.getboolean( 138 | config_header, 'shuffle_data_after_epoch') 139 | 140 | # initialize to store the minimum validation set label error rate 141 | self.min_dev_ler = parser.getfloat(config_header, 'min_dev_ler') 142 | 143 | # set up GPU if available 144 | self.tf_device = str(parser.get(config_header, 'tf_device')) 145 | 146 | # set up the max amount of simultaneous users 147 | # this restricts GPU usage to the inverse of self.simultaneous_users_count 148 | self.simultaneous_users_count = parser.getint(config_header, 'simultaneous_users_count') 149 | 150 | def set_up_directories(self, model_name): 151 | # Set up model directory 152 | self.model_dir = os.path.join(get_model_dir(), self.model_dir) 153 | # summary will contain logs 154 | self.SUMMARY_DIR = os.path.join( 155 | self.model_dir, "summary", self.session_name) 156 | # session will contain models 157 | self.SESSION_DIR = os.path.join( 158 | self.model_dir, "session", self.session_name) 159 | 160 | if not os.path.exists(self.SESSION_DIR): 161 | os.makedirs(self.SESSION_DIR) 162 | if not os.path.exists(self.SUMMARY_DIR): 163 | os.makedirs(self.SUMMARY_DIR) 164 | 165 | # set the model name and restore if not None 166 | if model_name is not None: 167 | self.model_path = os.path.join(self.SESSION_DIR, model_name) 168 | else: 169 | self.model_path = None 170 | 171 | def set_up_model(self): 172 | self.sets = ['train', 'dev', 'test'] 173 | 174 | # read data set, inherits configuration path 175 | # to parse the config file for where data lives 176 | self.data_sets = read_datasets(self.conf_path, 177 | self.sets, 178 | self.n_input, 179 | self.n_context 180 | ) 181 | 182 | self.n_examples_train = len(self.data_sets.train._txt_files) 183 | self.n_examples_dev = len(self.data_sets.dev._txt_files) 184 | self.n_examples_test = len(self.data_sets.test._txt_files) 185 | self.batch_size = self.data_sets.train._batch_size 186 | self.n_batches_per_epoch = int(np.ceil( 187 | self.n_examples_train / self.batch_size)) 188 | 189 | logger.info('''Training model: {} 190 | Train examples: {:,} 191 | Dev examples: {:,} 192 | Test examples: {:,} 193 | Epochs: {} 194 | Training batch size: {} 195 | Batches per epoch: {}'''.format( 196 | self.session_name, 197 | self.n_examples_train, 198 | self.n_examples_dev, 199 | self.n_examples_test, 200 | self.epochs, 201 | self.batch_size, 202 | self.n_batches_per_epoch)) 203 | 204 | def run_model(self): 205 | self.graph = tf.Graph() 206 | with self.graph.as_default(), tf.device('/cpu:0'): 207 | 208 | with tf.device(self.tf_device): 209 | # Run multiple functions on the specificed tf_device 210 | # tf_device GPU set in configs, but is overridden if not available 211 | self.setup_network_and_graph() 212 | self.load_placeholder_into_network() 213 | self.setup_loss_function() 214 | self.setup_optimizer() 215 | self.setup_decoder() 216 | 217 | self.setup_summary_statistics() 218 | 219 | # create the configuration for the session 220 | tf_config = tf.ConfigProto() 221 | tf_config.allow_soft_placement = True 222 | tf_config.gpu_options.per_process_gpu_memory_fraction = \ 223 | (1.0 / self.simultaneous_users_count) 224 | 225 | # create the session 226 | self.sess = tf.Session(config=tf_config) 227 | 228 | # initialize the summary writer 229 | self.writer = tf.summary.FileWriter( 230 | self.SUMMARY_DIR, graph=self.sess.graph) 231 | 232 | # Add ops to save and restore all the variables 233 | self.saver = tf.train.Saver() 234 | 235 | # For printing out section headers 236 | section = '\n{0:=^40}\n' 237 | 238 | # If there is a model_path declared, then restore the model 239 | if self.model_path is not None: 240 | self.saver.restore(self.sess, self.model_path) 241 | # If there is NOT a model_path declared, build the model from scratch 242 | else: 243 | # Op to initialize the variables 244 | init_op = tf.global_variables_initializer() 245 | 246 | # Initializate the weights and biases 247 | self.sess.run(init_op) 248 | 249 | # MAIN LOGIC for running the training epochs 250 | logger.info(section.format('Run training epoch')) 251 | self.run_training_epochs() 252 | 253 | logger.info(section.format('Decoding test data')) 254 | # make the assumption for working on the test data, that the epoch here is the last epoch 255 | _, self.test_ler = self.run_batches(self.data_sets.test, is_training=False, 256 | decode=True, write_to_file=False, epoch=self.epochs) 257 | 258 | # Add the final test data to the summary writer 259 | # (single point on the graph for end of training run) 260 | summary_line = self.sess.run( 261 | self.test_ler_op, {self.ler_placeholder: self.test_ler}) 262 | self.writer.add_summary(summary_line, self.epochs) 263 | 264 | logger.info('Test Label Error Rate: {}'.format(self.test_ler)) 265 | 266 | # save train summaries to disk 267 | self.writer.flush() 268 | 269 | self.sess.close() 270 | 271 | def setup_network_and_graph(self): 272 | # e.g: log filter bank or MFCC features 273 | # shape = [batch_size, max_stepsize, n_input + (2 * n_input * n_context)] 274 | # the batch_size and max_stepsize can vary along each step 275 | self.input_tensor = tf.placeholder( 276 | tf.float32, [None, None, self.n_input + (2 * self.n_input * self.n_context)], name='input') 277 | 278 | # Use sparse_placeholder; will generate a SparseTensor, required by ctc_loss op. 279 | self.targets = tf.sparse_placeholder(tf.int32, name='targets') 280 | # 1d array of size [batch_size] 281 | self.seq_length = tf.placeholder(tf.int32, [None], name='seq_length') 282 | 283 | def load_placeholder_into_network(self): 284 | # logits is the non-normalized output/activations from the last layer. 285 | # logits will be input for the loss function. 286 | # nn_model is from the import statement in the load_model function 287 | # summary_op variables are for tensorboard 288 | if self.network_type == 'SimpleLSTM': 289 | self.logits, summary_op = SimpleLSTM_model( 290 | self.conf_path, 291 | self.input_tensor, 292 | tf.to_int64(self.seq_length) 293 | ) 294 | elif self.network_type == 'BiRNN': 295 | self.logits, summary_op = BiRNN_model( 296 | self.conf_path, 297 | self.input_tensor, 298 | tf.to_int64(self.seq_length), 299 | self.n_input, 300 | self.n_context 301 | ) 302 | else: 303 | raise ValueError('network_type must be SimpleLSTM or BiRNN') 304 | self.summary_op = tf.summary.merge([summary_op]) 305 | 306 | def setup_loss_function(self): 307 | with tf.name_scope("loss"): 308 | self.total_loss = ctc_ops.ctc_loss( 309 | self.targets, self.logits, self.seq_length) 310 | self.avg_loss = tf.reduce_mean(self.total_loss) 311 | self.loss_summary = tf.summary.scalar("avg_loss", self.avg_loss) 312 | 313 | self.cost_placeholder = tf.placeholder(dtype=tf.float32, shape=[]) 314 | 315 | self.train_cost_op = tf.summary.scalar( 316 | "train_avg_loss", self.cost_placeholder) 317 | 318 | def setup_optimizer(self): 319 | # Note: The optimizer is created in models/RNN/utils.py 320 | with tf.name_scope("train"): 321 | self.optimizer = create_optimizer() 322 | self.optimizer = self.optimizer.minimize(self.avg_loss) 323 | 324 | def setup_decoder(self): 325 | with tf.name_scope("decode"): 326 | if self.beam_search_decoder == 'default': 327 | self.decoded, self.log_prob = ctc_ops.ctc_beam_search_decoder( 328 | self.logits, self.seq_length, merge_repeated=False) 329 | elif self.beam_search_decoder == 'greedy': 330 | self.decoded, self.log_prob = ctc_ops.ctc_greedy_decoder( 331 | self.logits, self.seq_length, merge_repeated=False) 332 | else: 333 | logging.warning("Invalid beam search decoder option selected!") 334 | 335 | def setup_summary_statistics(self): 336 | # Create a placholder for the summary statistics 337 | with tf.name_scope("accuracy"): 338 | # Compute the edit (Levenshtein) distance of the top path 339 | distance = tf.edit_distance( 340 | tf.cast(self.decoded[0], tf.int32), self.targets) 341 | 342 | # Compute the label error rate (accuracy) 343 | self.ler = tf.reduce_mean(distance, name='label_error_rate') 344 | self.ler_placeholder = tf.placeholder(dtype=tf.float32, shape=[]) 345 | self.train_ler_op = tf.summary.scalar( 346 | "train_label_error_rate", self.ler_placeholder) 347 | self.dev_ler_op = tf.summary.scalar( 348 | "validation_label_error_rate", self.ler_placeholder) 349 | self.test_ler_op = tf.summary.scalar( 350 | "test_label_error_rate", self.ler_placeholder) 351 | 352 | def run_training_epochs(self): 353 | train_start = time.time() 354 | for epoch in range(self.epochs): 355 | # Initialize variables that can be updated 356 | save_dev_model = False 357 | stop_training = False 358 | is_checkpoint_step, is_validation_step = \ 359 | self.validation_and_checkpoint_check(epoch) 360 | 361 | epoch_start = time.time() 362 | 363 | self.train_cost, self.train_ler = self.run_batches( 364 | self.data_sets.train, 365 | is_training=True, 366 | decode=False, 367 | write_to_file=False, 368 | epoch=epoch) 369 | 370 | epoch_duration = time.time() - epoch_start 371 | 372 | log = 'Epoch {}/{}, train_cost: {:.3f}, \ 373 | train_ler: {:.3f}, time: {:.2f} sec' 374 | logger.info(log.format( 375 | epoch + 1, 376 | self.epochs, 377 | self.train_cost, 378 | self.train_ler, 379 | epoch_duration)) 380 | 381 | summary_line = self.sess.run( 382 | self.train_ler_op, {self.ler_placeholder: self.train_ler}) 383 | self.writer.add_summary(summary_line, epoch) 384 | 385 | summary_line = self.sess.run( 386 | self.train_cost_op, {self.cost_placeholder: self.train_cost}) 387 | self.writer.add_summary(summary_line, epoch) 388 | 389 | # Shuffle the data for the next epoch 390 | if self.shuffle_data_after_epoch: 391 | np.random.shuffle(self.data_sets.train._txt_files) 392 | 393 | # Run validation if it was determined to run a validation step 394 | if is_validation_step: 395 | self.run_validation_step(epoch) 396 | 397 | if (epoch + 1) == self.epochs or is_checkpoint_step: 398 | # save the final model 399 | save_path = self.saver.save(self.sess, os.path.join( 400 | self.SESSION_DIR, 'model.ckpt'), epoch) 401 | logger.info("Model saved: {}".format(save_path)) 402 | 403 | if save_dev_model: 404 | # If the dev set is not improving, 405 | # the training is killed to prevent overfitting 406 | # And then save the best validation performance model 407 | save_path = self.saver.save(self.sess, os.path.join( 408 | self.SESSION_DIR, 'model-best.ckpt')) 409 | logger.info( 410 | "Model with best validation label error rate saved: {}". 411 | format(save_path)) 412 | 413 | if stop_training: 414 | break 415 | 416 | train_duration = time.time() - train_start 417 | logger.info('Training complete, total duration: {:.2f} min'.format( 418 | train_duration / 60)) 419 | 420 | def run_validation_step(self, epoch): 421 | dev_ler = 0 422 | 423 | _, dev_ler = self.run_batches(self.data_sets.dev, 424 | is_training=False, 425 | decode=True, 426 | write_to_file=False, 427 | epoch=epoch) 428 | 429 | logger.info('Validation Label Error Rate: {}'.format(dev_ler)) 430 | 431 | summary_line = self.sess.run( 432 | self.dev_ler_op, {self.ler_placeholder: dev_ler}) 433 | self.writer.add_summary(summary_line, epoch) 434 | 435 | if dev_ler < self.min_dev_ler: 436 | self.min_dev_ler = dev_ler 437 | 438 | # average historical LER 439 | history_avg_ler = np.mean(self.AVG_VALIDATION_LERS) 440 | 441 | # if this LER is not better than average of previous epochs, exit 442 | if history_avg_ler - dev_ler <= self.CURR_VALIDATION_LER_DIFF: 443 | log = "Validation label error rate not improved by more than {:.2%} \ 444 | after {} epochs. Exit" 445 | warnings.warn(log.format(self.CURR_VALIDATION_LER_DIFF, 446 | self.AVG_VALIDATION_LER_EPOCHS)) 447 | 448 | # save avg validation accuracy in the next slot 449 | self.AVG_VALIDATION_LERS[ 450 | epoch % self.AVG_VALIDATION_LER_EPOCHS] = dev_ler 451 | 452 | def validation_and_checkpoint_check(self, epoch): 453 | # initially set at False unless indicated to change 454 | is_checkpoint_step = False 455 | is_validation_step = False 456 | 457 | # Check if the current epoch is a validation or checkpoint step 458 | if (epoch > 0) and ((epoch + 1) != self.epochs): 459 | if (epoch + 1) % self.SAVE_MODEL_EPOCH_NUM == 0: 460 | is_checkpoint_step = True 461 | if (epoch + 1) % self.VALIDATION_EPOCH_NUM == 0: 462 | is_validation_step = True 463 | 464 | return is_checkpoint_step, is_validation_step 465 | 466 | def run_batches(self, dataset, is_training, decode, write_to_file, epoch): 467 | n_examples = len(dataset._txt_files) 468 | 469 | n_batches_per_epoch = int(np.ceil(n_examples / dataset._batch_size)) 470 | 471 | self.train_cost = 0 472 | self.train_ler = 0 473 | 474 | for batch in range(n_batches_per_epoch): 475 | # Get next batch of training data (audio features) and transcripts 476 | source, source_lengths, sparse_labels = dataset.next_batch() 477 | 478 | feed = {self.input_tensor: source, 479 | self.targets: sparse_labels, 480 | self.seq_length: source_lengths} 481 | 482 | # If the is_training is false, this means straight decoding without computing loss 483 | if is_training: 484 | # avg_loss is the loss_op, optimizer is the train_op; 485 | # running these pushes tensors (data) through graph 486 | batch_cost, _ = self.sess.run( 487 | [self.avg_loss, self.optimizer], feed) 488 | self.train_cost += batch_cost * dataset._batch_size 489 | logger.debug('Batch cost: %.2f | Train cost: %.2f', 490 | batch_cost, self.train_cost) 491 | 492 | self.train_ler += self.sess.run(self.ler, feed_dict=feed) * dataset._batch_size 493 | logger.debug('Label error rate: %.2f', self.train_ler) 494 | 495 | # Turn on decode only 1 batch per epoch 496 | if decode and batch == 0: 497 | d = self.sess.run(self.decoded[0], feed_dict={ 498 | self.input_tensor: source, 499 | self.targets: sparse_labels, 500 | self.seq_length: source_lengths} 501 | ) 502 | dense_decoded = tf.sparse_tensor_to_dense( 503 | d, default_value=-1).eval(session=self.sess) 504 | dense_labels = sparse_tuple_to_texts(sparse_labels) 505 | 506 | # only print a set number of example translations 507 | counter = 0 508 | counter_max = 4 509 | if counter < counter_max: 510 | for orig, decoded_arr in zip(dense_labels, dense_decoded): 511 | # convert to strings 512 | decoded_str = ndarray_to_text(decoded_arr) 513 | logger.info('Batch {}, file {}'.format(batch, counter)) 514 | logger.info('Original: {}'.format(orig)) 515 | logger.info('Decoded: {}'.format(decoded_str)) 516 | counter += 1 517 | 518 | # save out variables for testing 519 | self.dense_decoded = dense_decoded 520 | self.dense_labels = dense_labels 521 | 522 | # Metrics mean 523 | if is_training: 524 | self.train_cost /= n_examples 525 | self.train_ler /= n_examples 526 | 527 | # Populate summary for histograms and distributions in tensorboard 528 | self.accuracy, summary_line = self.sess.run( 529 | [self.avg_loss, self.summary_op], feed) 530 | self.writer.add_summary(summary_line, epoch) 531 | 532 | return self.train_cost, self.train_ler 533 | 534 | 535 | # to run in console 536 | if __name__ == '__main__': 537 | import click 538 | 539 | # Use click to parse command line arguments 540 | @click.command() 541 | @click.option('--config', default='neural_network.ini', help='Configuration file name') 542 | @click.option('--name', default=None, help='Model name for logging') 543 | @click.option('--debug', type=bool, default=False, 544 | help='Use debug settings in config file') 545 | # Train RNN model using a given configuration file 546 | def main(config='neural_network.ini', name=None, debug=False): 547 | logging.basicConfig(level=logging.DEBUG, 548 | format='%(asctime)s [%(levelname)s] %(name)s: %(message)s') 549 | global logger 550 | logger = logging.getLogger(os.path.basename(__file__)) 551 | 552 | # create the Tf_train_ctc class 553 | tf_train_ctc = Tf_train_ctc( 554 | config_file=config, model_name=name, debug=debug) 555 | 556 | # run the training 557 | tf_train_ctc.run_model() 558 | 559 | main() 560 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/gpu.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.client import device_lib 2 | 3 | 4 | def get_available_gpus(): 5 | """ 6 | Returns the number of GPUs available on this system. 7 | """ 8 | local_device_protos = device_lib.list_local_devices() 9 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 10 | 11 | 12 | def check_if_gpu_available(gpu_name): 13 | """ 14 | Returns boolean of if a specific gpu_name (string) is available 15 | On the system 16 | """ 17 | list_of_gpus = get_available_gpus() 18 | if gpu_name not in list_of_gpus: 19 | return False 20 | else: 21 | return True 22 | -------------------------------------------------------------------------------- /src/utils/set_dirs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | 6 | def get_relevant_directories( 7 | home_dir=None, 8 | data_dir=None, 9 | conf_dir=None, 10 | debug=False): 11 | 12 | home_dir = get_home_dir(home_dir=home_dir) 13 | 14 | data_dir = get_data_dir(data_dir=data_dir, home_dir=home_dir) 15 | 16 | conf_dir = get_conf_dir(conf_dir=conf_dir, home_dir=home_dir, debug=debug) 17 | 18 | return home_dir, data_dir, conf_dir 19 | 20 | 21 | def get_home_dir(home_dir=None): 22 | if home_dir is None: 23 | home_dir = os.environ['RNN_TUTORIAL'] 24 | return home_dir 25 | 26 | 27 | def get_data_dir(data_dir=None, home_dir=None): 28 | if data_dir is None: 29 | data_dir = os.path.join(get_home_dir(home_dir=home_dir), 'data', 'raw') 30 | # if the beginning of the data_dir is not '/' then prepend home_dir behind it 31 | elif not os.path.isabs(data_dir): 32 | data_dir = os.path.join(get_home_dir(home_dir=home_dir), data_dir) 33 | return data_dir 34 | 35 | 36 | def get_conf_dir(conf_dir=None, home_dir=None, debug=False): 37 | if conf_dir is None: 38 | conf_dir = os.path.join(get_home_dir(home_dir=home_dir), 'configs') 39 | # Descend to the testing folder if debug==True 40 | if debug: 41 | conf_dir = os.path.join(conf_dir, 'testing') 42 | return conf_dir 43 | 44 | 45 | def get_model_dir(model_dir=None, home_dir=None): 46 | if model_dir is None: 47 | model_dir = os.path.join(get_home_dir(home_dir=home_dir), 'models') 48 | return model_dir 49 | --------------------------------------------------------------------------------