├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── requirements.txt ├── setup.py ├── src └── deepspeech │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── alphabet.py │ ├── datasets │ │ ├── __init__.py │ │ ├── librispeech.py │ │ └── utils.py │ ├── loader.py │ └── preprocess.py │ ├── decoder │ ├── __init__.py │ ├── base.py │ ├── beam.py │ └── greedy.py │ ├── global_state.py │ ├── logging │ ├── __init__.py │ ├── log_level_action.py │ └── mixin.py │ ├── loss │ ├── __init__.py │ ├── ctc_loss.py │ └── eval.py │ ├── models │ ├── __init__.py │ ├── deepspeech.py │ ├── deepspeech2.py │ └── model.py │ ├── networks │ ├── __init__.py │ ├── deepspeech.py │ ├── deepspeech2.py │ └── utils.py │ ├── run.py │ └── utils │ ├── __init__.py │ └── singleton.py └── tests ├── __init__.py ├── data ├── __init__.py ├── test_alphabet.py └── test_preprocess.py ├── decoder ├── __init__.py └── test_decoder.py ├── models ├── __init__.py ├── test_deepspeech.py ├── test_deepspeech2.py └── test_model.py ├── networks └── test_deepspeech.py ├── test_utils.py └── utils ├── __init__.py └── test_singleton.py /.gitignore: -------------------------------------------------------------------------------- 1 | # DeepSpeech 2 | deps 3 | 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 2 | 3 | LABEL maintainer="sam@myrtle.ai" 4 | 5 | #- Upgrade system and install dependencies ------------------------------------- 6 | RUN apt-get update && \ 7 | apt-get upgrade -y && \ 8 | apt-get install -y --no-install-recommends \ 9 | build-essential \ 10 | cmake \ 11 | git \ 12 | libboost-program-options-dev \ 13 | libboost-system-dev \ 14 | libboost-test-dev \ 15 | libboost-thread-dev \ 16 | libbz2-dev \ 17 | libeigen3-dev \ 18 | liblzma-dev \ 19 | libsndfile1 \ 20 | python3 \ 21 | python3-dev \ 22 | python3-pip \ 23 | python3-setuptools \ 24 | python3-wheel \ 25 | sudo \ 26 | vim && \ 27 | rm -rf /var/lib/apt/lists 28 | 29 | #- Enable passwordless sudo for users under the "sudo" group ------------------- 30 | RUN sed -i.bkp -e \ 31 | 's/%sudo\s\+ALL=(ALL\(:ALL\)\?)\s\+ALL/%sudo ALL=NOPASSWD:ALL/g' \ 32 | /etc/sudoers 33 | 34 | #- Create data group for NFS mount -------------------------------------------- 35 | RUN groupadd --system data --gid 5555 36 | 37 | #- Create and switch to a non-root user ---------------------------------------- 38 | RUN groupadd -r ubuntu && \ 39 | useradd --no-log-init \ 40 | --create-home \ 41 | --gid ubuntu \ 42 | ubuntu && \ 43 | usermod -aG sudo ubuntu 44 | USER ubuntu 45 | WORKDIR /home/ubuntu 46 | 47 | #- Install Python packages ----------------------------------------------------- 48 | ARG WHEELDIR=/home/ubuntu/.cache/pip/wheels 49 | COPY --chown=ubuntu:ubuntu deps deps 50 | COPY --chown=ubuntu:ubuntu requirements.txt requirements.txt 51 | RUN pip3 install --find-links ${WHEELDIR} \ 52 | -r requirements.txt && \ 53 | rm requirements.txt && \ 54 | rm -r ${WHEELDIR} && \ 55 | rm -r /home/ubuntu/.cache/pip 56 | 57 | # warp-ctc 58 | RUN cd /home/ubuntu/deps/warp-ctc/pytorch_binding && \ 59 | git reset --hard 6f3f1cb7871f682e118c49788f5e54468b59c953 && \ 60 | python3 setup.py bdist_wheel && \ 61 | pip3 install dist/warpctc-0.0.0-cp35-cp35m-linux_x86_64.whl 62 | 63 | # ctcdecode bindings 64 | RUN cd /home/ubuntu/deps/ctcdecode && \ 65 | git reset --hard 6f4326b43b8dc49fd2e328ce231d1ba37f8e373f && \ 66 | pip3 install . 67 | 68 | # kenlm for binaries for building LMs and Python lib for easy analysis 69 | RUN cd /home/ubuntu/deps/kenlm && \ 70 | git reset --hard 328cc2995202e84d29e3773203d29cdd6cc07132 && \ 71 | mkdir build && \ 72 | cd build && \ 73 | cmake .. && \ 74 | make -j 4 && \ 75 | sudo mv bin/lmplz bin/build_binary /usr/bin/ && \ 76 | pip3 install /home/ubuntu/deps/kenlm 77 | 78 | RUN sudo rm -rf deps 79 | 80 | #- Install deepspeech package -------------------------------------------------- 81 | COPY --chown=ubuntu:ubuntu . deepspeech 82 | RUN pip3 install deepspeech/ 83 | RUN rm -rf deepspeech 84 | 85 | #- Setup Jupyter --------------------------------------------------------------- 86 | EXPOSE 9999 87 | ENV PATH /home/ubuntu/.local/bin:$PATH 88 | ENV SHELL /bin/bash 89 | CMD ["jupyter", "lab", \ 90 | "--ip=0.0.0.0", \ 91 | "--port=9999"] 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Contact: deepspeech@myrtle.ai 4 | 5 | Copyright (c) 2018, Myrtle Software Limited, www.myrtle.ai 6 | All rights reserved. 7 | 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | IMAGE_NAME=deepspeech 2 | 3 | # use branch name for tag if possible 4 | IMAGE_TAG=$(shell git symbolic-ref --short -q HEAD || echo 'dev') 5 | 6 | build: deps/ctcdecode deps/kenlm deps/warp-ctc 7 | sudo docker build -t $(IMAGE_NAME):$(IMAGE_TAG) . 8 | 9 | deps/ctcdecode: 10 | git clone --recursive git@github.com:parlance/ctcdecode.git deps/ctcdecode 11 | 12 | deps/kenlm: 13 | git clone git@github.com:kpu/kenlm.git deps/kenlm 14 | 15 | deps/warp-ctc: 16 | git clone git@github.com:t-vi/warp-ctc.git deps/warp-ctc 17 | 18 | clean: 19 | sudo docker images -q $(IMAGE_NAME):$(IMAGE_TAG) | \ 20 | xargs --no-run-if-empty sudo docker rmi 21 | rm -rf deps 22 | 23 | .PHONY: build clean 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Myrtle Deep Speech 2 | 3 | A PyTorch implementation of [DeepSpeech](https://arxiv.org/abs/1412.5567) and 4 | [DeepSpeech2](https://arxiv.org/abs/1512.02595). 5 | 6 | This repository is intended as an evolving baseline for other implementations 7 | to compare their training performance against. 8 | 9 | Current roadmap: 10 | 1. ~Pre-trained weights for both networks and full performance statistics.~ 11 | - See v0.1 release: https://github.com/MyrtleSoftware/deepspeech/releases/tag/v0.1 12 | 1. Mixed-precision training. 13 | 14 | ## Running 15 | 16 | Build the Docker image: 17 | 18 | ``` 19 | make build 20 | ``` 21 | 22 | Run the Docker container (here using 23 | [nvidia-docker](https://github.com/NVIDIA/nvidia-docker)), ensuring to publish 24 | the port of the JupyterLab session to the host: 25 | 26 | ``` 27 | sudo docker run --runtime=nvidia --shm-size 512M -p 9999:9999 deepspeech 28 | ``` 29 | 30 | The JupyterLab session can be accessed via `localhost:9999`. 31 | 32 | This Python package will accessible in the running Docker container and is 33 | accessible through either the command line interface: 34 | 35 | ``` 36 | deepspeech --help 37 | ``` 38 | 39 | or as a Python package: 40 | 41 | ```python 42 | import deepspeech 43 | ``` 44 | 45 | ## Examples 46 | 47 | `deepspeech --help` will print the configurable parameters (batch size, 48 | learning rate, log location, number of epochs...) - it aims to have reasonably 49 | sensible defaults. 50 | 51 | ### Training 52 | 53 | A Deep Speech training run can be started by the following command, adding 54 | flags as necessary: 55 | 56 | ``` 57 | deepspeech ds1 58 | ``` 59 | 60 | By default the experimental data and logs are output to 61 | `/tmp/experiments/year_month_date-hour_minute_second_microsecond`. 62 | 63 | ### Inference 64 | 65 | A Deep Speech evaluation run can be started by the following command, adding 66 | flags as necessary: 67 | 68 | ``` 69 | deepspeech ds1 \ 70 | --state_dict_path $MODEL_PATH \ 71 | --log_file \ 72 | --decoder greedy \ 73 | --train_subsets \ 74 | --dev_log wer \ 75 | --dev_subsets dev-clean \ 76 | --dev_batch_size 1 77 | ``` 78 | 79 | Note the lack of an argument to `--log_file` causes the WER results to be 80 | written to stderr. 81 | 82 | ## Dataset 83 | 84 | The package contains code to download and use the [LibriSpeech ASR 85 | corpus](http://www.openslr.org/12/). 86 | 87 | ## WER 88 | 89 | The word error rate (WER) is computed using the formula that is widely used in 90 | many open-source speech-to-text systems (Kaldi, PaddlePaddle, Mozilla 91 | DeepSpeech). In pseudocode, where `N` is the number of validation or test 92 | samples: 93 | 94 | ``` 95 | sum_edits = sum([edit_distance(target, predict) 96 | for target, predict in zip(targets, predictions)]) 97 | sum_lens = sum([len(target) for target in targets]) 98 | WER = (1.0/N) * (sum_edits / sum_lens) 99 | ``` 100 | 101 | This reduces the impact on the WER of errors in short sentences. Toy example: 102 | 103 | | Target | Prediction | Edit Distance | Label Length | 104 | |--------------------------------|-------------------------------|---------------|--------------| 105 | | lectures | lectured | 1 | 1 | 106 | | i'm afraid he said | i am afraid he said | 2 | 4 | 107 | | nice to see you mister meeking | nice to see your mister makin | 2 | 6 | 108 | 109 | The mean WER of each sample considered individually is: 110 | 111 | ``` 112 | >>> (1.0/3) * ((1.0/1) + (2.0/4) + (2.0/6)) 113 | 0.611111111111111 114 | ``` 115 | 116 | Compared to the pseudocode version given above: 117 | 118 | ``` 119 | >>> (1.0/3) * ((1.0 + 2 + 2) / (1.0 + 4 + 6)) 120 | 0.1515151515151515 121 | ``` 122 | 123 | ## Maintainer 124 | 125 | Please contact `sam at myrtle dot ai`. 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyterlab 2 | librosa 3 | numpy 4 | psutil 5 | pysoundfile 6 | pytest 7 | python_speech_features 8 | requests 9 | tensorboardX 10 | tensorflow 11 | 12 | https://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-linux_x86_64.whl 13 | torchvision 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="deepspeech", 5 | version="0.2.1", 6 | description="train and evaluate a DeepSpeech or DeepSpeech2 network", 7 | author="Sam Davis", 8 | author_email="sam@myrtle.ai", 9 | packages=setuptools.find_packages('src'), 10 | package_dir={'': 'src'}, 11 | python_requires='>=3.5', 12 | entry_points={ 13 | 'console_scripts': ['deepspeech=deepspeech.run:main'] 14 | } 15 | ) 16 | -------------------------------------------------------------------------------- /src/deepspeech/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/src/deepspeech/__init__.py -------------------------------------------------------------------------------- /src/deepspeech/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/src/deepspeech/data/__init__.py -------------------------------------------------------------------------------- /src/deepspeech/data/alphabet.py: -------------------------------------------------------------------------------- 1 | class Alphabet: 2 | """An alphabet for a language. 3 | 4 | Args: 5 | symbols (sequence of str): Sequence of symbols in the alphabet. Each 6 | symbol will be assigned, in iteration order, an index (int) 7 | starting from 0. 8 | 9 | Raises: 10 | ValueError: Duplicate symbol in symbols. 11 | 12 | Attributes: 13 | symbols: The original sequence of symbols. 14 | """ 15 | 16 | def __init__(self, symbols): 17 | if len(set(symbols)) != len(symbols): 18 | raise ValueError('Duplicate symbol in symbols.') 19 | 20 | self.symbols = symbols 21 | self._index_map = dict(enumerate(symbols)) 22 | self._symbol_map = {letter: i for i, letter in self._index_map.items()} 23 | 24 | def __repr__(self): 25 | return self.__class__.__name__ + ('(symbols=%r)' % self.symbols) 26 | 27 | def __len__(self): 28 | """Returns the number of symbols in the alphabet.""" 29 | return len(self.symbols) 30 | 31 | def __getitem__(self, index): 32 | symbol = self.get_symbol(index) 33 | if symbol is None: 34 | raise IndexError('Index %d is out of range') 35 | return symbol 36 | 37 | def get_symbol(self, index): 38 | """Returns the symbol for an index or None if index has no symbol.""" 39 | return self._index_map.get(index) 40 | 41 | def get_index(self, symbol): 42 | """Returns the index for a symbol or None if symbol not in Alphabet.""" 43 | return self._symbol_map.get(symbol) 44 | 45 | def get_symbols(self, indices): 46 | """Maps each index in a sequence of indices to it's symbol. 47 | 48 | Args: 49 | indices: A sequence of indices (int). 50 | 51 | Returns: 52 | A list of symbols (str). Indices in the sequence that do not have a 53 | corresponding symbol will be ignored. This means len(returned list) 54 | may be shorted than len(indices). 55 | """ 56 | symbols = [self.get_symbol(index) for index in indices] 57 | return list(filter(lambda x: x is not None, symbols)) 58 | 59 | def get_indices(self, sentence): 60 | """Maps each symbol in a sentence to it's index. 61 | 62 | Args: 63 | sentence: A sequence of symbols. 64 | 65 | Returns: 66 | A list of indices (int). Symbols in the sentence that are not in 67 | the alphabet will be ignored. This means len(returned list) may be 68 | shorter than len(sentence). 69 | """ 70 | indices = [self.get_index(symbol) for symbol in sentence] 71 | return list(filter(lambda x: x is not None, indices)) 72 | -------------------------------------------------------------------------------- /src/deepspeech/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from deepspeech.data.datasets.librispeech import LibriSpeech # noqa: F401 2 | -------------------------------------------------------------------------------- /src/deepspeech/data/datasets/librispeech.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import shutil 4 | import tarfile 5 | 6 | import requests 7 | import soundfile 8 | import torch 9 | 10 | from deepspeech.data.datasets import utils 11 | 12 | 13 | class LibriSpeech(torch.utils.data.Dataset): 14 | """LibriSpeech Dataset - http://openslr.org/12/ 15 | 16 | Args: 17 | root (string): Root directory of the dataset. This should contain the 18 | LibriSpeech directory which itself contains one directory per 19 | subset in subsets. These will be created if they do not exist, or 20 | are corrupt, and download is True. 21 | subsets (list of strings): List of subsets to create the dataset from. 22 | Subset names must be in: `['train-clean-100', 'train-clean-360', 23 | 'train-other-500', 'dev-clean', 'dev-other', 'test-clean', 24 | 'test-other']`. 25 | transform (callable, optional): A function that returns a transformed 26 | piece of audiodata. 27 | target_transform (callable, optional): A function that returns a 28 | transformed target. 29 | download (bool, optional): If true, downloads the dataset from the 30 | internet and puts it in root directory. If dataset is already 31 | downloaded, it is not downloaded again. See the `download` method 32 | for more information. 33 | """ 34 | 35 | base_dir = 'LibriSpeech' 36 | openslr_url = 'http://www.openslr.org/resources/12/' 37 | data_files = { 38 | 'train-clean-100': {'archive_md5': '2a93770f6d5c6c964bc36631d331a522', 39 | 'dir_md5': 'b1b762a7384c17c06eee975933c71739'}, 40 | 'train-clean-360': {'archive_md5': 'c0e676e450a7ff2f54aeade5171606fa', 41 | 'dir_md5': 'ef1c8b93522d89ae27116d64a12d1d2f'}, 42 | 'train-other-500': {'archive_md5': 'd1a0fd59409feb2c614ce4d30c387708', 43 | 'dir_md5': '851de4dff9bdfd1a89d9eab18bc675c6'}, 44 | 'dev-clean': {'archive_md5': '42e2234ba48799c1f50f24a7926300a1', 45 | 'dir_md5': '9e3b56b96e2cbbcc941c00f52f2fdcf9'}, 46 | 'dev-other': {'archive_md5': 'c8d0bcc9cca99d4f8b62fcc847357931', 47 | 'dir_md5': '1417e91c9f0a1c2c1c61af1178ffa94b'}, 48 | 'test-clean': {'archive_md5': '32fa31d27d2e1cad72775fee3f4849a9', 49 | 'dir_md5': '5fad2e72ec7af2659d50e4df720bc22b'}, 50 | 'test-other': {'archive_md5': 'fb5a50374b501bb3bac4815ee91d3135', 51 | 'dir_md5': 'ddcbdd339bd02c8d2b6c1bbde28c828c'} 52 | } 53 | 54 | def __init__(self, root, subsets, transform=None, target_transform=None, 55 | download=False): 56 | self.root = os.path.expanduser(root) 57 | self.subsets = self._validate_subsets(subsets) 58 | self._transform = transform 59 | self._target_transform = target_transform 60 | 61 | if download: 62 | self.download() 63 | 64 | if not self.check_integrity(): 65 | raise ValueError('Dataset subset(s) not found or corrupt.' 66 | ' Set `download=True` to download.') 67 | 68 | self.load_data() 69 | 70 | def __getitem__(self, index): 71 | """Returns the sample at index in the dataset. 72 | 73 | The samples are in ascending order by audio sample length. 74 | 75 | Transforms are applied if not None. 76 | 77 | Args: 78 | index (int): Index. 79 | 80 | Returns: 81 | tuple: (audiodata, target) where audiodata is a np.int16 numpy 82 | array and target is a string containing the target transcription. 83 | """ 84 | path = self.paths[index] 85 | audio, rate = soundfile.read(path, dtype='int16') 86 | 87 | assert rate == 16000, '%r sample rate != 16000' % path 88 | assert len(audio) / rate == self.durations[index], \ 89 | '%r sample duration != to expected duration' 90 | 91 | if self._transform is not None: 92 | audio = self._transform(audio) 93 | 94 | target = self.transcriptions[index] 95 | if self._target_transform is not None: 96 | target = self._target_transform(target) 97 | 98 | return audio, target 99 | 100 | def __len__(self): 101 | return len(self.paths) 102 | 103 | def _validate_subsets(self, subsets): 104 | """Ensures subsets is non-empty and contains valid subsets only.""" 105 | if not subsets: 106 | raise ValueError('No subsets specified.') 107 | for subset in subsets: 108 | if subset not in self.data_files.keys(): 109 | raise ValueError('%r is not a valid subset.' % subset) 110 | return subsets 111 | 112 | def download(self): 113 | """Downloads and extracts self.subsets unless already cached. 114 | 115 | For each subset there are 3 possibilities: 116 | 1. The `LibriSpeech/subset` directory exists and is valid making 117 | this function a noop. 118 | 2. If not 1. but the `subset.tar.gz` archive file exists and is 119 | valid then it's contents are extracted. 120 | 3. If not 2. then `subset.tar.gz` is downloaded, checksum verified, 121 | and it's contents extracted. The archive file is then removed 122 | leaving behind the `LibriSpeech/subset` directory. 123 | """ 124 | os.makedirs(self.root, exist_ok=True) 125 | 126 | for subset in self.subsets: 127 | if self._check_subset_integrity(subset): 128 | print('%r already downloaded and verified.' % subset) 129 | continue 130 | path = os.path.join(self.root, subset + '.tar.gz') 131 | 132 | already_present = os.path.isfile(path) 133 | if not already_present: 134 | subset_url = self.openslr_url + subset + '.tar.gz' 135 | with requests.get(subset_url, stream=True) as r: 136 | r.raise_for_status() 137 | with open(path, 'wb') as f: 138 | shutil.copyfileobj(r.raw, f) 139 | 140 | archive_md5 = self.data_files[subset]['archive_md5'] 141 | if utils.checksum_file(path, 'md5') != archive_md5: 142 | raise utils.DownloadError('Invalid checksum for %r' % path) 143 | 144 | with tarfile.open(path, mode='r|gz') as tar: 145 | tar.extractall(self.root) 146 | 147 | if not already_present: 148 | os.remove(path) 149 | 150 | def _check_subset_integrity(self, subset): 151 | path = os.path.join(self.root, self.base_dir, subset) 152 | try: 153 | actual_md5 = utils.checksum_dir(path, 'md5') 154 | except (FileNotFoundError, NotADirectoryError, PermissionError): 155 | return False 156 | return actual_md5 == self.data_files[subset]['dir_md5'] 157 | 158 | def check_integrity(self): 159 | """Returns True if each subset is valid.""" 160 | return all([self._check_subset_integrity(s) for s in self.subsets]) 161 | 162 | def load_data(self): 163 | """Loads the data from disk.""" 164 | self.paths = [] 165 | self.durations = [] 166 | self.transcriptions = [] 167 | 168 | def raise_(err): 169 | """raises error if problem during os.walk""" 170 | raise err 171 | 172 | for subset in self.subsets: 173 | subset_path = os.path.join(self.root, self.base_dir, subset) 174 | for root, dirs, files in os.walk(subset_path, onerror=raise_): 175 | if not files: 176 | continue 177 | matches = fnmatch.filter(files, '*.trans.txt') 178 | assert len(matches) == 1, '> 1 transcription file found' 179 | self._parse_transcription_file(root, matches[0]) 180 | 181 | self._sort_by_duration() 182 | 183 | def _parse_transcription_file(self, root, name): 184 | """Parses each sample in a transcription file.""" 185 | trans_path = os.path.join(root, name) 186 | with open(trans_path, 'r', encoding='utf-8') as trans: 187 | # Each line has the form "ID THE TARGET TRANSCRIPTION" 188 | for line in trans: 189 | id_, transcript = line.split(maxsplit=1) 190 | self._process_audio(root, id_) 191 | self._process_transcript(transcript) 192 | 193 | def _process_audio(self, root, id): 194 | path = os.path.join(root, id + '.flac') 195 | self.paths.append(path) 196 | duration = soundfile.info(path).duration 197 | self.durations.append(duration) 198 | 199 | def _process_transcript(self, transcript): 200 | transcript = transcript.strip().upper() 201 | self.transcriptions.append(transcript) 202 | 203 | def _sort_by_duration(self): 204 | """Orders the loaded data by audio duration, shortest first.""" 205 | total_samples = len(self.paths) 206 | samples = zip(self.paths, self.durations, self.transcriptions) 207 | sorted_samples = sorted(samples, key=lambda sample: sample[1]) 208 | sorted_samples = [list(c) for c in zip(*sorted_samples)] 209 | self.paths, self.durations, self.transcriptions = sorted_samples 210 | assert (total_samples == 211 | len(self.paths) == 212 | len(self.durations) == 213 | len(self.transcriptions)), '_sort_by_duration len mis-match' 214 | -------------------------------------------------------------------------------- /src/deepspeech/data/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | 5 | INTEGRITY_ALGS = {'md5': hashlib.md5} 6 | 7 | 8 | class DownloadError(Exception): 9 | """Exception raised for errors during downloading.""" 10 | pass 11 | 12 | 13 | def checksum_file(path, algorithm): 14 | """Returns the checksum of the file at path using the given algorithm. 15 | 16 | Args: 17 | path (string): Path of the file to compute the checksum for. 18 | algorithm (string): Hash algorithm to use when computing checksum. 19 | See INTEGRITY_ALGS.keys() for a list of supported algorithms. 20 | 21 | Raises: 22 | ValueError: The stated algorithm is not supported. 23 | FileNotFoundError: File does not exist at path. 24 | IsADirectoryError: Path refers to a directory. 25 | """ 26 | alg = _parse_algorithm(algorithm) 27 | accum = alg() 28 | with open(path, 'rb') as f: 29 | # read in 1MB chunks 30 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 31 | accum.update(chunk) 32 | return accum.hexdigest() 33 | 34 | 35 | def checksum_dir(root, algorithm): 36 | """Returns the checksum of the directory at root using the given algorithm. 37 | 38 | The checksum is computed in a deterministic way over the directory contents 39 | at root. Subdirectories are included. Symlinks are not followed. 40 | 41 | Args: 42 | root (string): Root of the directory to compute the checksum for. 43 | algorithm (string): Hash algorithm to use when computing checksum. 44 | See INTEGRITY_ALGS.keys() for a list of supported algorithms. 45 | 46 | Raises: 47 | FileNotFoundError: root does not exist. 48 | NotADirectoryError: root is not a directory. 49 | PermissionError: Inadequate filesystem permissions. 50 | ValueError: The stated algorithm is not supported. 51 | """ 52 | # Hashing algorithms eat `bytes`. `os.walk` returns `bytes` objects when 53 | # called with a `bytes` object. 54 | try: 55 | b_root = root.encode('utf-8') 56 | except AttributeError: 57 | pass 58 | alg = _parse_algorithm(algorithm) 59 | 60 | def raise_(err): 61 | raise err 62 | 63 | accum = alg() 64 | for root, dirs, files in os.walk(b_root, onerror=raise_): 65 | dirs.sort() # Ensures os.walk visits dirs in a deterministic order. 66 | for file in sorted(files): 67 | file_path = os.path.join(root, file) 68 | accum.update(file) 69 | accum.update(bytes.fromhex(checksum_file(file_path, algorithm))) 70 | for dir_ in dirs: 71 | accum.update(dir_) 72 | 73 | return accum.hexdigest() 74 | 75 | 76 | def _parse_algorithm(algorithm): 77 | """Returns a constructor for the given algorithm if supported.""" 78 | try: 79 | alg = INTEGRITY_ALGS[algorithm] 80 | except KeyError: 81 | raise ValueError('Algorithm %r is not supported.' % algorithm) 82 | return alg 83 | -------------------------------------------------------------------------------- /src/deepspeech/data/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | def collate_input_sequences(samples): 6 | """Returns a batch of data given a list of samples. 7 | 8 | Args: 9 | samples: List of (x, y) where: 10 | 11 | `x`: A tuple: 12 | - `torch.Tensor`: an input sequence to the network with size 13 | `(len(torch.Tensor), n_features)`. 14 | - `int`: the length of the corresponding output sequence 15 | produced by the network given the `torch.Tensor` as 16 | input. 17 | `y`: A `torch.Tensor` containing the target output sequence. 18 | 19 | Returns: 20 | A tuple of `((batch_x, batch_out_lens), batch_y)` where: 21 | 22 | batch_x: The concatenation of all `torch.Tensor`'s in `x` along a 23 | new dim in descending order by `torch.Tensor` length. 24 | 25 | This results in a `torch.Tensor` of size (L, N, D) where L is 26 | the maximum `torch.Tensor` length, N is the number of samples, 27 | and D is n_features. 28 | 29 | `torch.Tensor`'s shorter than L are extended by zero padding. 30 | 31 | batch_out_lens: A `torch.IntTensor` containing the `int` values 32 | from `x` in an order that corresponds to the samples in 33 | `batch_x`. 34 | 35 | batch_y: A list of `torch.Tensor` containing the `y` `torch.Tensor` 36 | sequences in an order that corresponds to the samples in 37 | `batch_x`. 38 | 39 | Example: 40 | >>> x = [# input seq, len 5, 2 features. output seq, len 2 41 | ... (torch.full((5, 2), 1.0), 2), 42 | ... # input seq, len 4, 2 features. output seq, len 3 43 | ... (torch.full((4, 2), 2.0), 3)] 44 | >>> y = [torch.full((4,), 1.0), # target seq, len 4 45 | ... torch.full((3,), 2.0)] # target seq, len 3 46 | >>> smps = list(zip(x, y)) 47 | >>> (batch_x, batch_out_lens), batch_y = collate_input_sequences(smps) 48 | >>> print('%r' % batch_x) 49 | tensor([[[ 1., 1.], 50 | [ 2., 2.]], 51 | 52 | [[ 1., 1.], 53 | [ 2., 2.]], 54 | 55 | [[ 1., 1.], 56 | [ 2., 2.]], 57 | 58 | [[ 1., 1.], 59 | [ 2., 2.]], 60 | 61 | [[ 1., 1.], 62 | [ 0., 0.]]]) 63 | >>> print('%r' % batch_out_lens) 64 | tensor([ 2, 3], dtype=torch.int32) 65 | >>> print('%r' % batch_y) 66 | [tensor([ 1., 1., 1., 1.]), tensor([ 2., 2., 2.])] 67 | """ 68 | 69 | samples = [(*x, y) for x, y in samples] 70 | sorted_samples = sorted(samples, key=lambda s: len(s[0]), reverse=True) 71 | 72 | seqs, seq_lens, labels = zip(*sorted_samples) 73 | 74 | x = (pad_sequence(seqs), torch.IntTensor(seq_lens)) 75 | y = list(labels) 76 | 77 | return x, y 78 | -------------------------------------------------------------------------------- /src/deepspeech/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | # librosa causes a RuntimeWarning - apparently safe to ignore: 3 | # stackoverflow.com/questions/40845304/runtimewarning-numpy-dtype-size-changed-may-indicate-binary-incompatibility 4 | warnings.filterwarnings("ignore", message="numpy.dtype size changed") 5 | 6 | import librosa # noqa: E402 7 | import numpy as np # noqa: E402 8 | import python_speech_features # noqa: E402 9 | 10 | 11 | class MFCC: 12 | """Compute the Mel-frequency cepstral coefficients (MFCC) of audiodata.""" 13 | 14 | def __init__(self, numcep, winlen=0.025, winstep=0.02, sample_rate=16000): 15 | self.numcep = numcep 16 | self.winlen = winlen 17 | self.winstep = winstep 18 | self.sample_rate = sample_rate 19 | 20 | def __call__(self, audiodata): 21 | return python_speech_features.mfcc(audiodata, 22 | samplerate=self.sample_rate, 23 | winlen=self.winlen, 24 | winstep=self.winstep, 25 | numcep=self.numcep) 26 | 27 | def __repr__(self): 28 | params = '(numcep={0}, winlen={1}, winstep={2}, sample_rate={3})' 29 | params = params.format(self.numcep, 30 | self.winlen, 31 | self.winstep, 32 | self.sample_rate) 33 | return self.__class__.__name__ + params 34 | 35 | 36 | class LogMagnitudeSTFT: 37 | """Compute the log of the magnitude of the STFT of audiodata.""" 38 | 39 | def __init__(self, winlen=0.02, winstep=0.01, sample_rate=16000): 40 | self.winlen = winlen 41 | self.winstep = winstep 42 | self.sample_rate = sample_rate 43 | 44 | self._n_fft = int(winlen * sample_rate) 45 | self._hop_length = int(winstep * sample_rate) 46 | 47 | def __call__(self, audiodata): 48 | """ 49 | Args: 50 | audiodata: numpy ndarray of audio data with shape (N,). 51 | 52 | Returns: 53 | numpy ndarray X with shape (N, M). For each step in (0...N-1), the 54 | array X[step, :] contains the M log(1 + magnitude) components where 55 | M is equal to (int(winlen * sample_rate) / 2 + 1). 56 | """ 57 | D = librosa.stft(audiodata, 58 | n_fft=self._n_fft, 59 | hop_length=self._hop_length, 60 | win_length=self._n_fft, 61 | window='hamming') 62 | mag, _ = librosa.magphase(D) 63 | return np.log1p(mag).T 64 | 65 | def __repr__(self): 66 | params = '(winlen={0}, winstep={1}, sample_rate={2})' 67 | params = params.format(self.winlen, 68 | self.winstep, 69 | self.sample_rate) 70 | return self.__class__.__name__ + params 71 | 72 | 73 | class AddContextFrames: 74 | """Add context frames to each step in the original signal. 75 | 76 | Args: 77 | n_context: Number of context frames to add to frame in the original 78 | signal. 79 | 80 | Example: 81 | >>> signal = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 82 | >>> n_context = 2 83 | >>> print(add_context_frames(signal, n_context)) 84 | [[0 0 0 0 0 0 1 2 3 4 5 6 7 8 9] 85 | [0 0 0 1 2 3 4 5 6 7 8 9 0 0 0] 86 | [1 2 3 4 5 6 7 8 9 0 0 0 0 0 0]] 87 | """ 88 | 89 | def __init__(self, n_context): 90 | self.n_context = n_context 91 | 92 | def __call__(self, signal): 93 | """ 94 | Args: 95 | signal: numpy ndarray with shape (steps, features). 96 | 97 | Returns: 98 | numpy ndarray with shape: 99 | (steps, features * (n_context + 1 + n_context)) 100 | """ 101 | # Pad to ensure first and last n_context frames in original signal have 102 | # at least n_context frames to their left and right respectively. 103 | steps, features = signal.shape 104 | padding = np.zeros((self.n_context, features), dtype=signal.dtype) 105 | signal = np.concatenate((padding, signal, padding)) 106 | 107 | window_size = self.n_context + 1 + self.n_context 108 | strided_signal = np.lib.stride_tricks.as_strided( 109 | signal, 110 | # Shape of the new array. 111 | (steps, window_size, features), 112 | # Strides of the new array (bytes to step in each dim). 113 | (signal.strides[0], signal.strides[0], signal.strides[1]), 114 | # Disable write to prevent accidental errors as elems share memory. 115 | writeable=False) 116 | 117 | # Flatten last dim and return a copy to permit writes. 118 | return strided_signal.reshape(steps, -1).copy() 119 | 120 | def __repr__(self): 121 | return self.__class__.__name__ + ('(n_context=%r)' % self.n_context) 122 | 123 | 124 | class Normalize: 125 | """Normalize a tensor to have zero mean and one standard deviation.""" 126 | 127 | def __call__(self, tensor): 128 | """ 129 | Args: 130 | tensor: numpy ndarray 131 | """ 132 | return (tensor - tensor.mean()) / tensor.std() 133 | 134 | def __repr__(self): 135 | return self.__class__.__name__ + '()' 136 | -------------------------------------------------------------------------------- /src/deepspeech/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from deepspeech.decoder.beam import BeamCTCDecoder # noqa: F401 2 | from deepspeech.decoder.greedy import GreedyCTCDecoder # noqa: F401 3 | -------------------------------------------------------------------------------- /src/deepspeech/decoder/base.py: -------------------------------------------------------------------------------- 1 | from deepspeech.logging import LoggerMixin 2 | 3 | 4 | class Decoder(LoggerMixin): 5 | """Decoder base class. 6 | 7 | Args: 8 | alphabet: An Alphabet object. 9 | blank_symbol: The symbol in `alphabet` to use as the blank during CTC 10 | decoding. 11 | """ 12 | 13 | def __init__(self, alphabet, blank_symbol): 14 | self._alphabet = alphabet 15 | self._blank_symbol = blank_symbol 16 | 17 | def decode(self, logits, logit_lens): 18 | """Returns a list of sentences given the logits for a batch. 19 | 20 | Args: 21 | logits: tensor of size (seq_len, batch, out_features). 22 | logit_lens: list of int representing the length of each sequence in 23 | logits. 24 | 25 | Returns: 26 | list containing batch number of sentences (strings). 27 | """ 28 | raise NotImplementedError 29 | -------------------------------------------------------------------------------- /src/deepspeech/decoder/beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ctcdecode import CTCBeamDecoder 4 | 5 | from deepspeech.decoder.base import Decoder 6 | 7 | 8 | class BeamCTCDecoder(Decoder): 9 | """A beam search decoder with an optional language model. 10 | 11 | Args: 12 | alphabet: See `Decoder`. 13 | blank_symbol: See `Decoder`. 14 | model_path: Path to KenLM LM. If None, LM score is not included. 15 | alpha: Language model weighting. 16 | beta: Word bonus weighting. 17 | cutoff_prob: Affects the list of symbols to consider when extending a 18 | prefix at each step. The symbols are sorted in descending order by 19 | probability mass. The first N symbols are considered such that 20 | their total probability mass is less than this value. N is also 21 | bounded by `cutoff_top_n`. 22 | cutoff_top_n: The top `cutoff_top_n` symbols with highest probability 23 | will be considered at each step. Note: `cutoff_prob` must be less 24 | then 1.0 for this to be considered else all symbols will be used. 25 | beam_width: Width of the beam search. 26 | num_processes: Number of threads for the beam search. 27 | """ 28 | 29 | def __init__(self, alphabet, blank_symbol, model_path=None, alpha=1.0, 30 | beta=1.0, cutoff_prob=1.0, cutoff_top_n=None, beam_width=128, 31 | num_processes=4): 32 | super().__init__(alphabet, blank_symbol) 33 | 34 | cutoff_top_n = cutoff_top_n or len(alphabet) 35 | blank_id = alphabet.get_index(blank_symbol) 36 | 37 | if model_path is None: 38 | self._logger.warning('language model will not be used as ' 39 | '`model_path` is None') 40 | if model_path is not None and alpha == 0.0: 41 | self._logger.warning("language model will not be used as it's " 42 | "weighting `alpha` is zero") 43 | 44 | self._decoder = CTCBeamDecoder(labels=alphabet, 45 | model_path=model_path, 46 | alpha=alpha, 47 | beta=beta, 48 | cutoff_top_n=cutoff_top_n, 49 | cutoff_prob=cutoff_prob, 50 | beam_width=beam_width, 51 | num_processes=num_processes, 52 | blank_id=blank_id) 53 | 54 | def decode(self, logits, logit_lens): 55 | """Returns a list of sentences given the logits for a batch. 56 | 57 | Args: 58 | logits: tensor of size (seq_len, batch, out_features). 59 | logit_lens: list of int representing the length of each sequence in 60 | logits. 61 | 62 | Returns: 63 | list containing batch number of sentences (strings). 64 | """ 65 | if logits.dtype != torch.float: 66 | self._logger.debug('casting logits to single-precision') 67 | logits = logits.float() 68 | 69 | logit_lens = torch.IntTensor(logit_lens).cpu() 70 | 71 | probs = F.softmax(logits, dim=2) 72 | probs.transpose_(0, 1) # decoder "expect[s] batch x seq x label_size" 73 | 74 | output, _, _, out_seq_len = self._decoder.decode(probs, logit_lens) 75 | 76 | batch_sentences = [] 77 | for b, batch in enumerate(output): 78 | size = out_seq_len[b][0] 79 | indices = batch[0][:size] 80 | sentence = ''.join(self._alphabet.get_symbols(indices.tolist())) 81 | batch_sentences.append(sentence) 82 | 83 | return batch_sentences 84 | -------------------------------------------------------------------------------- /src/deepspeech/decoder/greedy.py: -------------------------------------------------------------------------------- 1 | from deepspeech.decoder.base import Decoder 2 | 3 | 4 | class GreedyCTCDecoder(Decoder): 5 | """Selects the symbol with highest logit value at each step. 6 | 7 | Args: 8 | alphabet: See `Decoder`. 9 | blank_symbol: See `Decoder`. 10 | """ 11 | 12 | def __init__(self, alphabet, blank_symbol): 13 | super().__init__(alphabet, blank_symbol) 14 | 15 | def decode(self, logits, logit_lens): 16 | _, max_indices = logits.float().max(2) 17 | 18 | batch_sentences = [] 19 | 20 | for i, indices in enumerate(max_indices.t()): 21 | # Ignore predictions past input sequence length. 22 | indices = indices[:logit_lens[i]] 23 | 24 | no_dups, prev = [], None 25 | for index in indices: 26 | if prev is None or index != prev: 27 | no_dups.append(index.item()) 28 | prev = index 29 | 30 | symbols = self._alphabet.get_symbols(no_dups) 31 | 32 | no_blanks = [s for s in symbols if s != self._blank_symbol] 33 | 34 | batch_sentences.append(''.join(no_blanks)) 35 | 36 | return batch_sentences 37 | -------------------------------------------------------------------------------- /src/deepspeech/global_state.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from datetime import datetime 4 | 5 | import tensorboardX 6 | 7 | from deepspeech.utils.singleton import Singleton 8 | 9 | 10 | class GlobalState(metaclass=Singleton, check_args=True): 11 | """Contains all the global state for a single experiment. 12 | 13 | Warning: Global state is an anti-pattern - think before adding attributes! 14 | 15 | Args: 16 | exp_dir (string, optional): path to directory where all experimental 17 | data will be stored. If `None` it defaults to the concatenation of 18 | `tempfile.gettempdir()`, 'experiments', plus the current date and 19 | time. 20 | log_frequency (int): Controls how frequent the `log_step` method will 21 | be `True`. 22 | 23 | Attributes: 24 | exp_dir: Path to global log directory. 25 | writer: A `tensorboardX.SummaryWriter` instance set to write it's files 26 | to `exp_dir`. 27 | step (int): An integer used to denote the current step during a 28 | training session. 29 | """ 30 | def __init__(self, exp_dir=None, log_frequency=1): 31 | if exp_dir is None: 32 | parent = os.path.join(tempfile.gettempdir(), 'experiments') 33 | os.makedirs(parent, exist_ok=True) 34 | exp_time = datetime.now().strftime('%Y_%m_%d-%H_%M_%S_%f') 35 | exp_dir = os.path.join(parent, exp_time) 36 | 37 | os.makedirs(exp_dir, exist_ok=True) # ensure it exists! 38 | 39 | self.exp_dir = exp_dir 40 | 41 | self._init_writer() 42 | 43 | self.log_frequency = log_frequency 44 | 45 | self.step = 0 46 | 47 | def _init_writer(self): 48 | self.writer = tensorboardX.SummaryWriter(log_dir=self.exp_dir) 49 | 50 | def log_step(self): 51 | """Returns `True` each time `self.step % log_frequency == 0`""" 52 | return self.step % self.log_frequency == 0 53 | 54 | def state_dict(self): 55 | state = {'log_frequency': self.log_frequency, 56 | 'step': self.step} 57 | return state 58 | 59 | def load_state_dict(self, state_dict): 60 | for name, value in state_dict.items(): 61 | setattr(self, name, value) 62 | self._init_writer() 63 | -------------------------------------------------------------------------------- /src/deepspeech/logging/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from deepspeech.logging.log_level_action import LogLevelAction 3 | from deepspeech.logging.mixin import log_call 4 | from deepspeech.logging.mixin import log_call_debug 5 | from deepspeech.logging.mixin import log_call_info 6 | from deepspeech.logging.mixin import log_call_warning 7 | from deepspeech.logging.mixin import log_call_error 8 | from deepspeech.logging.mixin import log_call_critical 9 | from deepspeech.logging.mixin import LoggerMixin 10 | -------------------------------------------------------------------------------- /src/deepspeech/logging/log_level_action.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | 5 | class LogLevelAction(argparse.Action): 6 | """An `argparse.Action` for levels in the `logging` module. 7 | 8 | Example usage: 9 | 10 | >>> import argparse 11 | >>> parser = argparse.ArgumentParser() 12 | >>> parser.add_argument('--loglevel', action=LogLevelAction) 13 | 14 | """ 15 | 16 | LEVELS = {'DEBUG': logging.DEBUG, 17 | 'INFO': logging.INFO, 18 | 'WARNING': logging.WARNING, 19 | 'ERROR': logging.ERROR, 20 | 'CRITICAL': logging.CRITICAL} 21 | 22 | def __init__(self, option_strings, dest, nargs=None, const=None, 23 | type=str, choices=LEVELS.keys(), default='DEBUG', **kwargs): 24 | if nargs is not None: 25 | raise ValueError('nargs must be None') 26 | if const is not None: 27 | raise ValueError('const must be None') 28 | if type is not str: 29 | raise ValueError('type must be str') 30 | if any([choice not in self.LEVELS for choice in choices]): 31 | raise ValueError('choices=%r must be a subset of %r' % ( 32 | choices, list(self.LEVELS.keys()))) 33 | if default not in choices: 34 | raise ValueError('default=%r must be in %r' % (default, 35 | list(choices))) 36 | 37 | super().__init__(option_strings, dest, nargs=nargs, const=const, 38 | type=type, choices=choices, default=default, **kwargs) 39 | 40 | def __call__(self, parser, namespace, value, option_string=None): 41 | if value not in self.choices: 42 | name = self.metavar 43 | if name is None: 44 | name = self.dest.upper() 45 | raise ValueError('%s must be in %s' % (name, self.choices)) 46 | setattr(namespace, self.dest, self._str_to_int(value)) 47 | 48 | @classmethod 49 | def _str_to_int(cls, str_level): 50 | return cls.LEVELS[str_level] 51 | -------------------------------------------------------------------------------- /src/deepspeech/logging/mixin.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class LoggerMixin: 5 | """Adds a logging.Logger to a class. 6 | 7 | Attributes: 8 | _logger: A logging.Logger with name equal to the concatenation of the 9 | class module and class name. 10 | """ 11 | @property 12 | def _logger(self): 13 | name = '.'.join([self.__module__, self.__class__.__name__]) 14 | return logging.getLogger(name) 15 | 16 | 17 | class log_call: 18 | """A decorator that logs a method call plus it's args and kwargs. 19 | 20 | The class that the method belongs to must include LoggerMixin as a base. 21 | 22 | Args: 23 | level (int): The logging level (e.g. logging.DEBUG). 24 | """ 25 | def __init__(self, level): 26 | self._level = level 27 | 28 | def __call__(self, f): 29 | def wrapper(f_self, *args, **kwargs): 30 | msg = 'entering %r, args: %r, kwargs: %r' 31 | f_self._logger.log(self._level, msg, f.__name__, args, kwargs) 32 | 33 | result = f(f_self, *args, **kwargs) 34 | 35 | f_self._logger.log(self._level, 'exited %r', f.__name__) 36 | 37 | return result 38 | return wrapper 39 | 40 | 41 | def log_call_debug(f): 42 | return log_call(logging.DEBUG)(f) 43 | 44 | 45 | def log_call_info(f): 46 | return log_call(logging.INFO)(f) 47 | 48 | 49 | def log_call_warning(f): 50 | return log_call(logging.WARNING)(f) 51 | 52 | 53 | def log_call_error(f): 54 | return log_call(logging.ERROR)(f) 55 | 56 | 57 | def log_call_critical(f): 58 | return log_call(logging.CRITICAL)(f) 59 | -------------------------------------------------------------------------------- /src/deepspeech/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from deepspeech.loss.ctc_loss import CTCLoss # noqa: F401 2 | from deepspeech.loss.eval import levenshtein # noqa: F401 3 | -------------------------------------------------------------------------------- /src/deepspeech/loss/ctc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import warpctc 4 | 5 | from deepspeech.logging import LoggerMixin 6 | 7 | 8 | class CTCLoss(torch.nn.Module, LoggerMixin): 9 | """Connectionist Temporal Classification (CTC) Loss. 10 | 11 | This computes the forward and backward pass in single-precision using a 12 | single call to an optimised kernel. 13 | 14 | Args: 15 | blank_index (int): The index of the blank symbol in logits. 16 | size_average (bool, optional): Normalise the loss by the batch size. 17 | length_average (bool, optional): Normalise the loss by the total number 18 | of input steps (i.e. `sum(logit_lens)`). 19 | """ 20 | 21 | def __init__(self, blank_index, size_average=False, length_average=False): 22 | super().__init__() 23 | self._ctc_loss = warpctc.CTCLoss(reduce=True, 24 | size_average=size_average, 25 | length_average=length_average, 26 | blank_label=blank_index) 27 | 28 | def forward(self, logits, labels, logit_lens): 29 | """Returns the CTC loss. 30 | 31 | Args: 32 | logits: A torch.Tensor of logits with shape 33 | `(seq_len, batch, alphabet_len)`. These will be cast to 34 | `float32`. 35 | labels: A list of torch.IntTensor containing the target labels for 36 | each sample in the batch. 37 | logit_lens: A torch.IntTensor containing the length of the input 38 | sequences. 39 | """ 40 | if logits.dtype != torch.float: 41 | self._logger.debug('casting logits to single-precision') 42 | logits = logits.float() 43 | 44 | if not logits.is_cuda: 45 | self._logger.debug('using cpu ctc loss') 46 | 47 | label_lens = torch.IntTensor([len(label) for label in labels]) 48 | labels = torch.cat(labels) 49 | 50 | loss = self._ctc_loss(logits, logit_lens, labels, label_lens) 51 | 52 | return loss 53 | -------------------------------------------------------------------------------- /src/deepspeech/loss/eval.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This is a straightforward implementation of a well-known algorithm, and thus 4 | # probably shouldn't be covered by copyright to begin with. But in case it is, 5 | # the author (Magnus Lie Hetland) has, to the extent possible under law, 6 | # dedicated all copyright and related and neighboring rights to this software 7 | # to the public domain worldwide, by distributing it under the CC0 license, 8 | # version 1.0. This software is distributed without any warranty. For more 9 | # information, see 10 | def levenshtein(a, b): 11 | "Calculates the Levenshtein distance between a and b." 12 | n, m = len(a), len(b) 13 | if n > m: 14 | # Make sure n <= m, to use O(min(n,m)) space 15 | a, b = b, a 16 | n, m = m, n 17 | 18 | current = range(n + 1) 19 | for i in range(1, m + 1): 20 | previous, current = current, [i] + [0]*n 21 | for j in range(1, n + 1): 22 | add, delete = previous[j] + 1, current[j - 1] + 1 23 | change = previous[j - 1] 24 | if a[j - 1] != b[i - 1]: 25 | change = change + 1 26 | current[j] = min(add, delete, change) 27 | 28 | return current[n] 29 | -------------------------------------------------------------------------------- /src/deepspeech/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from deepspeech.models.deepspeech2 import DeepSpeech2 3 | from deepspeech.models.deepspeech import DeepSpeech 4 | from deepspeech.models.model import Model 5 | -------------------------------------------------------------------------------- /src/deepspeech/models/deepspeech.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import Compose 3 | 4 | from deepspeech.data import preprocess 5 | from deepspeech.models.model import Model 6 | from deepspeech.networks.deepspeech import Network 7 | 8 | 9 | class DeepSpeech(Model): 10 | """Deep Speech Model. 11 | 12 | Args: 13 | optimiser_cls: See `Model`. 14 | optimiser_kwargs: See `Model`. 15 | decoder_cls: See `Model`. 16 | decoder_kwargs: See `Model`. 17 | n_hidden (int): Internal hidden unit size. 18 | n_context (int): Number of context frames to use on each side of the 19 | current input frame. 20 | n_mfcc (int): Number of Mel-Frequency Cepstral Coefficients to use as 21 | input for a single frame. 22 | drop_prob (float): Dropout drop probability, [0.0, 1.0] inclusive. 23 | winlen (float): Window length in ms to compute input features over. 24 | winstep (float): Window step size in ms. 25 | sample_rate (int): Sample rate in Hz of input data. 26 | 27 | Attributes: 28 | See base class. 29 | """ 30 | 31 | def __init__(self, optimiser_cls=None, optimiser_kwargs=None, 32 | decoder_cls=None, decoder_kwargs=None, 33 | n_hidden=2048, n_context=9, n_mfcc=26, drop_prob=0.25, 34 | winlen=0.025, winstep=0.02, sample_rate=16000): 35 | 36 | self._n_hidden = n_hidden 37 | self._n_context = n_context 38 | self._n_mfcc = n_mfcc 39 | self._drop_prob = drop_prob 40 | self._winlen = winlen 41 | self._winstep = winstep 42 | self._sample_rate = sample_rate 43 | 44 | network = self._get_network() 45 | 46 | super().__init__(network=network, 47 | optimiser_cls=optimiser_cls, 48 | optimiser_kwargs=optimiser_kwargs, 49 | decoder_cls=decoder_cls, 50 | decoder_kwargs=decoder_kwargs, 51 | clip_gradients=None) 52 | 53 | def _get_network(self): 54 | return Network(in_features=self._n_mfcc*(2*self._n_context + 1), 55 | n_hidden=self._n_hidden, 56 | out_features=len(self.ALPHABET), 57 | drop_prob=self._drop_prob) 58 | 59 | @property 60 | def transform(self): 61 | return Compose([preprocess.MFCC(self._n_mfcc), 62 | preprocess.AddContextFrames(self._n_context), 63 | preprocess.Normalize(), 64 | torch.FloatTensor, 65 | lambda t: (t, len(t))]) 66 | 67 | @property 68 | def target_transform(self): 69 | return Compose([str.lower, 70 | self.ALPHABET.get_indices, 71 | torch.IntTensor]) 72 | -------------------------------------------------------------------------------- /src/deepspeech/models/deepspeech2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.transforms import Compose 4 | 5 | from deepspeech.data import preprocess 6 | from deepspeech.models.model import Model 7 | from deepspeech.networks.deepspeech2 import Network 8 | 9 | 10 | class DeepSpeech2(Model): 11 | """Deep Speech 2 Model. 12 | 13 | Args: 14 | optimiser_cls: See `Model`. 15 | optimiser_kwargs: See `Model`. 16 | decoder_cls: See `Model`. 17 | decoder_kwargs: See `Model`. 18 | n_hidden (int): Internal hidden unit size. 19 | rnn_layers (int): Number of recurrent layers to stack. 20 | winlen (float): Window length in ms to compute input features over. 21 | winstep (float): Window step size in ms. 22 | sample_rate (int): Sample rate in Hz of input data. 23 | clip_gradients: See `Model`. 24 | 25 | Attributes: 26 | See base class. 27 | """ 28 | 29 | def __init__(self, optimiser_cls=None, optimiser_kwargs=None, 30 | decoder_cls=None, decoder_kwargs=None, n_hidden=800, 31 | rnn_layers=5, winlen=0.02, winstep=0.01, sample_rate=16000, 32 | clip_gradients=400): 33 | 34 | self._n_hidden = n_hidden 35 | self._rnn_layers = rnn_layers 36 | self._winlen = winlen 37 | self._winstep = winstep 38 | self._sample_rate = sample_rate 39 | 40 | network = self._get_network() 41 | 42 | super().__init__(network=network, 43 | optimiser_cls=optimiser_cls, 44 | optimiser_kwargs=optimiser_kwargs, 45 | decoder_cls=decoder_cls, 46 | decoder_kwargs=decoder_kwargs, 47 | clip_gradients=clip_gradients) 48 | 49 | def _get_network(self): 50 | return Network(in_features=int((self._sample_rate*self._winlen)//2+1), 51 | n_hidden=self._n_hidden, 52 | out_features=len(self.ALPHABET), 53 | rnn_type='gru', 54 | rnn_layers=self._rnn_layers, 55 | bidirectional=True) 56 | 57 | @property 58 | def transform(self): 59 | return Compose([lambda t: t.astype(np.float32), 60 | preprocess.LogMagnitudeSTFT( 61 | winlen=self._winlen, 62 | winstep=self._winstep, 63 | sample_rate=self._sample_rate), 64 | preprocess.Normalize(), 65 | torch.from_numpy, 66 | lambda t: (t, Network.output_len(len(t))), 67 | ]) 68 | 69 | @property 70 | def target_transform(self): 71 | return Compose([str.lower, 72 | self.ALPHABET.get_indices, 73 | torch.IntTensor]) 74 | -------------------------------------------------------------------------------- /src/deepspeech/models/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import psutil 4 | import time 5 | 6 | import torch 7 | from torch.nn.utils import clip_grad_norm_ 8 | 9 | from deepspeech.data.alphabet import Alphabet 10 | from deepspeech.decoder import GreedyCTCDecoder 11 | from deepspeech.global_state import GlobalState 12 | from deepspeech.logging import log_call_info 13 | from deepspeech.logging import LoggerMixin 14 | from deepspeech.loss import CTCLoss 15 | from deepspeech.loss import levenshtein 16 | from deepspeech.networks.utils import to_cuda 17 | 18 | 19 | _BLANK_SYMBOL = '_' 20 | 21 | 22 | def _gen_alphabet(): 23 | symbols = list(" abcdefghijklmnopqrstuvwxyz'") 24 | symbols.append(_BLANK_SYMBOL) 25 | return Alphabet(symbols) 26 | 27 | 28 | class Model(LoggerMixin): 29 | """A speech-to-text model. 30 | 31 | Args: 32 | network: A speech-to-text `torch.nn.Module`. 33 | optimiser_cls (callable, optional): If not None, this optimiser will be 34 | instantiated with an OrderedDict of the network parameters as the 35 | first argument and **optimiser_kwargs as the remaining arguments 36 | unless they are None. 37 | optimiser_kwargs (dict, optional): A dictionary of arguments to pass to 38 | the optimiser when it is created. Defaults to the empty dictionary 39 | if None. 40 | decoder_cls (callable, optional): A callable that implements the 41 | `deepspeech.decoder.Decoder` interface. Defaults to 42 | `DEFAULT_DECODER_CLS` if None. 43 | decoder_kwargs (dict): A dictionary of arguments to pass to the decoder 44 | when it is created. Defaults to `DEFAULT_DECODER_KWARGS` if 45 | `decoder_kwargs` is None. 46 | clip_gradients (int, optional): If None no gradient clipping is 47 | performed. If an int, it is used as the `max_norm` parameter to 48 | `torch.nn.utils.clip_grad_norm`. 49 | 50 | Attributes: 51 | BLANK_SYMBOL: The string that denotes the blank symbol in the CTC 52 | algorithm. 53 | ALPHABET: A `deepspeech.data.alphabet.Alphabet` - contains 54 | `BLANK_SYMBOL`. 55 | DEFAULT_DECODER_CLS: See Args. 56 | DEFAULT_DECODER_KWARGS: See Args. 57 | completed_epochs: Number of epochs completed during training. 58 | network: See Args. 59 | decoder: A `deepspeech.decoder.BeamCTCDecoder` instance. 60 | optimiser: An `optimiser_cls` instance or None. 61 | loss: A `deepspeech.loss.CTCLoss` instance. 62 | transform: A function that returns a transformed piece of audio data. 63 | target_transform: A function that returns a transformed target. 64 | """ 65 | BLANK_SYMBOL = _BLANK_SYMBOL 66 | ALPHABET = _gen_alphabet() 67 | 68 | DEFAULT_DECODER_CLS = GreedyCTCDecoder 69 | DEFAULT_DECODER_KWARGS = {'alphabet': ALPHABET, 70 | 'blank_symbol': BLANK_SYMBOL} 71 | 72 | def __init__(self, network, optimiser_cls=None, optimiser_kwargs=None, 73 | decoder_cls=None, decoder_kwargs=None, clip_gradients=None): 74 | self.completed_epochs = 0 75 | 76 | self._optimiser_cls = optimiser_cls 77 | self._optimiser_kwargs = optimiser_kwargs 78 | self._clip_gradients = clip_gradients 79 | 80 | self._init_network(network) 81 | self._init_decoder(decoder_cls, decoder_kwargs) 82 | self._init_optimiser() 83 | self._init_loss() 84 | 85 | self._global_state = GlobalState.get_or_init_singleton() 86 | 87 | def _init_network(self, network): 88 | if not torch.cuda.is_available(): 89 | self._logger.info('CUDA not available') 90 | else: 91 | self._logger.info('CUDA available, moving network ' 92 | 'parameters and buffers to the GPU') 93 | to_cuda(network) 94 | 95 | self.network = network 96 | 97 | def _init_decoder(self, decoder_cls, decoder_kwargs): 98 | if decoder_cls is None: 99 | decoder_cls = self.DEFAULT_DECODER_CLS 100 | 101 | if decoder_kwargs is None: 102 | decoder_kwargs = copy.copy(self.DEFAULT_DECODER_KWARGS) 103 | 104 | self.decoder = decoder_cls(**decoder_kwargs) 105 | 106 | def _init_optimiser(self): 107 | self.reset_optimiser() 108 | 109 | def reset_optimiser(self): 110 | """Assigns a new `self.optimiser` using the current network params.""" 111 | if self._optimiser_cls is None: 112 | self.optimiser = None 113 | self._logger.debug('No optimiser specified') 114 | return 115 | 116 | kwargs = self._optimiser_kwargs or {} 117 | opt = self._optimiser_cls(self.network.parameters(), **kwargs) 118 | 119 | self.optimiser = opt 120 | 121 | def _init_loss(self): 122 | blank_index = self.ALPHABET.get_index(self.BLANK_SYMBOL) 123 | self.loss = CTCLoss(blank_index=blank_index, 124 | size_average=False, 125 | length_average=False) 126 | 127 | @property 128 | def transform(self): 129 | raise NotImplementedError 130 | 131 | @property 132 | def target_transform(self): 133 | raise NotImplementedError 134 | 135 | def state_dict(self): 136 | state = {'completed_epochs': self.completed_epochs, 137 | 'network': self.network.state_dict(), 138 | 'global_state': self._global_state.state_dict()} 139 | if self.optimiser is not None: 140 | state['optimiser'] = self.optimiser.state_dict() 141 | return state 142 | 143 | def load_state_dict(self, state_dict): 144 | self.completed_epochs = state_dict['completed_epochs'] 145 | self.network.load_state_dict(state_dict['network']) 146 | self._global_state.load_state_dict(state_dict['global_state']) 147 | if self.optimiser is not None: 148 | self.optimiser.load_state_dict(state_dict['optimiser']) 149 | 150 | @property 151 | def _zero_grad(self): 152 | return lambda: self.network.zero_grad() 153 | 154 | @property 155 | def _backward(self): 156 | return lambda batch_loss: batch_loss.backward() 157 | 158 | @property 159 | def _maybe_clip_gradients(self): 160 | if self._clip_gradients is None: 161 | return lambda: None 162 | 163 | return lambda: clip_grad_norm_(self.network.parameters(), 164 | self._clip_gradients) 165 | 166 | @log_call_info 167 | def train(self, loader): 168 | """Trains the Model for an epoch. 169 | 170 | Args: 171 | loader: A `torch.utils.data.DataLoader` that generates batches of 172 | training data. 173 | """ 174 | if self.optimiser is None: 175 | raise AttributeError('Cannot train when optimiser is None!') 176 | 177 | self.network.train() 178 | self._train_log_init() 179 | epoch_loss = 0.0 180 | total_samples = 0 181 | 182 | data_iter = iter(loader) # Explicit creation to log queue sizes. 183 | for step, ((x, logit_lens), y) in enumerate(data_iter): 184 | self._zero_grad() 185 | 186 | logits = self.network(x) 187 | 188 | batch_loss = self.loss(logits, y, logit_lens) 189 | 190 | epoch_loss += batch_loss.item() 191 | 192 | total_samples += len(logit_lens) 193 | 194 | self._backward(batch_loss) 195 | 196 | self._maybe_clip_gradients() 197 | 198 | self.optimiser.step() 199 | 200 | self._train_log_step(step, x, logits, logit_lens, batch_loss.item(), data_iter) # noqa: E501 201 | 202 | self._global_state.step += 1 203 | 204 | del logits, x, logit_lens, y 205 | 206 | self._train_log_end(epoch_loss, total_samples) 207 | self.completed_epochs += 1 208 | 209 | @log_call_info 210 | def eval_wer(self, loader): 211 | """Evaluates the WER of the Model. 212 | 213 | Args: 214 | loader: A `torch.utils.data.DataLoader` that generates batches of 215 | data. 216 | """ 217 | self.network.eval() 218 | 219 | total_lev = 0 220 | total_lab_len = 0 221 | n = 0 222 | 223 | self._logger.debug('idx,model_label_prediction,target,edit_distance') 224 | for i, ((x, logit_lens), y) in enumerate(loader): 225 | with torch.no_grad(): # Ensure the gradient isn't computed. 226 | logits = self.network(x) 227 | 228 | preds = self.decoder.decode(logits.cpu(), logit_lens) 229 | acts = [''.join(self.ALPHABET.get_symbols(yi.data.numpy())) 230 | for yi in y] 231 | 232 | for pred, act in zip(preds, acts): 233 | lev = levenshtein(pred.split(), act.split()) 234 | 235 | self._logger.debug('%d,%r,%r,%d', n, pred, act, lev) 236 | 237 | n += 1 238 | total_lev += lev 239 | total_lab_len += len(act.split()) 240 | 241 | wer = float(total_lev) / total_lab_len 242 | self._logger.debug('eval/wer: %r', wer) 243 | self._global_state.writer.add_scalar('eval/wer', 244 | wer, self._global_state.step) 245 | 246 | return wer 247 | 248 | @log_call_info 249 | def eval_loss(self, loader): 250 | """Evaluates the CTC loss of the Model. 251 | 252 | Args: 253 | loader: A `torch.utils.data.DataLoader` that generates batches of 254 | data. 255 | """ 256 | self.network.eval() 257 | 258 | total_loss = 0.0 259 | total_samples = 0 260 | 261 | self._logger.debug('idx,batch_mean_sample_loss') 262 | 263 | for i, ((x, logit_lens), y) in enumerate(loader): 264 | with torch.no_grad(): # Ensure the gradient isn't computed. 265 | logits = self.network(x) 266 | 267 | batch_loss = self.loss(logits, y, logit_lens).item() 268 | batch_samples = len(logit_lens) 269 | 270 | total_loss += batch_loss 271 | total_samples += batch_samples 272 | 273 | self._logger.debug('%d,%f', i, batch_loss / batch_samples) 274 | 275 | mean_sample_loss = total_loss / max(1, total_samples) 276 | self._logger.debug('eval/mean_sample_loss: %f', mean_sample_loss) 277 | self._global_state.writer.add_scalar('eval/mean_sample_loss', 278 | mean_sample_loss, 279 | self._global_state.step) 280 | return mean_sample_loss 281 | 282 | def _train_log_init(self): 283 | header = 'step,global_step,completed_epochs,sum_logit_lens,loss' 284 | self._logger.debug(header) 285 | self._cum_batch_size = 0 286 | 287 | def _train_log_step(self, step, x, logits, logit_lens, loss, data_iter): 288 | start = time.time() 289 | 290 | total_steps = logit_lens.sum().item() 291 | 292 | self._logger.debug('%d,%d,%d,%d,%f', 293 | step, 294 | self._global_state.step, 295 | self.completed_epochs, 296 | total_steps, 297 | loss) 298 | 299 | self._global_state.writer.add_scalar('train/batch_loss', 300 | loss, 301 | self._global_state.step) 302 | self._global_state.writer.add_scalar('train/batch_size', 303 | len(logit_lens), 304 | self._global_state.step) 305 | 306 | self._cum_batch_size += len(logit_lens) 307 | self._global_state.writer.add_scalar('train/epoch_cum_batch_size', 308 | self._cum_batch_size, 309 | self._global_state.step) 310 | 311 | self._global_state.writer.add_scalar('train/batch_len-x-batch_size', 312 | x.size(0) * x.size(1), 313 | self._global_state.step) 314 | self._global_state.writer.add_scalar('train/sum_logit_lens', 315 | total_steps, 316 | self._global_state.step) 317 | self._global_state.writer.add_scalar('train/memory_percent', 318 | psutil.Process().memory_percent(), 319 | self._global_state.step) 320 | 321 | self._train_log_step_data_queue(data_iter) 322 | 323 | self._train_log_step_cuda_memory() 324 | 325 | self._train_log_step_grad_param_stats() 326 | 327 | self._global_state.writer.add_scalar('train/log_step_time', 328 | time.time() - start, 329 | self._global_state.step) 330 | 331 | def _train_log_step_data_queue(self, data_iter): 332 | """Logs the number of batches in the PyTorch DataLoader queue.""" 333 | # If num_workers is 0 then there is no Queue and each batch is loaded 334 | # when next is called. 335 | if data_iter.num_workers > 0: 336 | # Otherwise there exists a queue from which samples are read from. 337 | if data_iter.pin_memory or data_iter.timeout > 0: 338 | # The loader iterator in PyTorch 0.4 with pin_memory or a 339 | # timeout has a single thread fill a queue.Queue from a 340 | # multiprocessing.SimpleQueue that is filled by num_workers 341 | # other workers. The queue.Queue is used when next is called. 342 | # See: https://pytorch.org/docs/0.4.0/_modules/torch/utils/data/dataloader.html#DataLoader # noqa: E501 343 | self._global_state.writer.add_scalar( 344 | 'train/queue_size', 345 | data_iter.data_queue.qsize(), 346 | self._global_state.step) 347 | else: 348 | # Otherwise the loader iterator reads from a 349 | # multiprocessing.SimpleQueue. This has no size function... 350 | self._global_state.writer.add_scalar( 351 | 'train/queue_empty', 352 | data_iter.data_queue.empty(), 353 | self._global_state.step) 354 | 355 | def _train_log_step_cuda_memory(self): 356 | """Logs CUDA memory usage.""" 357 | if torch.cuda.is_available(): 358 | self._global_state.writer.add_scalar( 359 | 'train/memory_allocated', 360 | torch.cuda.memory_allocated(), 361 | self._global_state.step) 362 | self._global_state.writer.add_scalar( 363 | 'train/max_memory_allocated', 364 | torch.cuda.max_memory_allocated(), 365 | self._global_state.step) 366 | self._global_state.writer.add_scalar( 367 | 'train/memory_cached', 368 | torch.cuda.memory_cached(), 369 | self._global_state.step) 370 | self._global_state.writer.add_scalar( 371 | 'train/max_memory_cached', 372 | torch.cuda.max_memory_cached(), 373 | self._global_state.step) 374 | 375 | def _train_log_step_grad_param_stats(self): 376 | """Logs gradient and parameter values.""" 377 | if self._global_state.log_step(): 378 | for name, param in self.network.named_parameters(): 379 | self._global_state.writer.add_histogram( 380 | 'parameters/%s' % name, param, self._global_state.step) 381 | 382 | self._global_state.writer.add_histogram( 383 | 'gradients/%s' % name, param.grad, self._global_state.step) 384 | 385 | def _train_log_end(self, epoch_loss, total_samples): 386 | mean_sample_loss = float(epoch_loss) / total_samples 387 | self._logger.debug('train/mean_sample_loss: %r', mean_sample_loss) 388 | self._logger.info('epoch %d finished', self.completed_epochs) 389 | 390 | self._global_state.writer.add_scalar('train/mean_sample_loss', 391 | mean_sample_loss, 392 | self._global_state.step) 393 | -------------------------------------------------------------------------------- /src/deepspeech/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/src/deepspeech/networks/__init__.py -------------------------------------------------------------------------------- /src/deepspeech/networks/deepspeech.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from deepspeech.networks.utils import OverLastDim 4 | 5 | 6 | class Network(nn.Module): 7 | """A network with 3 FC layers, a Bi-LSTM, and 2 FC layers. 8 | 9 | Args: 10 | in_features: Number of input features per step per batch. 11 | n_hidden: Internal hidden unit size. 12 | out_features: Number of output features per step per batch. 13 | drop_prob: Dropout drop probability. 14 | relu_clip: ReLU clamp value: `min(max(0, x), relu_clip)`. 15 | forget_gate_bias: Total initialized value of the bias used in the 16 | forget gate. Set to None to use PyTorch's default initialisation. 17 | (See: http://proceedings.mlr.press/v37/jozefowicz15.pdf) 18 | """ 19 | 20 | def __init__(self, in_features, n_hidden, out_features, drop_prob, 21 | relu_clip=20.0, forget_gate_bias=1.0): 22 | super().__init__() 23 | 24 | self._relu_clip = relu_clip 25 | self._drop_prob = drop_prob 26 | 27 | self.fc1 = self._fully_connected(in_features, n_hidden) 28 | self.fc2 = self._fully_connected(n_hidden, n_hidden) 29 | self.fc3 = self._fully_connected(n_hidden, 2*n_hidden) 30 | self.bi_lstm = self._bi_lstm(2*n_hidden, n_hidden, forget_gate_bias) 31 | self.fc4 = self._fully_connected(2*n_hidden, n_hidden) 32 | self.out = self._fully_connected(n_hidden, 33 | out_features, 34 | relu=False, 35 | dropout=False) 36 | 37 | def _fully_connected(self, in_f, out_f, relu=True, dropout=True): 38 | layers = [nn.Linear(in_f, out_f)] 39 | if relu: 40 | layers.append(nn.Hardtanh(0, self._relu_clip, inplace=True)) 41 | if dropout: 42 | layers.append(nn.Dropout(p=self._drop_prob)) 43 | return OverLastDim(nn.Sequential(*layers)) 44 | 45 | def _bi_lstm(self, input_size, hidden_size, forget_gate_bias): 46 | lstm = nn.LSTM(input_size=input_size, 47 | hidden_size=hidden_size, 48 | bidirectional=True) 49 | if forget_gate_bias is not None: 50 | for name in ['bias_ih_l0', 'bias_ih_l0_reverse']: 51 | bias = getattr(lstm, name) 52 | bias.data[hidden_size:2*hidden_size].fill_(forget_gate_bias) 53 | for name in ['bias_hh_l0', 'bias_hh_l0_reverse']: 54 | bias = getattr(lstm, name) 55 | bias.data[hidden_size:2*hidden_size].fill_(0) 56 | return lstm 57 | 58 | def forward(self, x): 59 | """Computes a single forward pass through the network. 60 | 61 | Args: 62 | x: A tensor of shape (seq_len, batch, in_features). 63 | 64 | Returns: 65 | Logits of shape (seq_len, batch, out_features). 66 | """ 67 | h = self.fc1(x) 68 | h = self.fc2(h) 69 | h = self.fc3(h) 70 | h, _ = self.bi_lstm(h) 71 | h = self.fc4(h) 72 | out = self.out(h) 73 | return out 74 | -------------------------------------------------------------------------------- /src/deepspeech/networks/deepspeech2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | from torch import nn 7 | from torch.nn.parameter import Parameter 8 | 9 | from deepspeech.networks.utils import OverLastDim 10 | 11 | 12 | SUPPORTED_RNNS = { 13 | 'gru': nn.GRU, 14 | 'lstm': nn.LSTM, 15 | 'rnn': nn.RNN 16 | } 17 | 18 | 19 | class Network(nn.Module): 20 | """A network with 2 conv layers, N recurrent layers, and a FC layer. 21 | 22 | The architecture is based on the Deep Speech 2 paper: 23 | https://arxiv.org/abs/1512.02595 24 | 25 | The implementation is based on a PyTorch version: 26 | https://github.com/SeanNaren/deepspeech.pytorch 27 | 28 | Args: 29 | in_features: Number of input features per step per batch. 30 | n_hidden: Internal hidden unit size. 31 | out_features: Number of output features per step per batch. 32 | rnn_type: Type of recurrent neural network to use. See 33 | SUPPORTED_RNNS.keys() for a complete list. 34 | bidirectional: Recurrent neural networks are bidirectional if True. 35 | rnn_layers: Number of recurrent layers to stack. 36 | context: Number of look-ahead context frames to use if not 37 | bidirectional. 38 | relu_clip: ReLU clamp value: `min(max(0, x), relu_clip)`. 39 | """ 40 | 41 | def __init__(self, in_features, n_hidden, out_features, rnn_type='lstm', 42 | bidirectional=True, rnn_layers=5, context=20, relu_clip=20.0): 43 | super().__init__() 44 | self._relu_clip = relu_clip 45 | 46 | self._conv_layers() 47 | conv_features = self._conv_layer_feature_size(in_features) 48 | 49 | self._rnn_layers(in_features=conv_features, 50 | n_hidden=n_hidden, 51 | rnn_layers=rnn_layers, 52 | rnn_type=SUPPORTED_RNNS[rnn_type], 53 | bidirectional=bidirectional, 54 | context=context) 55 | 56 | fully_connected = nn.Sequential( 57 | nn.BatchNorm1d(n_hidden), 58 | nn.Linear(n_hidden, out_features, bias=False) 59 | ) 60 | self.fc = OverLastDim(fully_connected) 61 | 62 | def _conv_layers(self): 63 | self.conv = nn.Sequential( 64 | nn.Conv2d(in_channels=1, 65 | out_channels=32, 66 | kernel_size=(41, 11), 67 | stride=(2, 2), 68 | padding=(0, 10)), 69 | nn.BatchNorm2d(32), 70 | nn.Hardtanh(0, self._relu_clip, inplace=True), 71 | nn.Conv2d(in_channels=32, 72 | out_channels=32, 73 | kernel_size=(21, 11), 74 | stride=(2, 1), 75 | padding=(0, 0)), 76 | nn.BatchNorm2d(32), 77 | nn.Hardtanh(0, self._relu_clip, inplace=True) 78 | ) 79 | 80 | @staticmethod 81 | def _conv_output_size(input_size, filter_size, padding, stride): 82 | """Returns the length of the output size for a convolution. 83 | 84 | Applies the standard formula: 85 | 86 | ((W - F + 2P) / S) + 1 87 | 88 | where: 89 | 90 | W: `input_size` 91 | F: `filter_size` 92 | P: `padding` on one side 93 | S: `stride` 94 | """ 95 | return (float(input_size - filter_size + 2*padding) / stride) + 1 96 | 97 | @staticmethod 98 | def output_len(input_len): 99 | """Returns the length of the CTC matrix for a given input seq. len.""" 100 | output_len = Network._conv_output_size(input_len, 11, 10, 2) 101 | output_len = Network._conv_output_size(output_len, 11, 0, 1) 102 | return int(output_len) 103 | 104 | def _conv_layer_feature_size(self, in_features): 105 | """Returns the number of features after processing with conv layers.""" 106 | rnn_input_size = Network._conv_output_size(in_features, 41, 0, 2) 107 | rnn_input_size = Network._conv_output_size(rnn_input_size, 21, 0, 2) 108 | rnn_input_size *= 32 # Collapse channels. 109 | return int(rnn_input_size) 110 | 111 | def _rnn_layers(self, in_features, n_hidden, rnn_layers, rnn_type, 112 | bidirectional, context): 113 | rnns = OrderedDict() 114 | for i in range(rnn_layers): 115 | rnn = RNNWrapper(input_size=in_features, 116 | hidden_size=n_hidden, 117 | rnn_type=rnn_type, 118 | bidirectional=bidirectional, 119 | batch_norm=i > 0) 120 | rnns[str(i)] = rnn 121 | in_features = n_hidden 122 | self.rnns = nn.Sequential(rnns) 123 | 124 | if not bidirectional: 125 | self.lookahead = nn.Sequential( 126 | Lookahead(n_hidden, context=context), 127 | nn.Hardtanh(0, self._relu_clip, inplace=True)) 128 | 129 | def forward(self, x): 130 | """Computes a single forward pass through the network. 131 | 132 | Args: 133 | x: A tensor of shape (seq_len, batch, in_features). 134 | 135 | Returns: 136 | Logits of shape (seq_len, batch, out_features). 137 | """ 138 | # T, N, H = seq_len, batch, features 139 | x = x.permute(1, 2, 0) # TxNxH -> NxHxT 140 | x.unsqueeze_(dim=1) # NxHxT -> Nx1xHxT 141 | x = self.conv(x) 142 | 143 | N, H1, H2, T = x.size() 144 | x = x.view(N, H1*H2, T) 145 | x = x.permute(2, 0, 1) # NxHxT -> TxNxH 146 | x = self.rnns(x.contiguous()) 147 | 148 | if hasattr(self, 'lookahead'): 149 | x = self.lookahead(x) 150 | 151 | return self.fc(x) 152 | 153 | 154 | class RNNWrapper(nn.Module): 155 | def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, 156 | bidirectional=False, batch_norm=True): 157 | """Bias-free RNN wrapper with optional batch norm and bidir summation. 158 | 159 | Instantiates an RNN without bias parameters. Optionally applies a batch 160 | normalisation layer to the input with the statistics computed over all 161 | time steps. If the RNN is bidirectional, the output from the forward 162 | and backward units is summed before return. 163 | """ 164 | super().__init__() 165 | if batch_norm: 166 | self.batch_norm = OverLastDim(nn.BatchNorm1d(input_size)) 167 | self.bidirectional = bidirectional 168 | self.rnn = rnn_type(input_size=input_size, 169 | hidden_size=hidden_size, 170 | bidirectional=bidirectional, 171 | bias=False) 172 | 173 | def forward(self, x): 174 | if hasattr(self, 'batch_norm'): 175 | x = self.batch_norm(x) 176 | x, _ = self.rnn(x) 177 | if self.bidirectional: 178 | # TxNx(H*2) -> TxNxH by sum. 179 | seq_len, batch_size, _ = x.size() 180 | x = x.view(seq_len, batch_size, 2, -1) \ 181 | .sum(dim=2) \ 182 | .view(seq_len, batch_size, -1) 183 | return x 184 | 185 | 186 | class Lookahead(nn.Module): 187 | """Wang et al 2016: Lookahead Conv. Layer for Unidir. Rec. Neural Nets. 188 | 189 | (31/07/2018, sam) This should be updated as it looks old but we aren't 190 | using it. 191 | """ 192 | def __init__(self, n_features, context): 193 | # should we handle batch_first=True? 194 | super(Lookahead, self).__init__() 195 | self.n_features = n_features 196 | self.weight = Parameter(torch.Tensor(n_features, context + 1)) 197 | assert context > 0 198 | self.context = context 199 | self.register_parameter('bias', None) 200 | self.init_parameters() 201 | 202 | def init_parameters(self): # what's a better way initialiase this layer? 203 | stdv = 1. / math.sqrt(self.weight.size(1)) 204 | self.weight.data.uniform_(-stdv, stdv) 205 | 206 | def forward(self, input): 207 | """ 208 | Args: 209 | input: size [seq_len, batch_size, num_features] 210 | 211 | Returns: 212 | output shape - same as input 213 | """ 214 | 215 | seq_len = input.size(0) 216 | # pad the 0th dimension (T/sequence) with zeroes whose number = context 217 | # Once pytorch's padding functions have settled, should move to those. 218 | padding = torch.zeros(self.context, *(input.size()[1:])) \ 219 | .type_as(input.data) 220 | x = torch.cat((input, Variable(padding)), 0) 221 | 222 | # add lookahead windows (with context+1 width) as a fourth dimension 223 | # for each seq-batch-feature combination 224 | 225 | # TxLxNxH - sequence, context, batch, feature 226 | x = [x[i:i + self.context + 1] for i in range(seq_len)] 227 | x = torch.stack(x) 228 | 229 | # TxNxHxL - sequence, batch, feature, context 230 | x = x.permute(0, 2, 3, 1) 231 | 232 | x = torch.mul(x, self.weight).sum(dim=3) 233 | return x 234 | 235 | def __repr__(self): 236 | return self.__class__.__name__ + '(' \ 237 | + 'n_features=' + str(self.n_features) \ 238 | + ', context=' + str(self.context) + ')' 239 | -------------------------------------------------------------------------------- /src/deepspeech/networks/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def to_cuda(network): 5 | """Calls model.cuda() and moves input(s) to the GPU before forward.""" 6 | network.cuda() 7 | 8 | network._to_cuda_forward_cache = network.forward 9 | 10 | def cuda_forward(x): 11 | return network._to_cuda_forward_cache(x.cuda(non_blocking=True)) 12 | 13 | network.forward = cuda_forward 14 | 15 | 16 | class OverLastDim(nn.Module): 17 | """Collapses a tensor to 2D, applies a module, and (re-)expands the tensor. 18 | 19 | An n-dimensional tensor of shape (s_1, s_2, ..., s_n) is first collapsed to 20 | a tensor with shape (s_1*s_2*...*s_n-1, s_n). The module is called with 21 | this as input producing (s_1*s_2*...*s_n-1, s_n') --- note that the final 22 | dimension can change. This is expanded to (s_1, s_2, ..., s_n-1, s_n') and 23 | returned. 24 | 25 | Args: 26 | module (nn.Module): Module to apply. Must accept a 2D tensor as input 27 | and produce a 2D tensor as output, optionally changing the size of 28 | the last dimension. 29 | """ 30 | 31 | def __init__(self, module): 32 | super().__init__() 33 | self.module = module 34 | 35 | def forward(self, x): 36 | *dims, input_size = x.size() 37 | 38 | reduced_dims = 1 39 | for dim in dims: 40 | reduced_dims *= dim 41 | 42 | x = x.view(reduced_dims, -1) 43 | x = self.module(x) 44 | x = x.view(*dims, -1) 45 | return x 46 | -------------------------------------------------------------------------------- /src/deepspeech/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import re 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from deepspeech.data.datasets import LibriSpeech 10 | from deepspeech.data.loader import collate_input_sequences 11 | from deepspeech.decoder import BeamCTCDecoder 12 | from deepspeech.decoder import GreedyCTCDecoder 13 | from deepspeech.global_state import GlobalState 14 | from deepspeech.logging import LogLevelAction 15 | from deepspeech.models import DeepSpeech 16 | from deepspeech.models import DeepSpeech2 17 | from deepspeech.models import Model 18 | 19 | 20 | MODEL_CHOICES = ['ds1', 'ds2'] 21 | 22 | 23 | def main(args=None): 24 | """Train and evaluate a DeepSpeech or DeepSpeech2 network. 25 | 26 | Args: 27 | args (list str, optional): List of arguments to use. If `None`, 28 | defaults to `sys.argv`. 29 | """ 30 | args = get_parser().parse_args(args) 31 | 32 | global_state = GlobalState(exp_dir=args.exp_dir, 33 | log_frequency=args.slow_log_freq) 34 | 35 | init_logger(global_state.exp_dir, args.log_file) 36 | 37 | logging.debug(args) 38 | 39 | if torch.cuda.is_available(): 40 | torch.backends.cudnn.benchmark = True 41 | 42 | decoder_cls, decoder_kwargs = get_decoder(args) 43 | 44 | model = get_model(args, decoder_cls, decoder_kwargs, global_state.exp_dir) 45 | 46 | train_loader = get_train_loader(args, model) 47 | 48 | dev_loader = get_dev_loader(args, model) 49 | 50 | if train_loader is not None: 51 | for epoch in range(model.completed_epochs, args.n_epochs): 52 | maybe_eval(model, dev_loader, args.dev_log) 53 | model.train(train_loader) 54 | _save_model(args.model, model, args.exp_dir) 55 | 56 | maybe_eval(model, dev_loader, args.dev_log) 57 | 58 | 59 | def maybe_eval(model, dev_loader, dev_log): 60 | """Evaluates `model` on `dev_loader` for each statistic in `dev_log`. 61 | 62 | Args: 63 | model: A `deepspeech.models.Model`. 64 | dev_loader (optional): A `torch.utils.data.DataLoader`. If `None`, 65 | evaluation is skipped. 66 | dev_log (optional): A list of strings, where each string refers to the 67 | name of a statistic to compute. Each statistic will be computed at 68 | most once. Supported statistics: ['loss', 'wer']. 69 | """ 70 | if dev_loader is not None: 71 | for stat in set(dev_log): 72 | if stat == 'loss': 73 | model.eval_loss(dev_loader) 74 | elif stat == 'wer': 75 | model.eval_wer(dev_loader) 76 | else: 77 | raise ValueError('unknown evaluation stat request: %r' % stat) 78 | 79 | 80 | def get_parser(): 81 | """Returns an `argparse.ArgumentParser`.""" 82 | parser = argparse.ArgumentParser( 83 | description='train and evaluate a DeepSpeech or DeepSpeech2 network', 84 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 85 | ) 86 | 87 | # logging ----------------------------------------------------------------- 88 | 89 | parser.add_argument('--log_level', 90 | action=LogLevelAction, 91 | default='DEBUG', 92 | help='logging level - see `logging` module') 93 | 94 | parser.add_argument('--slow_log_freq', 95 | default=500, 96 | type=int, 97 | help='run slow logs every `slow_log_freq` batches') 98 | 99 | parser.add_argument('--exp_dir', 100 | default=None, 101 | help='path to directory to keep experimental data - ' 102 | 'see `deepspeech.global_state.GlobalState`') 103 | 104 | parser.add_argument('--log_file', 105 | nargs='?', 106 | default='log.txt', 107 | const=None, 108 | help='filename to use for log file - logs written to ' 109 | 'stderr if empty') 110 | 111 | # model ------------------------------------------------------------------- 112 | 113 | parser.add_argument('model', 114 | choices=MODEL_CHOICES, 115 | help='model to train') 116 | 117 | parser.add_argument('--state_dict_path', 118 | default=None, 119 | help='path to initial state_dict to load into model - ' 120 | 'takes precedence over ' 121 | '`--no_resume_from_exp_dir`') 122 | 123 | parser.add_argument('--no_resume_from_exp_dir', 124 | action='store_true', 125 | default=False, 126 | help='do not load the last state_dict in exp_dir') 127 | 128 | # decoder ----------------------------------------------------------------- 129 | 130 | parser.add_argument('--decoder', 131 | default='greedy', 132 | choices=['beam', 'greedy'], 133 | help='decoder to use') 134 | 135 | parser.add_argument('--lm_path', 136 | default=None, 137 | help='path to language model - if None, no lm is used') 138 | 139 | parser.add_argument('--lm_weight', 140 | default=None, 141 | type=float, 142 | help='language model weight in loss (i.e. alpha)') 143 | 144 | parser.add_argument('--word_weight', 145 | default=None, 146 | type=float, 147 | help='word bonus weight in loss (i.e. beta)') 148 | 149 | parser.add_argument('--beam_width', 150 | default=None, 151 | type=int, 152 | help='width of beam search') 153 | 154 | # optimizer --------------------------------------------------------------- 155 | 156 | parser.add_argument('--learning_rate', 157 | default=0.0003, 158 | type=float, 159 | help='learning rate of Adam optimizer') 160 | 161 | # data -------------------------------------------------------------------- 162 | 163 | parser.add_argument('--cachedir', 164 | default='/tmp/data/cache/', 165 | help='location to download dataset(s)') 166 | 167 | # training 168 | 169 | TRAIN_SUBSETS = ['train-clean-100', 170 | 'train-clean-360', 171 | 'train-other-500'] 172 | parser.add_argument('--train_subsets', 173 | default=TRAIN_SUBSETS, 174 | choices=TRAIN_SUBSETS, 175 | help='LibriSpeech subsets to train on', 176 | nargs='*') 177 | 178 | parser.add_argument('--train_batch_size', 179 | default=16, 180 | type=int, 181 | help='number of samples in a training batch') 182 | 183 | parser.add_argument('--train_num_workers', 184 | default=4, 185 | type=int, 186 | help='number of subprocesses for train DataLoader') 187 | 188 | # validation 189 | 190 | parser.add_argument('--dev_log', 191 | default=['loss', 'wer'], 192 | choices=['loss', 'wer'], 193 | nargs='*', 194 | help='validation statistics to log') 195 | 196 | parser.add_argument('--dev_subsets', 197 | default=['dev-clean', 'dev-other'], 198 | choices=['dev-clean', 'dev-other', 199 | 'test-clean', 'test-other'], 200 | help='LibriSpeech subsets to evaluate loss and WER on', 201 | nargs='*') 202 | 203 | parser.add_argument('--dev_batch_size', 204 | default=16, 205 | type=int, 206 | help='number of samples in a validation batch') 207 | 208 | parser.add_argument('--dev_num_workers', 209 | default=4, 210 | type=int, 211 | help='number of subprocesses for dev DataLoader') 212 | 213 | # outer loop -------------------------------------------------------------- 214 | 215 | parser.add_argument('--n_epochs', 216 | default=15, 217 | type=int, 218 | help='number of epochs') 219 | 220 | # ------------------------------------------------------------------------- 221 | 222 | return parser 223 | 224 | 225 | def init_logger(exp_dir, log_file): 226 | """Initialises the `logging.Logger`.""" 227 | logger = logging.getLogger() 228 | 229 | formatter = logging.Formatter(fmt='%(asctime)s - %(name)s - %(funcName)s -' 230 | ' %(levelname)s: %(message)s') 231 | 232 | if log_file is not None: 233 | handler = logging.FileHandler(os.path.join(exp_dir, log_file)) 234 | else: 235 | handler = logging.StreamHandler() 236 | 237 | handler.setFormatter(formatter) 238 | 239 | logger.addHandler(handler) 240 | 241 | logger.setLevel(logging.DEBUG) 242 | 243 | 244 | def get_model(args, decoder_cls, decoder_kwargs, exp_dir): 245 | """Returns a `deepspeech.models.Model`. 246 | 247 | Args: 248 | args: An `argparse.Namespace` for the `argparse.ArgumentParser` 249 | returned by `get_parser`. 250 | decoder_cls: See `deepspeech.models.Model`. 251 | decoder_kwargs: See `deepspeech.models.Model`. 252 | exp_dir: path to directory where all experimental data will be stored. 253 | """ 254 | model_cls = {'ds1': DeepSpeech, 'ds2': DeepSpeech2}[args.model] 255 | 256 | model = model_cls(optimiser_cls=torch.optim.Adam, 257 | optimiser_kwargs={'lr': args.learning_rate}, 258 | decoder_cls=decoder_cls, 259 | decoder_kwargs=decoder_kwargs) 260 | 261 | state_dict_path = args.state_dict_path 262 | if state_dict_path is None and not args.no_resume_from_exp_dir: 263 | # Restore from last saved `state_dict` in `exp_dir`. 264 | state_dict_path = _get_last_state_dict_path(args.model, exp_dir) 265 | 266 | if state_dict_path is not None: 267 | # Restore from user-specified `state_dict`. 268 | logging.debug('restoring state_dict at %s' % state_dict_path) 269 | map_location = 'cpu' if not torch.cuda.is_available() else None 270 | model.load_state_dict(torch.load(state_dict_path, map_location)) 271 | else: 272 | logging.debug('using randomly initialised model') 273 | _save_model(args.model, model, exp_dir) 274 | 275 | return model 276 | 277 | 278 | def get_decoder(args): 279 | """Returns a `deepspeech.decoder.Decoder`. 280 | 281 | Args: 282 | args: An `argparse.Namespace` for the `argparse.ArgumentParser` 283 | returned by `get_parser`. 284 | """ 285 | decoder_kwargs = {'alphabet': Model.ALPHABET, 286 | 'blank_symbol': Model.BLANK_SYMBOL} 287 | 288 | if args.decoder == 'beam': 289 | decoder_cls = BeamCTCDecoder 290 | 291 | if args.lm_weight is not None: 292 | decoder_kwargs['alpha'] = args.lm_weight 293 | if args.word_weight is not None: 294 | decoder_kwargs['beta'] = args.word_weight 295 | if args.beam_width is not None: 296 | decoder_kwargs['beam_width'] = args.beam_width 297 | if args.lm_path is not None: 298 | decoder_kwargs['model_path'] = args.lm_path 299 | 300 | elif args.decoder == 'greedy': 301 | decoder_cls = GreedyCTCDecoder 302 | 303 | beam_args = ['lm_weight', 'word_weight', 'beam_width', 'lm_path'] 304 | for arg in beam_args: 305 | if getattr(args, arg) is not None: 306 | raise ValueError('greedy decoder selected but %r is not ' 307 | 'None' % arg) 308 | 309 | return decoder_cls, decoder_kwargs 310 | 311 | 312 | def all_state_dicts(model_str, exp_dir): 313 | """Returns a dict of (epoch, filename) for all state_dicts in `exp_dir`. 314 | 315 | Args: 316 | model_str: Model whose state_dicts to consider. 317 | exp_dir: path to directory where all experimental data will be stored. 318 | """ 319 | state_dicts = {} 320 | 321 | for f in os.listdir(exp_dir): 322 | match = re.match('(%s-([0-9]+).pt)' % model_str, f) 323 | if not match: 324 | continue 325 | 326 | groups = match.groups() 327 | name = groups[0] 328 | epoch = groups[1] 329 | 330 | state_dicts[epoch] = name 331 | 332 | return state_dicts 333 | 334 | 335 | def _get_last_state_dict_path(model_str, exp_dir): 336 | """Returns the absolute path of the last state_dict in `exp_dir` or `None`. 337 | 338 | Args: 339 | model_str: Model whose state_dicts to consider. 340 | exp_dir: path to directory where all experimental data will be stored. 341 | """ 342 | state_dicts = all_state_dicts(model_str, exp_dir) 343 | 344 | if len(state_dicts) == 0: 345 | return None 346 | 347 | last_epoch = sorted(state_dicts.keys())[-1] 348 | 349 | return os.path.join(exp_dir, state_dicts[last_epoch]) 350 | 351 | 352 | def _save_model(model_str, model, exp_dir): 353 | """Saves the model's `state_dict` in `exp_dir`. 354 | 355 | Args: 356 | model_str: Argument name of `model`. 357 | model: A `deepspeech.models.Model`. 358 | exp_dir: path to directory where the `model`'s `state_dict` will be 359 | stored. 360 | """ 361 | save_name = '%s-%d.pt' % (model_str, model.completed_epochs) 362 | save_path = os.path.join(exp_dir, save_name) 363 | torch.save(model.state_dict(), save_path) 364 | 365 | 366 | def get_train_loader(args, model): 367 | """Returns a `torch.nn.DataLoader over the training data. 368 | 369 | Args: 370 | args: An `argparse.Namespace` for the `argparse.ArgumentParser` 371 | returned by `get_parser`. 372 | model: A `deepspeech.models.Model`. 373 | """ 374 | if len(args.train_subsets) == 0: 375 | logging.debug('no `train_subsets` specified') 376 | return 377 | 378 | todo_epochs = args.n_epochs - model.completed_epochs 379 | if todo_epochs <= 0: 380 | logging.debug('`n_epochs` <= `model.completed_epochs`') 381 | return 382 | 383 | train_cache = os.path.join(args.cachedir, 'train') 384 | train_dataset = LibriSpeech(root=train_cache, 385 | subsets=args.train_subsets, 386 | transform=model.transform, 387 | target_transform=model.target_transform, 388 | download=True) 389 | 390 | return DataLoader(train_dataset, 391 | collate_fn=collate_input_sequences, 392 | pin_memory=torch.cuda.is_available(), 393 | num_workers=args.train_num_workers, 394 | batch_size=args.train_batch_size, 395 | shuffle=True) 396 | 397 | 398 | def get_dev_loader(args, model): 399 | """Returns a `torch.nn.DataLoader over the validation data. 400 | 401 | Args: 402 | args: An `argparse.Namespace` for the `argparse.ArgumentParser` 403 | returned by `get_parser`. 404 | model: A `deepspeech.models.Model`. 405 | """ 406 | if len(args.dev_subsets) == 0: 407 | logging.debug('no `dev_subsets` specified') 408 | return 409 | 410 | if len(args.dev_log) == 0: 411 | logging.debug('no `dev_log` statistics specified') 412 | return 413 | 414 | dev_cache = os.path.join(args.cachedir, 'dev') 415 | dev_dataset = LibriSpeech(root=dev_cache, 416 | subsets=args.dev_subsets, 417 | transform=model.transform, 418 | target_transform=model.target_transform, 419 | download=True) 420 | 421 | return DataLoader(dev_dataset, 422 | collate_fn=collate_input_sequences, 423 | pin_memory=torch.cuda.is_available(), 424 | num_workers=args.dev_num_workers, 425 | batch_size=args.dev_batch_size, 426 | shuffle=False) 427 | -------------------------------------------------------------------------------- /src/deepspeech/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/src/deepspeech/utils/__init__.py -------------------------------------------------------------------------------- /src/deepspeech/utils/singleton.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import sys 3 | 4 | 5 | class SingletonNotExistError(Exception): 6 | """Raised when getting a Singleton that does not exist. 7 | 8 | Args: 9 | singleton_cls_name: `__name__` of the Singleton class that has no 10 | instance. 11 | message (optional): Explanation of the error. 12 | """ 13 | def __init__(self, singleton_cls_name, message=None): 14 | self.singleton_cls_name = singleton_cls_name 15 | self.message = message 16 | 17 | 18 | class SingletonRefsExistError(Exception): 19 | """Raised when resetting Singleton class but refs to current instance exist. 20 | 21 | Args: 22 | singleton_cls_name: `__name__` of the Singleton class that still has 23 | refs to the current Singleton instance. 24 | message (optional): Explanation of the error. 25 | """ 26 | def __init__(self, singleton_cls_name, message=None): 27 | self.singleton_cls_name = singleton_cls_name 28 | self.message = message 29 | 30 | 31 | class Singleton(type): 32 | """A Singleton metaclass ensures at most one instance of a class exists. 33 | 34 | Args: 35 | check_args (optional, bool): If `True` when passed as a kwd (see 36 | example below) then it verifies that each call to the classes 37 | `__init__` method has the same set of arguments. If not, a 38 | `ValueError` is raised. Default: `False`. 39 | 40 | Example: 41 | 42 | >>> class Foo(metaclass=Singleton, check_args=True): 43 | ... def __init__(self, val): 44 | ... self.val = val 45 | ... 46 | >>> Foo.get_singleton() 47 | Traceback (most recent call last): 48 | ... 49 | deepspeech.utils.singleton.SingletonNotExistError: Foo 50 | >>> a = Foo(val=6) 51 | >>> b = Foo(6) 52 | >>> a is b 53 | True 54 | >>> Foo(val=8) 55 | Traceback (most recent call last): 56 | ... 57 | ValueError: Foo instance already exists but previously initialised 58 | differently... 59 | >>> c = Foo.get_singleton() 60 | >>> a is c 61 | True 62 | >>> d = Foo.get_or_init_singleton(val=8) # `check_args` skipped! 63 | >>> a is d 64 | True 65 | """ 66 | def __new__(metacls, name, bases, namespace, **kwds): 67 | cls_methods = ['get_singleton', 68 | 'get_or_init_singleton', 69 | '_reset_singleton'] 70 | for cls_method in cls_methods: 71 | namespace[cls_method] = classmethod(getattr(metacls, cls_method)) 72 | return type.__new__(metacls, name, bases, namespace) 73 | 74 | def __init__(cls, name, bases, namespace, **kwds): 75 | cls.__check_args = 'check_args' in kwds 76 | cls.__instance = None 77 | 78 | def get_singleton(cls): 79 | """Returns the Singleton instance if it exists. 80 | 81 | Raises: 82 | SingletonNotExistError: Singleton instance yet to be created. 83 | """ 84 | if cls.__instance is None: 85 | raise SingletonNotExistError(cls.__name__) 86 | return cls.__instance 87 | 88 | def get_or_init_singleton(cls, *args, **kwargs): 89 | """Returns the Singleton instance if it exists else calls `__init__`. 90 | 91 | Warning: The arguments are not used, and hence not checked, if the 92 | Singleton instance already exists even if `check_args=True`. 93 | """ 94 | if cls.__instance is None: 95 | return Singleton.__call__(cls, *args, **kwargs) 96 | return cls.__instance 97 | 98 | def _reset_singleton(cls): 99 | """Removes the Singleton class's reference to the Singleton instance. 100 | 101 | Raises: 102 | SingletonRefsExistError: Raised if there exist objects that refer 103 | to the current Singleton instance if it exists. 104 | """ 105 | if cls.__instance is not None: 106 | # sub 2 to remove `getrefcount` arg ref and `cls.__instance` ref 107 | n_refs = sys.getrefcount(cls.__instance) - 2 108 | if n_refs > 0: 109 | err_msg = ('failed to reset %s: %d ref(s) to the Singleton ' 110 | 'instance still exist') % (cls.__name__, n_refs) 111 | raise SingletonRefsExistError(cls.__name__, err_msg) 112 | 113 | cls.__instance = None 114 | if cls.__check_args: 115 | del cls.__args 116 | 117 | def __call__(cls, *args, **kwargs): 118 | if cls.__instance is None: 119 | cls.__instance = super(Singleton, cls).__call__(*args, **kwargs) 120 | if cls.__check_args: 121 | cls.__args = _get_init_arguments(cls, *args, **kwargs) 122 | elif cls.__check_args: 123 | err_msg = (cls.__name__ + ' instance already exists but ' 124 | 'previously initialised differently - ' 125 | 'instance: %s, call: %s') 126 | args = _get_init_arguments(cls, *args, **kwargs) 127 | if args != cls.__args: 128 | raise ValueError(err_msg % (cls.__args, args)) 129 | 130 | return cls.__instance 131 | 132 | 133 | def _get_init_arguments(cls, *args, **kwargs): 134 | """Returns an OrderedDict of args passed to cls.__init__ given [kw]args.""" 135 | init_args = inspect.signature(cls.__init__) 136 | bound_args = init_args.bind(None, *args, **kwargs) 137 | bound_args.apply_defaults() 138 | arg_dict = bound_args.arguments 139 | del arg_dict['self'] 140 | return arg_dict 141 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/test_alphabet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import pytest 4 | 5 | from deepspeech.data.alphabet import Alphabet 6 | 7 | 8 | SYMBOLS = OrderedDict([(symbol, index) for index, symbol in enumerate('abcd')]) 9 | 10 | 11 | @pytest.fixture 12 | def alphabet(): 13 | return Alphabet(SYMBOLS.keys()) 14 | 15 | 16 | def test_duplicate_symbol_raise_valuerror(): 17 | with pytest.raises(ValueError): 18 | Alphabet('aa') 19 | 20 | 21 | def test_len(alphabet): 22 | assert len(alphabet) == len(SYMBOLS) 23 | 24 | 25 | def test_iterator(alphabet): 26 | exp_symbols = list(SYMBOLS.keys()) 27 | for index, symbol in enumerate(alphabet): 28 | assert symbol == exp_symbols[index] 29 | 30 | 31 | def test_get_symbol(alphabet): 32 | for symbol, index in SYMBOLS.items(): 33 | assert alphabet.get_symbol(index) == symbol 34 | 35 | 36 | def test_get_index(alphabet): 37 | for symbol, index in SYMBOLS.items(): 38 | assert alphabet.get_index(symbol) == index 39 | 40 | 41 | def test_get_symbols(alphabet): 42 | sentence = ['a', 'b', 'b', 'c'] 43 | indices = [0, 1, 1, 99, 2] 44 | actual = alphabet.get_symbols(indices) 45 | assert len(actual) == len(sentence) 46 | assert all([a == e for a, e in zip(actual, sentence)]) 47 | 48 | 49 | def test_get_indices(alphabet): 50 | sentence = ['a', 'b', 'b', 'invalid', 'c'] 51 | indices = [0, 1, 1, 2] 52 | actual = alphabet.get_indices(sentence) 53 | assert len(actual) == len(indices) 54 | assert all([a == e for a, e in zip(actual, indices)]) 55 | -------------------------------------------------------------------------------- /tests/data/test_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from deepspeech.data import preprocess 4 | 5 | 6 | def test_add_context_frames(): 7 | """Ensures context frames are correctly added to each step in a signal.""" 8 | steps, features, n_context = 10, 5, 3 9 | 10 | signal_flat = np.arange(steps * features) 11 | signal = signal_flat.reshape(steps, features) 12 | add_context_frames = preprocess.AddContextFrames(n_context) 13 | context_signal = add_context_frames(signal) 14 | 15 | assert len(context_signal) == steps 16 | 17 | # context_signal should be equal to sliding a window of size 18 | # (features * # (n_context + 1 + n_context)) with a step size of (features) 19 | # over a flattened, padded version of the original signal. 20 | pad = np.zeros(features * n_context) 21 | signal_flat_pad = np.concatenate((pad, signal_flat, pad)) 22 | window_size = features * (n_context + 1 + n_context) 23 | 24 | for i, step in enumerate(context_signal): 25 | offset = i * features 26 | assert np.all(step == signal_flat_pad[offset:offset + window_size]) 27 | 28 | 29 | def test_normalize(): 30 | """Ensures normalized tensor has mean zero, std one.""" 31 | normalize = preprocess.Normalize() 32 | tensor = (np.random.rand(100, 10) - 0.5) * 100 33 | 34 | normed = normalize(tensor) 35 | assert np.allclose(normed.mean(), 0.0) 36 | assert np.allclose(normed.std(), 1.0) 37 | -------------------------------------------------------------------------------- /tests/decoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/tests/decoder/__init__.py -------------------------------------------------------------------------------- /tests/decoder/test_decoder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import pytest 4 | import torch 5 | 6 | from deepspeech.data.alphabet import Alphabet 7 | from deepspeech.decoder import GreedyCTCDecoder 8 | 9 | 10 | BLANK = '' 11 | SYMBOLS = OrderedDict([(s, i) for i, s in enumerate([BLANK] + list('abcd'))]) 12 | 13 | 14 | @pytest.fixture 15 | def alphabet(): 16 | return Alphabet(SYMBOLS.keys()) 17 | 18 | 19 | def test_greedy_ctc_decoder_decode(alphabet): 20 | """Simple test to ensure GreedyCTCDecoder runs.""" 21 | decoder = GreedyCTCDecoder(alphabet, BLANK) 22 | 23 | logits = torch.tensor([[[0.0, 0.5, 9.0, 0.0]], # b 24 | [[0.1, 0.1, 0.1, 4.5]], # c 25 | [[0.7, 7.2, 0.4, 0.9]], # a 26 | [[0.3, 8.1, 0.9, 0.5]], # a 27 | [[0.3, 8.1, 0.9, 0.5]], # a 28 | [[0.9, 0.9, 0.9, 1.0]], # c 29 | [[1.0, 0.1, 0.1, 0.1]], # 30 | [[0.5, 0.5, 0.5, 1.8]]]) # c 31 | 32 | actual = decoder.decode(logits, [8]) 33 | assert len(actual) == 1 34 | assert actual[0] == 'bcacc' 35 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_deepspeech.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from deepspeech.models import DeepSpeech 5 | 6 | 7 | @pytest.fixture 8 | def model(): 9 | return DeepSpeech() 10 | 11 | 12 | def test_load_state_dict_restores_parameters(model): 13 | act_model = DeepSpeech() 14 | act_model.load_state_dict(model.state_dict()) 15 | 16 | # Naive equality check: all network parameters are equal. 17 | exp_params = dict(model.network.named_parameters()) 18 | print(exp_params.keys()) 19 | for act_name, act_param in act_model.network.named_parameters(): 20 | assert torch.allclose(act_param.float(), exp_params[act_name].float()) 21 | -------------------------------------------------------------------------------- /tests/models/test_deepspeech2.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from deepspeech.models import DeepSpeech2 5 | 6 | 7 | @pytest.fixture 8 | def model(): 9 | return DeepSpeech2() 10 | 11 | 12 | def test_load_state_dict_restores_parameters(model): 13 | act_model = DeepSpeech2() 14 | act_model.load_state_dict(model.state_dict()) 15 | 16 | # Naive equality check: all network parameters are equal. 17 | exp_params = dict(model.network.named_parameters()) 18 | print(exp_params.keys()) 19 | for act_name, act_param in act_model.network.named_parameters(): 20 | assert torch.allclose(act_param.float(), exp_params[act_name].float()) 21 | -------------------------------------------------------------------------------- /tests/models/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from deepspeech.models import Model 5 | 6 | 7 | class ModelStub(Model): 8 | def __init__(self): 9 | super().__init__(network=torch.nn.Linear(1, 1)) 10 | 11 | @property 12 | def transform(self): 13 | return lambda x: x 14 | 15 | @property 16 | def target_transform(self): 17 | return lambda x: x 18 | 19 | 20 | @pytest.fixture 21 | def model_stub(): 22 | return ModelStub() 23 | 24 | 25 | def test_model_attrs_exist(model_stub): 26 | attrs = ['completed_epochs', 'BLANK_SYMBOL', 'ALPHABET', 'network', 27 | 'decoder', 'optimiser', 'loss', 'transform', 'target_transform'] 28 | for attr in attrs: 29 | assert hasattr(model_stub, attr) 30 | -------------------------------------------------------------------------------- /tests/networks/test_deepspeech.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | from deepspeech.networks.deepspeech import Network 6 | 7 | 8 | def test_forget_gate_bias_correct_initialisation(): 9 | """Ensures the forget gate's bias is set to the desired value. 10 | 11 | NVIDIA's cuDNN LSTM has _two_ bias vectors for the forget gate. i.e. 12 | 13 | ``` 14 | f_t = sigmoid(W_x*x_t + b_x + W_h*h_t + b_h) 15 | ``` 16 | 17 | where for time step `t`: 18 | - `W_x` is the input matrix 19 | - `x_t` is the input 20 | - `b_x` is the input bias 21 | - `W_h` is the hidden matrix 22 | - `h_t` is the hidden state 23 | - `b_h` is the hidden bias 24 | 25 | The sum total of `b_x` and `b_h` should equal the desired 26 | `forget_gate_bias`. 27 | """ 28 | for seed in range(10): 29 | # repeat test for a few different seeds 30 | random.seed(seed) 31 | torch.manual_seed(seed) 32 | 33 | # select some arbitrary hyperparameters 34 | in_features = random.randint(1, 6144) 35 | n_hidden = random.randint(1, 6144) 36 | out_features = random.randint(1, 6144) 37 | drop_prob = random.random() 38 | 39 | # select arbitrary forget gate bias in [-50.0, 50.0) 40 | forget_gate_bias = (random.random() - 0.5) * 100 41 | 42 | net = Network(in_features=in_features, 43 | n_hidden=n_hidden, 44 | out_features=out_features, 45 | drop_prob=drop_prob, 46 | forget_gate_bias=forget_gate_bias) 47 | 48 | params = dict(net.named_parameters()) 49 | 50 | for dir in ['', '_reverse']: 51 | bias_sum = params['bi_lstm.bias_ih_l0' + dir].data 52 | bias_sum += params['bi_lstm.bias_hh_l0' + dir].data 53 | 54 | assert torch.allclose(bias_sum[n_hidden:2*n_hidden], 55 | torch.tensor(forget_gate_bias)) 56 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | 4 | 5 | @contextlib.contextmanager 6 | def environ(**env): 7 | """A context manager that temporarily changes `os.environ`.""" 8 | saved = os.environ.copy() 9 | os.environ.update(env) 10 | try: 11 | yield 12 | finally: 13 | os.environ.clear() 14 | os.environ.update(saved) 15 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MyrtleSoftware/deepspeech/995ce2d69c8da9b513a2b4e5e3f0b3bae5682373/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_singleton.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from deepspeech.utils.singleton import Singleton 4 | from deepspeech.utils.singleton import SingletonNotExistError 5 | from deepspeech.utils.singleton import SingletonRefsExistError 6 | 7 | 8 | @pytest.fixture 9 | def singleton_cls(request): 10 | """Returns a unique Singleton class with `check_args=True` for each request. 11 | 12 | A Singleton class defined at the module level would be shared by all tests 13 | making them dependent on each other due to it's stateful nature. Using a 14 | unique class per test ensures isolation. 15 | """ 16 | def __init__(self, val): 17 | self.val = val 18 | 19 | cls_name = request.function.__name__ + '_singleton' 20 | cls = Singleton(cls_name, (), {'__init__': __init__}, check_args=True) 21 | 22 | return cls 23 | 24 | 25 | def test_singleton_created_once(singleton_cls): 26 | a = singleton_cls(3) 27 | b = singleton_cls(3) 28 | assert a is b 29 | 30 | 31 | def test_singleton_check_args_matching_ok(singleton_cls): 32 | a = singleton_cls(5) 33 | b = singleton_cls(val=5) 34 | assert a is b 35 | 36 | 37 | def test_singleton_check_args_nonmatching_raises_value_error(singleton_cls): 38 | singleton_cls(val=5) 39 | with pytest.raises(ValueError): 40 | singleton_cls(6) 41 | 42 | 43 | def test_get_singleton_raises_error_when_not_exist(singleton_cls): 44 | with pytest.raises(SingletonNotExistError): 45 | singleton_cls.get_singleton() 46 | 47 | 48 | def test_get_singleton_returns_singleton(singleton_cls): 49 | a = singleton_cls(4) 50 | b = singleton_cls.get_singleton() 51 | assert a is b 52 | 53 | 54 | def test_get_or_init_singleton_creates_singleton_when_not_exist(singleton_cls): 55 | a = singleton_cls.get_or_init_singleton(val=10) 56 | assert a.val == 10 57 | 58 | 59 | def test_get_or_init_singleton_returns_singleton_when_exist(singleton_cls): 60 | a = singleton_cls.get_or_init_singleton(val=10) 61 | b = singleton_cls.get_or_init_singleton(val=15) 62 | assert a is b 63 | assert a.val == b.val == 10 64 | 65 | 66 | def test_reset_singleton_removes_class_ref_when_no_other_refs(singleton_cls): 67 | singleton_cls(val=5) # create Singleton instance but keep no reference 68 | 69 | # ref to Singleton instance in class should be deleted 70 | singleton_cls._reset_singleton() 71 | 72 | with pytest.raises(SingletonNotExistError): 73 | assert singleton_cls.get_singleton() 74 | 75 | # should be able to create a new Singleton instance 76 | a = singleton_cls(val=10) 77 | 78 | b = singleton_cls.get_singleton() 79 | assert a is b 80 | 81 | 82 | def test_reset_singleton_raises_error_when_other_refs(singleton_cls): 83 | # create Singleton instance and _keep_ reference 84 | a = singleton_cls(val=10) # noqa: F841 85 | with pytest.raises(SingletonRefsExistError): 86 | singleton_cls._reset_singleton() 87 | --------------------------------------------------------------------------------