├── data └── LibriSpeech │ ├── dev-clean-wav │ ├── 3752-4944-0041.txt │ ├── 777-126732-0068.txt │ ├── 3752-4944-0041.wav │ └── 777-126732-0068.wav │ ├── test-clean-wav │ ├── 4507-16021-0019.txt │ ├── 7176-92135-0009.txt │ ├── 4507-16021-0019.wav │ └── 7176-92135-0009.wav │ └── train-clean-100-wav │ ├── 3879-174923-0005.txt │ ├── 211-122425-0059.txt │ ├── 1970-28415-0023.txt │ ├── 2843-152918-0008.txt │ ├── 3259-158083-0026.txt │ ├── 1970-28415-0023.wav │ ├── 211-122425-0059.wav │ ├── 2843-152918-0008.wav │ ├── 3259-158083-0026.wav │ └── 3879-174923-0005.wav ├── requirements.txt ├── CONTRIBUTING.md ├── .travis.yml ├── README.md ├── LICENSE ├── .gitignore ├── tests └── tests_utils.py ├── demo.py ├── utils.py └── train.py /data/LibriSpeech/dev-clean-wav/3752-4944-0041.txt: -------------------------------------------------------------------------------- 1 | how delightful the grass smells -------------------------------------------------------------------------------- /data/LibriSpeech/test-clean-wav/4507-16021-0019.txt: -------------------------------------------------------------------------------- 1 | it is the language of wretchedness -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/3879-174923-0005.txt: -------------------------------------------------------------------------------- 1 | he must vanish out of the world -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/211-122425-0059.txt: -------------------------------------------------------------------------------- 1 | and the two will pass off together -------------------------------------------------------------------------------- /data/LibriSpeech/dev-clean-wav/777-126732-0068.txt: -------------------------------------------------------------------------------- 1 | that boy hears too much of what is talked about here -------------------------------------------------------------------------------- /data/LibriSpeech/test-clean-wav/7176-92135-0009.txt: -------------------------------------------------------------------------------- 1 | and i should begin with a short homily on soliloquy -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/1970-28415-0023.txt: -------------------------------------------------------------------------------- 1 | where people were making their gifts to god -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/2843-152918-0008.txt: -------------------------------------------------------------------------------- 1 | one day may be pleasant enough but two three four -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/3259-158083-0026.txt: -------------------------------------------------------------------------------- 1 | i have a nephew fighting for democracy in france -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-speech-features==0.6 2 | tensorflow>=1.12.1 3 | tensorflow-tensorboard>=0.1.8 4 | -------------------------------------------------------------------------------- /data/LibriSpeech/dev-clean-wav/3752-4944-0041.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/dev-clean-wav/3752-4944-0041.wav -------------------------------------------------------------------------------- /data/LibriSpeech/dev-clean-wav/777-126732-0068.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/dev-clean-wav/777-126732-0068.wav -------------------------------------------------------------------------------- /data/LibriSpeech/test-clean-wav/4507-16021-0019.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/test-clean-wav/4507-16021-0019.wav -------------------------------------------------------------------------------- /data/LibriSpeech/test-clean-wav/7176-92135-0009.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/test-clean-wav/7176-92135-0009.wav -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/1970-28415-0023.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/train-clean-100-wav/1970-28415-0023.wav -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/211-122425-0059.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/train-clean-100-wav/211-122425-0059.wav -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/2843-152918-0008.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/train-clean-100-wav/2843-152918-0008.wav -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/3259-158083-0026.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/train-clean-100-wav/3259-158083-0026.wav -------------------------------------------------------------------------------- /data/LibriSpeech/train-clean-100-wav/3879-174923-0005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ugnelis/tensorflow-rnn-ctc/HEAD/data/LibriSpeech/train-clean-100-wav/3879-174923-0005.wav -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | First of all, thank you for taking the time to contribute to this project! 4 | 5 | When contributing to this repository, please first discuss the change you wish to make via issue. 6 | 7 | ## Project Style Guide 8 | 9 | For Python files, please keep to PEP 8 standart (https://www.python.org/dev/peps/pep-0008/). 10 | 11 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | install: 5 | - pip install -r requirements.txt 6 | 7 | # See https://gist.github.com/dan-blanchard/7045057 8 | - pip install --upgrade pip setuptools wheel 9 | - pip install --only-binary=numpy,scipy numpy scipy 10 | script: 11 | - python -m unittest discover tests 12 | notifications: 13 | email: 14 | on_success: never 15 | on_failure: always 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow RNN CTC [![Build Status](https://travis-ci.org/ugnelis/tensorflow-rnn-ctc.svg?branch=master)](https://travis-ci.org/ugnelis/tensorflow-rnn-ctc) 2 | 3 | Connectionist Temporal Classification (CTC) by using Recurrent Neural Network (RNN) in TensorFlow. 4 | 5 | ## Requirements 6 | 7 | - Python 2.7+ (for Linux) 8 | - Python 3.5+ (for Windows) 9 | - TensorFlow 1.12.1+ 10 | - NumPy 1.5+ 11 | - SciPy 0.12+ 12 | - python_speech_features 0.1+ 13 | 14 | ## Installation 15 | 16 | I suggest you to use [Anaconda](https://www.anaconda.com/download/). For `TensorFlow` and `python_speech_features` use `pip`: 17 | 18 | ```bash 19 | $ activate anaconda_env_name 20 | (anaconda_env_name)$ pip install python_speech_features 21 | (anaconda_env_name)$ pip install --ignore-installed --upgrade tensorflow # without GPU 22 | (anaconda_env_name)$ pip install --ignore-installed --upgrade tensorflow-gpu # with GPU 23 | ``` 24 | 25 | ## Training 26 | 27 | Run training by using `train.py` file. 28 | 29 | ```bash 30 | python train.py 31 | ``` 32 | 33 | ## License 34 | This project is licensed under the terms of the MIT license. 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ugnius Malūkas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project folders. 2 | models/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IntelliJ 107 | .idea/ 108 | *.iml 109 | -------------------------------------------------------------------------------- /tests/tests_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import utils 3 | 4 | from tempfile import mkdtemp, mktemp 5 | from shutil import rmtree 6 | 7 | # Extension of created test files. 8 | TEST_FILE_EXTENSION = 'txt' 9 | 10 | # Directory of test suite wav audio files. 11 | TEST_AUDIO_FILE_DIR = 'data/LibriSpeech/test-clean-wav' 12 | 13 | 14 | class UtilsTest(unittest.TestCase): 15 | def setUp(self): 16 | self.test_file_dir = mkdtemp() 17 | 18 | def tearDown(self): 19 | rmtree(self.test_file_dir) 20 | 21 | def test_read_text_file(self): 22 | content, file_path = self.create_test_file('test') 23 | self.assertEqual(utils.read_text_file(file_path), content) 24 | 25 | def test_make_char_array(self): 26 | self.assertEqual(utils.make_char_array('ab').tolist(), ['a', 'b']) 27 | 28 | def test_normalize_text(self): 29 | text = 'A\' ' 30 | 31 | self.assertEqual(utils.normalize_text(text), 'a') 32 | self.assertEqual(utils.normalize_text(text, False), 'a\'') 33 | 34 | def test_sparse_tuples_from_sequences(self): 35 | result = utils.sparse_tuples_from_sequences([[1], [2, 3]]) 36 | 37 | self.assertEqual(result[0].tolist(), [[0, 0], [1, 0], [1, 1]]) 38 | self.assertEqual(result[1].tolist(), [1, 2, 3]) 39 | self.assertEqual(result[2].tolist(), [2, 2]) 40 | 41 | def test_read_audio_files(self): 42 | self.assertTrue(utils.read_audio_files(TEST_AUDIO_FILE_DIR).size > 0) 43 | 44 | def test_read_text_files(self): 45 | content = self.create_test_file('test')[0] 46 | self.assertEqual(utils.read_text_files(self.test_file_dir, [TEST_FILE_EXTENSION]), content) 47 | 48 | def test_sequence_decoder(self): 49 | self.assertEqual(utils.sequence_decoder([1, 2, 3]), 'abc') 50 | 51 | def test_texts_encoder(self): 52 | self.assertEqual(utils.texts_encoder(['abc']).tolist()[0], [1, 2, 3]) 53 | 54 | def test_standardize_audios(self): 55 | files = utils.read_audio_files(TEST_AUDIO_FILE_DIR) 56 | self.assertEqual(utils.standardize_audios(files).size, files.size) 57 | 58 | def test_get_sequence_lengths(self): 59 | self.assertEqual(utils.get_sequence_lengths([[1], [], [1, 2]]).tolist(), [1, 0, 2]) 60 | 61 | def test_make_sequences_same_length(self): 62 | self.assertEqual(utils.make_sequences_same_length([[1, 2], []], [2, 0]).tolist(), [[1.0, 2.0], [0.0, 0.0]]) 63 | 64 | def create_test_file(self, content): 65 | """ 66 | Write a string to a new temporary test file. 67 | 68 | Args: 69 | content: 70 | Content to write to the text file. 71 | Returns: 72 | A tuple with (file content, file path). 73 | """ 74 | file_path = mktemp(suffix='.' + TEST_FILE_EXTENSION, dir=self.test_file_dir) 75 | 76 | with open(file_path, 'w') as f: 77 | f.write(content) 78 | 79 | return content, file_path 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import logging 7 | import sys 8 | import math 9 | 10 | import tensorflow as tf 11 | 12 | import utils 13 | 14 | # Logging configuration. 15 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 16 | level=logging.DEBUG, 17 | stream=sys.stdout) 18 | 19 | # Model path. 20 | MODEL_PATH = "./models/model.ckpt" 21 | 22 | # Summary directory. 23 | SUMMARY_PATH = "./logs/" 24 | 25 | # Data directories. 26 | DATA_DIR = "./data/LibriSpeech/" 27 | TRAIN_DIR = DATA_DIR + "train-clean-100-wav/" 28 | TEST_DIR = DATA_DIR + "test-clean-wav/" 29 | DEV_DIR = DATA_DIR + "dev-clean-wav/" 30 | 31 | # Constants. 32 | SPACE_TOKEN = '' 33 | SPACE_INDEX = 0 34 | FIRST_INDEX = ord('a') - 1 # 0 is reserved to space 35 | 36 | # Number of features. 37 | NUM_FEATURES = 13 38 | 39 | # Accounting the 0th index + space + blank label = 28 characters 40 | NUM_CLASSES = ord('z') - ord('a') + 1 + 1 + 1 41 | 42 | # Hyper-parameters. 43 | NUM_EPOCHS = 200 44 | NUM_HIDDEN = 50 45 | NUM_LAYERS = 1 46 | BATCH_SIZE = 1 47 | 48 | # Optimizer parameters. 49 | INITIAL_LEARNING_RATE = 1e-2 50 | MOMENTUM = 0.9 51 | 52 | 53 | def main(argv): 54 | # Read test data files. 55 | test_texts = utils.read_text_files(TEST_DIR) 56 | test_labels = utils.texts_encoder(test_texts, 57 | first_index=FIRST_INDEX, 58 | space_index=SPACE_INDEX, 59 | space_token=SPACE_TOKEN) 60 | test_labels = utils.sparse_tuples_from_sequences(test_labels) 61 | test_inputs = utils.read_audio_files(DEV_DIR) 62 | test_inputs = utils.standardize_audios(test_inputs) 63 | test_sequence_lengths = utils.get_sequence_lengths(test_inputs) 64 | test_inputs = utils.make_sequences_same_length(test_inputs, test_sequence_lengths) 65 | 66 | with tf.device('/cpu:0'): 67 | config = tf.ConfigProto() 68 | 69 | graph = tf.Graph() 70 | with graph.as_default(): 71 | logging.debug("Starting new TensorFlow graph.") 72 | inputs_placeholder = tf.placeholder(tf.float32, [None, None, NUM_FEATURES]) 73 | 74 | # SparseTensor placeholder required by ctc_loss op. 75 | labels_placeholder = tf.sparse_placeholder(tf.int32) 76 | 77 | # 1d array of size [batch_size]. 78 | sequence_length_placeholder = tf.placeholder(tf.int32, [None]) 79 | 80 | # Defining the cell. 81 | cell = tf.contrib.rnn.LSTMCell(NUM_HIDDEN, state_is_tuple=True) 82 | 83 | # Stacking rnn cells. 84 | stack = tf.contrib.rnn.MultiRNNCell([cell] * NUM_LAYERS, 85 | state_is_tuple=True) 86 | 87 | # Creates a recurrent neural network. 88 | outputs, _ = tf.nn.dynamic_rnn(stack, inputs_placeholder, sequence_length_placeholder, dtype=tf.float32) 89 | 90 | shape = tf.shape(inputs_placeholder) 91 | batch_size, max_time_steps = shape[0], shape[1] 92 | 93 | # Reshaping to apply the same weights over the time steps. 94 | outputs = tf.reshape(outputs, [-1, NUM_HIDDEN]) 95 | 96 | weights = tf.Variable(tf.truncated_normal([NUM_HIDDEN, NUM_CLASSES], stddev=0.1), 97 | name='weights') 98 | bias = tf.Variable(tf.constant(0., shape=[NUM_CLASSES]), 99 | name='bias') 100 | 101 | # Doing the affine projection. 102 | logits = tf.matmul(outputs, weights) + bias 103 | 104 | # Reshaping back to the original shape. 105 | logits = tf.reshape(logits, [batch_size, -1, NUM_CLASSES]) 106 | 107 | # Time is major. 108 | logits = tf.transpose(logits, (1, 0, 2)) 109 | 110 | # CTC decoder. 111 | decoded, neg_sum_logits = tf.nn.ctc_greedy_decoder(logits, sequence_length_placeholder) 112 | 113 | with tf.Session(config=config, graph=graph) as session: 114 | logging.debug("Starting TensorFlow session.") 115 | 116 | # Initialize the weights and biases. 117 | tf.global_variables_initializer().run() 118 | 119 | # Saver op to save and restore all the variables. 120 | saver = tf.train.Saver() 121 | 122 | # Restore model weights from previously saved model. 123 | saver.restore(session, MODEL_PATH) 124 | 125 | test_feed = {inputs_placeholder: test_inputs, 126 | sequence_length_placeholder: test_sequence_lengths} 127 | # Decoding. 128 | decoded_outputs = session.run(decoded[0], feed_dict=test_feed) 129 | dense_decoded = tf.sparse_tensor_to_dense(decoded_outputs, default_value=-1).eval(session=session) 130 | test_num = test_texts.shape[0] 131 | 132 | for i, sequence in enumerate(dense_decoded): 133 | sequence = [s for s in sequence if s != -1] 134 | decoded_text = utils.sequence_decoder(sequence) 135 | 136 | logging.info("Sequence %d/%d", i + 1, test_num) 137 | logging.info("Original:\n%s", test_texts[i]) 138 | logging.info("Decoded:\n%s", decoded_text) 139 | 140 | 141 | if __name__ == '__main__': 142 | tf.app.run() 143 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import logging 9 | import unicodedata 10 | import codecs 11 | 12 | import numpy as np 13 | import scipy.io.wavfile as wav 14 | from python_speech_features import mfcc 15 | 16 | 17 | def read_text_file(path): 18 | """ 19 | Read text from file 20 | 21 | Args: 22 | path: string. 23 | Path to text file. 24 | Returns: 25 | string. 26 | Read text. 27 | """ 28 | with codecs.open(path, encoding="utf-8") as file: 29 | return file.read() 30 | 31 | 32 | def normalize_text(text, remove_apostrophe=True): 33 | """ 34 | Normalize given text. 35 | 36 | Args: 37 | text: string. 38 | Given text. 39 | remove_apostrophe: bool. 40 | Whether to remove apostrophe in given text. 41 | Returns: 42 | string. 43 | Normalized text. 44 | """ 45 | 46 | # Convert unicode characters to ASCII. 47 | result = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode() 48 | 49 | # Remove apostrophes. 50 | if remove_apostrophe: 51 | result = result.replace("'", "") 52 | 53 | return re.sub("[^a-zA-Z']+", ' ', result).strip().lower() 54 | 55 | 56 | def read_text_files(dir, extensions=['txt']): 57 | """ 58 | Read text files. 59 | 60 | Args: 61 | dir: string. 62 | Data directory. 63 | extensions: list of strings. 64 | File extensions. 65 | Returns: 66 | files: array of texts. 67 | """ 68 | if not os.path.isdir(dir): 69 | logging.error("Text files directory %s is not found.", dir) 70 | return None 71 | 72 | if not all(isinstance(extension, str) for extension in extensions): 73 | logging.error("Variable 'extensions' is not a list of strings.") 74 | return None 75 | 76 | # Get files list. 77 | files_paths_list = [] 78 | for extension in extensions: 79 | file_glob = os.path.join(dir, '*.' + extension) 80 | files_paths_list.extend(glob.glob(file_glob)) 81 | 82 | # Read files. 83 | files = [] 84 | for file_path in files_paths_list: 85 | file = read_text_file(file_path) 86 | file = normalize_text(file) 87 | files.append(file) 88 | 89 | files = np.array(files) 90 | return files 91 | 92 | 93 | def read_audio_files(dir, extensions=['wav']): 94 | """ 95 | Read audio files. 96 | 97 | Args: 98 | dir: string. 99 | Data directory. 100 | extensions: list of strings. 101 | File extensions. 102 | Returns: 103 | files: array of audios. 104 | """ 105 | if not os.path.isdir(dir): 106 | logging.error("Audio files directory %s is not found.", dir) 107 | return None 108 | 109 | if not all(isinstance(extension, str) for extension in extensions): 110 | logging.error("Variable 'extensions' is not a list of strings.") 111 | return None 112 | 113 | # Get files list. 114 | files_paths_list = [] 115 | for extension in extensions: 116 | file_glob = os.path.join(dir, '*.' + extension) 117 | files_paths_list.extend(glob.glob(file_glob)) 118 | 119 | # Read files. 120 | files = [] 121 | for file_path in files_paths_list: 122 | audio_rate, audio_data = wav.read(file_path) 123 | file = mfcc(audio_data, samplerate=audio_rate) 124 | files.append(file) 125 | 126 | files = np.array(files) 127 | return files 128 | 129 | 130 | def make_char_array(text, space_token=''): 131 | """ 132 | Make text as char array. Replace spaces with space token. 133 | 134 | Args: 135 | text: string. 136 | Given text. 137 | space_token: string. 138 | Text which represents space char. 139 | Returns: 140 | string array. 141 | Split text. 142 | """ 143 | result = np.hstack([space_token if x == ' ' else list(x) for x in text]) 144 | return result 145 | 146 | 147 | def sparse_tuples_from_sequences(sequences, dtype=np.int32): 148 | """ 149 | Create a sparse representations of inputs. 150 | 151 | Args: 152 | sequences: a list of lists of type dtype where each element is a sequence 153 | Returns: 154 | A tuple with (indices, values, shape) 155 | """ 156 | indexes = [] 157 | values = [] 158 | 159 | for n, sequence in enumerate(sequences): 160 | indexes.extend(zip([n] * len(sequence), range(len(sequence)))) 161 | values.extend(sequence) 162 | 163 | indexes = np.asarray(indexes, dtype=np.int64) 164 | values = np.asarray(values, dtype=dtype) 165 | shape = np.asarray([len(sequences), np.asarray(indexes).max(0)[1] + 1], dtype=np.int64) 166 | 167 | return indexes, values, shape 168 | 169 | 170 | def sequence_decoder(sequence, first_index=(ord('a') - 1)): 171 | """ 172 | Read text files. 173 | 174 | Args: 175 | sequence: list of int. 176 | Encoded sequence 177 | first_index: int. 178 | First index (usually index of 'a'). 179 | Returns: 180 | decoded_text: string. 181 | """ 182 | decoded_text = ''.join([chr(x) for x in np.asarray(sequence) + first_index]) 183 | # Replacing blank label to none. 184 | decoded_text = decoded_text.replace(chr(ord('z') + 1), '') 185 | # Replacing space label to space. 186 | decoded_text = decoded_text.replace(chr(ord('a') - 1), ' ') 187 | return decoded_text 188 | 189 | 190 | def texts_encoder(texts, first_index=(ord('a') - 1), space_index=0, space_token=''): 191 | """ 192 | Encode texts to numbers. 193 | 194 | Args: 195 | texts: list of texts. 196 | Data directory. 197 | first_index: int. 198 | First index (usually index of 'a'). 199 | space_index: int. 200 | Index of 'space'. 201 | space_token: string. 202 | 'space' representation. 203 | Returns: 204 | array of encoded texts. 205 | """ 206 | result = [] 207 | for text in texts: 208 | item = make_char_array(text, space_token) 209 | item = np.asarray([space_index if x == space_token else ord(x) - first_index for x in item]) 210 | result.append(item) 211 | 212 | return np.array(result) 213 | 214 | 215 | def standardize_audios(inputs): 216 | """ 217 | Standardize audio inputs. 218 | 219 | Args: 220 | inputs: array of audios. 221 | Audio files. 222 | Returns: 223 | decoded_text: array of audios. 224 | """ 225 | result = [] 226 | for i in range(inputs.shape[0]): 227 | item = np.array((inputs[i] - np.mean(inputs[i])) / np.std(inputs[i])) 228 | result.append(item) 229 | 230 | return np.array(result) 231 | 232 | 233 | def get_sequence_lengths(inputs): 234 | """ 235 | Get sequence length of each sequence. 236 | 237 | Args: 238 | inputs: list of lists where each element is a sequence. 239 | Returns: 240 | array of sequence lengths. 241 | """ 242 | result = [] 243 | for input in inputs: 244 | result.append(len(input)) 245 | 246 | return np.array(result, dtype=np.int64) 247 | 248 | 249 | def make_sequences_same_length(sequences, sequences_lengths, default_value=0.0): 250 | """ 251 | Make sequences same length for avoiding value 252 | error: setting an array element with a sequence. 253 | 254 | Args: 255 | sequences: list of sequence arrays. 256 | sequences_lengths: list of int. 257 | default_value: float32. 258 | Default value of newly created array. 259 | Returns: 260 | result: array of with same dimensions [num_samples, max_length, num_features]. 261 | """ 262 | 263 | # Get number of sequnces. 264 | num_samples = len(sequences) 265 | 266 | max_length = np.max(sequences_lengths) 267 | 268 | # Get shape of the first non-zero length sequence. 269 | sample_shape = tuple() 270 | for s in sequences: 271 | if len(s) > 0: 272 | sample_shape = np.asarray(s).shape[1:] 273 | break 274 | 275 | # Create same sizes array 276 | result = (np.ones((num_samples, max_length) + sample_shape) * default_value) 277 | 278 | # Put sequences into new array. 279 | for idx, sequence in enumerate(sequences): 280 | result[idx, :len(sequence)] = sequence 281 | 282 | return result 283 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import logging 7 | import sys 8 | import math 9 | 10 | import tensorflow as tf 11 | 12 | import utils 13 | 14 | # Logging configuration. 15 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 16 | level=logging.DEBUG, 17 | stream=sys.stdout) 18 | 19 | # Model path. 20 | MODEL_PATH = "./models/model.ckpt" 21 | 22 | # Summary directory. 23 | SUMMARY_PATH = "./logs/" 24 | 25 | # Data directories. 26 | DATA_DIR = "./data/LibriSpeech/" 27 | TRAIN_DIR = DATA_DIR + "train-clean-100-wav/" 28 | TEST_DIR = DATA_DIR + "test-clean-wav/" 29 | DEV_DIR = DATA_DIR + "dev-clean-wav/" 30 | 31 | # Constants. 32 | SPACE_TOKEN = '' 33 | SPACE_INDEX = 0 34 | FIRST_INDEX = ord('a') - 1 # 0 is reserved to space 35 | 36 | # Number of features. 37 | NUM_FEATURES = 13 38 | 39 | # Accounting the 0th index + space + blank label = 28 characters 40 | NUM_CLASSES = ord('z') - ord('a') + 1 + 1 + 1 41 | 42 | # Hyper-parameters. 43 | NUM_EPOCHS = 200 44 | NUM_HIDDEN = 50 45 | NUM_LAYERS = 2 46 | BATCH_SIZE = 4 47 | 48 | # Optimizer parameters. 49 | INITIAL_LEARNING_RATE = 1e-2 50 | MOMENTUM = 0.9 51 | 52 | 53 | def main(argv): 54 | # Read train data files. 55 | train_texts = utils.read_text_files(TRAIN_DIR) 56 | train_labels = utils.texts_encoder(train_texts, 57 | first_index=FIRST_INDEX, 58 | space_index=SPACE_INDEX, 59 | space_token=SPACE_TOKEN) 60 | train_inputs = utils.read_audio_files(TRAIN_DIR) 61 | train_inputs = utils.standardize_audios(train_inputs) 62 | train_sequence_lengths = utils.get_sequence_lengths(train_inputs) 63 | train_inputs = utils.make_sequences_same_length(train_inputs, train_sequence_lengths) 64 | 65 | # Read validation data files. 66 | validation_texts = utils.read_text_files(DEV_DIR) 67 | validation_labels = utils.texts_encoder(validation_texts, 68 | first_index=FIRST_INDEX, 69 | space_index=SPACE_INDEX, 70 | space_token=SPACE_TOKEN) 71 | validation_labels = utils.sparse_tuples_from_sequences(validation_labels) 72 | validation_inputs = utils.read_audio_files(DEV_DIR) 73 | validation_inputs = utils.standardize_audios(validation_inputs) 74 | validation_sequence_lengths = utils.get_sequence_lengths(validation_inputs) 75 | validation_inputs = utils.make_sequences_same_length(validation_inputs, validation_sequence_lengths) 76 | 77 | # Read test data files. 78 | test_texts = utils.read_text_files(TEST_DIR) 79 | test_labels = utils.texts_encoder(test_texts, 80 | first_index=FIRST_INDEX, 81 | space_index=SPACE_INDEX, 82 | space_token=SPACE_TOKEN) 83 | test_labels = utils.sparse_tuples_from_sequences(test_labels) 84 | test_inputs = utils.read_audio_files(DEV_DIR) 85 | test_inputs = utils.standardize_audios(test_inputs) 86 | test_sequence_lengths = utils.get_sequence_lengths(test_inputs) 87 | test_inputs = utils.make_sequences_same_length(test_inputs, test_sequence_lengths) 88 | 89 | with tf.device('/cpu:0'): 90 | config = tf.ConfigProto() 91 | 92 | graph = tf.Graph() 93 | with graph.as_default(): 94 | logging.debug("Starting new TensorFlow graph.") 95 | inputs_placeholder = tf.placeholder(tf.float32, [None, None, NUM_FEATURES]) 96 | 97 | # SparseTensor placeholder required by ctc_loss op. 98 | labels_placeholder = tf.sparse_placeholder(tf.int32) 99 | 100 | # 1d array of size [batch_size]. 101 | sequence_length_placeholder = tf.placeholder(tf.int32, [None]) 102 | 103 | # Defining the cell. 104 | def lstm_cell(): 105 | return tf.contrib.rnn.LSTMCell(NUM_HIDDEN, state_is_tuple=True) 106 | 107 | # Stacking rnn cells. 108 | stack = tf.contrib.rnn.MultiRNNCell( 109 | [lstm_cell() for _ in range(NUM_LAYERS)], state_is_tuple=True) 110 | 111 | # Creates a recurrent neural network. 112 | outputs, _ = tf.nn.dynamic_rnn(stack, inputs_placeholder, sequence_length_placeholder, dtype=tf.float32) 113 | 114 | shape = tf.shape(inputs_placeholder) 115 | batch_size, max_time_steps = shape[0], shape[1] 116 | 117 | # Reshaping to apply the same weights over the time steps. 118 | outputs = tf.reshape(outputs, [-1, NUM_HIDDEN]) 119 | 120 | weights = tf.Variable(tf.truncated_normal([NUM_HIDDEN, NUM_CLASSES], stddev=0.1), 121 | name='weights') 122 | bias = tf.Variable(tf.constant(0., shape=[NUM_CLASSES]), 123 | name='bias') 124 | 125 | # Doing the affine projection. 126 | logits = tf.matmul(outputs, weights) + bias 127 | 128 | # Reshaping back to the original shape. 129 | logits = tf.reshape(logits, [batch_size, -1, NUM_CLASSES]) 130 | 131 | # Time is major. 132 | logits = tf.transpose(logits, (1, 0, 2)) 133 | 134 | with tf.name_scope('loss'): 135 | loss = tf.nn.ctc_loss(labels_placeholder, logits, sequence_length_placeholder) 136 | cost = tf.reduce_mean(loss) 137 | tf.summary.scalar("loss", cost) 138 | 139 | optimizer = tf.train.MomentumOptimizer(INITIAL_LEARNING_RATE, 0.9).minimize(cost) 140 | 141 | # CTC decoder. 142 | decoded, neg_sum_logits = tf.nn.ctc_greedy_decoder(logits, sequence_length_placeholder) 143 | 144 | label_error_rate = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), 145 | labels_placeholder)) 146 | 147 | with tf.Session(config=config, graph=graph) as session: 148 | logging.debug("Starting TensorFlow session.") 149 | 150 | # Saver op to save and restore all the variables. 151 | saver = tf.train.Saver() 152 | 153 | # Merge all the summaries and write them out. 154 | merged_summary = tf.summary.merge_all() 155 | 156 | # Initializing summary writer for TensorBoard. 157 | summary_writer = tf.summary.FileWriter(SUMMARY_PATH, tf.get_default_graph()) 158 | 159 | # Initialize the weights and biases. 160 | tf.global_variables_initializer().run() 161 | 162 | train_num = train_inputs.shape[0] 163 | validation_num = validation_inputs.shape[0] 164 | 165 | # Check if there is any example. 166 | if train_num <= 0: 167 | logging.error("There are no training examples.") 168 | return 169 | 170 | num_batches_per_epoch = math.ceil(train_num / BATCH_SIZE) 171 | 172 | for current_epoch in range(NUM_EPOCHS): 173 | train_cost = 0 174 | train_label_error_rate = 0 175 | start_time = time.time() 176 | 177 | for step in range(num_batches_per_epoch): 178 | # Format batches. 179 | if int(train_num / ((step + 1) * BATCH_SIZE)) >= 1: 180 | indexes = [i % train_num for i in range(step * BATCH_SIZE, (step + 1) * BATCH_SIZE)] 181 | else: 182 | indexes = [i % train_num for i in range(step * BATCH_SIZE, train_num)] 183 | 184 | batch_train_inputs = train_inputs[indexes] 185 | batch_train_sequence_lengths = train_sequence_lengths[indexes] 186 | batch_train_targets = utils.sparse_tuples_from_sequences(train_labels[indexes]) 187 | 188 | feed = {inputs_placeholder: batch_train_inputs, 189 | labels_placeholder: batch_train_targets, 190 | sequence_length_placeholder: batch_train_sequence_lengths} 191 | 192 | batch_cost, _, summary = session.run([cost, optimizer, merged_summary], feed) 193 | train_cost += batch_cost * BATCH_SIZE 194 | train_label_error_rate += session.run(label_error_rate, feed_dict=feed) * BATCH_SIZE 195 | 196 | # Write logs at every iteration. 197 | summary_writer.add_summary(summary, current_epoch * num_batches_per_epoch + step) 198 | 199 | train_cost /= train_num 200 | train_label_error_rate /= train_num 201 | 202 | validation_feed = {inputs_placeholder: validation_inputs, 203 | labels_placeholder: validation_labels, 204 | sequence_length_placeholder: validation_sequence_lengths} 205 | 206 | validation_cost, validation_label_error_rate = session.run([cost, label_error_rate], 207 | feed_dict=validation_feed) 208 | 209 | validation_cost /= validation_num 210 | validation_label_error_rate /= validation_num 211 | 212 | # Output intermediate step information. 213 | logging.info("Epoch %d/%d (time: %.3f s)", 214 | current_epoch + 1, 215 | NUM_EPOCHS, 216 | time.time() - start_time) 217 | logging.info("Train cost: %.3f, train label error rate: %.3f", 218 | train_cost, 219 | train_label_error_rate) 220 | logging.info("Validation cost: %.3f, validation label error rate: %.3f", 221 | validation_cost, 222 | validation_label_error_rate) 223 | 224 | test_feed = {inputs_placeholder: test_inputs, 225 | sequence_length_placeholder: test_sequence_lengths} 226 | # Decoding. 227 | decoded_outputs = session.run(decoded[0], feed_dict=test_feed) 228 | dense_decoded = tf.sparse_tensor_to_dense(decoded_outputs, default_value=-1).eval(session=session) 229 | test_num = test_texts.shape[0] 230 | 231 | for i, sequence in enumerate(dense_decoded): 232 | sequence = [s for s in sequence if s != -1] 233 | decoded_text = utils.sequence_decoder(sequence) 234 | 235 | logging.info("Sequence %d/%d", i + 1, test_num) 236 | logging.info("Original:\n%s", test_texts[i]) 237 | logging.info("Decoded:\n%s", decoded_text) 238 | 239 | # Save model weights to disk. 240 | save_path = saver.save(session, MODEL_PATH) 241 | logging.info("Model saved in file: %s", save_path) 242 | 243 | 244 | if __name__ == '__main__': 245 | tf.app.run() 246 | --------------------------------------------------------------------------------