├── asr ├── __init__.py ├── util │ ├── __init__.py │ ├── test_tf_installation.py │ ├── csv_helper.py │ ├── matplotlib_helper.py │ ├── storage.py │ ├── metrics.py │ ├── tf_contrib.py │ └── hooks.py ├── evaluate.py ├── labels.py ├── predict.py ├── train.py ├── params.py ├── input_functions.py └── model.py ├── .gitignore ├── images ├── network-architectures.png ├── ds1-network-architecture.png └── ds2-network-architecture.png ├── requirements.txt ├── log_temp.sh ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── train.sh ├── LICENSE ├── toc-gen.py ├── testruns.md └── README.md /asr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | *.idea 3 | *__pycache__ 4 | nohup.out 5 | *.swp 6 | temp.log 7 | -------------------------------------------------------------------------------- /images/network-architectures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdangschat/ctc-asr/HEAD/images/network-architectures.png -------------------------------------------------------------------------------- /images/ds1-network-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdangschat/ctc-asr/HEAD/images/ds1-network-architecture.png -------------------------------------------------------------------------------- /images/ds2-network-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdangschat/ctc-asr/HEAD/images/ds2-network-architecture.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.12.0 2 | # tensorflow-gpu >= 1.12.0 3 | tensorflow-estimator >= 1.12.0 4 | numpy >= 1.15 5 | scipy >= 1.1.0 6 | matplotlib >= 2.2.0 7 | tqdm >= 4.28.0 8 | python_speech_features >= 0.6 9 | gitpython >= 2.1.11 10 | requests >= 2.20.0 11 | ipython >= 6.5.0 12 | nvidia-ml-py3 >= 7.352.0 13 | librosa >= 0.6.2 -------------------------------------------------------------------------------- /log_temp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | LOG_TEMP_EVERY_SECOND=30 4 | LOG_TEMP_FILENAME="temp.log" 5 | 6 | while [[ True ]] 7 | do 8 | echo "========================================" >> ${LOG_TEMP_FILENAME} 9 | echo $(date) >> ${LOG_TEMP_FILENAME} 10 | echo "========================================" >> ${LOG_TEMP_FILENAME} 11 | echo "$(nvidia-smi -q -a | grep -E 'Power Draw|Memory Current|GPU Current Temp|Gpu|Used GPU Memory')" >> ${LOG_TEMP_FILENAME} 12 | # printf "\n" >> ${LOG_TEMP_FILENAME} 13 | # echo "$(sensors)" >> ${LOG_TEMP_FILENAME} 14 | printf "\n\n" >> ${LOG_TEMP_FILENAME} 15 | sleep ${LOG_TEMP_EVERY_SECOND} 16 | done -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | 5 | --- 6 | 7 | **Is your feature request related to a problem? Please describe.** 8 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 9 | 10 | **Describe the solution you'd like** 11 | A clear and concise description of what you want to happen. 12 | 13 | **Describe alternatives you've considered** 14 | A clear and concise description of any alternative solutions or features you've considered. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | -------------------------------------------------------------------------------- /asr/util/test_tf_installation.py: -------------------------------------------------------------------------------- 1 | """Validate the TensorFlow installation and availability of GPU support.""" 2 | 3 | import tensorflow as tf 4 | 5 | tf.logging.set_verbosity(tf.logging.INFO) 6 | 7 | 8 | def test_environment(): 9 | """Print TensorFlow installation information (GPU focused). 10 | 11 | Returns: 12 | Nothing. 13 | """ 14 | print('TensorFlow version:', tf.VERSION) 15 | print('GPU device name:', tf.test.gpu_device_name()) 16 | print('is GPU available:', tf.test.is_gpu_available()) 17 | print('is build with CUDA:', tf.test.is_built_with_cuda()) 18 | 19 | 20 | if __name__ == '__main__': 21 | test_environment() 22 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Configure a TMUX session named speech and prepare to start the training process. 4 | # Note: Script MUST be run from the repositories root folder. 5 | 6 | # First, kill all running temp_log.sh scripts. 7 | pgrep -f "bash ./log_temp.sh" | xargs -n 1 bash -c 'kill "$0"' 8 | 9 | tmux new-session -d -s speech '$SHELL' 10 | tmux set -g window-status-current-bg blue 11 | tmux select-pane -t 0 12 | tmux send-keys './log_temp.sh &' C-m 13 | tmux send-keys 'tail -f temp.log' C-m 14 | tmux split-window -h '$SHELL' 15 | tmux select-pane -t 1 16 | tmux send-keys 'tensorboard --logdir ../speech_checkpoints' C-m 17 | tmux split-window -v -p 80 '$SHELL' 18 | tmux select-pane -t 2 19 | tmux send-keys 'htop' C-m 20 | tmux select-pane -t 0 21 | tmux split-window -v -p 80 '$SHELL' 22 | tmux select-pane -t 1 23 | 24 | tmux attach-session -d -t speech 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | 5 | --- 6 | 7 | **Describe the bug** 8 | A clear and concise description of what the bug is. 9 | 10 | **To Reproduce** 11 | Steps to reproduce the behavior: 12 | 1. Go to '...' 13 | 2. Click on '....' 14 | 3. Scroll down to '....' 15 | 4. See error 16 | 17 | **Expected behavior** 18 | A clear and concise description of what you expected to happen. 19 | 20 | **Screenshots** 21 | If applicable, add screenshots to help explain your problem. 22 | 23 | **Desktop (please complete the following information):** 24 | - OS: [e.g. iOS] 25 | - Browser [e.g. chrome, safari] 26 | - Version [e.g. 22] 27 | 28 | **Smartphone (please complete the following information):** 29 | - Device: [e.g. iPhone6] 30 | - OS: [e.g. iOS8.1] 31 | - Browser [e.g. stock browser, safari] 32 | - Version [e.g. 22] 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 Marc Dangschat 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /asr/util/csv_helper.py: -------------------------------------------------------------------------------- 1 | """Helper methods to generate the CSV files.""" 2 | 3 | import csv 4 | import os 5 | 6 | from asr.params import CSV_HEADER_LENGTH, CSV_FIELDNAMES, CSV_DELIMITER, WIN_STEP 7 | 8 | 9 | def get_bucket_boundaries(csv_path, num_buckets): 10 | """Generate a list of bucket boundaries, based on the example length in the CSV file. 11 | 12 | The boundaries are chose based on the distribution of example lengths, to allow each bucket 13 | to fill up at the same rate. This produces at max `num_buckets`. 14 | 15 | Args: 16 | csv_path (str): Path to the CSV file. E.g. '../data/train.csv'. 17 | num_buckets (int): The maximum amount of buckets to create. 18 | 19 | Returns: 20 | List[int]: List containing bucket boundaries. 21 | """ 22 | assert os.path.exists(csv_path) and os.path.isfile(csv_path) 23 | 24 | with open(csv_path, 'r', encoding='utf-8') as file_handle: 25 | reader = csv.DictReader(file_handle, delimiter=CSV_DELIMITER, fieldnames=CSV_FIELDNAMES) 26 | csv_data = [csv_entry for csv_entry in reader][1:] 27 | 28 | # Calculate optimal bucket sizes. 29 | lengths = [int(float(d[CSV_HEADER_LENGTH]) / WIN_STEP) for d in csv_data] 30 | step = len(lengths) // num_buckets 31 | 32 | buckets = set() 33 | for i in range(step, len(lengths), step): 34 | buckets.add(lengths[i]) 35 | buckets = list(buckets) 36 | buckets.sort() 37 | 38 | return buckets 39 | -------------------------------------------------------------------------------- /asr/evaluate.py: -------------------------------------------------------------------------------- 1 | """Evaluate a trained ASR model.""" 2 | 3 | import tensorflow as tf 4 | 5 | from asr.input_functions import input_fn_generator 6 | from asr.model import CTCModel 7 | from asr.params import FLAGS 8 | 9 | # Evaluation specific flags. 10 | tf.flags.DEFINE_boolean('dev', False, 11 | "`True` if evaluation should use the dev set, " 12 | "`False` if it should use the test set.") 13 | 14 | # Which dataset TXT file to use for evaluation. 'test' or 'dev'. 15 | __EVALUATION_TARGET = 'dev' if FLAGS.dev else 'test' 16 | 17 | 18 | def main(_): 19 | """TensorFlow evaluation starting routine.""" 20 | 21 | # Setup TensorFlow run configuration and hooks. 22 | config = tf.estimator.RunConfig( 23 | model_dir=FLAGS.train_dir, 24 | save_summary_steps=FLAGS.log_frequency, 25 | session_config=tf.ConfigProto( 26 | log_device_placement=FLAGS.log_device_placement, 27 | gpu_options=tf.GPUOptions(allow_growth=FLAGS.allow_vram_growth) 28 | ) 29 | ) 30 | 31 | model = CTCModel() 32 | 33 | # Construct the estimator that embodies the model. 34 | estimator = tf.estimator.Estimator( 35 | model_fn=model.model_fn, 36 | model_dir=FLAGS.train_dir, 37 | config=config 38 | ) 39 | 40 | # Evaluate the trained model. 41 | dev_input_fn = input_fn_generator(__EVALUATION_TARGET) 42 | evaluation_result = estimator.evaluate(input_fn=dev_input_fn, hooks=None) 43 | tf.logging.info('Evaluation results for this model: {}'.format(evaluation_result)) 44 | 45 | 46 | if __name__ == '__main__': 47 | # General TensorFlow setup. 48 | tf.logging.set_verbosity(tf.logging.INFO) 49 | 50 | # Run training. 51 | tf.app.run() 52 | -------------------------------------------------------------------------------- /asr/labels.py: -------------------------------------------------------------------------------- 1 | """Convert characters (chr) to integer (int) labels and vice versa. 2 | 3 | REVIEW: index 0 bug, also see: 4 | https://github.com/baidu-research/warp-ctc/tree/master/tensorflow_binding 5 | 6 | `ctc_loss`_ maps labels from 0=, 1=, 2=a, ..., 27=z, 28= 7 | 8 | See: https://www.tensorflow.org/api_docs/python/tf/nn/ctc_loss 9 | """ 10 | 11 | __MAP = r' abcdefghijklmnopqrstuvwxyz' # 27 characters including . 12 | __CTOI = dict() 13 | __ITOC = dict([(0, '')]) # This is in case the net decodes a 0 on step 0. 14 | 15 | if not __CTOI or not __ITOC: 16 | for i, c in enumerate(__MAP): 17 | __CTOI.update({c: i + 1}) 18 | __ITOC.update({i + 1: c}) 19 | 20 | 21 | def ctoi(char): 22 | """Convert character label to integer. 23 | 24 | Args: 25 | char (char): Character label. 26 | 27 | Returns: 28 | int: Integer representation. 29 | """ 30 | if char not in __MAP: 31 | raise ValueError('Invalid input character \'{}\'.'.format(char)) 32 | if not len(char) == 1: 33 | raise ValueError('"{}" is not a valid character.'.format(char)) 34 | 35 | return __CTOI[char.lower()] 36 | 37 | 38 | def itoc(integer): 39 | """Convert integer label to character. 40 | 41 | Args: 42 | integer (int): Integer label. 43 | 44 | Returns: 45 | char: Character representation. 46 | """ 47 | if not 0 <= integer < num_classes(): 48 | raise ValueError('Integer label ({}) out of range.'.format(integer)) 49 | 50 | return __ITOC[integer] 51 | 52 | 53 | def num_classes(): 54 | """Return number of different classes, +1 for the label. 55 | 56 | Returns: 57 | int: Number of labels +1. 58 | """ 59 | return len(__MAP) + 2 60 | -------------------------------------------------------------------------------- /asr/util/matplotlib_helper.py: -------------------------------------------------------------------------------- 1 | """Wrapper for matplotlib. Configures image output for GUI and no GUI systems.""" 2 | 3 | import os 4 | from distutils.spawn import find_executable 5 | 6 | import matplotlib 7 | from matplotlib import rc 8 | 9 | 10 | def pyplot_display(func): 11 | """Provides decorator for `matplotlib.pyplot` plots. 12 | 13 | It only uses `show()` display or PyCharm remote has been found. 14 | Else the plot is being saved to /tmp/.png. 15 | 16 | Note: 17 | Wrapped methods need the imported pyplot argument as their first argument. 18 | Wrapped methods need to return the `fig = plt.figure(...)` argument after completion. 19 | 20 | Args: 21 | func (function): The plot function. 22 | 23 | Returns: 24 | function: The wrapped function. 25 | """ 26 | 27 | def wrapper(*args, **kwargs): 28 | rc('font', **{'family': 'monospace', 29 | 'serif': ['DejaVu Sans'], 30 | 'size': 12 31 | }) 32 | usetex = find_executable('latex') is not None 33 | rc('text', usetex=usetex) 34 | 35 | # Setup plot output based on the availability of a display (PyCharm remote execution). 36 | display = 'DISPLAY' in os.environ or \ 37 | all(var in os.environ for var in ['PYCHARM_HOSTED', 'PYCHARM_MATPLOTLIB_PORT']) 38 | if display: 39 | from matplotlib import pyplot as plt 40 | else: 41 | matplotlib.use('Agg') 42 | from matplotlib import pyplot as plt 43 | 44 | fig = func(plt, *args, **kwargs) # Call wrapped function. 45 | 46 | # Display or save the plot. 47 | if display: 48 | plt.show() 49 | # print('plt.show()') 50 | else: 51 | path = '/tmp/{}.png'.format(func.__name__) 52 | fig.savefig(path) 53 | print('Plot saved to: {}'.format(path)) 54 | 55 | return wrapper 56 | -------------------------------------------------------------------------------- /toc-gen.py: -------------------------------------------------------------------------------- 1 | """Generate the table of contents and insert it at the top of `README.md`. 2 | 3 | 4 | This script always assumes that the first heading is the document title 5 | and does NOT include it in the table of contents. 6 | It is assumed that only the first heading is H1 (#) and that all 7 | subsequent headings are at least H2 (##). 8 | 9 | Add the following to your `README.md` file (in the same folder): 10 | 11 | [...] 12 | 13 | ## Contents 14 | 15 | 16 | 17 | 18 | [...] 19 | """ 20 | 21 | import re 22 | 23 | 24 | _HEADER_REGEX = r'([#]+) ([^\n]+)' 25 | _PUNCTUATION_REGEX = r'[^\w\- ]' 26 | _HEADER_TEMPLATE = '{indent}* [{name}](#{anchor})' 27 | _START_TOC = '' 28 | _END_TOC = '' 29 | 30 | 31 | def __anchor(name): 32 | anchor = name.lower().replace(' ', '-') 33 | anchor = re.sub(_PUNCTUATION_REGEX, '', anchor) 34 | return anchor 35 | 36 | 37 | def __parse_header(header): 38 | r = re.match(_HEADER_REGEX, header) 39 | if r: 40 | level = len(r.group(1)) 41 | name = r.group(2) 42 | return level, __anchor(name), name 43 | 44 | 45 | def __iter_headers(md): 46 | headers = (line for line in md.splitlines() 47 | if line.startswith('#')) 48 | for header in headers: 49 | yield header 50 | 51 | 52 | def __get_header_item(header): 53 | level, anchor, name = __parse_header(header) 54 | # Levels are 1 for H1, 2 for H2, etc. Assuming all listed headings are 55 | # at least H2, then it should have zero indention. 56 | indent = ' ' * max(0, level - 2) 57 | return _HEADER_TEMPLATE.format(**locals()) 58 | 59 | 60 | def __gen_items(md): 61 | for header in __iter_headers(md): 62 | item = __get_header_item(header) 63 | yield item 64 | 65 | 66 | def __read_md(filename): 67 | with open(filename, 'r') as f: 68 | return f.read() 69 | 70 | 71 | def gen_toc(filename): 72 | md = __read_md(filename) 73 | i = md.index(_START_TOC) + len(_START_TOC) + 2 74 | j = md.index(_END_TOC) 75 | with open(filename, 'w') as f: 76 | f.write(md[:i]) 77 | for i, item in enumerate(__gen_items(md)): 78 | if i == 0: 79 | continue 80 | 81 | f.write(item + '\n') 82 | f.write('\n' + md[j:]) 83 | 84 | 85 | if __name__ == '__main__': 86 | gen_toc('README.md') 87 | -------------------------------------------------------------------------------- /asr/predict.py: -------------------------------------------------------------------------------- 1 | """Transcribe a given audio file. 2 | 3 | L8ER: Add flag to specify the checkpoint file to use. 4 | """ 5 | 6 | import tensorflow as tf 7 | 8 | from asr.input_functions import load_sample 9 | from asr.model import CTCModel 10 | from asr.params import FLAGS 11 | 12 | # Inference specific flags. 13 | tf.flags.DEFINE_string('input', 'examples/idontunderstandawordyoujustsaid.wav', 14 | "Path to the WAV file to transcribe.") 15 | 16 | 17 | def predict_input_fn(): 18 | """Generate a `tf.data.Dataset` containing the `FLAGS.input` file's spectrogram data. 19 | 20 | Returns: 21 | Dataset iterator. 22 | """ 23 | dataset = tf.data.Dataset.from_generator(__predict_input_generator, 24 | (tf.float32, tf.int32), 25 | (tf.TensorShape([None, 80]), tf.TensorShape([])) 26 | ) 27 | 28 | dataset = dataset.batch(1) 29 | iterator = dataset.make_one_shot_iterator() 30 | spectrogram, spectrogram_length = iterator.get_next() 31 | 32 | features = { 33 | 'spectrogram': spectrogram, 34 | 'spectrogram_length': spectrogram_length, 35 | } 36 | 37 | return features, None 38 | 39 | 40 | def __predict_input_generator(): 41 | yield load_sample(FLAGS.input) 42 | 43 | 44 | def main(_): 45 | """TensorFlow evaluation starting routine.""" 46 | 47 | # Setup TensorFlow run configuration and hooks. 48 | config = tf.estimator.RunConfig( 49 | model_dir=FLAGS.train_dir, 50 | session_config=tf.ConfigProto( 51 | log_device_placement=FLAGS.log_device_placement, 52 | gpu_options=tf.GPUOptions(allow_growth=FLAGS.allow_vram_growth) 53 | ) 54 | ) 55 | 56 | model = CTCModel() 57 | 58 | # Construct the estimator that embodies the model. 59 | estimator = tf.estimator.Estimator( 60 | model_fn=model.model_fn, 61 | model_dir=FLAGS.train_dir, 62 | config=config 63 | ) 64 | 65 | # Evaluate the given example. 66 | prediction = estimator.predict(input_fn=predict_input_fn, hooks=None) 67 | tf.logging.info('Inference results: {}'.format(list(prediction))) 68 | 69 | 70 | if __name__ == '__main__': 71 | # General TensorFlow setup. 72 | tf.enable_eager_execution() 73 | tf.logging.set_verbosity(tf.logging.INFO) 74 | 75 | # Run training. 76 | tf.app.run() 77 | -------------------------------------------------------------------------------- /asr/train.py: -------------------------------------------------------------------------------- 1 | """Train the ASR model. 2 | 3 | Tested with Python 3.5, 3.6 and 3.7. 4 | No Python 2 compatibility is being provided. 5 | """ 6 | 7 | import time 8 | 9 | import tensorflow as tf 10 | 11 | from asr.input_functions import input_fn_generator 12 | from asr.model import CTCModel 13 | from asr.params import FLAGS, get_parameters 14 | from asr.util import storage 15 | 16 | RANDOM_SEED = FLAGS.random_seed if FLAGS.random_seed != 0 else int(time.time()) 17 | 18 | 19 | def main(_): 20 | """TensorFlow starting routine.""" 21 | 22 | # Delete old model data if requested. 23 | storage.maybe_delete_checkpoints(FLAGS.train_dir, FLAGS.delete) 24 | 25 | # Logging information about the run. 26 | print('TensorFlow-Version: {}; Tag-Version: {}; Branch: {}; Commit: {}\nParameters: {}' 27 | .format(tf.VERSION, storage.git_latest_tag(), storage.git_branch(), 28 | storage.git_revision_hash(), get_parameters())) 29 | 30 | # Setup TensorFlow run configuration and hooks. 31 | config = tf.estimator.RunConfig( 32 | model_dir=FLAGS.train_dir, 33 | tf_random_seed=RANDOM_SEED, 34 | save_summary_steps=FLAGS.log_frequency, 35 | session_config=tf.ConfigProto( 36 | log_device_placement=FLAGS.log_device_placement, 37 | gpu_options=tf.GPUOptions(allow_growth=FLAGS.allow_vram_growth) 38 | ), 39 | keep_checkpoint_max=5, 40 | log_step_count_steps=FLAGS.log_frequency, 41 | train_distribute=None 42 | ) 43 | 44 | model = CTCModel() 45 | 46 | # Construct the estimator that embodies the model. 47 | estimator = tf.estimator.Estimator( 48 | model_fn=model.model_fn, 49 | model_dir=FLAGS.train_dir, 50 | config=config 51 | ) 52 | 53 | # Train the model. 54 | curriculum_train_input_fn = input_fn_generator('train_batch') 55 | estimator.train(input_fn=curriculum_train_input_fn, hooks=None) 56 | 57 | # Evaluate the trained model. 58 | dev_input_fn = input_fn_generator('dev') 59 | evaluation_result = estimator.evaluate(input_fn=dev_input_fn, hooks=None) 60 | tf.logging.info('Evaluation results of epoch {}: {}'.format(1, evaluation_result)) 61 | 62 | # Train the model and evaluate after each epoch. 63 | for epoch in range(2, FLAGS.max_epochs + 1): 64 | # Train the model. 65 | train_input_fn = input_fn_generator('train_bucket') 66 | estimator.train(input_fn=train_input_fn, hooks=None) 67 | 68 | # L8ER: Possible replacement for evaluate every epoch: 69 | # https://www.tensorflow.org/api_docs/python/tf/contrib/estimator/InMemoryEvaluatorHook 70 | 71 | # Evaluate the trained model. 72 | dev_input_fn = input_fn_generator('dev') 73 | evaluation_result = estimator.evaluate(input_fn=dev_input_fn, hooks=None) 74 | tf.logging.info('Evaluation results of epoch {}: {}'.format(epoch, evaluation_result)) 75 | 76 | 77 | if __name__ == '__main__': 78 | # General TensorFlow setup. 79 | tf.logging.set_verbosity(tf.logging.INFO) 80 | tf.set_random_seed(RANDOM_SEED) 81 | 82 | # Run training. 83 | tf.app.run() 84 | -------------------------------------------------------------------------------- /asr/util/storage.py: -------------------------------------------------------------------------------- 1 | """Storage and version control helper methods.""" 2 | 3 | import hashlib 4 | import os 5 | import shutil 6 | import tarfile 7 | import time 8 | 9 | import tensorflow as tf 10 | from git import Repo 11 | 12 | 13 | def git_revision_hash(): 14 | """Return the git revision id/hash. 15 | 16 | Returns: 17 | str: Git revision hash. 18 | """ 19 | repo = Repo('.', search_parent_directories=True) 20 | return repo.head.object.hexsha 21 | 22 | 23 | def git_branch(): 24 | """Return the active git branches name. 25 | 26 | Returns: 27 | str: Git branch. 28 | """ 29 | repo = Repo('.', search_parent_directories=True) 30 | try: 31 | branch_name = repo.active_branch.name 32 | except TypeError: 33 | branch_name = 'DETACHED HEAD' 34 | return branch_name 35 | 36 | 37 | def git_latest_tag(): 38 | """Return the latest added git tag. 39 | 40 | Returns: 41 | str: Git tag. 42 | """ 43 | repo = Repo('.', search_parent_directories=True) 44 | tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime) 45 | return tags[-1].name 46 | 47 | 48 | def delete_file_if_exists(path): 49 | """Delete the file for the given path, if it exists. 50 | 51 | Args: 52 | path (str): File path. 53 | 54 | Returns: 55 | Nothing. 56 | """ 57 | if os.path.exists(path) and os.path.isfile(path): 58 | for i in range(5): 59 | try: 60 | os.remove(path) 61 | break 62 | except (OSError, ValueError) as exception: 63 | print('WARN: Error deleting ({}/5) file: {}'.format(i, path)) 64 | if i == 4: 65 | raise RuntimeError(path) from exception 66 | time.sleep(1) 67 | 68 | 69 | def delete_directory_if_exists(path): 70 | """Recursive delete of a folder and all contained files. 71 | 72 | Args: 73 | path (str): Directory path. 74 | 75 | Returns: 76 | Nothing. 77 | """ 78 | 79 | if os.path.exists(path) and os.path.isdir(path): 80 | # https://docs.python.org/3/library/shutil.html#shutil.rmtree 81 | # Doesn't state which errors are possible. 82 | try: 83 | shutil.rmtree(path) 84 | except OSError as exception: 85 | raise exception 86 | 87 | 88 | def maybe_delete_checkpoints(path, delete): 89 | """Delete a TensorFlow checkpoint directory if requested and necessary. 90 | 91 | Args: 92 | path (str): 93 | Path to directory e.g. `FLAGS.train_dir`. 94 | delete (bool): 95 | Whether to delete old checkpoints or not. Should probably correspond to `FLAGS.delete`. 96 | 97 | Returns: 98 | Nothing. 99 | """ 100 | if tf.gfile.Exists(path) and delete: 101 | print('Deleting old checkpoint data from: {}'.format(path)) 102 | tf.gfile.DeleteRecursively(path) 103 | tf.gfile.MakeDirs(path) 104 | elif tf.gfile.Exists(path) and not delete: 105 | print('Found old checkpoint data at: {}'.format(path)) 106 | else: 107 | print('Starting a new training run in: {}'.format(path)) 108 | tf.gfile.MakeDirs(path) 109 | 110 | 111 | def md5(file_path): 112 | """Calculate the md5 checksum of files that do not fit in memory. 113 | 114 | Args: 115 | file_path (str): Path to file. 116 | 117 | Returns: 118 | str: md5 checksum. 119 | """ 120 | hash_md5 = hashlib.md5() 121 | with open(file_path, 'rb') as file_handle: 122 | for chunk in iter(lambda: file_handle.read(4096), b''): 123 | hash_md5.update(chunk) 124 | return hash_md5.hexdigest() 125 | 126 | 127 | def tar_extract_all(tar_path, target_path): 128 | """Extract a TAR archive. Overrides existing files. 129 | 130 | # L8ER: Deprecated: no longer needed in this project and will be removed. 131 | 132 | Args: 133 | tar_path (str): Path of TAR archive. 134 | target_path (str): Where to extract the archive. 135 | 136 | Returns: 137 | Nothing. 138 | """ 139 | assert os.path.exists(target_path) and os.path.isdir(target_path), 'target_path does not exist.' 140 | with tarfile.open(tar_path, 'r') as tar: 141 | for file_ in tar: 142 | try: 143 | tar.extract(file_, path=target_path) 144 | except IOError: 145 | os.remove(os.path.join(target_path, file_.name)) 146 | tar.extract(file_, path=target_path) 147 | finally: 148 | os.chmod(os.path.join(target_path, file_.name), file_.mode) 149 | -------------------------------------------------------------------------------- /asr/util/metrics.py: -------------------------------------------------------------------------------- 1 | """Methods the calculate cost metrics.""" 2 | 3 | import numpy as np 4 | 5 | from asr.labels import itoc 6 | from asr.params import NP_FLOAT 7 | 8 | 9 | def dense_to_text(decoded, originals): 10 | """Convert a dense, integer encoded `tf.Tensor` into a readable string. 11 | 12 | Create a summary comparing the decoded plaintext with a given original string. 13 | 14 | Args: 15 | decoded (np.ndarray): 16 | Integer array, containing the decoded sequences. 17 | originals (np.ndarray): 18 | String tensor, containing the original input string for comparision. 19 | `originals` can be an empty tensor. 20 | 21 | Returns: 22 | np.ndarray: 23 | 1D string Tensor containing only the decoded text outputs. 24 | [decoded_string_0, ..., decoded_string_N] 25 | np.ndarray: 26 | 2D string Tensor with layout: 27 | [[decoded_string_0, original_string_0], ... 28 | [decoded_string_N, original_string_N]] 29 | """ 30 | decoded_strings = [] 31 | original_strings = [] 32 | 33 | for d in decoded: 34 | decoded_strings.append(''.join([itoc(i) for i in d])) 35 | 36 | if len(originals) > 0: 37 | for o in originals: 38 | original_strings.append(''.join([c for c in o.decode('utf-8')])) 39 | else: 40 | original_strings = ['n/a'] * len(decoded_strings) 41 | 42 | decoded_strings = np.array(decoded_strings, dtype=np.object) 43 | original_strings = np.array(original_strings, dtype=np.object) 44 | 45 | summary = np.vstack([decoded_strings, original_strings]) 46 | 47 | return np.array(decoded_strings), summary 48 | 49 | 50 | # The following function has been taken from: 51 | # 52 | def wer(original, result): 53 | """Calculate the WER. 54 | 55 | The Word Error Rate (WER) is defined as the editing/Levenshtein distance 56 | on word level divided by the amount of words in the original text. 57 | In case of the original having more words (N) than the result and both 58 | being totally different (all N words resulting in 1 edit operation each), 59 | the WER will always be 1 (N / N = 1). 60 | 61 | Args: 62 | original (np.string): The original sentences. 63 | A tf.Tensor converted to `np.ndarray` object bytes by `tf.py_func`. 64 | result (np.string): The decoded sentences. 65 | A tf.Tensor converted to `np.ndarray` object bytes by `tf.py_func`. 66 | 67 | Returns: 68 | np.ndarray: Numpy array containing float scalar. 69 | """ 70 | # The WER ist calculated on word (and NOT on character) level. 71 | # Therefore we split the strings into words first: 72 | original = original.split() 73 | result = result.split() 74 | levenshtein_distance = levenshtein(original, result) / float(len(original)) 75 | return np.array(levenshtein_distance, dtype=NP_FLOAT) 76 | 77 | 78 | # The following functiom has been taken from: 79 | # 80 | def wer_batch(originals, results): 81 | """Calculate the Word Error Rate (WER) for a batch. 82 | 83 | Args: 84 | originals (np.ndarray): 2D string Tensor with the original sentences. [batch_size, 1] 85 | A tf.Tensor converted to `np.ndarray` bytes by `tf.py_func`. 86 | results (np.ndarray): 2D string Tensor with the decoded sentences. [batch_size, 1] 87 | A tf.Tensor converted to `np.ndarray` bytes by `tf.py_func`. 88 | 89 | Returns: 90 | np.ndarray: 91 | Float array containing the WER for every sample within the batch. [batch_size] 92 | np.ndarray: 93 | Float scalar with the average WER for the batch. 94 | """ 95 | count = len(originals) 96 | rates = np.array([], dtype=NP_FLOAT) 97 | mean = 0.0 98 | assert count == len(results) 99 | for i in range(count): 100 | rate = wer(originals[i], results[i]) 101 | mean = mean + rate 102 | rates = np.append(rates, rate) 103 | 104 | return rates, np.array(mean / float(count), dtype=NP_FLOAT) 105 | 106 | 107 | # The following code is from: 108 | # This is a straightforward implementation of a well-known algorithm, and thus 109 | # probably shouldn't be covered by copyright to begin with. But in case it is, 110 | # the author (Magnus Lie Hetland) has, to the extent possible under law, 111 | # dedicated all copyright and related and neighboring rights to this software 112 | # to the public domain worldwide, by distributing it under the CC0 license, 113 | # version 1.0. This software is distributed without any warranty. For more 114 | # information, see 115 | def levenshtein(a, b): 116 | """Calculate the Levenshtein distance between `a` and `b`. 117 | 118 | Args: 119 | a (str): Original word. 120 | b (str): Decoded word. 121 | 122 | Returns: 123 | float: Levenshtein distance. 124 | """ 125 | n, m = len(a), len(b) 126 | if n > m: 127 | # Make sure n <= m, to use O(min(n,m)) space 128 | a, b = b, a 129 | n, m = m, n 130 | 131 | current = list(range(n + 1)) 132 | for i in range(1, m + 1): 133 | previous, current = current, [i] + [0] * n 134 | for j in range(1, n + 1): 135 | add, delete = previous[j] + 1, current[j - 1] + 1 136 | change = previous[j - 1] 137 | if a[j - 1] != b[i - 1]: 138 | change = change + 1 139 | current[j] = min(add, delete, change) 140 | 141 | return current[n] 142 | -------------------------------------------------------------------------------- /testruns.md: -------------------------------------------------------------------------------- 1 | ## Testruns 2 | Listing of test runs and results. 3 | 4 | 5 | ### COSY (Reduced Dataset) 6 | | train_dir | Server | BS | Input | Norm. | Units | Ep | Layout | Loss | MED | WER | Notes | 7 | |-----------------------|--------|---:|---------|--------------|------:|---:|-------:|-------:|------:|-------:|----------------| 8 | | `3d1r2d_global` | cosy14 | 8 | 80 Mel | global | 2048 | 20 | 3d1r2d | 30.594 | 0.113 | 0.3195 | | 9 | | `3d1r2d_local` | cosy15 | 8 | 80 Mel | local | 2048 | 20 | 3d1r2d | 29.022 | 0.107 | 0.3086 | | 10 | | `3d1r2d_local_scalar` | cosy16 | 8 | 80 Mel | local scalar | 2048 | 20 | 3d1r2d | 31.882 | 0.114 | 0.3214 | | 11 | | `3d1r2d_none` | cosy14 | 8 | 80 Mel | none | 2048 | 20 | 3d1r2d | 29.604 | 0.112 | 0.317 | | 12 | | `3d1r2d_mfcc_local` | cosy15 | 8 | 80 MFCC | local | 2048 | 20 | 3d1r2d | 24.633 | 0.088 | 0.255 | | 13 | | `3d1r2d_local_3000u` | cosy16 | 8 | 80 Mel | local | 3000 | 20 | 3d1r2d | 34.556 | 0.102 | 0.290 | | 14 | 15 | 16 | #### Reduced Dataset 17 | Note that runs marked with *Reduced Dataset* did not use the complete dataset. 18 | * train: timit, tedlium, libri_speech, common_voice, ~~tatoeba~~ 19 | * test: libri_speech, common_voice 20 | * dev: libri_speech 21 | 22 | 23 | ### COSY 24 | | train_dir | Server | BS | Input | Norm. | Units | Ep | Layout | Loss | MED | WER | Notes | 25 | |--------------------------------|--------|---:|---------|--------------|------:|---:|-------:|-------:|------:|-------:|---------------| 26 | | `3d1r2d_global_mfcc_full` | cosy14 | 8 | 80 MFCC | global | 2048 | 20 | 3d1r2d | 25.606 | 0.106 | 0.304 | | 27 | | `3d2r2d_local_mfcc_full` | cosy15 | 8 | 80 MFCC | local | 2048 | 16 | 3d2r2d | 18.988 | 0.074 | 0.211 | | 28 | | `3d1r2d_global_mel_full` | cosy14 | 8 | 80 Mel | global | 2048 | 14 | 3d1r2d | 31.399 | 0.131 | 0.371 | | 29 | | `3d1r2d_local_mel_full` | cosy15 | 8 | 80 Mel | local | 2048 | 15 | 3d1r2d | 29.520 | 0.125 | 0.354 | | 30 | | `3d1r2d_local_scalar_mel_full` | cosy16 | 8 | 80 Mel | local scalar | 2048 | 15 | 3d1r2d | 31.669 | 0.132 | 0.373 | | 31 | | `3d1r2d_none_mel_full` | cosy17 | 8 | 80 Mel | none | 2048 | 16 | 3d1r2d | 32.006 | 0.135 | 0.376 | | 32 | | `3d1r2d_none_mfcc_full` | cosy14 | 8 | 80 MFCC | none | 2048 | 8 | 3d1r2d | 23.865 | 0.096 | 0.273 | | 33 | | `3d1r2d_none_mel_full_2` | cosy14 | 8 | 80 Mel | none | 2048 | 8 | 3d1r2d | 28.915 | 0.121 | 0.335 | For R. above. | 34 | | `3c1r2d_mel_local_full` | cosy17 | 8 | 80 Mel | local | 2048 | 8 | 3c1r2d | 22.695 | 0.091 | 0.2557 | | 35 | | `3c1r2d_mel_localscalar_full` | cosy14 | 8 | 80 Mel | local scalar | 2048 | 9 | 3c1r2d | 23.579 | 0.090 | 0.2556 | | 36 | | `3c1r2d_mel_global_full` | cosy15 | 8 | 80 Mel | global | 2048 | 9 | 3c1r2d | 24.059 | 0.094 | 0.2674 | | 37 | | `3c1r2d_mel_none_full` | cosy16 | 8 | 80 Mel | none | 2048 | 9 | 3c1r2d | 26.979 | 0.106 | 0.2919 | | 38 | | `3c1r2d_mfcc_local` | cosy15 | 8 | 80 MFCC | local | 2048 | 10 | 3c1r2d | 25.261 | 0.098 | 0.2724 | | 39 | | `3c1r2d_mfcc_localscalar` | cosy16 | 8 | 80 MFCC | local scalar | 2048 | 12 | 3c1r2d | 28.494 | 0.118 | 0.3235 | | 40 | 41 | 42 | ### FB02TIITs04; V100 32GB 43 | | train_dir | BS | Input | Norm. | Units | Ep | Layout | Loss | MED | WER | Notes | 44 | |------------------------------|---:|---------|-------|------:|---:|-------:|------:|------:|-------:|-----------------------------| 45 | | `3c1r2d_mel_local_full` | 8 | 80 Mel | local | 2048 | 20 | 3c4r2d | 25.43 | 0.083 | 0.2412 | | 46 | | `3c3r2d_mel_local` | 8 | 80 Mel | local | 2048 | 11 | 3c3r2d | 17.32 | 0.062 | 0.1762 | Stopped early. | 47 | | `3c4r2d_mel_local_full_lstm` | 8 | 80 Mel | local | 2048 | 5 | 3c4r2d | 11.849| 0.045 | 0.1264 | LSTM cells. | 48 | | `3c5r2d_mel_local_full` | 8 | 80 Mel | local | 2048 | 9 | 3c5r2d | 13.26 | 0.044 | 0.1292 | LSTM cells. Server crashed. | 49 | | `3c5r2d_mfcc_local_lstm_2` | 16 | 80 MFCC | local | 2048 | 5 | 3c5r2d | 12.06 | 0.046 | 0.1271 | LSTM cells. | 50 | | `3c4r2d_mel_local_tanh` | 16 | 80 Mel | local | 2048 | 9 | 3c4r2d | 25.58 | 0.113 | 0.308 | tanh RNN. | 51 | | `3c4r2d_mel_local` | 16 | 80 Mel | local | 2048 | x | 3c4r2d | xx.xx | 0.xxx | 0.xxx | ReLU RNN. | 52 | 53 | 54 | ### FB11-NX-T02; 2xV100 16GB 55 | | train_dir | BS | Input | Norm. | Units | Ep | Layout | Loss | MED | WER | Notes | 56 | |-------------------------------|---:|---------|-------|------:|---:|-------:|------:|------:|-------:|----------------------------| 57 | | `3c5r2d_mel_local_full_bs16` | 16 | 80 Mel | local | 2048 | 10 | 3c5r2d | 14.02 | 0.057 | 0.1583 | Stopped early. | 58 | | `3c5r2d_mfcc_local_full_bs16` | 16 | 80 MFCC | local | 2048 | 17 | 3c5r2d | 19.63 | 0.081 | 0.2207 | Tanh RNN. | 59 | | `3c4r2d_mfcc_local_bs16_relu` | 16 | 80 MFCC | local | 2048 | 16 | 3c4r2d | 20.45 | 0.081 | 0.2273 | ReLU RNN. HDD full. | 60 | | `3c4r2d_mfcc_local_bs16_relu` | 16 | 80 MFCC | local | 2048 | 15 | 3c4r2d | 20.28 | 0.082 | 0.230 | ReLU RNN. | 61 | | `3c4r2d_mel_local` | 16 | 80 Mel | local | 2048 | 9 | 3c4r2d | 17.42 | 0.068 | 0.194 | ReLU cells. For SortaGrad. | 62 | | `3c2r2d_mel_local` | 16 | 80 Mel | local | 2048 | 12 | 3c2r2d | 18.19 | 0.076 | 0.215 | ReLU cells. | 63 | 64 | 65 | ### FB11-NX-T01; 1xV100 16GB 66 | | train_dir | BS | Input | Norm. | Units | Ep | Layout | Loss | MED | WER | Notes | 67 | |--------------------------------|---:|---------|-------|------:|---:|-------:|------:|------:|-------:|---------------------------| 68 | | `3c4r2d_mfcc_local_bs16_gru` | 16 | 80 MFCC | local | 2048 | 10 | 3c4r2d | 16.78 | 0.067 | 0.1913 | GRU cells. | 69 | | `3c3r2d_mel_local_bs16_tanh` | 16 | 80 Mel | local | 2048 | 15 | 3c3r2d | 17.72 | 0.072 | 0.2059 | ReLU cells, despite name. | 70 | | `3c4r2d_mel_local_nosortagrad` | 16 | 80 Mel | local | 2048 | 15 | 3c4r2d | 17.89 | 0.070 | 0.2025 | ReLU cells. No SortaGrad. | 71 | 72 | 73 | 74 | ### Confidence Intervall Runs 75 | | Run | ED | WER | 76 | |----:|------:|------:| 77 | | 1 | 0.290 | 0.709 | 78 | | 2 | 0.301 | 0.727 | 79 | | 3 | 0.300 | 0.724 | 80 | | 4 | 0.297 | 0.724 | 81 | | 5 | 0.290 | 0.703 | 82 | | 6 | 0.301 | 0.721 | 83 | | 7 | 0.295 | 0.715 | 84 | | 8 | 0.275 | 0.733 | 85 | | 9 | 0.291 | 0.708 | 86 | | 10 | 0.291 | 0.703 | 87 | | 11 | 0.294 | 0.705 | 88 | | 12 | 0.290 | 0.699 | 89 | | 13 | 0.292 | 0.710 | 90 | | 14 | 0.294 | 0.718 | 91 | | 15 | 0.295 | 0.708 | 92 | | 16 | 0.299 | 0.714 | 93 | | 17 | 0.296 | 0.722 | 94 | | 18 | 0.296 | 0.712 | 95 | | 19 | 0.290 | 0.706 | 96 | | 20 | 0.298 | 0.718 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Speech Recognition System Using Connectionist Temporal Classification 2 | Automatic speech recognition (ASR) system implementation that utilizes the 3 | [connectionist temporal classification (CTC)](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.75.6306) 4 | cost function. 5 | It's inspired by Baidu's 6 | [Deep Speech: Scaling up end-to-end speech recognition](https://arxiv.org/abs/1412.5567) 7 | and 8 | [Deep Speech 2: End-to-End Speech Recognition in English and Mandarin](https://arxiv.org/abs/1512.02595) 9 | papers. 10 | The system is trained on a combined corpus, containing 900+ hours. 11 | It achieves a word error rate (WER) of 12.6% on the test dataset, without the use of an external 12 | language model. 13 | 14 | 15 | ## Contents 16 | 17 | 18 | * [Contents](#contents) 19 | * [Installation](#installation) 20 | * [Arch Linux](#arch-linux) 21 | * [Ubuntu](#ubuntu) 22 | * [Configuration](#configuration) 23 | * [Corpus](#corpus) 24 | * [CSV](#csv) 25 | * [Free Speech Corpora](#free-speech-corpora) 26 | * [Corpus Statistics](#corpus-statistics) 27 | * [Usage](#usage) 28 | * [Training](#training) 29 | * [Evaluation](#evaluation) 30 | * [Prediction](#prediction) 31 | 32 | 33 | 34 | ![Deep Speech 1 and 2 network architectures](images/network-architectures.png) 35 | 36 | (a) shows the Deep Speech (1) model and (b) a version of the Deep Speech 2 model architecture. 37 | 38 | 39 | ## Installation 40 | The system was tested on Arch Linux and Ubuntu 16.04, with Python version 3.5+ and the 1.12.0 41 | version of [TensorFlow](https://www.tensorflow.org/). It's highly recommended to use TensorFlow 42 | with GPU support for training. 43 | 44 | 45 | ### Arch Linux 46 | ```terminal 47 | # Install dependencies. 48 | sudo pacman -S sox python-tensorflow-opt-cuda tensorbaord 49 | 50 | # Install optional dependencies. LaTeX is only required to plot nice looking graphs. 51 | sudo pacman -S texlive-most 52 | 53 | # Clone reposetory and install Python depdendencies. 54 | git clone https://github.com/mdangschat/ctc-asr.git 55 | cd speech 56 | git checkout 57 | 58 | # Setup optional virtual environment. 59 | pip install -r requirements.txt 60 | ``` 61 | 62 | 63 | ### Ubuntu 64 | Be aware that the [`requirements.txt`](requirements.txt) file lists `tensorflow` as dependency, 65 | if you install TensorFlow through [pip](https://pypi.org/project/pip/) consider removing it as 66 | dependency and install `tensorflow-gpu` instead. 67 | It could also be worth it to [build TensorFlow from source](https://www.tensorflow.org/install/source). 68 | 69 | ```terminal 70 | # Install dependencies. 71 | sudo apt install python3-tk sox libsox-fmt-all 72 | 73 | # Install optional dependencies. LaTeX is only required to plot nice looking graphs. 74 | sudo apt install texlive 75 | 76 | # Clone reposetory and install Python depdendencies. Don't forget to use tensorflow-gpu. 77 | git clone https://github.com/mdangschat/ctc-asr.git 78 | cd speech 79 | git checkout 80 | 81 | # Setup optional virtual environment. 82 | pip3 install -r requirements.txt 83 | ``` 84 | 85 | 86 | ## Configuration 87 | The network architecture and training parameters can be configured by adding the appropriate flags 88 | or by directly editing the [`asr/params.py`](asr/params.py) configuration file. 89 | The default configuration requires quite a lot of VRAM (about 16 GB), consider reducing the number of units per 90 | layer (`num_units_dense`, `num_units_rnn`) and the amount of RNN layers (`num_layers_rnn`). 91 | 92 | 93 | ## Corpus 94 | There is list of some [free speech corpora](#free-speech-corpora) at the end of this section. 95 | However, the corpus is not part of this repository and has to be acquired by each user. 96 | For a quick start there is the [speech-corpus-dl](https://github.com/mdangschat/speech-corpus-dl) 97 | helper, that downloads a few free corpora, prepares the data and creates a merged corpus. 98 | 99 | All audio files have to be 16 kHz, mono, WAV files. 100 | For my trainings, I removed examples shorter than 0.7 and longer than 17.0 seconds. 101 | Additionally, TEDLIUM examples with labels of fewer than 5 words have also been removed. 102 | 103 | The following tree shows a possible structure for the required directories: 104 | ```terminal 105 | ./ctc-asr 106 | ├── asr 107 | ├── [...] 108 | ├── LICENSE 109 | ├── README.md 110 | ├── requirements.txt 111 | ├── testruns.md 112 | ./ctc-asr-checkpoints 113 | └── 3c2r2d-rnn 114 | ├── [...] 115 | ./speech-corpus 116 | ├── cache 117 | ├── corpus 118 | │   ├── cvv2 119 | │   ├── LibriSpeech 120 | │   ├── tatoeba_audio_eng 121 | │   └── TEDLIUM_release2 122 | ├── corpus.json 123 | ├── dev.csv 124 | ├── test.csv 125 | └── train.csv 126 | ``` 127 | Assuming that this repository is cloned into `some/folder/ctc-asr`, then by default 128 | the CSV files are expected to be in `some/folder/speech-corpus` and the audio files in 129 | `some/folder/speech-corpus/corpus`. 130 | TensorFlow checkpoints are written into `some/folder/ctc-asr-checkpoints`. 131 | Both folders (`ctc-asr-checkpoints` and `speech-corpus`) must exist, they can be changed 132 | in the [asr/params.py](asr/params.py) file. 133 | 134 | 135 | ### CSV 136 | The CSV files (e.g. train.csv) have the following format: 137 | ```csv 138 | path;label;length 139 | relative/path/to/example;lower case transcription without puntuation;3.14159265359 140 | [...] 141 | ``` 142 | Where `path` is the relative WAV path from the `DATA_DIR/corpus/` directory (String). 143 | By default, `label` is the lower case transcription without punctuation (String). 144 | Finally, `length` is the audio length in seconds (Float). 145 | 146 | 147 | ### Free Speech Corpora 148 | * [Common Voice](https://voice.mozilla.org/en/new) (v1) 149 | * [LibriSpeech ASR Corpus](http://www.openslr.org/12/) 150 | * [Tatoeba](https://tatoeba.org/eng/) 151 | * [TED-Lium](http://www.openslr.org/19/) (v2) 152 | * [TIMIT](https://catalog.ldc.upenn.edu/LDC93S1) 153 | 154 | 155 | ### Corpus Statistics 156 | ```terminal 157 | ipython python/dataset/word_counts.py 158 | Calculating statistics for /home/gpuinstall/workspace/ctc-asr/data/train.csv 159 | Word based statistics: 160 | total_words = 10,069,671 161 | number_unique_words = 81,161 162 | mean_sentence_length = 14.52 words 163 | min_sentence_length = 1 words 164 | max_sentence_length = 84 words 165 | Most common words: [('the', 551055), ('to', 306197), ('and', 272729), ('of', 243032), ('a', 223722), ('i', 192151), ('in', 149797), ('that', 146820), ('you', 144244), ('it', 118133)] 166 | 27416 words occurred only 1 time; 37,422 words occurred only 2 times; 49,939 words occurred only 5 times; 58,248 words occurred only 10 times. 167 | 168 | Character based statistics: 169 | total_characters = 52,004,043 170 | mean_label_length = 75.00 characters 171 | min_label_length = 2 characters 172 | max_label_length = 422 characters 173 | Most common characters: [(' ', 9376326), ('e', 5264177), ('t', 4205041), ('o', 3451023), ('a', 3358945), ('i', 2944773), ('n', 2858788), ('s', 2624239), ('h', 2598897), ('r', 2316473), ('d', 1791668), ('l', 1686896), ('u', 1234080), ('m', 1176076), ('w', 1052166), ('c', 999590), ('y', 974918), ('g', 888446), ('f', 851710), ('p', 710252), ('b', 646150), ('v', 421126), ('k', 387714), ('x', 62547), ('j', 61048), ('q', 34558), ('z', 26416)] 174 | Most common characters: [' ', 'e', 't', 'o', 'a', 'i', 'n', 's', 'h', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'y', 'g', 'f', 'p', 'b', 'v', 'k', 'x', 'j', 'q', 'z'] 175 | ``` 176 | 177 | 178 | ## Usage 179 | ### Training 180 | Start training by invoking `asr/train.py`. 181 | Use `asr/train.py -- --delete` to start a clean run and remove the old checkpoints. 182 | Please note that all commands are expected to be executed from the projects root folder. 183 | The additional `--` before the actual flags begin is used to indicate the end of IPython flags. 184 | 185 | The training progress can be monitored using Tensorboard. 186 | To start Tensorboard use `tensorboard --logdir `. 187 | By default it can then be accessed via [localhost:6006](http://localhost:6006). 188 | 189 | 190 | ### Evaluation 191 | Evaluate the current model by invoking `asr/evaluate.py`. 192 | Use `asr/evaluate.py -- --dev` to run on the development dataset, instead of the test set. 193 | 194 | 195 | ### Prediction 196 | To evaluate a given 16 kHz, mono WAV file use `asr/predict.py --input `. 197 | 198 | -------------------------------------------------------------------------------- /asr/params.py: -------------------------------------------------------------------------------- 1 | """Collection of hyper parameters, network layout, and reporting options.""" 2 | 3 | import os 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from asr.labels import num_classes 9 | 10 | # Path to git root. 11 | BASE_PATH = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../')) 12 | 13 | # Directories: 14 | # Note that the default `train_dir` is outside of the project directory. 15 | tf.flags.DEFINE_string('train_dir', 16 | os.path.join(BASE_PATH, '../ctc-asr-checkpoints/3c4r2d-rnn'), 17 | "Directory where to write event logs and checkpoints.") 18 | 19 | # Note that the default `corpus_dir` is outside of the project directory. 20 | __CORPUS_DIR = os.path.join(BASE_PATH, '../speech-corpus') 21 | tf.flags.DEFINE_string('corpus_dir', os.path.join(__CORPUS_DIR, 'corpus'), 22 | "Directory that holds the corpus manifest files.") 23 | tf.flags.DEFINE_string('train_csv', os.path.join(__CORPUS_DIR, 'train.csv'), 24 | "Path to the `train.txt` file.") 25 | tf.flags.DEFINE_string('test_csv', os.path.join(__CORPUS_DIR, 'test.csv'), 26 | "Path to the `test.txt` file.") 27 | tf.flags.DEFINE_string('dev_csv', os.path.join(__CORPUS_DIR, 'dev.csv'), 28 | "Path to the `dev.txt` file.") 29 | 30 | # Layer and activation options: 31 | tf.flags.DEFINE_string('used_model', 'ds2', 32 | ("Used inference model. Supported are 'ds1', and 'ds2'. " 33 | "Also see `FLAGS.feature_drop_every_second_frame`.")) 34 | 35 | tf.flags.DEFINE_integer('num_units_dense', 2048, 36 | "Number of units per dense layer.") 37 | tf.flags.DEFINE_float('relu_cutoff', 20.0, 38 | "Cutoff ReLU activations that exceed the cutoff.") 39 | 40 | tf.flags.DEFINE_multi_integer('conv_filters', [32, 32, 96], 41 | "Number of filters for each convolutional layer.") 42 | 43 | tf.flags.DEFINE_integer('num_layers_rnn', 4, 44 | "Number of stacked RNN cells.") 45 | tf.flags.DEFINE_integer('num_units_rnn', 2048, 46 | "Number of hidden units in each of the RNN cells.") 47 | # TODO: This is currently only implemented for cudnn (`FLAGS.cudnn = True`). 48 | tf.flags.DEFINE_string('rnn_cell', 'rnn_relu', 49 | "Used RNN cell type. Supported are the RNN versions 'rnn_relu' and " 50 | "'rnn_tanh', as well as the 'lstm' and 'gru' cells") 51 | 52 | # Inputs: 53 | tf.flags.DEFINE_integer('batch_size', 16, 54 | "Number of samples within a batch.") 55 | tf.flags.DEFINE_string('feature_type', 'mfcc', 56 | "Type of input features. Supported types are: 'mel' and 'mfcc'.") 57 | tf.flags.DEFINE_string('feature_normalization', 'local', 58 | ("Type of normalization applied to input features." 59 | "Supported are: 'none', 'local', and 'local_scalar'")) 60 | tf.flags.DEFINE_boolean('features_drop_every_second_frame', False, 61 | "[Deep Speech 1] like dropping of every 2nd input time frame.") 62 | 63 | # Learning Rate. 64 | tf.flags.DEFINE_integer('max_epochs', 15, 65 | "Number of epochs to run. [Deep Speech 1] uses about 20 epochs.") 66 | tf.flags.DEFINE_float('learning_rate', 1e-5, 67 | "Initial learning rate.") 68 | # TODO: The following LR flags are (currently) no longer supported. 69 | tf.flags.DEFINE_float('learning_rate_decay_factor', 4 / 5, 70 | "Learning rate decay factor.") 71 | tf.flags.DEFINE_integer('steps_per_decay', 75000, 72 | "Number of steps after which learning rate decays.") 73 | tf.flags.DEFINE_float('minimum_lr', 1e-6, 74 | "Minimum value the learning rate can decay to.") 75 | 76 | # Adam Optimizer: 77 | tf.flags.DEFINE_float('adam_beta1', 0.9, 78 | "Adam optimizer beta_1 power.") 79 | tf.flags.DEFINE_float('adam_beta2', 0.999, 80 | "Adam optimizer beta_2 power.") 81 | tf.flags.DEFINE_float('adam_epsilon', 1e-8, 82 | "Adam optimizer epsilon.") 83 | 84 | # CTC decoder: 85 | tf.flags.DEFINE_integer('beam_width', 1024, 86 | "Beam width used in the CTC `beam_search_decoder`.") 87 | 88 | # Dropout. 89 | tf.flags.DEFINE_float('conv_dropout_rate', 0.0, 90 | "Dropout rate for convolutional layers.") 91 | tf.flags.DEFINE_float('rnn_dropout_rate', 0.0, 92 | "Dropout rate for the RNN cell layers.") 93 | tf.flags.DEFINE_float('dense_dropout_rate', 0.1, 94 | "Dropout rate for dense layers.") 95 | 96 | # Corpus: 97 | tf.flags.DEFINE_integer('num_buckets', 96, 98 | "The maximum number of buckets to use for bucketing.") 99 | 100 | tf.flags.DEFINE_integer('num_classes', num_classes(), 101 | "Number of classes. Contains the additional CTC label.") 102 | tf.flags.DEFINE_integer('sampling_rate', 16000, 103 | "The sampling rate of the audio files (e.g. 2 * 8kHz).") 104 | 105 | # Performance / GPU: 106 | tf.flags.DEFINE_boolean('cudnn', True, 107 | "Whether to use Nvidia cuDNN implementations or the default TensorFlow " 108 | "implementation.") 109 | tf.flags.DEFINE_integer('shuffle_buffer_size', 2 ** 14, 110 | "Number of elements the dataset shuffle buffer should hold. " 111 | "This can consume a large amount of memory.") 112 | 113 | # Logging: 114 | tf.flags.DEFINE_integer('log_frequency', 200, 115 | "How often (every `log_frequency` steps) to log results.") 116 | tf.flags.DEFINE_integer('num_samples_to_report', 4, 117 | "The maximum number of decoded and original text samples to report in " 118 | "TensorBoard.") 119 | tf.flags.DEFINE_integer('gpu_hook_query_frequency', 5, 120 | "How often (every `gpu_hook_query_frequency` steps) statistics are " 121 | "queried from the GPUs.") 122 | tf.flags.DEFINE_integer('gpu_hook_average_queries', 100, 123 | "The number of queries to store for calculating average values.") 124 | 125 | # Miscellaneous: 126 | tf.flags.DEFINE_boolean('delete', False, 127 | "Whether to delete old checkpoints, or resume training.") 128 | tf.flags.DEFINE_integer('random_seed', 0, 129 | "TensorFlow random seed. Set to 0 to use the current timestamp instead.") 130 | tf.flags.DEFINE_boolean('log_device_placement', False, 131 | "Whether to log device placement.") 132 | tf.flags.DEFINE_boolean('allow_vram_growth', True, 133 | "Allow TensorFlow to allocate VRAM as needed, " 134 | "as opposed to allocating the whole VRAM at program start.") 135 | 136 | # ####### Changing belows parameters is likely to break something. ######### 137 | # Export names: 138 | TF_FLOAT = tf.float32 # ctc_* functions don't support float64. See #13 139 | NP_FLOAT = np.float32 # ctc_* functions don't support float64. See #13 140 | 141 | # Minimum and maximum length of examples in datasets (in seconds). 142 | MIN_EXAMPLE_LENGTH = 0.7 143 | MAX_EXAMPLE_LENGTH = 17.0 144 | 145 | # Feature extraction parameters: 146 | WIN_LENGTH = 0.025 # Window length in seconds. 147 | WIN_STEP = 0.010 # The step between successive windows in seconds. 148 | NUM_FEATURES = 80 # Number of features to extract. 149 | 150 | # CSV field names. The field order is always the same as this list from top to bottom. 151 | CSV_HEADER_PATH = 'path' 152 | CSV_HEADER_LABEL = 'label' 153 | CSV_HEADER_LENGTH = 'length' 154 | CSV_FIELDNAMES = [CSV_HEADER_PATH, CSV_HEADER_LABEL, CSV_HEADER_LENGTH] 155 | CSV_DELIMITER = ';' 156 | 157 | FLAGS = tf.flags.FLAGS 158 | 159 | 160 | def get_parameters(): 161 | """ 162 | Generate a summary containing the training, and network parameters. 163 | 164 | Returns: 165 | str: Summary of training parameters. 166 | """ 167 | res = '\n\tLearning Rate (lr={}, steps_per_decay={:,d}, decay_factor={});\n' \ 168 | '\tGPU-Options (cudnn={});\n' \ 169 | '\tModel (used_model={}, beam_width={:,d})\n' \ 170 | '\tConv (conv_filters={}); Dense (num_units={:,d});\n' \ 171 | '\tRNN (num_units={:,d}, num_layers={:,d});\n' \ 172 | '\tTraining (batch_size={:,d}, max_epochs={:,d}, log_frequency={:,d});\n' \ 173 | '\tFeatures (type={}, normalization={}, skip_every_2nd_frame={});' 174 | 175 | res = res.format(FLAGS.learning_rate, FLAGS.steps_per_decay, FLAGS.learning_rate_decay_factor, 176 | FLAGS.cudnn, 177 | FLAGS.used_model, FLAGS.beam_width, 178 | FLAGS.conv_filters, FLAGS.num_units_dense, 179 | FLAGS.num_units_rnn, FLAGS.num_layers_rnn, 180 | FLAGS.batch_size, FLAGS.max_epochs, FLAGS.log_frequency, 181 | FLAGS.feature_type, FLAGS.feature_normalization, 182 | FLAGS.features_drop_every_second_frame) 183 | 184 | return res 185 | -------------------------------------------------------------------------------- /asr/util/tf_contrib.py: -------------------------------------------------------------------------------- 1 | """Utility and helper methods for TensorFlow speech learning.""" 2 | 3 | import tensorflow as tf 4 | 5 | from asr.params import FLAGS, TF_FLOAT 6 | 7 | 8 | class AdamOptimizerLogger(tf.train.AdamOptimizer): 9 | """Modified `AdamOptimizer`_ that logs it's learning rate and step. 10 | 11 | .. _AdamOptimizer: 12 | https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 13 | """ 14 | 15 | def _apply_dense(self, grad, var): 16 | m = self.get_slot(var, 'm') 17 | v = self.get_slot(var, 'v') 18 | beta1_power, beta2_power = self._get_beta_accumulators() 19 | 20 | m_hat = m / (1.0 - beta1_power) 21 | v_hat = v / (1.0 - beta2_power) 22 | 23 | step = m_hat / (v_hat ** 0.5 + self._epsilon_t) 24 | 25 | # Use a histogram summary to monitor it during training. 26 | tf.summary.histogram('step', step) 27 | 28 | current_lr = self._lr_t * tf.sqrt(1.0 - beta2_power) / (1.0 - beta1_power) 29 | tf.summary.scalar('estimated_lr', current_lr) 30 | 31 | return super(AdamOptimizerLogger, self)._apply_dense(grad, var) 32 | 33 | 34 | def dense_layers(sequences, training, regularizer, initializer, 35 | num_layers=3, activation=tf.nn.relu): 36 | """Create a chain of dense (fully-connected) neural network layers. 37 | 38 | Args: 39 | sequences (tf.Tensor): Input sequences. 40 | training (bool): Whether the mode is training or not. 41 | regularizer: TF weight reqularizer. 42 | initializer: TF weight initializer. 43 | num_layers (int): 44 | activation (function): TF activation function. 45 | 46 | Returns: 47 | tf.Tensor: Output tensor. 48 | """ 49 | 50 | with tf.variable_scope('dense'): 51 | output = sequences 52 | for _ in range(num_layers): 53 | output = tf.layers.dense(output, FLAGS.num_units_dense, 54 | activation=activation, 55 | kernel_initializer=initializer, 56 | kernel_regularizer=regularizer) 57 | output = tf.minimum(output, FLAGS.relu_cutoff) 58 | output = tf.layers.dropout(output, rate=FLAGS.dense_dropout_rate, training=training) 59 | # output = [batch_size, time, num_units_dense] 60 | 61 | return output 62 | 63 | 64 | def conv_layers(sequences, 65 | filters=FLAGS.conv_filters, 66 | kernel_sizes=((11, 41), (11, 21), (11, 21)), 67 | strides=((2, 2), (1, 2), (1, 2)), 68 | kernel_initializer=tf.glorot_normal_initializer(), 69 | kernel_regularizer=None, 70 | training=True): 71 | """Add 2D convolutional layers to the network's graph. New sequence length are being calculated. 72 | 73 | Convolutional layer output shapes: 74 | Conv 'VALID' output width (W) is calculated by: 75 | W = (W_i - K_w) // S_w + 1 76 | Conv 'SAME' output width (W) is calculated by: 77 | W = (W_i - K_w + 2*(K_w//2)) // S_w + 1 78 | Where W_i is the input width, K_w the kernel width, and S_w the stride width. 79 | Height (H) is calculated analog to width (W). 80 | 81 | For the default setup, the convolutional layers reduce `output` size to: 82 | conv1 = [batch_size, W, H, NUM_CHANNELS] = [batch_size, ~time / 2, 40, NUM_FILTERS] 83 | conv2 = [batch_size, W, H, NUM_CHANNELS] = [batch_size, ~time, 20, NUM_FILTERS] 84 | conv3 = [batch_size, W, H, NUM_CHANNELS] = [batch_size, ~time, 10, NUM_FILTERS] 85 | 86 | This values are reshaped to input for a following RNN layer by the following metric: 87 | [batch_size, time, 10 * NUM_FILTERS] 88 | where 10 is the number of frequencies left over from convolutions. 89 | 90 | Args: 91 | sequences (tf.Tensor): 92 | The input sequences. 93 | filters (Tuple[int]): 94 | Tuple of number of filters per convolutional layers. 95 | kernel_sizes (Tuple[Tuple[int, int]]): 96 | Tuple of tuples of height and width values. One tuple per convolutional layer. 97 | strides (Tuple[Tuple[int, int]]): 98 | Tuple of tuples of x and y stride values. One tuple per convolutional layer. 99 | kernel_initializer (tf.Tensor): 100 | TensorFlow kernel initializer. 101 | kernel_regularizer (tf.Tensor): 102 | TensorFlow kernel regularizer. 103 | training (bool): 104 | `FLAGS.conv_dropout_rate` is being applied during training only. 105 | 106 | Returns: 107 | tf.Tensor: `output` 108 | Convolutional layers output. 109 | tf.Tensor: `seq_length` 110 | Sequence length of the batch elements. Note that the shortest samples within a 111 | batch are stretched to the convolutional length of the longest one. 112 | 113 | .. _`conv2d`: 114 | https://www.tensorflow.org/api_docs/python/tf/layers/conv2d 115 | """ 116 | 117 | if not (len(filters) == len(kernel_sizes) == len(strides)): 118 | raise ValueError('conv_layers(): Arguments filters, kernel_size, and strides must contain ' 119 | 'the same number of elements.') 120 | 121 | output = sequences 122 | for tmp in zip(filters, kernel_sizes, strides): 123 | _filter, kernel_size, stride = tmp 124 | 125 | output = tf.layers.conv2d(inputs=output, 126 | filters=_filter, 127 | kernel_size=kernel_size, 128 | strides=stride, 129 | padding='SAME', 130 | activation=tf.nn.relu, 131 | kernel_initializer=kernel_initializer, 132 | kernel_regularizer=kernel_regularizer) 133 | 134 | output = tf.minimum(output, FLAGS.relu_cutoff) 135 | output = tf.layers.dropout(output, rate=FLAGS.conv_dropout_rate, training=training) 136 | 137 | # Reshape to: conv3 = [batch_size, time, 10 * NUM_FILTERS], where 10 is the number of 138 | # frequencies left over from convolutions. 139 | output = tf.reshape(output, [tf.shape(output)[0], -1, 10 * filters[-1]]) 140 | 141 | # Update seq_length to convolutions. shape[1] = time steps; shape[0] = batch_size 142 | # Note that the shortest samples within a batch are stretched to the convolutional 143 | # length of the longest one. 144 | seq_length = tf.tile([tf.shape(output)[1]], [tf.shape(output)[0]]) 145 | 146 | return output, seq_length 147 | 148 | 149 | def bidirectional_cells(num_units, num_layers, dropout=1.0): 150 | """Create two lists of forward and backward cells that can be used to build a BDLSTM stack. 151 | 152 | Args: 153 | num_units (int): Number of units within the RNN cell. 154 | num_layers (int): Amount of cells to create for each list. 155 | dropout (float): Probability [0, 1] to drop an output. If it's constant 0 156 | no outputs will be dropped. 157 | 158 | Returns: 159 | [tf.nn.rnn_cell.LSTMCell]: List of forward cells. 160 | [tf.nn.rnn_cell.LSTMCell]: List of backward cells. 161 | """ 162 | keep_prob = min(1.0, max(0.0, 1.0 - dropout)) 163 | 164 | _fw_cells = [create_cell(num_units, keep_prob=keep_prob) for _ in range(num_layers)] 165 | _bw_cells = [create_cell(num_units, keep_prob=keep_prob) for _ in range(num_layers)] 166 | return _fw_cells, _bw_cells 167 | 168 | 169 | def create_cell(num_units, keep_prob=1.0): 170 | """Create a RNN cell with added dropout wrapper. 171 | 172 | Args: 173 | num_units (int): Number of units within the RNN cell. 174 | keep_prob (float): Probability [0, 1] to keep an output. It it's constant 1 175 | no outputs will be dropped. 176 | 177 | Returns: 178 | tf.nn.rnn_cell.LSTMCell: RNN cell with dropout wrapper. 179 | """ 180 | # Can be: `tf.nn.rnn_cell.RNNCell`, `tf.nn.rnn_cell.GRUCell`, `tf.nn.rnn_cell.LSTMCell`. 181 | 182 | # https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/LSTMCell 183 | # cell = tf.nn.rnn_cell.LSTMCell(num_units=num_units, use_peepholes=True) 184 | 185 | # https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/GRUCell 186 | # cell = tf.nn.rnn_cell.GRUCell(num_units=num_units) 187 | 188 | # https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicRNNCell 189 | cell = tf.nn.rnn_cell.BasicRNNCell(num_units=num_units, activation=tf.nn.tanh) 190 | 191 | return tf.nn.rnn_cell.DropoutWrapper(cell, 192 | input_keep_prob=keep_prob, 193 | output_keep_prob=keep_prob, 194 | seed=FLAGS.random_seed) 195 | 196 | 197 | def variable_on_cpu(name, shape, initializer): 198 | """Helper to create a variable stored on CPU memory. 199 | 200 | Args: 201 | name (str): Name of the variable. 202 | shape (list of int): List of ints, e.g. a numpy shape. 203 | initializer: Initializer for the variable. 204 | 205 | Returns: 206 | tf.Tensor: Variable tensor. 207 | """ 208 | with tf.device('/cpu:0'): 209 | return tf.get_variable(name, shape, initializer=initializer, dtype=TF_FLOAT) 210 | 211 | 212 | def variable_with_weight_decay(name, shape, stddev, weight_decay): 213 | """Helper to create an initialized variable with weight decay. 214 | 215 | Note that the variable is initialized with a truncated normal distribution. 216 | A weight decay is added only if one is specified. 217 | 218 | Args: 219 | name (str): Name of the variable. 220 | shape (list of int): List of ints, e.g. a numpy shape. 221 | stddev (float): Standard deviation of the Gaussian. 222 | weight_decay: Add L2Loss weight decay multiplied by this float. 223 | If None, weight decay is not added for this variable. 224 | 225 | Returns: 226 | tf.Tensor: Variable tensor. 227 | """ 228 | initializer = tf.truncated_normal_initializer(stddev=stddev, dtype=TF_FLOAT) 229 | var = variable_on_cpu(name, shape, initializer=initializer) 230 | 231 | if weight_decay is not None: 232 | weight_decay = tf.multiply(tf.nn.l2_loss(var), weight_decay, name='weight_loss') 233 | tf.add_to_collection('losses', weight_decay) 234 | 235 | return var 236 | -------------------------------------------------------------------------------- /asr/input_functions.py: -------------------------------------------------------------------------------- 1 | """Routines to load a corpus and perform the necessary pre processing on the audio files and labels. 2 | 3 | This contains helper methods to load audio files, too. 4 | """ 5 | 6 | import csv 7 | import os 8 | import random 9 | 10 | import numpy as np 11 | import python_speech_features as psf 12 | import tensorflow as tf 13 | from scipy.io import wavfile 14 | 15 | from asr.labels import ctoi 16 | from asr.params import CSV_DELIMITER, CSV_FIELDNAMES, CSV_HEADER_LABEL, CSV_HEADER_PATH 17 | from asr.params import NP_FLOAT, WIN_LENGTH, WIN_STEP, NUM_FEATURES, FLAGS 18 | from asr.util.csv_helper import get_bucket_boundaries 19 | 20 | 21 | def input_fn_generator(target): 22 | """Generate the `input_fn` for the TensorFlow estimator. 23 | 24 | Args: 25 | target (str): The type of input, this affects the used CSV file, batching method and epochs. 26 | Supported targets are: 27 | * 'train_bucket': Creates 1 epoch of training data, using bucketing. 28 | Examples are shuffled. 29 | * 'train_batch': Creates 1 epoch of training data, using batches. 30 | Examples are in the order of the `train.csv` file. 31 | * 'dev': Creates 1 epoch of evaluation data from the `dev.csv` file. 32 | Uses buckets. Examples are shuffled. 33 | * 'test': Creates 1 epoch of evaluation data from the `test.csv` file. 34 | Uses buckets. Examples are shuffled. 35 | 36 | Returns: 37 | function: Input function pointer. 38 | """ 39 | if target == 'train_bucket': 40 | csv_path = FLAGS.train_csv 41 | use_buckets = True 42 | epochs = 1 43 | elif target == 'train_batch': 44 | csv_path = FLAGS.train_csv 45 | use_buckets = False 46 | epochs = 1 47 | elif target == 'dev': 48 | csv_path = FLAGS.dev_csv 49 | use_buckets = True 50 | epochs = 1 51 | elif target == 'test': 52 | csv_path = FLAGS.test_csv 53 | use_buckets = True 54 | epochs = 1 55 | else: 56 | raise ValueError('Invalid target: "{}"'.format(target)) 57 | 58 | # Read bucket boundaries from CSV file. 59 | if use_buckets: 60 | bucket_boundaries = get_bucket_boundaries(csv_path, FLAGS.num_buckets) 61 | tf.logging.info('Using {} buckets for the {} set.'.format(len(bucket_boundaries), target)) 62 | 63 | def input_fn(): 64 | # L8ER: Try out the following two (not working as of TF v1.12): 65 | # https://www.tensorflow.org/api_docs/python/tf/data/experimental/latency_stats 66 | # https://www.tensorflow.org/api_docs/python/tf/data/experimental/StatsAggregator 67 | 68 | def element_length_fn(_spectrogram, _spectrogram_length, _label_encoded, _label_plaintext): 69 | del _spectrogram 70 | del _label_encoded 71 | del _label_plaintext 72 | return _spectrogram_length 73 | 74 | assert os.path.exists(csv_path) and os.path.isfile(csv_path) 75 | 76 | with tf.device('/cpu:0'): 77 | dataset = tf.data.Dataset.from_generator( 78 | __input_generator, 79 | (tf.float32, tf.int32, tf.int32, tf.string), 80 | (tf.TensorShape([None, 80]), tf.TensorShape([]), 81 | tf.TensorShape([None]), tf.TensorShape([])), 82 | args=[csv_path, use_buckets]) 83 | 84 | if use_buckets: 85 | # Set shuffle buffer to an arbitrary size to ensure good enough shuffling. 86 | # At the moment, most shuffling is done by the `__input_generator` function. 87 | # Also see: https://stackoverflow.com/a/47025850/2785397 88 | dataset = dataset.shuffle(FLAGS.shuffle_buffer_size) 89 | 90 | dataset = dataset.apply( 91 | tf.data.experimental.bucket_by_sequence_length( 92 | element_length_func=element_length_fn, 93 | bucket_boundaries=bucket_boundaries, 94 | bucket_batch_sizes=[FLAGS.batch_size] * (len(bucket_boundaries) + 1), 95 | pad_to_bucket_boundary=False, # False => pad to longest example in batch 96 | no_padding=False 97 | ) 98 | ) 99 | 100 | else: 101 | dataset = dataset.padded_batch(batch_size=FLAGS.batch_size, 102 | padded_shapes=([None, 80], [], [None], []), 103 | drop_remainder=True) 104 | 105 | # dataset.cache() 106 | dataset = dataset.prefetch(64) 107 | 108 | # Number of epochs. 109 | dataset = dataset.repeat(epochs) 110 | 111 | iterator = dataset.make_one_shot_iterator() 112 | spectrogram, spectrogram_length, label_encoded, label_plaintext = iterator.get_next() 113 | 114 | features = { 115 | 'spectrogram': spectrogram, 116 | 'spectrogram_length': spectrogram_length, 117 | 'label_plaintext': label_plaintext 118 | } 119 | 120 | return features, label_encoded 121 | 122 | return input_fn 123 | 124 | 125 | def __input_generator(*args): 126 | assert len(args) == 2, '__input_generator() arguments are a path and shuffle boolean.' 127 | assert isinstance(args[0], bytes) 128 | assert isinstance(args[1], np.bool_) 129 | csv_path = str(args[0], 'utf-8') 130 | shuffle = bool(args[1]) 131 | 132 | with open(csv_path, 'r', encoding='utf-8') as file_handle: 133 | reader = csv.DictReader(file_handle, delimiter=CSV_DELIMITER, fieldnames=CSV_FIELDNAMES) 134 | lines = list(reader)[1: -1] # Remove CSV header and final blank line. 135 | 136 | # Shuffle the CSV lines. 137 | if shuffle: 138 | random.shuffle(lines) 139 | 140 | # Read the CSV lines and extract spectrogram and label for each line. 141 | for line in lines: 142 | path = line[CSV_HEADER_PATH] 143 | label = line[CSV_HEADER_LABEL] 144 | 145 | path = os.path.join(FLAGS.corpus_dir, path) 146 | 147 | # Convert the WAV file into 148 | spectrogram, spectrogram_length = load_sample(path) 149 | 150 | # Convert character sequence label to integer sequence. 151 | label_encoded = [ctoi(c) for c in label] 152 | 153 | yield spectrogram, spectrogram_length, label_encoded, label 154 | 155 | 156 | def load_sample(file_path, feature_type=None, feature_normalization=None): 157 | """Loads the wave file and converts it into feature vectors. 158 | 159 | Args: 160 | file_path (str or bytes): 161 | A TensorFlow queue of file names to read from. 162 | `tf.py_func` converts the provided Tensor into `np.ndarray`s bytes. 163 | 164 | feature_type (str): Optional. 165 | If `None` is provided, use `FLAGS.feature_type`. 166 | Type of features to generate. Options are 'mel' and 'mfcc'. 167 | 168 | feature_normalization (str): Optional. 169 | If `None` is provided, use `FLAGS.feature_normalization`. 170 | 171 | Whether to normalize the generated features with the stated method or not. 172 | Please consult `sample_normalization` for a complete list of normalization methods. 173 | 174 | 'local': Use local (in sample) mean and standard deviation values, and apply the 175 | normalization element wise, like in `global`. 176 | 177 | 'local_scalar': Uses only the mean and standard deviation of the current sample. 178 | The normalization is being applied by ([sample] - mean_scalar) / std_scalar 179 | 180 | 'none': No normalization is being applied. 181 | 182 | Returns: 183 | Tuple[np.ndarray. np.ndarray]: 184 | 2D array with [time, num_features] shape, containing `NP_FLOAT`. 185 | 186 | Array containing a single int32. 187 | """ 188 | __supported_feature_types = ['mel', 'mfcc'] 189 | __supported_feature_normalizations = ['none', 'local', 'local_scalar'] 190 | 191 | feature_type = feature_type if feature_type is not None else FLAGS.feature_type 192 | feature_normalization = feature_normalization if feature_normalization is not None \ 193 | else FLAGS.feature_normalization 194 | 195 | if feature_type not in __supported_feature_types: 196 | raise ValueError('Requested feature type of {} isn\'t supported.' 197 | .format(feature_type)) 198 | 199 | if feature_normalization not in __supported_feature_normalizations: 200 | raise ValueError('Requested feature normalization method {} is invalid.' 201 | .format(feature_normalization)) 202 | 203 | if type(file_path) is not str: 204 | file_path = str(file_path, 'utf-8') 205 | 206 | if not os.path.isfile(file_path): 207 | raise ValueError('"{}" does not exist.'.format(file_path)) 208 | 209 | # Load the audio files sample rate and data. 210 | (sampling_rate, audio_data) = wavfile.read(file_path) 211 | 212 | if len(audio_data) < 401: 213 | raise RuntimeError('Sample length {:,d} to short: {}'.format(len(audio_data), file_path)) 214 | 215 | if not sampling_rate == FLAGS.sampling_rate: 216 | raise RuntimeError('Sampling rate is {:,d}, expected {:,d}.' 217 | .format(sampling_rate, FLAGS.sampling_rate)) 218 | 219 | # At 16000 Hz, 512 samples ~= 32ms. At 16000 Hz, 200 samples = 12ms. 16 samples = 1ms @ 16kHz. 220 | f_max = sampling_rate / 2. # Maximum frequency (Nyquist rate). 221 | f_min = 64. # Minimum frequency. 222 | n_fft = 1024 # Number of samples in a frame. 223 | 224 | if feature_type == 'mfcc': 225 | sample = __mfcc( 226 | audio_data, sampling_rate, WIN_LENGTH, WIN_STEP, NUM_FEATURES, n_fft, f_min, f_max 227 | ) 228 | elif feature_type == 'mel': 229 | sample = __mel( 230 | audio_data, sampling_rate, WIN_LENGTH, WIN_STEP, NUM_FEATURES, n_fft, f_min, f_max 231 | ) 232 | else: 233 | raise ValueError('Unsupported feature type') 234 | 235 | # Make sure that data type matches TensorFlow type. 236 | sample = sample.astype(NP_FLOAT) 237 | 238 | # Drop every 2nd time frame, if requested. 239 | if FLAGS.features_drop_every_second_frame: 240 | # [time, NUM_FEATURES] => [time // 2, NUM_FEATURES] 241 | sample = sample[:: 2, :] 242 | 243 | # Get length of the sample. 244 | sample_len = np.array(sample.shape[0], dtype=np.int32) 245 | 246 | # Apply feature normalization. 247 | sample = __feature_normalization(sample, feature_normalization) 248 | 249 | # sample = [time, NUM_FEATURES], sample_len: scalar 250 | return sample, sample_len 251 | 252 | 253 | def __mfcc(audio_data, sampling_rate, win_len, win_step, num_features, n_fft, f_min, f_max): 254 | """Convert a wav signal into Mel Frequency Cepstral Coefficients (MFCC). 255 | 256 | Args: 257 | audio_data (np.ndarray): Wav signal. 258 | sampling_rate (int): Sampling rate. 259 | win_len (float): Window length in seconds. 260 | win_step (float): Window stride in seconds. 261 | num_features (int): Number of features to generate. 262 | n_fft (int): Number of Fast Fourier Transforms. 263 | f_min (float): Minimum frequency to consider. 264 | f_max (float): Maximum frequency to consider. 265 | 266 | Returns: 267 | np.ndarray: MFCC feature vectors. Shape: [time, num_features] 268 | """ 269 | if num_features % 2 != 0: 270 | raise ValueError('num_features is not a multiple of 2.') 271 | 272 | # Compute MFCC features. 273 | mfcc = psf.mfcc(signal=audio_data, samplerate=sampling_rate, winlen=win_len, winstep=win_step, 274 | numcep=num_features // 2, nfilt=num_features, nfft=n_fft, 275 | lowfreq=f_min, highfreq=f_max, 276 | preemph=0.97, ceplifter=22, appendEnergy=True) 277 | 278 | # And the first-order differences (delta features). 279 | mfcc_delta = psf.delta(mfcc, 2) 280 | 281 | # Combine MFCC with MFCC_delta 282 | return np.concatenate([mfcc, mfcc_delta], axis=1) 283 | 284 | 285 | def __mel(audio_data, sampling_rate, win_len, win_step, num_features, n_fft, f_min, f_max): 286 | """Convert a wav signal into a logarithmically scaled mel filterbank. 287 | 288 | Args: 289 | audio_data (np.ndarray): Wav signal. 290 | sampling_rate (int): Sampling rate. 291 | win_len (float): Window length in seconds. 292 | win_step (float): Window stride in seconds. 293 | num_features (int): Number of features to generate. 294 | n_fft (int): Number of Fast Fourier Transforms. 295 | f_min (float): Minimum frequency to consider. 296 | f_max (float): Maximum frequency to consider. 297 | 298 | Returns: 299 | np.ndarray: Mel-filterbank. Shape: [time, num_features] 300 | """ 301 | mel = psf.logfbank(signal=audio_data, samplerate=sampling_rate, winlen=win_len, 302 | winstep=win_step, nfilt=num_features, nfft=n_fft, 303 | lowfreq=f_min, highfreq=f_max, preemph=0.97) 304 | return mel 305 | 306 | 307 | def __feature_normalization(features, method): 308 | """Normalize the given feature vector `y`, with the stated normalization `method`. 309 | 310 | Args: 311 | features (np.ndarray): 312 | The signal array 313 | 314 | method (str): 315 | Normalization method: 316 | 317 | 'local': Use local (in sample) mean and standard deviation values, and apply the 318 | normalization element wise, like in `global`. 319 | 320 | 'local_scalar': Uses only the mean and standard deviation of the current sample. 321 | The normalization is being applied by ([sample] - mean_scalar) / std_scalar 322 | 323 | 'none': No normalization is being applied. 324 | 325 | Returns: 326 | np.ndarray: The normalized feature vector. 327 | """ 328 | if method == 'none': 329 | return features 330 | if method == 'local': 331 | return (features - np.mean(features, axis=0)) / np.std(features, axis=0) 332 | if method == 'local_scalar': 333 | # Option 'local' uses scalar values. 334 | return (features - np.mean(features)) / np.std(features) 335 | raise ValueError('Invalid normalization method.') 336 | 337 | 338 | # Create a dataset for testing purposes. 339 | if __name__ == '__main__': 340 | __NEXT_ELEMENT = input_fn_generator('train_bucket') 341 | 342 | with tf.Session() as session: 343 | # for example in range(FLAGS.num_examples_train): 344 | for example in range(5): 345 | print('Dataset elements:', session.run(__NEXT_ELEMENT)) 346 | 347 | print('The End.') 348 | -------------------------------------------------------------------------------- /asr/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the ASR system's model definition. 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow.contrib as tfc 8 | 9 | from asr.params import FLAGS, TF_FLOAT 10 | from asr.util import tf_contrib, metrics 11 | from asr.util.hooks import GPUStatisticsHook 12 | 13 | 14 | class CTCModel: 15 | """Container class for the ASR system's TensorFlow model.""" 16 | 17 | def __init__(self): 18 | # Initialize attributes. 19 | self.loss_op = None 20 | self.train_op = None 21 | self.hooks = None 22 | 23 | def model_fn(self, features, labels, mode): 24 | """Create model graph and return a configured estimator. 25 | 26 | Model function that constructs the model's graph and returns the configured 27 | `tf.estimator.EstimatorSpec`. 28 | 29 | This method sets the `self.loss_op` and `self.train_op` variables. 30 | 31 | Args: 32 | features (tf.Tensor or Dict[tf.Tensor]): 33 | This is the first item returned from the `input_fn` passed to `train`, `evaluate`, 34 | and `predict`. This should be a single `tf.Tensor` or `dict` of same. 35 | 36 | labels: 37 | This is the second item returned from the `input_fn` passed to `train`, `evaluate`, 38 | and `predict`. This should be a single `tf.Tensor` or `dict` of same 39 | (for multi-head models). If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` 40 | will be passed. If the `model_fn`'s signature does not accept `mode`, the `model_fn` 41 | must still be able to handle `labels=None`. 42 | 43 | mode (tf.estimator.ModeKeys): Optional. 44 | Specifies if this training, evaluation or prediction. See `tf.estimator.ModeKeys`. 45 | 46 | Returns: 47 | `tf.estimator.EstimatorSpec` 48 | """ 49 | spectrogram_length = features['spectrogram_length'] 50 | spectrogram = features['spectrogram'] 51 | 52 | # Create the inference graph. 53 | logits, seq_length = self.inference_fn( 54 | spectrogram, spectrogram_length, training=(mode == tf.estimator.ModeKeys.TRAIN)) 55 | 56 | if mode == tf.estimator.ModeKeys.PREDICT: 57 | # CTC decode. 58 | decoded, plaintext, _ = self.decode_fn(logits, seq_length, None) 59 | 60 | prediction = { 61 | 'decoded': tf.sparse.to_dense(decoded), 62 | 'plaintext': plaintext 63 | } 64 | 65 | return tf.estimator.EstimatorSpec(mode=mode, predictions=prediction) 66 | 67 | tf.summary.scalar('spectrogram_length', spectrogram_length[0]) 68 | tf.summary.image('spectrogram', tf.expand_dims(spectrogram, 3)) 69 | 70 | # Convert dense labels tensor into sparse tensor. 71 | labels = tfc.layers.dense_to_sparse(labels) 72 | 73 | # CTC loss operator. 74 | self.loss_op = self.loss_fn(logits, seq_length, labels) 75 | 76 | # During training. 77 | if mode == tf.estimator.ModeKeys.TRAIN: 78 | # Set up the optimizer for training. 79 | global_step = tf.train.get_global_step() 80 | optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, 81 | beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, 82 | epsilon=FLAGS.adam_epsilon) 83 | self.train_op = optimizer.minimize(loss=self.loss_op, global_step=global_step) 84 | 85 | # Add various hooks. 86 | self.hooks_fn() 87 | 88 | # Code for training and evaluation. 89 | label_plaintext = features['label_plaintext'] 90 | 91 | # CTC decode. 92 | decoded, plaintext, plaintext_summary = self.decode_fn(logits, 93 | seq_length, 94 | label_plaintext) 95 | 96 | tf.summary.text('decoded_text', plaintext_summary[:, : FLAGS.num_samples_to_report]) 97 | 98 | # Error metrics for decoded text. 99 | _, mean_ed, _, wer = self.error_rates_fn(labels, label_plaintext, decoded, plaintext) 100 | 101 | tf.summary.scalar('mean_edit_distance', mean_ed, family='Metrics') 102 | tf.summary.scalar('word_error_rate', wer, family='Metrics') 103 | 104 | if mode == tf.estimator.ModeKeys.TRAIN: 105 | return tf.estimator.EstimatorSpec(mode=mode, 106 | loss=self.loss_op, 107 | train_op=self.train_op, 108 | training_hooks=self.hooks) 109 | 110 | # During evaluation. 111 | if mode == tf.estimator.ModeKeys.EVAL: 112 | eval_metrics_ops = { 113 | 'mean_edit_distance': tf.metrics.mean(mean_ed, name='mean_edit_distance'), 114 | 'word_error_rate': tf.metrics.mean(wer, name='word_error_rate') 115 | } 116 | 117 | return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss_op, 118 | eval_metric_ops=eval_metrics_ops) 119 | 120 | # This should never be reached. 121 | raise RuntimeError('Invalid mode.') 122 | 123 | @staticmethod 124 | def inference_fn(sequences, seq_length, training=True): 125 | """Build a TensorFlow inference graph according to the selected model in `FLAGS.used_model`. 126 | 127 | Supports the default [Deep Speech 1] model ('ds1') and an [Deep Speech 2] inspired 128 | implementation ('ds2'). 129 | 130 | Args: 131 | sequences (tf.Tensor): 132 | 3D float Tensor with input sequences. [batch_size, time, NUM_INPUTS] 133 | 134 | seq_length (tf.Tensor): 135 | 1D int Tensor with sequence length. [batch_size] 136 | 137 | training (bool): 138 | If `True` apply dropout else if `False` the data is passed through unaltered. 139 | 140 | Returns: 141 | tf.Tensor: `logits` 142 | Softmax layer (logits) pre activation function, i.e. layer(X*W + b) 143 | tf.Tensor: `seq_length` 144 | 1D Tensor containing approximated sequence lengths. 145 | """ 146 | initializer = tf.truncated_normal_initializer(stddev=0.046875, dtype=TF_FLOAT) 147 | regularizer = tfc.layers.l2_regularizer(0.0046875) 148 | 149 | if FLAGS.used_model == 'ds1': 150 | # Dense input layers. 151 | output3 = tf_contrib.dense_layers(sequences, training, regularizer, initializer) 152 | # output3 = [batch_size, time, num_units_dense] 153 | 154 | elif FLAGS.used_model == 'ds2': 155 | # 2D convolutional input layers. 156 | with tf.variable_scope('conv'): 157 | # sequences = [batch_size, time, NUM_INPUTS] => [batch_size, time, NUM_INPUTS, 1] 158 | sequences = tf.expand_dims(sequences, 3) 159 | 160 | # Apply convolutions. 161 | output3, seq_length = tf_contrib.conv_layers(sequences) 162 | else: 163 | raise ValueError('Unsupported model "{}" in flags.'.format(FLAGS.used_model)) 164 | 165 | # RNN layers. 166 | with tf.variable_scope('rnn'): 167 | rnn_dropout_rate = FLAGS.rnn_dropout_rate if training else 0.0 168 | 169 | if not FLAGS.cudnn: # Use TensorFlow RNNs. 170 | # Create a stack of RNN cells. 171 | fw_cells, bw_cells = tf_contrib.bidirectional_cells(FLAGS.num_units_rnn, 172 | FLAGS.num_layers_rnn, 173 | dropout=rnn_dropout_rate) 174 | 175 | # https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/stack_bidirectional_dynamic_rnn 176 | output_rnn, _, _ = tfc.rnn.stack_bidirectional_dynamic_rnn( 177 | fw_cells, bw_cells, 178 | inputs=output3, 179 | dtype=TF_FLOAT, 180 | sequence_length=seq_length, 181 | parallel_iterations=64, 182 | time_major=False 183 | ) 184 | # output_rnn = [batch_size, time, num_units_rnn * 2] 185 | 186 | else: # FLAGS.cudnn Use cuDNN RNNs. 187 | # cuDNN RNNs only support time major inputs. 188 | output3 = tfc.rnn.transpose_batch_time(output3) 189 | 190 | # https://www.tensorflow.org/api_docs/python/tf/contrib/cudnn_rnn/CudnnRNNRelu 191 | # https://www.tensorflow.org/api_docs/python/tf/contrib/cudnn_rnn/CudnnRNNTanh 192 | # https://www.tensorflow.org/api_docs/python/tf/contrib/cudnn_rnn/CudnnLSTM 193 | # https://www.tensorflow.org/api_docs/python/tf/contrib/cudnn_rnn/CudnnGRU 194 | supported_rnns = { 195 | 'rnn_relu': tfc.cudnn_rnn.CudnnRNNRelu, 196 | 'rnn_tanh': tfc.cudnn_rnn.CudnnRNNTanh, 197 | 'gru': tfc.cudnn_rnn.CudnnGRU, 198 | 'lstm': tfc.cudnn_rnn.CudnnLSTM 199 | } 200 | assert FLAGS.rnn_cell in supported_rnns 201 | 202 | rnn = supported_rnns[FLAGS.rnn_cell](num_layers=FLAGS.num_layers_rnn, 203 | num_units=FLAGS.num_units_rnn, 204 | input_mode='linear_input', 205 | direction='bidirectional', 206 | dropout=rnn_dropout_rate, 207 | seed=FLAGS.random_seed, 208 | dtype=TF_FLOAT, 209 | # Glorot Uniform Initializer 210 | kernel_initializer=None, 211 | # Constant 0.0 Initializer 212 | bias_initializer=None) 213 | 214 | output_rnn, _ = rnn(output3) 215 | output_rnn = tfc.rnn.transpose_batch_time(output_rnn) 216 | # output_rnn = [batch_size, time, num_units_rnn * 2] 217 | 218 | # Dense4 219 | with tf.variable_scope('dense4'): 220 | dense4 = tf.layers.dense(output_rnn, FLAGS.num_units_dense, 221 | activation=tf.nn.relu, 222 | kernel_initializer=initializer, 223 | kernel_regularizer=regularizer) 224 | dense4 = tf.minimum(dense4, FLAGS.relu_cutoff) 225 | dense4 = tf.layers.dropout(dense4, rate=FLAGS.dense_dropout_rate, training=training) 226 | # dense4 = [batch_size, conv_time, num_units_dense] 227 | 228 | # Logits: layer(XW + b), 229 | # We don't apply softmax here because most TensorFlow loss functions perform 230 | # a softmax activation as needed, and therefore don't expect activated logits. 231 | with tf.variable_scope('logits'): 232 | logits = tf.layers.dense(dense4, FLAGS.num_classes, kernel_initializer=initializer) 233 | logits = tfc.rnn.transpose_batch_time(logits) 234 | 235 | # logits = [time, batch_size, NUM_CLASSES] 236 | return logits, seq_length 237 | 238 | @staticmethod 239 | def loss_fn(logits, seq_length, labels): 240 | """Calculate the networks CTC loss. 241 | 242 | Args: 243 | logits (tf.Tensor): 244 | 3D float Tensor. If time_major == False, this will be a Tensor shaped: 245 | [batch_size, max_time, num_classes]. If time_major == True (default), this will be a 246 | Tensor shaped: [max_time, batch_size, num_classes]. The logits. 247 | 248 | labels (tf.SparseTensor): 249 | An int32 SparseTensor. labels.indices[i, :] == [b, t] means labels.values[i] stores 250 | the id for (batch b, time t). 251 | 252 | seq_length (tf.Tensor): 253 | 1D int32 vector, size [batch_size]. The sequence lengths. 254 | 255 | Returns: 256 | tf.Tensor: 1D float Tensor with size [1], containing the mean loss. 257 | """ 258 | # https://www.tensorflow.org/api_docs/python/tf/nn/ctc_loss 259 | total_loss = tf.nn.ctc_loss(labels=labels, 260 | inputs=logits, 261 | sequence_length=seq_length, 262 | preprocess_collapse_repeated=False, 263 | ctc_merge_repeated=True, 264 | time_major=True) 265 | 266 | # Average CTC loss. 267 | mean_loss = tf.reduce_mean(total_loss) 268 | tf.summary.scalar('loss', mean_loss, family='Metrics') 269 | return mean_loss 270 | 271 | @staticmethod 272 | def decode_fn(logits, seq_len, originals=None): 273 | """Decode a given inference (`logits`) and convert it to plaintext. 274 | 275 | Args: 276 | logits (tf.Tensor): 277 | Logits Tensor of shape [time (input), batch_size, num_classes]. 278 | 279 | seq_len (tf.Tensor): 280 | Tensor containing the batches sequence lengths of shape [batch_size]. 281 | 282 | originals (tf.Tensor or None): Optional, default `None`. 283 | String Tensor of shape [batch_size] with the original plaintext. 284 | 285 | Returns: 286 | tf.Tensor: Decoded integer labels. 287 | tf.Tensor: Decoded plaintext's. 288 | tf.Tensor: Decoded plaintext's and original texts for comparision in `tf.summary.text`. 289 | """ 290 | # tf.nn.ctc_beam_search_decoder provides more accurate results, but is slower. 291 | # https://www.tensorflow.org/api_docs/python/tf/nn/ctc_beam_search_decoder 292 | decoded, _ = tf.nn.ctc_beam_search_decoder(inputs=logits, 293 | sequence_length=seq_len, 294 | beam_width=FLAGS.beam_width, 295 | top_paths=1, 296 | merge_repeated=False) 297 | 298 | # ctc_greedy_decoder returns a list with one SparseTensor as only element, if `top_paths=1`. 299 | decoded = tf.cast(decoded[0], tf.int32) 300 | 301 | dense = tf.sparse.to_dense(decoded) 302 | 303 | originals = originals if originals is not None else np.array([], dtype=np.int32) 304 | 305 | # Translate decoded integer data back to character strings. 306 | plaintext, plaintext_summary = tf.py_func(metrics.dense_to_text, [dense, originals], 307 | [tf.string, tf.string], name='py_dense_to_text') 308 | 309 | return decoded, plaintext, plaintext_summary 310 | 311 | @staticmethod 312 | def error_rates_fn(labels, originals, decoded, decoded_texts): 313 | """Calculate edit distance and word error rate. 314 | 315 | Args: 316 | labels (tf.SparseTensor or tf.Tensor): 317 | Integer SparseTensor containing the target. 318 | With dense shape [batch_size, time (target)]. 319 | Dense Tensors are converted into SparseTensors if `FLAGS.use_warp_ctc == True`. 320 | 321 | originals (tf.Tensor): 322 | String Tensor of shape [batch_size] with the original plaintext. 323 | 324 | decoded (tf.Tensor): 325 | Integer tensor of the decoded output labels. 326 | 327 | decoded_texts (tf.Tensor) 328 | String tensor with the decoded output labels converted to normal text. 329 | 330 | Returns: 331 | tf.Tensor: Edit distances for the batch. 332 | tf.Tensor: Mean edit distance. 333 | tf.Tensor: Word error rates for the batch. 334 | tf.Tensor: Word error rate. 335 | """ 336 | 337 | # Edit distances and average edit distance. 338 | edit_distances = tf.edit_distance(decoded, labels) 339 | mean_edit_distance = tf.reduce_mean(edit_distances) 340 | 341 | # Word error rates for the batch and average word error rate (WER). 342 | wers, wer = tf.py_func(metrics.wer_batch, [originals, decoded_texts], 343 | [TF_FLOAT, TF_FLOAT], name='py_wer_batch') 344 | 345 | return edit_distances, mean_edit_distance, wers, wer 346 | 347 | def hooks_fn(self): 348 | """Produce and configure session hooks. 349 | 350 | Returns: 351 | List[tf.Tensor]: List containing TensorFlow hooks. 352 | """ 353 | 354 | # GPU statistics hook. 355 | gpu_stats_hook = GPUStatisticsHook( 356 | log_every_n_steps=FLAGS.log_frequency, 357 | query_every_n_steps=FLAGS.gpu_hook_query_frequency, 358 | average_n=FLAGS.gpu_hook_average_queries, 359 | stats=['mem_util', 'gpu_util'], 360 | output_dir=FLAGS.train_dir, 361 | suppress_stdout=False, 362 | group_tag='gpu' 363 | ) 364 | 365 | # Session hooks. 366 | self.hooks = [ 367 | # Monitors the loss tensor and stops training if loss is NaN. 368 | tf.train.NanTensorHook(self.loss_op), 369 | gpu_stats_hook 370 | ] 371 | 372 | return self.hooks 373 | -------------------------------------------------------------------------------- /asr/util/hooks.py: -------------------------------------------------------------------------------- 1 | """Collection of TensorFlow hooks.""" 2 | 3 | import time 4 | from datetime import datetime 5 | 6 | import pynvml as nvml 7 | import tensorflow as tf 8 | from tensorflow.core.framework.summary_pb2 import Summary 9 | from tensorflow.python.platform import tf_logging as logging 10 | from tensorflow.python.training import summary_io, training_util, session_run_hook 11 | 12 | from asr.params import FLAGS 13 | 14 | 15 | class GPUStatisticsHook(tf.train.SessionRunHook): 16 | """A session hook that log GPU statistics to tensorboard and to the log stream.""" 17 | 18 | def __init__(self, 19 | log_every_n_steps=None, 20 | log_every_n_secs=None, 21 | query_every_n_steps=None, 22 | query_every_n_secs=None, 23 | output_dir=None, 24 | summary_writer=None, 25 | stats=('mem_used', 'mem_free', 'mem_total', 'mem_util', 'gpu_util'), 26 | average_n=1, 27 | suppress_stdout=False, 28 | group_tag='gpu'): 29 | """Create an instance of `GPUStatisticsHook`. 30 | 31 | Arguments: 32 | log_every_n_steps (int): 33 | Integer controlling after how many (global) steps the hook is supposed to log the 34 | averaged values to tensorboard or the logging stream. 35 | When set `every_n_secs` must be None. 36 | log_every_n_secs (int): 37 | Integer controlling after how many seconds the hook is supposed to log the 38 | averaged values to tensorboard or the logging stream. 39 | When set `every_n_steps` must be None. 40 | query_every_n_steps (int): 41 | Integer controlling after how many (global) steps the hook is supposed to query 42 | values from the hardware. 43 | When set `every_n_secs` must be None. 44 | query_every_n_secs (int): 45 | Integer controlling after how many seconds the hook is supposed to query 46 | values from the hardware. 47 | When set `every_n_steps` must be None. 48 | output_dir (str): 49 | In case `summary_writer` is None, this parameter is used to construct a 50 | FileWriter for writing a summary statistic. 51 | summary_writer (tensorflow.summary.FileWriter): 52 | FileWriter to use for writing the summary statistics. 53 | stats (:obj:`tuple` of `str`): 54 | List of strings to control what statistics are written to tensorboard. 55 | Valid strings are ('mem_used', 'mem_free', 'mem_total', 'mem_util', 'gpu_util'). 56 | Note that ('mem_used', 'mem_free', 'mem_total') are logged in MiB and encompass 57 | the global GPU state (therefore including all processes running on that GPU). 58 | Note that ('mem_util', 'gpu_util') are given in percent (0, 100). 59 | average_n (int): 60 | Integer controlling how many values (i.e. results of a query) should be memorized 61 | for averaging. 62 | Default is 1, resulting in only the value from the last query execution 63 | being remembered. 64 | suppress_stdout (bool): 65 | If True, statistics are only logged to tensorboard. 66 | If False, statistics are logged to tensorboard and are written into tensorflow 67 | logging with INFO level. 68 | group_tag (str): 69 | Name of the tag under which the values will appear in tensorboard. 70 | Default is 'gpu' 71 | """ 72 | 73 | # Check if only log_every_n_steps or only log_every_n_secs is set. 74 | if (log_every_n_steps is None) == (log_every_n_secs is None): 75 | raise ValueError("exactly one of log_every_n_steps and log_every_n_secs should be " 76 | "provided.") 77 | 78 | # Check if only query_every_n_steps or only query_every_n_secs is set. 79 | if (query_every_n_steps is None) == (query_every_n_secs is None): 80 | raise ValueError("exactly one of query_every_n_steps and query_every_n_secs should be " 81 | "provided.") 82 | 83 | # Timer controlling how often the statistics are queried from the GPUs. 84 | self._query_timer = tf.train.SecondOrStepTimer(every_steps=query_every_n_steps, 85 | every_secs=query_every_n_secs) 86 | 87 | # Timer controlling how often statistics are logged (i.e. written to TB or to logging). 88 | self._log_timer = tf.train.SecondOrStepTimer(every_steps=log_every_n_steps, 89 | every_secs=log_every_n_secs) 90 | 91 | # Initialize the internal variables. 92 | self._summary_writer = summary_writer 93 | self._output_dir = output_dir 94 | self._last_global_step = None 95 | self._global_step_check_count = 0 96 | self._steps_per_run = 1 97 | self._global_step_tensor = None 98 | self._statistics_to_log = stats 99 | self._suppress_stdout = suppress_stdout 100 | self._group_tag = group_tag 101 | 102 | self._average_n = average_n 103 | self._gpu_statistics = dict() 104 | 105 | self._global_step_write_count = 0 106 | 107 | # Initialize the NVML interface. 108 | nvml.nvmlInit() 109 | 110 | # Query the number of available GPUs. 111 | self._device_count = nvml.nvmlDeviceGetCount() 112 | 113 | # Create a summary dict for each GPU. 114 | for gpu_id in range(self._device_count): 115 | self._gpu_statistics[gpu_id] = self.__init_gpu_summaries() 116 | 117 | # def _set_steps_per_run(self, steps_per_run): 118 | # self._steps_per_run = steps_per_run 119 | 120 | @staticmethod 121 | def __statistic_keys(): 122 | """Get the keys for all statistics that the hook can query. 123 | 124 | Returns: 125 | list: 126 | List of keys. 127 | """ 128 | return [ 129 | 'mem_used', # Used memory. 130 | 'mem_free', # Free memory. 131 | 'mem_total', # Total memory. 132 | 'mem_util', # Memory IO utilization. 133 | 'gpu_util' # GPU utilization. 134 | ] 135 | 136 | def __init_gpu_summaries(self): 137 | """ 138 | Create a dictionary with all summary keys initialized as empty lists. 139 | 140 | Returns: 141 | dict: 142 | Dictionary containing an empty list for each key from `__statistic_keys`. 143 | """ 144 | summaries = dict() 145 | for key in self.__statistic_keys(): 146 | summaries[key] = list() 147 | 148 | return summaries 149 | 150 | @staticmethod 151 | def __query_mem(handle): 152 | """ 153 | Query information on the memory of a GPU. 154 | 155 | Arguments: 156 | handle: 157 | NVML device handle. 158 | 159 | Returns: 160 | summaries (:obj:`dict`): 161 | Dictionary containing the memory values for ['mem_used', 'mem_free', 'mem_total']. 162 | All values are given in MiB as integers. 163 | """ 164 | # Query information on the GPUs memory usage. 165 | info = nvml.nvmlDeviceGetMemoryInfo(handle) 166 | 167 | summaries = dict() 168 | bytes_mib = 1024.0 ** 2 169 | summaries['mem_used'] = int(info.used / bytes_mib) 170 | summaries['mem_free'] = int(info.free / bytes_mib) 171 | summaries['mem_total'] = int(info.total / bytes_mib) 172 | 173 | return summaries 174 | 175 | @staticmethod 176 | def __query_util(handle): 177 | """Query information on the utilization of a GPU. 178 | 179 | Arguments: 180 | handle: 181 | NVML device handle. 182 | 183 | Returns: 184 | summaries (:obj:`dict`): 185 | Dictionary containing the memory values for ['mem_util', 'gpu_util']. 186 | All values are given as integers in the range (0, 100). 187 | """ 188 | # Query information on the GPU utilization. 189 | util = nvml.nvmlDeviceGetUtilizationRates(handle) 190 | 191 | summaries = dict() 192 | # Percent of time over the past second during which global (device) memory was being 193 | # read or written. 194 | summaries['mem_util'] = util.memory 195 | # Percent of time over the past second during which one or more kernels was executing 196 | # on the GPU. 197 | summaries['gpu_util'] = util.gpu 198 | 199 | return summaries 200 | 201 | def begin(self): 202 | """Called once before graph finalization. 203 | 204 | Is called once before the default graph in the active tensorflow session is 205 | finalized and the training has starts. 206 | The hook can modify the graph by adding new operations to it. 207 | After the begin() call the graph will be finalized and the other callbacks can not modify 208 | the graph anymore. Second call of begin() on the same graph, should not change the graph. 209 | """ 210 | # Create a summary writer if possible. 211 | if self._summary_writer is None and self._output_dir: 212 | self._summary_writer = summary_io.SummaryWriterCache.get(self._output_dir) 213 | 214 | # Get read access to the global step tensor. 215 | # pylint: disable=protected-access 216 | self._global_step_tensor = training_util._get_or_create_global_step_read() 217 | if self._global_step_tensor is None: 218 | raise RuntimeError("Global step should be created to use StepCounterHook.") 219 | 220 | def end(self, session): 221 | """Called at the end of a session. 222 | 223 | Arguments: 224 | session (tf.Session): 225 | The `session` argument can be used in case the hook wants to run final ops, 226 | such as saving a last checkpoint. 227 | """ 228 | # Shutdown the NVML interface. 229 | nvml.nvmlShutdown() 230 | 231 | def before_run(self, run_context): 232 | """Is called once before each call to session.run (training iteration in general). 233 | 234 | At this point the graph is finalized and you can not add ops. 235 | 236 | Arguments: 237 | run_context (tf.train.SessionRunContext): 238 | The `run_context` argument is a `SessionRunContext` that provides 239 | information about the upcoming `run()` call: the originally requested 240 | op/tensors, the TensorFlow Session. 241 | SessionRunHook objects can stop the loop by calling `request_stop()` of 242 | `run_context`. 243 | Sadly you have to take a look at 'tensorflow/python/training/session_run_hook.py' 244 | for more details. 245 | Returns: 246 | tf.train.SessionRunArgs: 247 | None or a `SessionRunArgs` object. 248 | Represents arguments to be added to a `Session.run()` call. 249 | Sadly you have to take a look at 'tensorflow/python/training/session_run_hook.py' 250 | for more details. 251 | """ 252 | # Request to read the global step tensor when running the hook. 253 | # The content of the requested tensors is passed to the hooks `after_run` function. 254 | fetches = [ 255 | # This will deliver the global step as it was before the `session.run` 256 | # call was executed. 257 | self._global_step_tensor 258 | ] 259 | return session_run_hook.SessionRunArgs(fetches=fetches) 260 | 261 | def after_run(self, run_context, run_values): 262 | """Is called once after each call to session.run (training iteration in general). 263 | 264 | At this point the graph is finalized and you can not add ops. 265 | 266 | Arguments: 267 | run_context (tf.train.SessionRunContext): 268 | The `run_context` argument is a `SessionRunContext` that provides 269 | information about the upcoming `run()` call: the originally requested 270 | op/tensors, the TensorFlow Session. 271 | SessionRunHook objects can stop the loop by calling `request_stop()` of 272 | `run_context`. 273 | Sadly you have to take a look at 'tensorflow/python/training/session_run_hook.py' 274 | for more details. 275 | run_values (tf.train.SessionRunValues): 276 | Contains the results of `Session.run()` 277 | However, this only seems to contain the results for the operations requested with 278 | the `before_run`. 279 | Sadly you have to take a look at 'tensorflow/python/training/session_run_hook.py' 280 | for more details. 281 | """ 282 | # Ignore input argument. 283 | _ = run_context 284 | 285 | # Get the values of the tensors requested inside the `before_run` function. 286 | # read the global step as it was before the `session.run` call was executed. 287 | stale_global_step = run_values.results[0] 288 | 289 | # Check if the query timer should trigger for the current global step (i.e. last step + 1). 290 | if self._query_timer.should_trigger_for_step(stale_global_step + self._steps_per_run): 291 | # Get the actual global step from the global steps tensor. 292 | global_step = run_context.session.run(self._global_step_tensor) 293 | if self._query_timer.should_trigger_for_step(global_step): 294 | # Get the elapsed time and elapsed steps since the last trigger event. 295 | elapsed_time, elapsed_steps = self._query_timer.update_last_triggered_step( 296 | global_step) 297 | if elapsed_time is not None: 298 | self._update_statistics(elapsed_steps, elapsed_time, global_step) 299 | 300 | # Check if the log timer should trigger for the current global step (i.e. last step + 1). 301 | if self._log_timer.should_trigger_for_step(stale_global_step + self._steps_per_run): 302 | # Get the actual global step from the global steps tensor. 303 | global_step = run_context.session.run(self._global_step_tensor) 304 | if self._log_timer.should_trigger_for_step(global_step): 305 | # Get the elapsed time and elapsed steps since the last trigger event. 306 | elapsed_time, elapsed_steps = self._log_timer.update_last_triggered_step( 307 | global_step) 308 | if elapsed_time is not None: 309 | self._log_statistics(elapsed_steps, elapsed_time, global_step) 310 | 311 | # Check whether the global step has been increased. Here, we do not use the 312 | # timer.last_triggered_step as the timer might record a different global 313 | # step value such that the comparison could be unreliable. For simplicity, 314 | # we just compare the stale_global_step with previously recorded version. 315 | if stale_global_step == self._last_global_step: 316 | # Here, we use a counter to count how many times we have observed that the 317 | # global step has not been increased. For some Optimizers, the global step 318 | # is not increased each time by design. For example, SyncReplicaOptimizer 319 | # doesn't increase the global step in worker's main train step. 320 | self._global_step_check_count += 1 321 | if self._global_step_check_count % 20 == 0: 322 | self._global_step_check_count = 0 323 | logging.warning( 324 | "It seems that global step (tf.train.get_global_step) has not " 325 | "been increased. Current value (could be stable): %s vs previous " 326 | "value: %s. You could increase the global step by passing " 327 | "tf.train.get_global_step() to Optimizer.apply_gradients or " 328 | "Optimizer.minimize.", stale_global_step, self._last_global_step) 329 | else: 330 | # Whenever we observe the increment, reset the counter. 331 | self._global_step_check_count = 0 332 | 333 | self._last_global_step = stale_global_step 334 | 335 | def _update_statistics(self, elapsed_steps, elapsed_time, global_step): 336 | """Collect and store all summary values. 337 | 338 | Arguments: 339 | elapsed_steps (int): 340 | The number of steps between the current trigger event and the last one. 341 | elapsed_time (float): 342 | The number of seconds between the current trigger event and the last one. 343 | global_step (tf.Tensor): 344 | Global step tensor. 345 | """ 346 | # Iterate the available GPUs. 347 | for gpu_id in range(self._device_count): 348 | summaries = dict() 349 | 350 | # Acquire a GPU device handle. 351 | handle = nvml.nvmlDeviceGetHandleByIndex(gpu_id) 352 | 353 | # Query information on the GPUs memory usage. 354 | summaries.update(self.__query_mem(handle)) 355 | 356 | # Query information on the GPUs utilization. 357 | summaries.update(self.__query_util(handle)) 358 | 359 | # Update the value history for the current GPU. 360 | for k in summaries.keys(): 361 | if k in self._statistics_to_log: 362 | self._gpu_statistics[gpu_id][k] = \ 363 | self._gpu_statistics[gpu_id][k][-self._average_n:] + [summaries[k]] 364 | 365 | def _log_statistics(self, elapsed_steps, elapsed_time, global_step): 366 | """Collect and store all summary values. 367 | 368 | Arguments: 369 | elapsed_steps (int): 370 | The number of steps between the current trigger event and the last one. 371 | elapsed_time (float): 372 | The number of seconds between the current trigger event and the last one. 373 | global_step (tf.Tensor): 374 | Global step tensor. 375 | """ 376 | 377 | # Write summary for tensorboard. 378 | if self._summary_writer is not None: 379 | summary_list = list() 380 | # Add only summaries. 381 | for gpu_id in self._gpu_statistics.keys(): 382 | for statistic in self._gpu_statistics[gpu_id].keys(): 383 | # only add them if they are requested for logging. 384 | if statistic in self._statistics_to_log: 385 | values = self._gpu_statistics[gpu_id][statistic] 386 | # Only Calculate and write average if there is data available. 387 | if values: 388 | avg_value = sum(values) / len(values) 389 | avg_summary = Summary.Value(tag='{}/{}:{}' 390 | .format(self._group_tag, gpu_id, statistic), 391 | simple_value=avg_value) 392 | summary_list.append(avg_summary) 393 | 394 | # Write all statistics as simple scalar summaries. 395 | summary = Summary(value=summary_list) 396 | self._summary_writer.add_summary(summary, global_step) 397 | 398 | # Log summaries to the logging stream. 399 | if not self._suppress_stdout: 400 | for gpu_id in self._gpu_statistics.keys(): 401 | # Acquire a GPU device handle. 402 | handle = nvml.nvmlDeviceGetHandleByIndex(gpu_id) 403 | 404 | # Query the device name. 405 | name = nvml.nvmlDeviceGetName(handle).decode('utf-8') 406 | 407 | for statistic in self._gpu_statistics[gpu_id].keys(): 408 | # Log utilization information with INFO level. 409 | logging.debug("%s: %s", name, '{}: {}' 410 | .format(statistic, self._gpu_statistics[gpu_id][statistic])) 411 | 412 | 413 | # The following code has been inspired by : 414 | class TraceHook(tf.train.SessionRunHook): 415 | """Hook to perform Traces every N steps.""" 416 | 417 | def __init__(self, file_writer, log_frequency, trace_level=tf.RunOptions.FULL_TRACE): 418 | self._trace = log_frequency == 1 419 | self.writer = file_writer 420 | self.trace_level = trace_level 421 | self.log_frequency = log_frequency 422 | self._global_step_tensor = None 423 | 424 | def begin(self): 425 | self._global_step_tensor = tf.train.get_global_step() 426 | if self._global_step_tensor is None: 427 | raise RuntimeError("Global step should be created to use TraceHook.") 428 | 429 | def before_run(self, run_context): 430 | if self._trace: 431 | options = tf.RunOptions(trace_level=self.trace_level) 432 | else: 433 | options = None 434 | 435 | return tf.train.SessionRunArgs(fetches=self._global_step_tensor, options=options) 436 | 437 | def after_run(self, run_context, run_values): 438 | global_step = run_values.results - 1 439 | if self._trace: 440 | self._trace = False 441 | self.writer.add_run_metadata(run_values.run_metadata, '{}'.format(global_step)) 442 | if not (global_step + 1) % self.log_frequency: 443 | self._trace = True 444 | 445 | 446 | class LoggerHook(tf.train.SessionRunHook): 447 | """Log loss and runtime.""" 448 | 449 | def __init__(self, loss_op): 450 | self.loss_op = loss_op 451 | self._global_step_tensor = None 452 | self._start_time = 0 453 | 454 | def begin(self): 455 | self._global_step_tensor = tf.train.get_global_step() 456 | self._start_time = time.time() 457 | 458 | def before_run(self, run_context): 459 | # Asks for loss value and global step. 460 | return tf.train.SessionRunArgs(fetches=[self.loss_op, self._global_step_tensor]) 461 | 462 | def after_run(self, run_context, run_values): 463 | loss_value, global_step = run_values.results 464 | 465 | if global_step % FLAGS.log_frequency == 0: 466 | current_time = time.time() 467 | duration = current_time - self._start_time 468 | self._start_time = current_time 469 | 470 | examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration 471 | sec_per_batch = duration / float(FLAGS.log_frequency) 472 | batch_per_sec = float(FLAGS.log_frequency) / duration 473 | 474 | print('{:%Y-%m-%d %H:%M:%S}: (step={:,d}); loss={:.4f}; ' 475 | '{:.1f} examples/sec ({:.3f} sec/batch) ({:.2f} batch/sec)' 476 | .format(datetime.now(), global_step, loss_value, examples_per_sec, 477 | sec_per_batch, batch_per_sec)) 478 | --------------------------------------------------------------------------------