├── .gitignore
├── LICENSE
├── README.md
├── configs
├── neural_network.ini
└── testing
│ └── neural_network.ini
├── data
└── raw
│ └── librivox
│ └── LibriSpeech
│ ├── dev-clean-wav
│ ├── 3752-4944-0041.txt
│ ├── 3752-4944-0041.wav
│ ├── 777-126732-0068.txt
│ └── 777-126732-0068.wav
│ ├── test-clean-wav
│ ├── 4507-16021-0019.txt
│ ├── 4507-16021-0019.wav
│ ├── 7176-92135-0009.txt
│ └── 7176-92135-0009.wav
│ └── train-clean-100-wav
│ ├── 1970-28415-0023.txt
│ ├── 1970-28415-0023.wav
│ ├── 211-122425-0059.txt
│ ├── 211-122425-0059.wav
│ ├── 2843-152918-0008.txt
│ ├── 2843-152918-0008.wav
│ ├── 3259-158083-0026.txt
│ ├── 3259-158083-0026.wav
│ ├── 3879-174923-0005.txt
│ └── 3879-174923-0005.wav
├── requirements.txt
└── src
├── __init__.py
├── data_manipulation
├── __init__.py
└── datasets.py
├── features
├── __init__.py
└── utils
│ ├── __init__.py
│ ├── load_audio_to_mem.py
│ └── text.py
├── models
├── RNN
│ ├── __init__.py
│ ├── rnn.py
│ └── utils.py
└── __init__.py
├── tests
├── __init__.py
└── train_framework
│ ├── __init__.py
│ └── tf_train_ctc_test.py
├── train_framework
├── __init__.py
└── tf_train_ctc.py
└── utils
├── __init__.py
├── gpu.py
└── set_dirs.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 |
4 | # Jupyter Notebook
5 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Recurrent Neural Networks - A Short TensorFlow Tutorial
2 |
3 | ### Setup
4 | Clone this repo to your local machine, and add the RNN-Tutorial directory as a system variable to your `~/.profile`. Instructions given for bash shell:
5 |
6 | ```bash
7 | git clone https://github.com/silicon-valley-data-science/RNN-Tutorial
8 | cd RNN-Tutorial
9 | echo "export RNN_TUTORIAL=${PWD}" >> ~/.profile
10 | echo "export PYTHONPATH=$RNN_TUTORIAL/src:${PYTHONPATH}" >> ~/.profile
11 | source ~/.profile
12 | ```
13 |
14 | Create a Conda environment (You will need to [Install Conda](https://conda.io/docs/install/quick.html) first)
15 |
16 | ```bash
17 | conda create --name tf-rnn python=3
18 | source activate tf-rnn
19 | cd $RNN_TUTORIAL
20 | pip install -r requirements.txt
21 | ```
22 |
23 | ### Install TensorFlow
24 |
25 | If you have a NVIDIA GPU with [CUDA](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/#package-manager-installation) already installed
26 |
27 | ```bash
28 | pip install tensorflow-gpu==1.0.1
29 | ```
30 |
31 | If you will be running TensorFlow on CPU only (i.e. a MacBook Pro), use the following command (if you get an error the first time you run this command read below):
32 |
33 | ```bash
34 | pip install --upgrade --ignore-installed\
35 | https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.1-py3-none-any.whl
36 | ```
37 |
38 | **Error note** (if you did not get an error skip this paragraph): Depending on how you installed pip and/or conda, we've seen different outcomes. If you get an error the first time, rerunning it may incorrectly show that it installs without error. Try running with `pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.1-py3-none-any.whl --ignore-installed`. The `--ignore-installed` flag tells it to reinstall the package. If that still doesn't work, please open an [issue](https://github.com/silicon-valley-data-science/RNN-Tutorial/issues), or you can try to follow the advice [here](https://www.tensorflow.org/install/install_mac).
39 |
40 |
41 | ### Run unittests
42 | We have included example unittests for the `tf_train_ctc.py` script
43 |
44 | ```bash
45 | python $RNN_TUTORIAL/src/tests/train_framework/tf_train_ctc_test.py
46 | ```
47 |
48 |
49 | ### Run RNN training
50 | All configurations for the RNN training script can be found in `$RNN_TUTORIAL/configs/neural_network.ini`
51 |
52 | ```bash
53 | python $RNN_TUTORIAL/src/train_framework/tf_train_ctc.py
54 | ```
55 |
56 | _NOTE: If you have a GPU available, the code will run faster if you set `tf_device = /gpu:0` in `configs/neural_network.ini`_
57 |
58 |
59 | ### TensorBoard configuration
60 | To visualize your results via tensorboard:
61 |
62 | ```bash
63 | tensorboard --logdir=$RNN_TUTORIAL/models/nn/debug_models/summary/
64 | ```
65 |
66 | - TensorBoard can be found in your browser at [http://localhost:6006](http://localhost:6006).
67 | - `tf.name_scope` is used to define parts of the network for visualization in TensorBoard. TensorBoard automatically finds any similarly structured network parts, such as identical fully connected layers and groups them in the graph visualization.
68 | - Related to this are the `tf.summary.* methods` that log values of network parts, such as distributions of layer activations or error rate across epochs. These summaries are grouped within the `tf.name_scope`.
69 | - See the official TensorFlow documentation for more details.
70 |
71 |
72 | ### Add data
73 | We have included example data from the [LibriVox corpus](https://librivox.org) in `data/raw/librivox/LibriSpeech/`. The data is separated into folders:
74 |
75 | - Train: train-clean-100-wav (5 examples)
76 | - Test: test-clean-wav (2 examples)
77 | - Dev: dev-clean-wav (2 examples)
78 |
79 | If you would like to train a performant model, you can add additional wave and txt files to these folders, or create a new folder and update `configs/neural_network.ini` with the folder locations
80 |
81 |
82 | ### Remove additions
83 |
84 | We made a few additions to your `.profile` -- remove those additions if you want, or if you want to keep the system variables, add it to your `.bash_profile` by running:
85 |
86 | ```bash
87 | echo "source ~/.profile" >> .bash_profile
88 | ```
89 |
--------------------------------------------------------------------------------
/configs/neural_network.ini:
--------------------------------------------------------------------------------
1 | [nn]
2 | epochs = 100
3 | network_type = BiRNN
4 | decode_train = True
5 | n_input = 26
6 | n_context = 9
7 | model_dir = nn/debug_models
8 | SAVE_MODEL_EPOCH_NUM = 2
9 | VALIDATION_EPOCH_NUM = 1
10 | CURR_VALIDATION_LER_DIFF = 0.005
11 | AVG_VALIDATION_LER_EPOCHS = 2
12 | beam_search_decoder = default
13 | shuffle_data_after_epoch = True
14 | min_dev_ler = 100.0
15 | tf_device = /gpu:0
16 | simultaneous_users_count = 4
17 |
18 | [data]
19 | #If data_dir does not start with '/', then home_dir is prepended in set_dirs.py
20 | data_dir = data/raw/librivox/LibriSpeech/
21 | dir_pattern_train = train-clean-100-wav
22 | dir_pattern_dev = dev-clean-wav
23 | dir_pattern_test = test-clean-wav
24 | n_train_limit = 5
25 | n_dev_limit = 2
26 | n_test_limit = 2
27 | batch_size_train = 2
28 | batch_size_dev = 2
29 | batch_size_test = 2
30 | start_idx_init_train = 0
31 | start_idx_init_dev = 0
32 | start_idx_init_test = 0
33 | sort_train = filesize_low_high
34 | sort_dev = random
35 | sort_test = random
36 |
37 | [optimizer]
38 | # AdamOptimizer (http://arxiv.org/abs/1412.6980) parameters
39 | beta1 = 0.9
40 | beta2 = 0.999
41 | epsilon = 1e-8
42 | learning_rate = 0.001
43 |
44 | [simplelstm]
45 | n_character = 29
46 | default_stddev = 0.046875
47 | b1_stddev = %(default_stddev)s
48 | h1_stddev = %(default_stddev)s
49 | n_layers = 2
50 | n_hidden_units = 512
51 |
52 | [birnn]
53 | n_character = 29
54 | use_warpctc = False
55 | dropout_rate = 0.05
56 | dropout_rate2 = %(dropout_rate)s
57 | dropout_rate3 = %(dropout_rate)s
58 | dropout_rate4 = 0.0
59 | dropout_rate5 = 0.0
60 | dropout_rate6 = %(dropout_rate)s
61 | dropout_rates = %(dropout_rate)s,%(dropout_rate2)s,%(dropout_rate3)s,%(dropout_rate4)s,%(dropout_rate5)s,%(dropout_rate6)s
62 | relu_clip = 20
63 | default_stddev = 0.046875
64 | b1_stddev = %(default_stddev)s
65 | h1_stddev = %(default_stddev)s
66 | b2_stddev = %(default_stddev)s
67 | h2_stddev = %(default_stddev)s
68 | b3_stddev = %(default_stddev)s
69 | h3_stddev = %(default_stddev)s
70 | b5_stddev = %(default_stddev)s
71 | h5_stddev = %(default_stddev)s
72 | b6_stddev = %(default_stddev)s
73 | h6_stddev = %(default_stddev)s
74 | n_hidden = 1024
75 | n_hidden_1 = %(n_hidden)s
76 | n_hidden_2 = %(n_hidden)s
77 | n_hidden_5 = %(n_hidden)s
78 | n_cell_dim = %(n_hidden)s
79 | n_hidden_3 = 2 * %(n_cell_dim)s
80 | n_hidden_6 = %(n_character)s
81 |
--------------------------------------------------------------------------------
/configs/testing/neural_network.ini:
--------------------------------------------------------------------------------
1 | [nn]
2 | epochs = 1
3 | network_type = BiRNN
4 | decode_train = True
5 | n_input = 26
6 | n_context = 9
7 | model_dir = nn/debug_models
8 | SAVE_MODEL_EPOCH_NUM = 2
9 | VALIDATION_EPOCH_NUM = 1
10 | CURR_VALIDATION_LER_DIFF = 0.005
11 | AVG_VALIDATION_LER_EPOCHS = 2
12 | beam_search_decoder = default
13 | shuffle_data_after_epoch = True
14 | min_dev_ler = 100.0
15 | tf_device = /gpu:0
16 | simultaneous_users_count = 4
17 |
18 | [data]
19 | #If data_dir does not start with '/', then home_dir is prepended in set_dirs.py
20 | data_dir = data/raw/librivox/LibriSpeech/
21 | dir_pattern_train = train-clean-100-wav
22 | dir_pattern_dev = dev-clean-wav
23 | dir_pattern_test = test-clean-wav
24 | n_train_limit = 2
25 | n_dev_limit = 2
26 | n_test_limit = 2
27 | batch_size_train = 2
28 | batch_size_dev = 2
29 | batch_size_test = 2
30 | start_idx_init_train = 0
31 | start_idx_init_dev = 0
32 | start_idx_init_test = 0
33 | sort_train = filesize_low_high
34 | sort_dev = random
35 | sort_test = random
36 |
37 | [optimizer]
38 | # AdamOptimizer (http://arxiv.org/abs/1412.6980) parameters
39 | beta1 = 0.9
40 | beta2 = 0.999
41 | epsilon = 1e-8
42 | learning_rate = 0.001
43 |
44 | [simplelstm]
45 | n_character = 29
46 | default_stddev = 0.046875
47 | b1_stddev = %(default_stddev)s
48 | h1_stddev = %(default_stddev)s
49 | n_layers = 2
50 | n_hidden_units = 512
51 |
52 | [birnn]
53 | n_character = 29
54 | use_warpctc = False
55 | dropout_rate = 0.05
56 | dropout_rate2 = %(dropout_rate)s
57 | dropout_rate3 = %(dropout_rate)s
58 | dropout_rate4 = 0.0
59 | dropout_rate5 = 0.0
60 | dropout_rate6 = %(dropout_rate)s
61 | dropout_rates = %(dropout_rate)s,%(dropout_rate2)s,%(dropout_rate3)s,%(dropout_rate4)s,%(dropout_rate5)s,%(dropout_rate6)s
62 | relu_clip = 20
63 | default_stddev = 0.046875
64 | b1_stddev = %(default_stddev)s
65 | h1_stddev = %(default_stddev)s
66 | b2_stddev = %(default_stddev)s
67 | h2_stddev = %(default_stddev)s
68 | b3_stddev = %(default_stddev)s
69 | h3_stddev = %(default_stddev)s
70 | b5_stddev = %(default_stddev)s
71 | h5_stddev = %(default_stddev)s
72 | b6_stddev = %(default_stddev)s
73 | h6_stddev = %(default_stddev)s
74 | n_hidden = 1024
75 | n_hidden_1 = %(n_hidden)s
76 | n_hidden_2 = %(n_hidden)s
77 | n_hidden_5 = %(n_hidden)s
78 | n_cell_dim = %(n_hidden)s
79 | n_hidden_3 = 2 * %(n_cell_dim)s
80 | n_hidden_6 = %(n_character)s
81 |
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/dev-clean-wav/3752-4944-0041.txt:
--------------------------------------------------------------------------------
1 | how delightful the grass smells
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/dev-clean-wav/3752-4944-0041.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/dev-clean-wav/3752-4944-0041.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/dev-clean-wav/777-126732-0068.txt:
--------------------------------------------------------------------------------
1 | that boy hears too much of what is talked about here
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/dev-clean-wav/777-126732-0068.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/dev-clean-wav/777-126732-0068.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/test-clean-wav/4507-16021-0019.txt:
--------------------------------------------------------------------------------
1 | it is the language of wretchedness
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/test-clean-wav/4507-16021-0019.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/test-clean-wav/4507-16021-0019.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/test-clean-wav/7176-92135-0009.txt:
--------------------------------------------------------------------------------
1 | and i should begin with a short homily on soliloquy
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/test-clean-wav/7176-92135-0009.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/test-clean-wav/7176-92135-0009.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/1970-28415-0023.txt:
--------------------------------------------------------------------------------
1 | where people were making their gifts to god
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/1970-28415-0023.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/1970-28415-0023.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/211-122425-0059.txt:
--------------------------------------------------------------------------------
1 | and the two will pass off together
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/211-122425-0059.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/211-122425-0059.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/2843-152918-0008.txt:
--------------------------------------------------------------------------------
1 | one day may be pleasant enough but two three four
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/2843-152918-0008.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/2843-152918-0008.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/3259-158083-0026.txt:
--------------------------------------------------------------------------------
1 | i have a nephew fighting for democracy in france
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/3259-158083-0026.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/3259-158083-0026.wav
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/3879-174923-0005.txt:
--------------------------------------------------------------------------------
1 | he must vanish out of the world
--------------------------------------------------------------------------------
/data/raw/librivox/LibriSpeech/train-clean-100-wav/3879-174923-0005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/data/raw/librivox/LibriSpeech/train-clean-100-wav/3879-174923-0005.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | mako==1.0.6
2 | matplotlib==2.0.0
3 | numpy==1.12.0
4 | protobuf==3.2.0
5 | python-speech-features==0.5
6 | pyyaml==5.4
7 | pyxdg==0.26
8 | requests==2.20.0
9 | scipy==0.18.1
10 | sox==1.2.6
11 | tree==0.1.0
12 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/__init__.py
--------------------------------------------------------------------------------
/src/data_manipulation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/data_manipulation/__init__.py
--------------------------------------------------------------------------------
/src/data_manipulation/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from math import ceil
3 | from random import random
4 | from glob import glob
5 | from configparser import ConfigParser
6 | import logging
7 | from collections import namedtuple
8 |
9 | from features.utils.load_audio_to_mem import get_audio_and_transcript, pad_sequences
10 | from features.utils.text import sparse_tuple_from
11 | from utils.set_dirs import get_data_dir
12 |
13 |
14 | DataSets = namedtuple('DataSets', 'train dev test')
15 |
16 |
17 | def read_datasets(conf_path, sets, numcep, numcontext,
18 | thread_count=8):
19 | '''Main function to create DataSet objects.
20 |
21 | This function calls an internal function _get_data_set_dict that
22 | reads the configuration file. Then it calls the internal function _read_data_set
23 | which collects the text files in the data directories, returning a DataSet object.
24 | This function returns a DataSets object containing the requested datasets.
25 |
26 | Args:
27 | sets (list): List of datasets to create. Options are: 'train', 'dev', 'test'
28 | numcep (int): Number of mel-frequency cepstral coefficients to compute.
29 | numcontext (int): For each time point, number of contextual samples to include.
30 | thread_count (int): Number of threads
31 |
32 | Returns:
33 | DataSets: A single `DataSets` instance containing each of the requested datasets
34 |
35 | E.g., when sets=['train'], datasets.train exists, with methods to retrieve examples.
36 |
37 | This function has been modified from Mozilla DeepSpeech:
38 | https://github.com/mozilla/DeepSpeech/blob/master/util/importers/librivox.py
39 |
40 | # This Source Code Form is subject to the terms of the Mozilla Public
41 | # License, v. 2.0. If a copy of the MPL was not distributed with this
42 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
43 | '''
44 | data_dir, dataset_config = _get_data_set_dict(conf_path, sets)
45 |
46 | def _read_data_set(config):
47 | path = os.path.join(data_dir, config['dir_pattern'])
48 | return DataSet.from_directory(path,
49 | thread_count=thread_count,
50 | batch_size=config['batch_size'],
51 | numcep=numcep,
52 | numcontext=numcontext,
53 | start_idx=config['start_idx'],
54 | limit=config['limit'],
55 | sort=config['sort']
56 | )
57 | datasets = {name: _read_data_set(dataset_config[name])
58 | if name in sets else None
59 | for name in ('train', 'dev', 'test')}
60 | return DataSets(**datasets)
61 |
62 |
63 | class DataSet:
64 | '''
65 | Train/test/dev dataset API for loading via threads and delivering batches.
66 |
67 | This class has been modified from Mozilla DeepSpeech:
68 | https://github.com/mozilla/DeepSpeech/blob/master/util/importers/librivox.py
69 |
70 | # This Source Code Form is subject to the terms of the Mozilla Public
71 | # License, v. 2.0. If a copy of the MPL was not distributed with this
72 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
73 | '''
74 |
75 | def __init__(self, txt_files, thread_count, batch_size, numcep, numcontext):
76 | self._coord = None
77 | self._numcep = numcep
78 | self._txt_files = txt_files
79 | self._batch_size = batch_size
80 | self._numcontext = numcontext
81 | self._start_idx = 0
82 |
83 | @classmethod
84 | def from_directory(cls, dirpath, thread_count, batch_size, numcep, numcontext, start_idx=0, limit=0, sort=None):
85 | if not os.path.exists(dirpath):
86 | raise IOError("'%s' does not exist" % dirpath)
87 | txt_files = txt_filenames(dirpath, start_idx=start_idx, limit=limit, sort=sort)
88 | if len(txt_files) == 0:
89 | raise RuntimeError('start_idx=%d and limit=%d arguments result in zero files' % (start_idx, limit))
90 | return cls(txt_files, thread_count, batch_size, numcep, numcontext)
91 |
92 | def next_batch(self, batch_size=None):
93 | if batch_size is None:
94 | batch_size = self._batch_size
95 |
96 | end_idx = min(len(self._txt_files), self._start_idx + batch_size)
97 | idx_list = range(self._start_idx, end_idx)
98 | txt_files = [self._txt_files[i] for i in idx_list]
99 | wav_files = [x.replace('.txt', '.wav') for x in txt_files]
100 | (source, _, target, _) = get_audio_and_transcript(txt_files,
101 | wav_files,
102 | self._numcep,
103 | self._numcontext)
104 | self._start_idx += batch_size
105 | # Verify that the start_idx is not larger than total available sample size
106 | if self._start_idx >= self.size:
107 | self._start_idx = 0
108 |
109 | # Pad input to max_time_step of this batch
110 | source, source_lengths = pad_sequences(source)
111 | sparse_labels = sparse_tuple_from(target)
112 | return source, source_lengths, sparse_labels
113 |
114 | @property
115 | def files(self):
116 | return self._txt_files
117 |
118 | @property
119 | def size(self):
120 | return len(self.files)
121 |
122 | @property
123 | def total_batches(self):
124 | # Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
125 | return int(ceil(float(len(self._txt_files)) / float(self._batch_size)))
126 | # END DataSet
127 |
128 | SORTS = ['filesize_low_high', 'filesize_high_low', 'alpha', 'random']
129 |
130 | def txt_filenames(dataset_path, start_idx=0, limit=None, sort='alpha'):
131 | # Obtain list of txt files
132 | txt_files = glob(os.path.join(dataset_path, "*.txt"))
133 | limit = limit or len(txt_files)
134 |
135 | # Optional: sort files to improve padding performance
136 | if sort not in SORTS:
137 | raise ValueError('sort must be one of [%s]', SORTS)
138 | reverse = False
139 | key = None
140 | if 'filesize' in sort:
141 | key = os.path.getsize
142 | if sort == 'filesize_high_low':
143 | reverse = True
144 | elif sort == 'random':
145 | key = lambda *args: random()
146 | txt_files = sorted(txt_files, key=key, reverse=reverse)
147 |
148 | return txt_files[start_idx:limit + start_idx]
149 |
150 |
151 | def _get_data_set_dict(conf_path, sets):
152 | parser = ConfigParser(os.environ)
153 | parser.read(conf_path)
154 | config_header = 'data'
155 | data_dir = get_data_dir(parser.get(config_header, 'data_dir'))
156 | data_dict = {}
157 |
158 | if 'train' in sets:
159 | d = {}
160 | d['dir_pattern'] = parser.get(config_header, 'dir_pattern_train')
161 | d['limit'] = parser.getint(config_header, 'n_train_limit')
162 | d['sort'] = parser.get(config_header, 'sort_train')
163 | d['batch_size'] = parser.getint(config_header, 'batch_size_train')
164 | d['start_idx'] = parser.getint(config_header, 'start_idx_init_train')
165 | data_dict['train'] = d
166 | logging.debug('Training configuration: %s', str(d))
167 |
168 | if 'dev' in sets:
169 | d = {}
170 | d['dir_pattern'] = parser.get(config_header, 'dir_pattern_dev')
171 | d['limit'] = parser.getint(config_header, 'n_dev_limit')
172 | d['sort'] = parser.get(config_header, 'sort_dev')
173 | d['batch_size'] = parser.getint(config_header, 'batch_size_dev')
174 | d['start_idx'] = parser.getint(config_header, 'start_idx_init_dev')
175 | data_dict['dev'] = d
176 | logging.debug('Dev configuration: %s', str(d))
177 |
178 | if 'test' in sets:
179 | d = {}
180 | d['dir_pattern'] = parser.get(config_header, 'dir_pattern_test')
181 | d['limit'] = parser.getint(config_header, 'n_test_limit')
182 | d['sort'] = parser.get(config_header, 'sort_test')
183 | d['batch_size'] = parser.getint(config_header, 'batch_size_test')
184 | d['start_idx'] = parser.getint(config_header, 'start_idx_init_test')
185 | data_dict['test'] = d
186 | logging.debug('Test configuration: %s', str(d))
187 |
188 | return data_dir, data_dict
189 |
--------------------------------------------------------------------------------
/src/features/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/features/__init__.py
--------------------------------------------------------------------------------
/src/features/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/features/utils/__init__.py
--------------------------------------------------------------------------------
/src/features/utils/load_audio_to_mem.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import scipy.io.wavfile as wav
5 |
6 | import numpy as np
7 | from python_speech_features import mfcc
8 | from features.utils.text import text_to_char_array, normalize_txt_file
9 |
10 |
11 | def load_wavfile(wavfile):
12 | """
13 | Read a wav file using scipy.io.wavfile
14 | """
15 | rate, sig = wav.read(wavfile)
16 | data_name = os.path.splitext(os.path.basename(wavfile))[0]
17 | return rate, sig, data_name
18 |
19 |
20 | def get_audio_and_transcript(txt_files, wav_files, n_input, n_context):
21 | '''
22 | Loads audio files and text transcriptions from ordered lists of filenames.
23 | Converts to audio to MFCC arrays and text to numerical arrays.
24 | Returns list of arrays. Returned audio array list can be padded with
25 | pad_sequences function in this same module.
26 | '''
27 | audio = []
28 | audio_len = []
29 | transcript = []
30 | transcript_len = []
31 |
32 | for txt_file, wav_file in zip(txt_files, wav_files):
33 | # load audio and convert to features
34 | audio_data = audiofile_to_input_vector(wav_file, n_input, n_context)
35 | audio_data = audio_data.astype('float32')
36 |
37 | audio.append(audio_data)
38 | audio_len.append(np.int32(len(audio_data)))
39 |
40 | # load text transcription and convert to numerical array
41 | target = normalize_txt_file(txt_file)
42 | target = text_to_char_array(target)
43 | transcript.append(target)
44 | transcript_len.append(len(target))
45 |
46 | audio = np.asarray(audio)
47 | audio_len = np.asarray(audio_len)
48 | transcript = np.asarray(transcript)
49 | transcript_len = np.asarray(transcript_len)
50 | return audio, audio_len, transcript, transcript_len
51 |
52 |
53 | def audiofile_to_input_vector(audio_filename, numcep, numcontext):
54 | '''
55 | Turn an audio file into feature representation.
56 |
57 | This function has been modified from Mozilla DeepSpeech:
58 | https://github.com/mozilla/DeepSpeech/blob/master/util/audio.py
59 |
60 | # This Source Code Form is subject to the terms of the Mozilla Public
61 | # License, v. 2.0. If a copy of the MPL was not distributed with this
62 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
63 | '''
64 |
65 | # Load wav files
66 | fs, audio = wav.read(audio_filename)
67 |
68 | # Get mfcc coefficients
69 | orig_inputs = mfcc(audio, samplerate=fs, numcep=numcep)
70 |
71 | # We only keep every second feature (BiRNN stride = 2)
72 | orig_inputs = orig_inputs[::2]
73 |
74 | # For each time slice of the training set, we need to copy the context this makes
75 | # the numcep dimensions vector into a numcep + 2*numcep*numcontext dimensions
76 | # because of:
77 | # - numcep dimensions for the current mfcc feature set
78 | # - numcontext*numcep dimensions for each of the past and future (x2) mfcc feature set
79 | # => so numcep + 2*numcontext*numcep
80 | train_inputs = np.array([], np.float32)
81 | train_inputs.resize((orig_inputs.shape[0], numcep + 2 * numcep * numcontext))
82 |
83 | # Prepare pre-fix post fix context
84 | empty_mfcc = np.array([])
85 | empty_mfcc.resize((numcep))
86 |
87 | # Prepare train_inputs with past and future contexts
88 | time_slices = range(train_inputs.shape[0])
89 | context_past_min = time_slices[0] + numcontext
90 | context_future_max = time_slices[-1] - numcontext
91 | for time_slice in time_slices:
92 | # Reminder: array[start:stop:step]
93 | # slices from indice |start| up to |stop| (not included), every |step|
94 |
95 | # Add empty context data of the correct size to the start and end
96 | # of the MFCC feature matrix
97 |
98 | # Pick up to numcontext time slices in the past, and complete with empty
99 | # mfcc features
100 | need_empty_past = max(0, (context_past_min - time_slice))
101 | empty_source_past = list(empty_mfcc for empty_slots in range(need_empty_past))
102 | data_source_past = orig_inputs[max(0, time_slice - numcontext):time_slice]
103 | assert(len(empty_source_past) + len(data_source_past) == numcontext)
104 |
105 | # Pick up to numcontext time slices in the future, and complete with empty
106 | # mfcc features
107 | need_empty_future = max(0, (time_slice - context_future_max))
108 | empty_source_future = list(empty_mfcc for empty_slots in range(need_empty_future))
109 | data_source_future = orig_inputs[time_slice + 1:time_slice + numcontext + 1]
110 | assert(len(empty_source_future) + len(data_source_future) == numcontext)
111 |
112 | if need_empty_past:
113 | past = np.concatenate((empty_source_past, data_source_past))
114 | else:
115 | past = data_source_past
116 |
117 | if need_empty_future:
118 | future = np.concatenate((data_source_future, empty_source_future))
119 | else:
120 | future = data_source_future
121 |
122 | past = np.reshape(past, numcontext * numcep)
123 | now = orig_inputs[time_slice]
124 | future = np.reshape(future, numcontext * numcep)
125 |
126 | train_inputs[time_slice] = np.concatenate((past, now, future))
127 | assert(len(train_inputs[time_slice]) == numcep + 2 * numcep * numcontext)
128 |
129 | # Scale/standardize the inputs
130 | # This can be done more efficiently in the TensorFlow graph
131 | train_inputs = (train_inputs - np.mean(train_inputs)) / np.std(train_inputs)
132 | return train_inputs
133 |
134 |
135 | def pad_sequences(sequences, maxlen=None, dtype=np.float32,
136 | padding='post', truncating='post', value=0.):
137 |
138 | '''
139 | # From TensorLayer:
140 | # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/prepro.html
141 |
142 | Pads each sequence to the same length of the longest sequence.
143 |
144 | If maxlen is provided, any sequence longer than maxlen is truncated to
145 | maxlen. Truncation happens off either the beginning or the end
146 | (default) of the sequence. Supports post-padding (default) and
147 | pre-padding.
148 |
149 | Args:
150 | sequences: list of lists where each element is a sequence
151 | maxlen: int, maximum length
152 | dtype: type to cast the resulting sequence.
153 | padding: 'pre' or 'post', pad either before or after each sequence.
154 | truncating: 'pre' or 'post', remove values from sequences larger
155 | than maxlen either in the beginning or in the end of the sequence
156 | value: float, value to pad the sequences to the desired value.
157 |
158 | Returns:
159 | numpy.ndarray: Padded sequences shape = (number_of_sequences, maxlen)
160 | numpy.ndarray: original sequence lengths
161 | '''
162 | lengths = np.asarray([len(s) for s in sequences], dtype=np.int64)
163 |
164 | nb_samples = len(sequences)
165 | if maxlen is None:
166 | maxlen = np.max(lengths)
167 |
168 | # take the sample shape from the first non empty sequence
169 | # checking for consistency in the main loop below.
170 | sample_shape = tuple()
171 | for s in sequences:
172 | if len(s) > 0:
173 | sample_shape = np.asarray(s).shape[1:]
174 | break
175 |
176 | x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype)
177 | for idx, s in enumerate(sequences):
178 | if len(s) == 0:
179 | continue # empty list was found
180 | if truncating == 'pre':
181 | trunc = s[-maxlen:]
182 | elif truncating == 'post':
183 | trunc = s[:maxlen]
184 | else:
185 | raise ValueError('Truncating type "%s" not understood' % truncating)
186 |
187 | # check `trunc` has expected shape
188 | trunc = np.asarray(trunc, dtype=dtype)
189 | if trunc.shape[1:] != sample_shape:
190 | raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
191 | (trunc.shape[1:], idx, sample_shape))
192 |
193 | if padding == 'post':
194 | x[idx, :len(trunc)] = trunc
195 | elif padding == 'pre':
196 | x[idx, -len(trunc):] = trunc
197 | else:
198 | raise ValueError('Padding type "%s" not understood' % padding)
199 | return x, lengths
200 |
--------------------------------------------------------------------------------
/src/features/utils/text.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import unicodedata
3 | import codecs
4 | import re
5 |
6 | import tensorflow as tf
7 |
8 | # Constants
9 | SPACE_TOKEN = ''
10 | SPACE_INDEX = 0
11 | FIRST_INDEX = ord('a') - 1 # 0 is reserved to space
12 |
13 |
14 | def normalize_txt_file(txt_file, remove_apostrophe=True):
15 | """
16 | Given a path to a text file, return contents with unsupported characters removed.
17 | """
18 | with codecs.open(txt_file, encoding="utf-8") as open_txt_file:
19 | return normalize_text(open_txt_file.read(), remove_apostrophe=remove_apostrophe)
20 |
21 |
22 | def normalize_text(original, remove_apostrophe=True):
23 | """
24 | Given a Python string ``original``, remove unsupported characters.
25 |
26 | The only supported characters are letters and apostrophes.
27 | """
28 | # convert any unicode characters to ASCII equivalent
29 | # then ignore anything else and decode to a string
30 | result = unicodedata.normalize("NFKD", original).encode("ascii", "ignore").decode()
31 | if remove_apostrophe:
32 | # remove apostrophes to keep contractions together
33 | result = result.replace("'", "")
34 | # return lowercase alphabetic characters and apostrophes (if still present)
35 | return re.sub("[^a-zA-Z']+", ' ', result).strip().lower()
36 |
37 |
38 | def text_to_char_array(original):
39 | """
40 | Given a Python string ``original``, map characters
41 | to integers and return a numpy array representing the processed string.
42 |
43 | This function has been modified from Mozilla DeepSpeech:
44 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py
45 |
46 | # This Source Code Form is subject to the terms of the Mozilla Public
47 | # License, v. 2.0. If a copy of the MPL was not distributed with this
48 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
49 | """
50 |
51 | # Create list of sentence's words w/spaces replaced by ''
52 | result = original.replace(' ', ' ')
53 | result = result.split(' ')
54 |
55 | # Tokenize words into letters adding in SPACE_TOKEN where required
56 | result = np.hstack([SPACE_TOKEN if xt == '' else list(xt) for xt in result])
57 |
58 | # Return characters mapped into indicies
59 | return np.asarray([SPACE_INDEX if xt == SPACE_TOKEN else ord(xt) - FIRST_INDEX for xt in result])
60 |
61 |
62 | def sparse_tuple_from(sequences, dtype=np.int32):
63 | """
64 | Create a sparse representention of ``sequences``.
65 |
66 | Args:
67 | sequences: a list of lists of type dtype where each element is a sequence
68 | Returns:
69 | A tuple with (indices, values, shape)
70 |
71 | This function has been modified from Mozilla DeepSpeech:
72 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py
73 |
74 | # This Source Code Form is subject to the terms of the Mozilla Public
75 | # License, v. 2.0. If a copy of the MPL was not distributed with this
76 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
77 | """
78 |
79 | indices = []
80 | values = []
81 |
82 | for n, seq in enumerate(sequences):
83 | indices.extend(zip([n] * len(seq), range(len(seq))))
84 | values.extend(seq)
85 |
86 | indices = np.asarray(indices, dtype=np.int64)
87 | values = np.asarray(values, dtype=dtype)
88 | shape = np.asarray([len(sequences), indices.max(0)[1] + 1], dtype=np.int64)
89 |
90 | # return tf.SparseTensor(indices=indices, values=values, shape=shape)
91 | return indices, values, shape
92 |
93 |
94 | def sparse_tensor_value_to_texts(value):
95 | """
96 | Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
97 | representing its values.
98 |
99 | This function has been modified from Mozilla DeepSpeech:
100 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py
101 |
102 | # This Source Code Form is subject to the terms of the Mozilla Public
103 | # License, v. 2.0. If a copy of the MPL was not distributed with this
104 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
105 | """
106 | return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape))
107 |
108 |
109 | def sparse_tuple_to_texts(tuple):
110 | '''
111 | This function has been modified from Mozilla DeepSpeech:
112 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py
113 |
114 | # This Source Code Form is subject to the terms of the Mozilla Public
115 | # License, v. 2.0. If a copy of the MPL was not distributed with this
116 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
117 | '''
118 | indices = tuple[0]
119 | values = tuple[1]
120 | results = [''] * tuple[2][0]
121 | for i in range(len(indices)):
122 | index = indices[i][0]
123 | c = values[i]
124 | c = ' ' if c == SPACE_INDEX else chr(c + FIRST_INDEX)
125 | results[index] = results[index] + c
126 | # List of strings
127 | return results
128 |
129 |
130 | def ndarray_to_text(value):
131 | '''
132 | This function has been modified from Mozilla DeepSpeech:
133 | https://github.com/mozilla/DeepSpeech/blob/master/util/text.py
134 |
135 | # This Source Code Form is subject to the terms of the Mozilla Public
136 | # License, v. 2.0. If a copy of the MPL was not distributed with this
137 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
138 | '''
139 | results = ''
140 | for i in range(len(value)):
141 | results += chr(value[i] + FIRST_INDEX)
142 | return results.replace('`', ' ')
143 |
144 |
145 | def gather_nd(params, indices, shape):
146 | '''
147 | # Function aken from https://github.com/tensorflow/tensorflow/issues/206#issuecomment-229678962
148 |
149 | '''
150 | rank = len(shape)
151 | flat_params = tf.reshape(params, [-1])
152 | multipliers = [reduce(lambda x, y: x * y, shape[i + 1:], 1) for i in range(0, rank)]
153 | indices_unpacked = tf.unstack(tf.transpose(indices, [rank - 1] + range(0, rank - 1)))
154 | flat_indices = sum([a * b for a, b in zip(multipliers, indices_unpacked)])
155 | return tf.gather(flat_params, flat_indices)
156 |
157 |
158 | def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
159 | '''
160 | The CTC implementation in TensorFlow needs labels in a sparse representation,
161 | but sparse data and queues don't mix well, so we store padded tensors in the
162 | queue and convert to a sparse representation after dequeuing a batch.
163 |
164 | Taken from https://github.com/tensorflow/tensorflow/issues/1742#issuecomment-205291527
165 | '''
166 |
167 | # The second dimension of labels must be equal to the longest label length in the batch
168 | correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths))
169 | with tf.control_dependencies([correct_shape_assert]):
170 | labels = tf.identity(labels)
171 |
172 | label_shape = tf.shape(labels)
173 | num_batches_tns = tf.stack([label_shape[0]])
174 | max_num_labels_tns = tf.stack([label_shape[1]])
175 |
176 | def range_less_than(previous_state, current_input):
177 | return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
178 |
179 | init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
180 | init = tf.expand_dims(init, 0)
181 | dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1)
182 | dense_mask = dense_mask[:, 0, :]
183 |
184 | label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape)
185 |
186 | label_ind = tf.boolean_mask(label_array, dense_mask)
187 |
188 | batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
189 | tf.reverse(label_shape, [0]))
190 | )
191 | batch_ind = tf.boolean_mask(batch_array, dense_mask)
192 | batch_label = tf.concat([batch_ind, label_ind], 0)
193 | indices = tf.transpose(tf.reshape(batch_label, [2, -1]))
194 | shape = [batch_size, tf.reduce_max(label_lengths)]
195 | vals_sparse = gather_nd(labels, indices, shape)
196 |
197 | return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape))
198 |
--------------------------------------------------------------------------------
/src/models/RNN/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/models/RNN/rnn.py:
--------------------------------------------------------------------------------
1 | # Note: All calls to tf.name_scope or tf.summary.* support TensorBoard visualization.
2 |
3 | import os
4 | import tensorflow as tf
5 | from configparser import ConfigParser
6 |
7 | from models.RNN.utils import variable_on_cpu
8 |
9 |
10 | def SimpleLSTM(conf_path, input_tensor, seq_length):
11 | '''
12 | This function was initially based on open source code from Mozilla DeepSpeech:
13 | https://github.com/mozilla/DeepSpeech/blob/master/DeepSpeech.py
14 |
15 | # This Source Code Form is subject to the terms of the Mozilla Public
16 | # License, v. 2.0. If a copy of the MPL was not distributed with this
17 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
18 | '''
19 | parser = ConfigParser(os.environ)
20 | parser.read(conf_path)
21 |
22 | # SimpleLSTM
23 | n_character = parser.getint('simplelstm', 'n_character')
24 | b1_stddev = parser.getfloat('simplelstm', 'b1_stddev')
25 | h1_stddev = parser.getfloat('simplelstm', 'h1_stddev')
26 | n_layers = parser.getint('simplelstm', 'n_layers')
27 | n_hidden_units = parser.getint('simplelstm', 'n_hidden_units')
28 |
29 | # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
30 | # batch_x_shape = tf.shape(batch_x)
31 |
32 | input_tensor_shape = tf.shape(input_tensor)
33 | n_items = input_tensor_shape[0]
34 |
35 | with tf.name_scope("lstm"):
36 | # Initialize weights
37 | # with tf.device('/cpu:0'):
38 | W = tf.get_variable('W', shape=[n_hidden_units, n_character],
39 | # initializer=tf.truncated_normal_initializer(stddev=h1_stddev),
40 | initializer=tf.random_normal_initializer(stddev=h1_stddev),
41 | )
42 | # Initialize bias
43 | # with tf.device('/cpu:0'):
44 | # b = tf.get_variable('b', initializer=tf.zeros_initializer([n_character]))
45 | b = tf.get_variable('b', shape=[n_character],
46 | # initializer=tf.constant_initializer(value=0),
47 | initializer=tf.random_normal_initializer(stddev=b1_stddev),
48 | )
49 |
50 | # Define the cell
51 | # Can be:
52 | # tf.contrib.rnn.BasicRNNCell
53 | # tf.contrib.rnn.GRUCell
54 | cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, state_is_tuple=True)
55 |
56 | # Stacking rnn cells
57 | stack = tf.contrib.rnn.MultiRNNCell([cell] * n_layers, state_is_tuple=True)
58 |
59 | # Get layer activations (second output is the final state of the layer, do not need)
60 | outputs, _ = tf.nn.dynamic_rnn(stack, input_tensor, seq_length,
61 | time_major=False, dtype=tf.float32)
62 |
63 | # Reshape to apply the same weights over the timesteps
64 | outputs = tf.reshape(outputs, [-1, n_hidden_units])
65 |
66 | # Perform affine transformation to layer output:
67 | # multiply by weights (linear transformation), add bias (translation)
68 | logits = tf.add(tf.matmul(outputs, W), b)
69 |
70 | tf.summary.histogram("weights", W)
71 | tf.summary.histogram("biases", b)
72 | tf.summary.histogram("activations", logits)
73 |
74 | # Reshaping back to the original shape
75 | logits = tf.reshape(logits, [n_items, -1, n_character])
76 |
77 | # Put time as the major axis
78 | logits = tf.transpose(logits, (1, 0, 2))
79 |
80 | summary_op = tf.summary.merge_all()
81 |
82 | return logits, summary_op
83 |
84 | def BiRNN(conf_path, batch_x, seq_length, n_input, n_context):
85 | """
86 | This function was initially based on open source code from Mozilla DeepSpeech:
87 | https://github.com/mozilla/DeepSpeech/blob/master/DeepSpeech.py
88 |
89 | # This Source Code Form is subject to the terms of the Mozilla Public
90 | # License, v. 2.0. If a copy of the MPL was not distributed with this
91 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
92 | """
93 | parser = ConfigParser(os.environ)
94 | parser.read(conf_path)
95 |
96 | dropout = [float(x) for x in parser.get('birnn', 'dropout_rates').split(',')]
97 | relu_clip = parser.getint('birnn', 'relu_clip')
98 |
99 | b1_stddev = parser.getfloat('birnn', 'b1_stddev')
100 | h1_stddev = parser.getfloat('birnn', 'h1_stddev')
101 | b2_stddev = parser.getfloat('birnn', 'b2_stddev')
102 | h2_stddev = parser.getfloat('birnn', 'h2_stddev')
103 | b3_stddev = parser.getfloat('birnn', 'b3_stddev')
104 | h3_stddev = parser.getfloat('birnn', 'h3_stddev')
105 | b5_stddev = parser.getfloat('birnn', 'b5_stddev')
106 | h5_stddev = parser.getfloat('birnn', 'h5_stddev')
107 | b6_stddev = parser.getfloat('birnn', 'b6_stddev')
108 | h6_stddev = parser.getfloat('birnn', 'h6_stddev')
109 |
110 | n_hidden_1 = parser.getint('birnn', 'n_hidden_1')
111 | n_hidden_2 = parser.getint('birnn', 'n_hidden_2')
112 | n_hidden_5 = parser.getint('birnn', 'n_hidden_5')
113 | n_cell_dim = parser.getint('birnn', 'n_cell_dim')
114 |
115 | n_hidden_3 = int(eval(parser.get('birnn', 'n_hidden_3')))
116 | n_hidden_6 = parser.getint('birnn', 'n_hidden_6')
117 |
118 | # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
119 | batch_x_shape = tf.shape(batch_x)
120 |
121 | # Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
122 | # This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
123 |
124 | # Permute n_steps and batch_size
125 | batch_x = tf.transpose(batch_x, [1, 0, 2])
126 | # Reshape to prepare input for first layer
127 | batch_x = tf.reshape(batch_x,
128 | [-1, n_input + 2 * n_input * n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
129 |
130 | # The next three blocks will pass `batch_x` through three hidden layers with
131 | # clipped RELU activation and dropout.
132 |
133 | # 1st layer
134 | with tf.name_scope('fc1'):
135 | b1 = variable_on_cpu('b1', [n_hidden_1], tf.random_normal_initializer(stddev=b1_stddev))
136 | h1 = variable_on_cpu('h1', [n_input + 2 * n_input * n_context, n_hidden_1],
137 | tf.random_normal_initializer(stddev=h1_stddev))
138 | layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), relu_clip)
139 | layer_1 = tf.nn.dropout(layer_1, (1.0 - dropout[0]))
140 |
141 | tf.summary.histogram("weights", h1)
142 | tf.summary.histogram("biases", b1)
143 | tf.summary.histogram("activations", layer_1)
144 |
145 | # 2nd layer
146 | with tf.name_scope('fc2'):
147 | b2 = variable_on_cpu('b2', [n_hidden_2], tf.random_normal_initializer(stddev=b2_stddev))
148 | h2 = variable_on_cpu('h2', [n_hidden_1, n_hidden_2], tf.random_normal_initializer(stddev=h2_stddev))
149 | layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), relu_clip)
150 | layer_2 = tf.nn.dropout(layer_2, (1.0 - dropout[1]))
151 |
152 | tf.summary.histogram("weights", h2)
153 | tf.summary.histogram("biases", b2)
154 | tf.summary.histogram("activations", layer_2)
155 |
156 | # 3rd layer
157 | with tf.name_scope('fc3'):
158 | b3 = variable_on_cpu('b3', [n_hidden_3], tf.random_normal_initializer(stddev=b3_stddev))
159 | h3 = variable_on_cpu('h3', [n_hidden_2, n_hidden_3], tf.random_normal_initializer(stddev=h3_stddev))
160 | layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), relu_clip)
161 | layer_3 = tf.nn.dropout(layer_3, (1.0 - dropout[2]))
162 |
163 | tf.summary.histogram("weights", h3)
164 | tf.summary.histogram("biases", b3)
165 | tf.summary.histogram("activations", layer_3)
166 |
167 | # Create the forward and backward LSTM units. Inputs have length `n_cell_dim`.
168 | # LSTM forget gate bias initialized at `1.0` (default), meaning less forgetting
169 | # at the beginning of training (remembers more previous info)
170 | with tf.name_scope('lstm'):
171 | # Forward direction cell:
172 | lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True)
173 | lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell,
174 | input_keep_prob=1.0 - dropout[3],
175 | output_keep_prob=1.0 - dropout[3],
176 | # seed=random_seed,
177 | )
178 | # Backward direction cell:
179 | lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True)
180 | lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell,
181 | input_keep_prob=1.0 - dropout[4],
182 | output_keep_prob=1.0 - dropout[4],
183 | # seed=random_seed,
184 | )
185 |
186 | # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
187 | # as the LSTM BRNN expects its input to be of shape `[max_time, batch_size, input_size]`.
188 | layer_3 = tf.reshape(layer_3, [-1, batch_x_shape[0], n_hidden_3])
189 |
190 | # Now we feed `layer_3` into the LSTM BRNN cell and obtain the LSTM BRNN output.
191 | outputs, output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_fw_cell,
192 | cell_bw=lstm_bw_cell,
193 | inputs=layer_3,
194 | dtype=tf.float32,
195 | time_major=True,
196 | sequence_length=seq_length)
197 |
198 | tf.summary.histogram("activations", outputs)
199 |
200 | # Reshape outputs from two tensors each of shape [n_steps, batch_size, n_cell_dim]
201 | # to a single tensor of shape [n_steps*batch_size, 2*n_cell_dim]
202 | outputs = tf.concat(outputs, 2)
203 | outputs = tf.reshape(outputs, [-1, 2 * n_cell_dim])
204 |
205 | with tf.name_scope('fc5'):
206 | # Now we feed `outputs` to the fifth hidden layer with clipped RELU activation and dropout
207 | b5 = variable_on_cpu('b5', [n_hidden_5], tf.random_normal_initializer(stddev=b5_stddev))
208 | h5 = variable_on_cpu('h5', [(2 * n_cell_dim), n_hidden_5], tf.random_normal_initializer(stddev=h5_stddev))
209 | layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(outputs, h5), b5)), relu_clip)
210 | layer_5 = tf.nn.dropout(layer_5, (1.0 - dropout[5]))
211 |
212 | tf.summary.histogram("weights", h5)
213 | tf.summary.histogram("biases", b5)
214 | tf.summary.histogram("activations", layer_5)
215 |
216 | with tf.name_scope('fc6'):
217 | # Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5`
218 | # creating `n_classes` dimensional vectors, the logits.
219 | b6 = variable_on_cpu('b6', [n_hidden_6], tf.random_normal_initializer(stddev=b6_stddev))
220 | h6 = variable_on_cpu('h6', [n_hidden_5, n_hidden_6], tf.random_normal_initializer(stddev=h6_stddev))
221 | layer_6 = tf.add(tf.matmul(layer_5, h6), b6)
222 |
223 | tf.summary.histogram("weights", h6)
224 | tf.summary.histogram("biases", b6)
225 | tf.summary.histogram("activations", layer_6)
226 |
227 | # Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
228 | # to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
229 | # Note, that this differs from the input in that it is time-major.
230 | layer_6 = tf.reshape(layer_6, [-1, batch_x_shape[0], n_hidden_6])
231 |
232 | summary_op = tf.summary.merge_all()
233 |
234 | # Output shape: [n_steps, batch_size, n_hidden_6]
235 | return layer_6, summary_op
236 |
--------------------------------------------------------------------------------
/src/models/RNN/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 |
4 | from configparser import ConfigParser
5 | from utils.set_dirs import get_conf_dir
6 |
7 | conf_dir = get_conf_dir(debug=False)
8 | parser = ConfigParser(os.environ)
9 | parser.read(os.path.join(conf_dir, 'neural_network.ini'))
10 |
11 | # AdamOptimizer
12 | beta1 = parser.getfloat('optimizer', 'beta1')
13 | beta2 = parser.getfloat('optimizer', 'beta2')
14 | epsilon = parser.getfloat('optimizer', 'epsilon')
15 | learning_rate = parser.getfloat('optimizer', 'learning_rate')
16 |
17 |
18 | def variable_on_cpu(name, shape, initializer):
19 | """
20 | Next we concern ourselves with graph creation.
21 | However, before we do so we must introduce a utility function ``variable_on_cpu()``
22 | used to create a variable in CPU memory.
23 | """
24 | # Use the /cpu:0 device for scoped operations
25 | with tf.device('/cpu:0'):
26 | # Create or get apropos variable
27 | var = tf.get_variable(name=name, shape=shape, initializer=initializer)
28 | return var
29 |
30 |
31 | def create_optimizer():
32 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
33 | beta1=beta1,
34 | beta2=beta2,
35 | epsilon=epsilon)
36 | return optimizer
37 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/models/__init__.py
--------------------------------------------------------------------------------
/src/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/tests/__init__.py
--------------------------------------------------------------------------------
/src/tests/train_framework/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/tests/train_framework/__init__.py
--------------------------------------------------------------------------------
/src/tests/train_framework/tf_train_ctc_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import unittest
3 | import os
4 | import logging
5 | import tensorflow as tf
6 |
7 | # custom modules
8 | import train_framework.tf_train_ctc as tf_train
9 |
10 |
11 | class TestTrain_ctc(unittest.TestCase):
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 | def setUp(self):
15 | '''
16 | Create the Tf_train_ctc instance
17 | '''
18 | self.tf_train_ctc = tf_train.Tf_train_ctc(
19 | config_file='neural_network.ini',
20 | debug=True,
21 | model_name=None)
22 |
23 | def tearDown(self):
24 | '''
25 | Close TF session if available
26 | '''
27 | if hasattr(self.tf_train_ctc, 'sess'):
28 | self.tf_train_ctc.sess.close()
29 | self.tf_train_ctc.writer.flush()
30 |
31 | def test_setup_tf_train_framework(self):
32 | '''
33 | Does the instance have the expected fields (i.e. type casting)
34 | '''
35 | tf_train_ctc = self.tf_train_ctc
36 |
37 | # make sure everything is loaded correctly
38 | self.assertEqual(isinstance(tf_train_ctc.epochs, int), True)
39 | self.assertEqual(isinstance(tf_train_ctc.network_type, str), True)
40 | self.assertEqual(isinstance(tf_train_ctc.n_input, int), True)
41 | self.assertEqual(isinstance(tf_train_ctc.n_context, int), True)
42 | self.assertEqual(isinstance(tf_train_ctc.model_dir, str), True)
43 | self.assertEqual(isinstance(tf_train_ctc.session_name, str), True)
44 | self.assertEqual(isinstance(tf_train_ctc.SAVE_MODEL_EPOCH_NUM, int), True)
45 | self.assertEqual(isinstance(tf_train_ctc.VALIDATION_EPOCH_NUM, int), True)
46 | self.assertEqual(isinstance(tf_train_ctc.CURR_VALIDATION_LER_DIFF, float), True)
47 | self.assertNotEqual(tf_train_ctc.beam_search_decoder, None)
48 | self.assertEqual(isinstance(tf_train_ctc.AVG_VALIDATION_LERS, list), True)
49 | self.assertEqual(isinstance(tf_train_ctc.shuffle_data_after_epoch, bool), True)
50 | self.assertEqual(isinstance(tf_train_ctc.min_dev_ler, float), True)
51 |
52 | # tests associated with data_sets object
53 | self.assertNotEqual(tf_train_ctc.data_sets, None)
54 | self.assertTrue(tf_train_ctc.sets == ['train', 'dev', 'test']) # this will vary if changed in future
55 | self.assertEqual(isinstance(tf_train_ctc.n_examples_train, int), True)
56 | self.assertEqual(isinstance(tf_train_ctc.n_examples_dev, int), True)
57 | self.assertEqual(isinstance(tf_train_ctc.n_examples_test, int), True)
58 | self.assertEqual(isinstance(tf_train_ctc.batch_size, int), True)
59 | self.assertEqual(isinstance(tf_train_ctc.n_batches_per_epoch, int), True)
60 |
61 | # make sure folders are made
62 | self.assertTrue(os.path.exists(tf_train_ctc.SESSION_DIR))
63 | self.assertTrue(os.path.exists(tf_train_ctc.SUMMARY_DIR))
64 |
65 | # since model_name set as None, model_path should be None
66 | self.assertEqual(tf_train_ctc.model_path, None)
67 |
68 | def test_run_tf_train_gpu(self):
69 | '''
70 | Can a small model run on the GPU?
71 | '''
72 | tf_train_ctc = self.tf_train_ctc
73 | tf_train_ctc.run_model()
74 |
75 | # Verify objects of the train framework were created in the test training run
76 | self.assertNotEqual(tf_train_ctc.graph, None)
77 | self.assertNotEqual(tf_train_ctc.sess, None)
78 | self.assertNotEqual(tf_train_ctc.writer, None)
79 | self.assertNotEqual(tf_train_ctc.saver, None)
80 | self.assertNotEqual(tf_train_ctc.logits, None)
81 |
82 | # Verify some learning has been done
83 | self.assertTrue(tf_train_ctc.train_ler > 0)
84 |
85 | # make sure the input targets (i.e. .txt files) are not empty string
86 | self.assertTrue(tf_train_ctc.dense_labels is not '')
87 |
88 | # shutdown the running model
89 | self.tearDown()
90 |
91 | def test_verify_if_cpu_can_be_used(self):
92 | '''
93 | Can a small model run on the CPU?
94 | '''
95 | tf_train_ctc = self.tf_train_ctc
96 | tf_train_ctc.tf_device = '/cpu:0'
97 |
98 | tf_train_ctc.graph = tf.Graph()
99 | with tf_train_ctc.graph.as_default(), tf.device(tf_train_ctc.tf_device):
100 | with tf.device(tf_train_ctc.tf_device):
101 | tf_train_ctc.setup_network_and_graph()
102 | tf_train_ctc.load_placeholder_into_network()
103 | tf_train_ctc.setup_loss_function()
104 | tf_train_ctc.setup_optimizer()
105 | tf_train_ctc.setup_decoder()
106 | tf_train_ctc.setup_summary_statistics()
107 | tf_train_ctc.sess = tf.Session(
108 | config=tf.ConfigProto(allow_soft_placement=True), graph=tf_train_ctc.graph)
109 |
110 | # initialize the summary writer
111 | tf_train_ctc.writer = tf.summary.FileWriter(
112 | tf_train_ctc.SUMMARY_DIR, graph=tf_train_ctc.sess.graph)
113 |
114 | # Add ops to save and restore all the variables
115 | tf_train_ctc.saver = tf.train.Saver()
116 |
117 | # Verify objects of the train framework were created in the test training run
118 | self.assertNotEqual(tf_train_ctc.graph, None)
119 | self.assertNotEqual(tf_train_ctc.sess, None)
120 | self.assertNotEqual(tf_train_ctc.writer, None)
121 | self.assertNotEqual(tf_train_ctc.saver, None)
122 |
123 | # shutdown the running model
124 | self.tearDown()
125 |
126 |
127 | if __name__ == '__main__':
128 | unittest.main()
129 |
--------------------------------------------------------------------------------
/src/train_framework/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/train_framework/__init__.py
--------------------------------------------------------------------------------
/src/train_framework/tf_train_ctc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import numpy as np
5 | import time
6 | import warnings
7 | from configparser import ConfigParser
8 | import logging
9 |
10 | import tensorflow as tf
11 | from tensorflow.python.ops import ctc_ops
12 |
13 | # Custom modules
14 | from features.utils.text import ndarray_to_text, sparse_tuple_to_texts
15 |
16 | # in future different than utils class
17 | from models.RNN.utils import create_optimizer
18 | from data_manipulation.datasets import read_datasets
19 | from utils.set_dirs import get_conf_dir, get_model_dir
20 | import utils.gpu as gpu_tool
21 |
22 | # Import the setup scripts for different types of model
23 | from models.RNN.rnn import BiRNN as BiRNN_model
24 | from models.RNN.rnn import SimpleLSTM as SimpleLSTM_model
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | class Tf_train_ctc(object):
30 | '''
31 | Class to train a speech recognition model with TensorFlow.
32 |
33 | Requirements:
34 | - TensorFlow 1.0.1
35 | - Python 3.5
36 | - Configuration: $RNN_TUTORIAL/configs/neural_network.ini
37 |
38 | Features:
39 | - Batch loading of input data
40 | - Checkpoints model
41 | - Label error rate is the edit (Levenshtein) distance of the top path vs true sentence
42 | - Logs summary stats for TensorBoard
43 | - Epoch 1: Train starting with shortest transcriptions, then shuffle
44 |
45 | # Note: All calls to tf.name_scope or tf.summary.* support TensorBoard visualization.
46 |
47 | This class was initially based on open source code from Mozilla DeepSpeech:
48 | https://github.com/mozilla/DeepSpeech/blob/master/DeepSpeech.py
49 |
50 | # This Source Code Form is subject to the terms of the Mozilla Public
51 | # License, v. 2.0. If a copy of the MPL was not distributed with this
52 | # file, You can obtain one at http://mozilla.org/MPL/2.0/.
53 | '''
54 |
55 | def __init__(self,
56 | config_file='neural_network.ini',
57 | model_name=None,
58 | debug=False):
59 | # set TF logging verbosity
60 | tf.logging.set_verbosity(tf.logging.INFO)
61 |
62 | # Load the configuration file depending on debug True/False
63 | self.debug = debug
64 | self.conf_path = get_conf_dir(debug=self.debug)
65 | self.conf_path = os.path.join(self.conf_path, config_file)
66 | self.load_configs()
67 |
68 | # Verify that the GPU is operational, if not use CPU
69 | if not gpu_tool.check_if_gpu_available(self.tf_device):
70 | self.tf_device = '/cpu:0'
71 | logging.info('Using this device for main computations: %s', self.tf_device)
72 |
73 | # set the directories
74 | self.set_up_directories(model_name)
75 |
76 | # set up the model
77 | self.set_up_model()
78 |
79 | def load_configs(self):
80 | parser = ConfigParser(os.environ)
81 | if not os.path.exists(self.conf_path):
82 | raise IOError("Configuration file '%s' does not exist" % self.conf_path)
83 | logging.info('Loading config from %s', self.conf_path)
84 | parser.read(self.conf_path)
85 |
86 | # set which set of configs to import
87 | config_header = 'nn'
88 |
89 | logger.info('config header: %s', config_header)
90 |
91 | self.epochs = parser.getint(config_header, 'epochs')
92 | logger.debug('self.epochs = %d', self.epochs)
93 |
94 | self.network_type = parser.get(config_header, 'network_type')
95 |
96 | # Number of mfcc features, 13 or 26
97 | self.n_input = parser.getint(config_header, 'n_input')
98 |
99 | # Number of contextual samples to include
100 | self.n_context = parser.getint(config_header, 'n_context')
101 |
102 | # self.decode_train = parser.getboolean(config_header, 'decode_train')
103 | # self.random_seed = parser.getint(config_header, 'random_seed')
104 | self.model_dir = parser.get(config_header, 'model_dir')
105 |
106 | # set the session name
107 | self.session_name = '{}_{}'.format(
108 | self.network_type, time.strftime("%Y%m%d-%H%M%S"))
109 | sess_prefix_str = 'develop'
110 | if len(sess_prefix_str) > 0:
111 | self.session_name = '{}_{}'.format(
112 | sess_prefix_str, self.session_name)
113 |
114 | # How often to save the model
115 | self.SAVE_MODEL_EPOCH_NUM = parser.getint(
116 | config_header, 'SAVE_MODEL_EPOCH_NUM')
117 |
118 | # decode dev set after N epochs
119 | self.VALIDATION_EPOCH_NUM = parser.getint(
120 | config_header, 'VALIDATION_EPOCH_NUM')
121 |
122 | # decide when to stop training prematurely
123 | self.CURR_VALIDATION_LER_DIFF = parser.getfloat(
124 | config_header, 'CURR_VALIDATION_LER_DIFF')
125 |
126 | self.AVG_VALIDATION_LER_EPOCHS = parser.getint(
127 | config_header, 'AVG_VALIDATION_LER_EPOCHS')
128 | # initialize list to hold average validation at end of each epoch
129 | self.AVG_VALIDATION_LERS = [
130 | 1.0 for _ in range(self.AVG_VALIDATION_LER_EPOCHS)]
131 |
132 | # setup type of decoder
133 | self.beam_search_decoder = parser.get(
134 | config_header, 'beam_search_decoder')
135 |
136 | # determine if the data input order should be shuffled after every epic
137 | self.shuffle_data_after_epoch = parser.getboolean(
138 | config_header, 'shuffle_data_after_epoch')
139 |
140 | # initialize to store the minimum validation set label error rate
141 | self.min_dev_ler = parser.getfloat(config_header, 'min_dev_ler')
142 |
143 | # set up GPU if available
144 | self.tf_device = str(parser.get(config_header, 'tf_device'))
145 |
146 | # set up the max amount of simultaneous users
147 | # this restricts GPU usage to the inverse of self.simultaneous_users_count
148 | self.simultaneous_users_count = parser.getint(config_header, 'simultaneous_users_count')
149 |
150 | def set_up_directories(self, model_name):
151 | # Set up model directory
152 | self.model_dir = os.path.join(get_model_dir(), self.model_dir)
153 | # summary will contain logs
154 | self.SUMMARY_DIR = os.path.join(
155 | self.model_dir, "summary", self.session_name)
156 | # session will contain models
157 | self.SESSION_DIR = os.path.join(
158 | self.model_dir, "session", self.session_name)
159 |
160 | if not os.path.exists(self.SESSION_DIR):
161 | os.makedirs(self.SESSION_DIR)
162 | if not os.path.exists(self.SUMMARY_DIR):
163 | os.makedirs(self.SUMMARY_DIR)
164 |
165 | # set the model name and restore if not None
166 | if model_name is not None:
167 | self.model_path = os.path.join(self.SESSION_DIR, model_name)
168 | else:
169 | self.model_path = None
170 |
171 | def set_up_model(self):
172 | self.sets = ['train', 'dev', 'test']
173 |
174 | # read data set, inherits configuration path
175 | # to parse the config file for where data lives
176 | self.data_sets = read_datasets(self.conf_path,
177 | self.sets,
178 | self.n_input,
179 | self.n_context
180 | )
181 |
182 | self.n_examples_train = len(self.data_sets.train._txt_files)
183 | self.n_examples_dev = len(self.data_sets.dev._txt_files)
184 | self.n_examples_test = len(self.data_sets.test._txt_files)
185 | self.batch_size = self.data_sets.train._batch_size
186 | self.n_batches_per_epoch = int(np.ceil(
187 | self.n_examples_train / self.batch_size))
188 |
189 | logger.info('''Training model: {}
190 | Train examples: {:,}
191 | Dev examples: {:,}
192 | Test examples: {:,}
193 | Epochs: {}
194 | Training batch size: {}
195 | Batches per epoch: {}'''.format(
196 | self.session_name,
197 | self.n_examples_train,
198 | self.n_examples_dev,
199 | self.n_examples_test,
200 | self.epochs,
201 | self.batch_size,
202 | self.n_batches_per_epoch))
203 |
204 | def run_model(self):
205 | self.graph = tf.Graph()
206 | with self.graph.as_default(), tf.device('/cpu:0'):
207 |
208 | with tf.device(self.tf_device):
209 | # Run multiple functions on the specificed tf_device
210 | # tf_device GPU set in configs, but is overridden if not available
211 | self.setup_network_and_graph()
212 | self.load_placeholder_into_network()
213 | self.setup_loss_function()
214 | self.setup_optimizer()
215 | self.setup_decoder()
216 |
217 | self.setup_summary_statistics()
218 |
219 | # create the configuration for the session
220 | tf_config = tf.ConfigProto()
221 | tf_config.allow_soft_placement = True
222 | tf_config.gpu_options.per_process_gpu_memory_fraction = \
223 | (1.0 / self.simultaneous_users_count)
224 |
225 | # create the session
226 | self.sess = tf.Session(config=tf_config)
227 |
228 | # initialize the summary writer
229 | self.writer = tf.summary.FileWriter(
230 | self.SUMMARY_DIR, graph=self.sess.graph)
231 |
232 | # Add ops to save and restore all the variables
233 | self.saver = tf.train.Saver()
234 |
235 | # For printing out section headers
236 | section = '\n{0:=^40}\n'
237 |
238 | # If there is a model_path declared, then restore the model
239 | if self.model_path is not None:
240 | self.saver.restore(self.sess, self.model_path)
241 | # If there is NOT a model_path declared, build the model from scratch
242 | else:
243 | # Op to initialize the variables
244 | init_op = tf.global_variables_initializer()
245 |
246 | # Initializate the weights and biases
247 | self.sess.run(init_op)
248 |
249 | # MAIN LOGIC for running the training epochs
250 | logger.info(section.format('Run training epoch'))
251 | self.run_training_epochs()
252 |
253 | logger.info(section.format('Decoding test data'))
254 | # make the assumption for working on the test data, that the epoch here is the last epoch
255 | _, self.test_ler = self.run_batches(self.data_sets.test, is_training=False,
256 | decode=True, write_to_file=False, epoch=self.epochs)
257 |
258 | # Add the final test data to the summary writer
259 | # (single point on the graph for end of training run)
260 | summary_line = self.sess.run(
261 | self.test_ler_op, {self.ler_placeholder: self.test_ler})
262 | self.writer.add_summary(summary_line, self.epochs)
263 |
264 | logger.info('Test Label Error Rate: {}'.format(self.test_ler))
265 |
266 | # save train summaries to disk
267 | self.writer.flush()
268 |
269 | self.sess.close()
270 |
271 | def setup_network_and_graph(self):
272 | # e.g: log filter bank or MFCC features
273 | # shape = [batch_size, max_stepsize, n_input + (2 * n_input * n_context)]
274 | # the batch_size and max_stepsize can vary along each step
275 | self.input_tensor = tf.placeholder(
276 | tf.float32, [None, None, self.n_input + (2 * self.n_input * self.n_context)], name='input')
277 |
278 | # Use sparse_placeholder; will generate a SparseTensor, required by ctc_loss op.
279 | self.targets = tf.sparse_placeholder(tf.int32, name='targets')
280 | # 1d array of size [batch_size]
281 | self.seq_length = tf.placeholder(tf.int32, [None], name='seq_length')
282 |
283 | def load_placeholder_into_network(self):
284 | # logits is the non-normalized output/activations from the last layer.
285 | # logits will be input for the loss function.
286 | # nn_model is from the import statement in the load_model function
287 | # summary_op variables are for tensorboard
288 | if self.network_type == 'SimpleLSTM':
289 | self.logits, summary_op = SimpleLSTM_model(
290 | self.conf_path,
291 | self.input_tensor,
292 | tf.to_int64(self.seq_length)
293 | )
294 | elif self.network_type == 'BiRNN':
295 | self.logits, summary_op = BiRNN_model(
296 | self.conf_path,
297 | self.input_tensor,
298 | tf.to_int64(self.seq_length),
299 | self.n_input,
300 | self.n_context
301 | )
302 | else:
303 | raise ValueError('network_type must be SimpleLSTM or BiRNN')
304 | self.summary_op = tf.summary.merge([summary_op])
305 |
306 | def setup_loss_function(self):
307 | with tf.name_scope("loss"):
308 | self.total_loss = ctc_ops.ctc_loss(
309 | self.targets, self.logits, self.seq_length)
310 | self.avg_loss = tf.reduce_mean(self.total_loss)
311 | self.loss_summary = tf.summary.scalar("avg_loss", self.avg_loss)
312 |
313 | self.cost_placeholder = tf.placeholder(dtype=tf.float32, shape=[])
314 |
315 | self.train_cost_op = tf.summary.scalar(
316 | "train_avg_loss", self.cost_placeholder)
317 |
318 | def setup_optimizer(self):
319 | # Note: The optimizer is created in models/RNN/utils.py
320 | with tf.name_scope("train"):
321 | self.optimizer = create_optimizer()
322 | self.optimizer = self.optimizer.minimize(self.avg_loss)
323 |
324 | def setup_decoder(self):
325 | with tf.name_scope("decode"):
326 | if self.beam_search_decoder == 'default':
327 | self.decoded, self.log_prob = ctc_ops.ctc_beam_search_decoder(
328 | self.logits, self.seq_length, merge_repeated=False)
329 | elif self.beam_search_decoder == 'greedy':
330 | self.decoded, self.log_prob = ctc_ops.ctc_greedy_decoder(
331 | self.logits, self.seq_length, merge_repeated=False)
332 | else:
333 | logging.warning("Invalid beam search decoder option selected!")
334 |
335 | def setup_summary_statistics(self):
336 | # Create a placholder for the summary statistics
337 | with tf.name_scope("accuracy"):
338 | # Compute the edit (Levenshtein) distance of the top path
339 | distance = tf.edit_distance(
340 | tf.cast(self.decoded[0], tf.int32), self.targets)
341 |
342 | # Compute the label error rate (accuracy)
343 | self.ler = tf.reduce_mean(distance, name='label_error_rate')
344 | self.ler_placeholder = tf.placeholder(dtype=tf.float32, shape=[])
345 | self.train_ler_op = tf.summary.scalar(
346 | "train_label_error_rate", self.ler_placeholder)
347 | self.dev_ler_op = tf.summary.scalar(
348 | "validation_label_error_rate", self.ler_placeholder)
349 | self.test_ler_op = tf.summary.scalar(
350 | "test_label_error_rate", self.ler_placeholder)
351 |
352 | def run_training_epochs(self):
353 | train_start = time.time()
354 | for epoch in range(self.epochs):
355 | # Initialize variables that can be updated
356 | save_dev_model = False
357 | stop_training = False
358 | is_checkpoint_step, is_validation_step = \
359 | self.validation_and_checkpoint_check(epoch)
360 |
361 | epoch_start = time.time()
362 |
363 | self.train_cost, self.train_ler = self.run_batches(
364 | self.data_sets.train,
365 | is_training=True,
366 | decode=False,
367 | write_to_file=False,
368 | epoch=epoch)
369 |
370 | epoch_duration = time.time() - epoch_start
371 |
372 | log = 'Epoch {}/{}, train_cost: {:.3f}, \
373 | train_ler: {:.3f}, time: {:.2f} sec'
374 | logger.info(log.format(
375 | epoch + 1,
376 | self.epochs,
377 | self.train_cost,
378 | self.train_ler,
379 | epoch_duration))
380 |
381 | summary_line = self.sess.run(
382 | self.train_ler_op, {self.ler_placeholder: self.train_ler})
383 | self.writer.add_summary(summary_line, epoch)
384 |
385 | summary_line = self.sess.run(
386 | self.train_cost_op, {self.cost_placeholder: self.train_cost})
387 | self.writer.add_summary(summary_line, epoch)
388 |
389 | # Shuffle the data for the next epoch
390 | if self.shuffle_data_after_epoch:
391 | np.random.shuffle(self.data_sets.train._txt_files)
392 |
393 | # Run validation if it was determined to run a validation step
394 | if is_validation_step:
395 | self.run_validation_step(epoch)
396 |
397 | if (epoch + 1) == self.epochs or is_checkpoint_step:
398 | # save the final model
399 | save_path = self.saver.save(self.sess, os.path.join(
400 | self.SESSION_DIR, 'model.ckpt'), epoch)
401 | logger.info("Model saved: {}".format(save_path))
402 |
403 | if save_dev_model:
404 | # If the dev set is not improving,
405 | # the training is killed to prevent overfitting
406 | # And then save the best validation performance model
407 | save_path = self.saver.save(self.sess, os.path.join(
408 | self.SESSION_DIR, 'model-best.ckpt'))
409 | logger.info(
410 | "Model with best validation label error rate saved: {}".
411 | format(save_path))
412 |
413 | if stop_training:
414 | break
415 |
416 | train_duration = time.time() - train_start
417 | logger.info('Training complete, total duration: {:.2f} min'.format(
418 | train_duration / 60))
419 |
420 | def run_validation_step(self, epoch):
421 | dev_ler = 0
422 |
423 | _, dev_ler = self.run_batches(self.data_sets.dev,
424 | is_training=False,
425 | decode=True,
426 | write_to_file=False,
427 | epoch=epoch)
428 |
429 | logger.info('Validation Label Error Rate: {}'.format(dev_ler))
430 |
431 | summary_line = self.sess.run(
432 | self.dev_ler_op, {self.ler_placeholder: dev_ler})
433 | self.writer.add_summary(summary_line, epoch)
434 |
435 | if dev_ler < self.min_dev_ler:
436 | self.min_dev_ler = dev_ler
437 |
438 | # average historical LER
439 | history_avg_ler = np.mean(self.AVG_VALIDATION_LERS)
440 |
441 | # if this LER is not better than average of previous epochs, exit
442 | if history_avg_ler - dev_ler <= self.CURR_VALIDATION_LER_DIFF:
443 | log = "Validation label error rate not improved by more than {:.2%} \
444 | after {} epochs. Exit"
445 | warnings.warn(log.format(self.CURR_VALIDATION_LER_DIFF,
446 | self.AVG_VALIDATION_LER_EPOCHS))
447 |
448 | # save avg validation accuracy in the next slot
449 | self.AVG_VALIDATION_LERS[
450 | epoch % self.AVG_VALIDATION_LER_EPOCHS] = dev_ler
451 |
452 | def validation_and_checkpoint_check(self, epoch):
453 | # initially set at False unless indicated to change
454 | is_checkpoint_step = False
455 | is_validation_step = False
456 |
457 | # Check if the current epoch is a validation or checkpoint step
458 | if (epoch > 0) and ((epoch + 1) != self.epochs):
459 | if (epoch + 1) % self.SAVE_MODEL_EPOCH_NUM == 0:
460 | is_checkpoint_step = True
461 | if (epoch + 1) % self.VALIDATION_EPOCH_NUM == 0:
462 | is_validation_step = True
463 |
464 | return is_checkpoint_step, is_validation_step
465 |
466 | def run_batches(self, dataset, is_training, decode, write_to_file, epoch):
467 | n_examples = len(dataset._txt_files)
468 |
469 | n_batches_per_epoch = int(np.ceil(n_examples / dataset._batch_size))
470 |
471 | self.train_cost = 0
472 | self.train_ler = 0
473 |
474 | for batch in range(n_batches_per_epoch):
475 | # Get next batch of training data (audio features) and transcripts
476 | source, source_lengths, sparse_labels = dataset.next_batch()
477 |
478 | feed = {self.input_tensor: source,
479 | self.targets: sparse_labels,
480 | self.seq_length: source_lengths}
481 |
482 | # If the is_training is false, this means straight decoding without computing loss
483 | if is_training:
484 | # avg_loss is the loss_op, optimizer is the train_op;
485 | # running these pushes tensors (data) through graph
486 | batch_cost, _ = self.sess.run(
487 | [self.avg_loss, self.optimizer], feed)
488 | self.train_cost += batch_cost * dataset._batch_size
489 | logger.debug('Batch cost: %.2f | Train cost: %.2f',
490 | batch_cost, self.train_cost)
491 |
492 | self.train_ler += self.sess.run(self.ler, feed_dict=feed) * dataset._batch_size
493 | logger.debug('Label error rate: %.2f', self.train_ler)
494 |
495 | # Turn on decode only 1 batch per epoch
496 | if decode and batch == 0:
497 | d = self.sess.run(self.decoded[0], feed_dict={
498 | self.input_tensor: source,
499 | self.targets: sparse_labels,
500 | self.seq_length: source_lengths}
501 | )
502 | dense_decoded = tf.sparse_tensor_to_dense(
503 | d, default_value=-1).eval(session=self.sess)
504 | dense_labels = sparse_tuple_to_texts(sparse_labels)
505 |
506 | # only print a set number of example translations
507 | counter = 0
508 | counter_max = 4
509 | if counter < counter_max:
510 | for orig, decoded_arr in zip(dense_labels, dense_decoded):
511 | # convert to strings
512 | decoded_str = ndarray_to_text(decoded_arr)
513 | logger.info('Batch {}, file {}'.format(batch, counter))
514 | logger.info('Original: {}'.format(orig))
515 | logger.info('Decoded: {}'.format(decoded_str))
516 | counter += 1
517 |
518 | # save out variables for testing
519 | self.dense_decoded = dense_decoded
520 | self.dense_labels = dense_labels
521 |
522 | # Metrics mean
523 | if is_training:
524 | self.train_cost /= n_examples
525 | self.train_ler /= n_examples
526 |
527 | # Populate summary for histograms and distributions in tensorboard
528 | self.accuracy, summary_line = self.sess.run(
529 | [self.avg_loss, self.summary_op], feed)
530 | self.writer.add_summary(summary_line, epoch)
531 |
532 | return self.train_cost, self.train_ler
533 |
534 |
535 | # to run in console
536 | if __name__ == '__main__':
537 | import click
538 |
539 | # Use click to parse command line arguments
540 | @click.command()
541 | @click.option('--config', default='neural_network.ini', help='Configuration file name')
542 | @click.option('--name', default=None, help='Model name for logging')
543 | @click.option('--debug', type=bool, default=False,
544 | help='Use debug settings in config file')
545 | # Train RNN model using a given configuration file
546 | def main(config='neural_network.ini', name=None, debug=False):
547 | logging.basicConfig(level=logging.DEBUG,
548 | format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
549 | global logger
550 | logger = logging.getLogger(os.path.basename(__file__))
551 |
552 | # create the Tf_train_ctc class
553 | tf_train_ctc = Tf_train_ctc(
554 | config_file=config, model_name=name, debug=debug)
555 |
556 | # run the training
557 | tf_train_ctc.run_model()
558 |
559 | main()
560 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubash1/RNN-Tutorial/1d02ee9d9e4f0ff24fc88beabb8ff9da6b8c8886/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/gpu.py:
--------------------------------------------------------------------------------
1 | from tensorflow.python.client import device_lib
2 |
3 |
4 | def get_available_gpus():
5 | """
6 | Returns the number of GPUs available on this system.
7 | """
8 | local_device_protos = device_lib.list_local_devices()
9 | return [x.name for x in local_device_protos if x.device_type == 'GPU']
10 |
11 |
12 | def check_if_gpu_available(gpu_name):
13 | """
14 | Returns boolean of if a specific gpu_name (string) is available
15 | On the system
16 | """
17 | list_of_gpus = get_available_gpus()
18 | if gpu_name not in list_of_gpus:
19 | return False
20 | else:
21 | return True
22 |
--------------------------------------------------------------------------------
/src/utils/set_dirs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 |
5 |
6 | def get_relevant_directories(
7 | home_dir=None,
8 | data_dir=None,
9 | conf_dir=None,
10 | debug=False):
11 |
12 | home_dir = get_home_dir(home_dir=home_dir)
13 |
14 | data_dir = get_data_dir(data_dir=data_dir, home_dir=home_dir)
15 |
16 | conf_dir = get_conf_dir(conf_dir=conf_dir, home_dir=home_dir, debug=debug)
17 |
18 | return home_dir, data_dir, conf_dir
19 |
20 |
21 | def get_home_dir(home_dir=None):
22 | if home_dir is None:
23 | home_dir = os.environ['RNN_TUTORIAL']
24 | return home_dir
25 |
26 |
27 | def get_data_dir(data_dir=None, home_dir=None):
28 | if data_dir is None:
29 | data_dir = os.path.join(get_home_dir(home_dir=home_dir), 'data', 'raw')
30 | # if the beginning of the data_dir is not '/' then prepend home_dir behind it
31 | elif not os.path.isabs(data_dir):
32 | data_dir = os.path.join(get_home_dir(home_dir=home_dir), data_dir)
33 | return data_dir
34 |
35 |
36 | def get_conf_dir(conf_dir=None, home_dir=None, debug=False):
37 | if conf_dir is None:
38 | conf_dir = os.path.join(get_home_dir(home_dir=home_dir), 'configs')
39 | # Descend to the testing folder if debug==True
40 | if debug:
41 | conf_dir = os.path.join(conf_dir, 'testing')
42 | return conf_dir
43 |
44 |
45 | def get_model_dir(model_dir=None, home_dir=None):
46 | if model_dir is None:
47 | model_dir = os.path.join(get_home_dir(home_dir=home_dir), 'models')
48 | return model_dir
49 |
--------------------------------------------------------------------------------