├── .github ├── stale.yml └── workflows │ ├── ci-test.yml │ └── pull-request.yml ├── .gitignore ├── Dockerfile ├── LICENCE ├── README.md ├── configs ├── an4.yaml ├── commonvoice.yaml ├── librispeech.yaml └── tedlium.yaml ├── data ├── an4.py ├── common_voice.py ├── librispeech.py ├── merge_manifests.py ├── ted.py ├── verify_manifest.py └── voxforge.py ├── deepspeech_pytorch ├── __init__.py ├── checkpoint.py ├── configs │ ├── __init__.py │ ├── inference_config.py │ ├── lightning_config.py │ └── train_config.py ├── data │ ├── __init__.py │ ├── data_opts.py │ └── utils.py ├── decoder.py ├── enums.py ├── inference.py ├── loader │ ├── __init__.py │ ├── data_loader.py │ ├── data_module.py │ ├── sparse_image_warp.py │ └── spec_augment.py ├── model.py ├── testing.py ├── training.py ├── utils.py └── validation.py ├── kubernetes ├── README.md ├── data │ ├── persistent_volume.yaml │ ├── storage.yaml │ └── transfer_data.yaml └── train.yaml ├── labels.json ├── noise_inject.py ├── requirements.txt ├── search_lm_params.py ├── select_lm_params.py ├── server.py ├── setup.py ├── test.py ├── tests ├── __init__.py ├── pretrained_smoke_test.py └── smoke_test.py ├── train.py └── transcribe.py /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 30 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 30 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - high priority 8 | - nostale 9 | # Label to use when marking an issue as stale 10 | staleLabel: stale 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /.github/workflows/ci-test.yml: -------------------------------------------------------------------------------- 1 | name: ci-test 2 | on: 3 | push: 4 | branches: [ master ] 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Check out code 10 | uses: actions/checkout@v2 11 | - name: Set up node 12 | uses: actions/setup-node@v1 13 | - name: Install dependencies 14 | run: | 15 | python --version 16 | python -m pip install --upgrade --user pip 17 | pip --version 18 | pip install -r requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade 19 | sudo apt-get update -y 20 | sudo apt-get install -y git curl ca-certificates sox libsox-dev libsox-fmt-all 21 | git clone --recursive https://github.com/parlance/ctcdecode.git 22 | cd ctcdecode; pip install . 23 | shell: bash 24 | - name: Run tests 25 | run: pytest tests/ -vv -s 26 | -------------------------------------------------------------------------------- /.github/workflows/pull-request.yml: -------------------------------------------------------------------------------- 1 | name: pull-request 2 | on: 3 | pull_request: 4 | branches: [ master ] 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Check out code 10 | uses: actions/checkout@v2 11 | - name: Set up node 12 | uses: actions/setup-node@v1 13 | - name: Install dependencies 14 | run: | 15 | python --version 16 | python -m pip install --upgrade --user pip 17 | pip --version 18 | pip install -r requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade 19 | sudo apt-get update -y 20 | sudo apt-get install -y git curl ca-certificates sox libsox-dev libsox-fmt-all 21 | git clone --recursive https://github.com/parlance/ctcdecode.git 22 | cd ctcdecode; pip install . 23 | shell: bash 24 | - name: Run tests 25 | run: pytest tests/ -vv -s 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # IntelliJ IDEA 74 | .idea/ 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # deepspeech.pytorch-specific dirs 104 | models/ 105 | data/*.csv 106 | data/*_dataset/ 107 | .vscode/ 108 | outputs/ 109 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.03-py3 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | 4 | ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH 5 | 6 | WORKDIR /workspace/ 7 | 8 | # install basics 9 | RUN apt-get update -y 10 | RUN apt-get install -y git curl ca-certificates bzip2 cmake tree htop bmon iotop sox libsox-dev libsox-fmt-all vim 11 | 12 | # install ctcdecode 13 | RUN git clone --recursive https://github.com/parlance/ctcdecode.git 14 | RUN cd ctcdecode; pip install . 15 | 16 | # install deepspeech.pytorch 17 | ADD . /workspace/deepspeech.pytorch 18 | RUN cd deepspeech.pytorch; pip install -r requirements.txt && pip install -e . 19 | 20 | # launch jupyter 21 | WORKDIR /workspace/deepspeech.pytorch 22 | RUN mkdir data; mkdir notebooks; 23 | CMD jupyter-notebook --ip="*" --no-browser --allow-root 24 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sean Naren 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deepspeech.pytorch 2 | ![Tests](https://github.com/SeanNaren/deepspeech.pytorch/actions/workflows/ci-test.yml/badge.svg) 3 | 4 | Implementation of DeepSpeech2 for PyTorch using [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). The repo supports training/testing and inference using the [DeepSpeech2](http://arxiv.org/pdf/1512.02595v1.pdf) model. Optionally a [kenlm](https://github.com/kpu/kenlm) language model can be used at inference time. 5 | 6 | ## Install 7 | 8 | Several libraries are needed to be installed for training to work. I will assume that everything is being installed in 9 | an Anaconda installation on Ubuntu, with PyTorch installed. 10 | 11 | Install [PyTorch](https://github.com/pytorch/pytorch#installation) if you haven't already. 12 | 13 | If you want decoding to support beam search with an optional language model, install ctcdecode: 14 | ``` 15 | git clone --recursive https://github.com/parlance/ctcdecode.git 16 | cd ctcdecode && pip install . 17 | ``` 18 | 19 | Finally clone this repo and run this within the repo: 20 | ``` 21 | pip install -r requirements.txt 22 | pip install -e . # Dev install 23 | ``` 24 | 25 | If you plan to use Multi-node training, you'll need etcd. Below is the command to install on Ubuntu. 26 | ``` 27 | sudo apt-get install etcd 28 | ``` 29 | 30 | ### Docker 31 | 32 | To use the image with a GPU you'll need to have [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) installed. 33 | 34 | ```bash 35 | sudo docker run -ti --gpus all -v `pwd`/data:/workspace/data --tmpfs /tmp -p 8888:8888 --net=host --ipc=host seannaren/deepspeech.pytorch:latest # Opens a Jupyter notebook, mounting the /data drive in the container 36 | ``` 37 | 38 | Optionally you can use the command line by changing the entrypoint: 39 | 40 | ```bash 41 | sudo docker run -ti --gpus all -v `pwd`/data:/workspace/data --tmpfs /tmp --entrypoint=/bin/bash --net=host --ipc=host seannaren/deepspeech.pytorch:latest 42 | ``` 43 | 44 | ## Training 45 | 46 | ### Datasets 47 | 48 | Currently supports [AN4](http://www.speech.cs.cmu.edu/databases/an4/), [TEDLIUM](https://www.openslr.org/51/), [Voxforge](http://www.voxforge.org/), [Common Voice](https://commonvoice.mozilla.org/en/datasets) and [LibriSpeech](https://www.openslr.org/12). Scripts will setup the dataset and create manifest files used in data-loading. The scripts can be found in the data/ folder. Many of the scripts allow you to download the raw datasets separately if you choose so. 49 | 50 | ### Training Commands 51 | 52 | ##### AN4 53 | 54 | ```bash 55 | cd data/ && python an4.py && cd .. 56 | 57 | python train.py +configs=an4 58 | ``` 59 | 60 | ##### LibriSpeech 61 | 62 | ```bash 63 | cd data/ && python librispeech.py && cd .. 64 | 65 | python train.py +configs=librispeech 66 | ``` 67 | 68 | ##### Common Voice 69 | 70 | ```bash 71 | cd data/ && python common_voice.py && cd .. 72 | 73 | python train.py +configs=commonvoice 74 | ``` 75 | ##### TEDlium 76 | 77 | ```bash 78 | cd data/ && python ted.py && cd .. 79 | 80 | python train.py +configs=tedlium 81 | ``` 82 | 83 | #### Custom Dataset 84 | 85 | To create a custom dataset you must create a JSON file containing the locations of the training/testing data. This has to be in the format of: 86 | ```json 87 | { 88 | "root_path":"path/to", 89 | "samples":[ 90 | {"wav_path":"audio.wav","transcript_path":"text.txt"}, 91 | {"wav_path":"audio2.wav","transcript_path":"text2.txt"}, 92 | ... 93 | ] 94 | } 95 | ``` 96 | Where the `root_path` is the root directory, `wav_path` is to the audio file, and the `transcript_path` is to a text file containing the transcript on one line. This can then be used as stated below. 97 | 98 | ##### Note on CSV files ... 99 | Up until release [V2.1](https://github.com/SeanNaren/deepspeech.pytorch/releases/tag/V2.1), deepspeech.pytorch used CSV manifest files instead of JSON. 100 | These manifest files are formatted similarly as a 2 column table: 101 | ``` 102 | /path/to/audio.wav,/path/to/text.txt 103 | /path/to/audio2.wav,/path/to/text2.txt 104 | ... 105 | ``` 106 | Note that this format is incompatible [V3.0](https://github.com/SeanNaren/deepspeech.pytorch/releases/tag/V3.0) onwards. 107 | 108 | #### Merging multiple manifest files 109 | 110 | To create bigger manifest files (to train/test on multiple datasets at once) we can merge manifest files together like below. 111 | 112 | ``` 113 | cd data/ 114 | python merge_manifests.py manifest_1.json manifest_2.json --out new_manifest_dir 115 | ``` 116 | 117 | ### Modifying Training Configs 118 | 119 | Configuration is done via [Hydra](https://github.com/facebookresearch/hydra). 120 | 121 | Defaults can be seen in [config.py](deepspeech_pytorch/configs/train_config.py). Below is how you can override values set already: 122 | 123 | ``` 124 | python train.py data.train_path=data/train_manifest.json data.val_path=data/val_manifest.json 125 | ``` 126 | 127 | Use `python train.py --help` for all parameters and options. 128 | 129 | You can also specify a config file to keep parameters stored in a yaml file like so: 130 | 131 | Create folder `experiment/` and file `experiment/an4.yaml`: 132 | ```yaml 133 | data: 134 | train_path: data/an4_train_manifest.json 135 | val_path: data/an4_val_manifest.json 136 | ``` 137 | 138 | ``` 139 | python train.py +experiment=an4 140 | ``` 141 | 142 | To see options available, check [here](./deepspeech_pytorch/configs/train_config.py). 143 | 144 | ### Multi-GPU Training 145 | 146 | We support single-machine multi-GPU training via [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). 147 | 148 | Below is an example command when training on a machine with 4 local GPUs: 149 | 150 | ``` 151 | python train.py +configs=an4 trainer.gpus=4 152 | ``` 153 | 154 | ### Multi-Node Training 155 | 156 | Also supported is multi-machine capabilities using TorchElastic. This requires a node to exist as an explicit etcd host (which could be one of the GPU nodes but isn't recommended), a shared mount across your cluster to load/save checkpoints and communication between the nodes. 157 | 158 | Below is an example where we've set one of our GPU nodes as our etcd host however if you're scaling up, it would be suggested to have a separate instance as your etcd instance to your GPU nodes as this will be a single point of failure. 159 | 160 | Assumed below is a shared drive called /share where we save our checkpoints and data to access. 161 | 162 | Run on the etcd host: 163 | ``` 164 | PUBLIC_HOST_NAME=127.0.0.1 # Change to public host name for all nodes to connect 165 | etcd --enable-v2 \ 166 | --listen-client-urls http://$PUBLIC_HOST_NAME:4377 \ 167 | --advertise-client-urls http://$PUBLIC_HOST_NAME:4377 \ 168 | --listen-peer-urls http://$PUBLIC_HOST_NAME:4379 169 | ``` 170 | 171 | Run on each GPU node: 172 | ``` 173 | python -m torchelastic.distributed.launch \ 174 | --nnodes=2 \ 175 | --nproc_per_node=4 \ 176 | --rdzv_id=123 \ 177 | --rdzv_backend=etcd \ 178 | --rdzv_endpoint=$PUBLIC_HOST_NAME:4377 \ 179 | train.py data.train_path=/share/data/an4_train_manifest.json \ 180 | data.val_path=/share/data/an4_val_manifest.json model.precision=half \ 181 | data.num_workers=8 checkpoint.save_folder=/share/checkpoints/ \ 182 | checkpoint.checkpoint=true checkpoint.load_auto_checkpoint=true checkpointing.save_n_recent_models=3 \ 183 | data.batch_size=8 trainer.max_epochs=70 \ 184 | trainer.accelerator=ddp trainer.gpus=4 trainer.num_nodes=2 185 | ``` 186 | 187 | Using the `load_auto_checkpoint=true` flag we can re-continue training from the latest saved checkpoint. 188 | 189 | Currently it is expected that there is an NFS drive/shared mount across all nodes within the cluster to load the latest checkpoint from. 190 | 191 | ### Augmentation 192 | 193 | There is support for three different types of augmentations: SpecAugment, noise injection and random tempo/gain perturbations. 194 | 195 | #### SpecAugment 196 | 197 | Applies simple Spectral Augmentation techniques directly on Mel spectogram features to make the model more robust to variations in input data. To enable SpecAugment, use the `--spec-augment` flag when training. 198 | 199 | SpecAugment implementation was adapted from [this](https://github.com/DemisEom/SpecAugment) project. 200 | 201 | #### Noise Injection 202 | 203 | Dynamically adds noise into the training data to increase robustness. To use, first fill a directory up with all the noise files you want to sample from. 204 | The dataloader will randomly pick samples from this directory. 205 | 206 | To enable noise injection, use the `--noise-dir /path/to/noise/dir/` to specify where your noise files are. There are a few noise parameters to tweak, such as 207 | `--noise_prob` to determine the probability that noise is added, and the `--noise-min`, `--noise-max` parameters to determine the minimum and maximum noise to add in training. 208 | 209 | Included is a script to inject noise into an audio file to hear what different noise levels/files would sound like. Useful for curating the noise dataset. 210 | 211 | ``` 212 | python noise_inject.py --input-path /path/to/input.wav --noise-path /path/to/noise.wav --output-path /path/to/input_injected.wav --noise-level 0.5 # higher levels means more noise 213 | ``` 214 | 215 | #### Tempo/Gain Perturbation 216 | 217 | Applies small changes to the tempo and gain when loading audio to increase robustness. To use, use the `--speed-volume-perturb` flag when training. 218 | 219 | ### Checkpoints 220 | 221 | Typically checkpoints are stored in `lightning_logs/` in the current working directory of the script. 222 | 223 | This can be adjusted: 224 | 225 | ``` 226 | python train.py checkpoint.file_path=save_dir/ 227 | ``` 228 | 229 | To load a previously saved checkpoint: 230 | 231 | ``` 232 | python train.py trainer.resume_from_checkpoint=lightning_logs/deepspeech_checkpoint_epoch_N_iter_N.ckpt 233 | ``` 234 | 235 | This continues from the same training state. 236 | 237 | ## Testing/Inference 238 | 239 | To evaluate a trained model on a test set (has to be in the same format as the training set): 240 | 241 | ``` 242 | python test.py model.model_path=models/deepspeech.pth test_path=/path/to/test_manifest.json 243 | ``` 244 | 245 | An example script to output a transcription has been provided: 246 | 247 | ``` 248 | python transcribe.py \ 249 | model.model_path=models/deepspeech.pth \ 250 | model.cuda=True \ 251 | chunk_size_seconds=-1 \ 252 | audio_path=audio_path=/path/to/audio.wav 253 | ``` 254 | 255 | If you used mixed-precision or half precision when training the model, you can use the `model.precision=half` for a speed/memory benefit. If you want to transcribe a long audio file that does not fit in the GPU, change the value of `chunk_size_seconds` to a positive number which represents the chunk size in seconds that will be used to segment the long audio file based on it. 256 | 257 | ## Inference Server 258 | 259 | Included is a basic server script that will allow post request to be sent to the server to transcribe files. 260 | 261 | ``` 262 | python server.py --host 0.0.0.0 --port 8000 # Run on one window 263 | 264 | curl -X POST http://0.0.0.0:8000/transcribe -H "Content-type: multipart/form-data" -F "file=@/path/to/input.wav" 265 | ``` 266 | 267 | ## Using an ARPA LM 268 | 269 | We support using kenlm based LMs. Below are instructions on how to take the LibriSpeech LMs found [here](http://www.openslr.org/11/) and tune the model to give you the best parameters when decoding, based on LibriSpeech. 270 | 271 | ### Tuning the LibriSpeech LMs 272 | 273 | First ensure you've set up the librispeech datasets from the data/ folder. 274 | In addition download the latest pre-trained librispeech model from the releases page, as well as the ARPA model you want to tune from [here](http://www.openslr.org/11/). For the below we use the 3-gram ARPA model (3e-7 prune). 275 | 276 | First we need to generate the acoustic output to be used to evaluate the model on LibriSpeech val. 277 | ``` 278 | python test.py data.test_path=data/librispeech_val_manifest.json model.model_path=librispeech_pretrained_v2.pth save_output=librispeech_val_output.npy 279 | ``` 280 | 281 | We use a beam width of 128 which gives reasonable results. We suggest using a CPU intensive node to carry out the grid search. 282 | 283 | ``` 284 | python search_lm_params.py --num-workers 16 --saved-output librispeech_val_output.npy --output-path libri_tune_output.json --lm-alpha-from 0 --lm-alpha-to 5 --lm-beta-from 0 --lm-beta-to 3 --lm-path 3-gram.pruned.3e-7.arpa --model-path librispeech_pretrained_v2.pth --beam-width 128 --lm-workers 16 285 | ``` 286 | 287 | This will run a grid search across the alpha/beta parameters using a beam width of 128. Use the below script to find the best alpha/beta params: 288 | 289 | ``` 290 | python select_lm_params.py --input-path libri_tune_output.json 291 | ``` 292 | 293 | Use the alpha/beta parameters when using the beam decoder. 294 | 295 | ### Building your own LM 296 | 297 | To build your own LM you need to use the KenLM repo found [here](https://github.com/kpu/kenlm). Have a read of the documentation to get a sense of how to train your own LM. The above steps once trained can be used to find the appropriate parameters. 298 | 299 | ### Alternate Decoders 300 | By default, `test.py` and `transcribe.py` use a `GreedyDecoder` which picks the highest-likelihood output label at each timestep. Repeated and blank symbols are then filtered to give the final output. 301 | 302 | A beam search decoder can optionally be used with the installation of the `ctcdecode` library as described in the Installation section. The `test` and `transcribe` scripts have a `lm` config. To use the beam decoder, add `lm.decoder_type=beam`. The beam decoder enables additional decoding parameters: 303 | - **lm.beam_width** how many beams to consider at each timestep 304 | - **lm.lm_path** optional binary KenLM language model to use for decoding 305 | - **lm.alpha** weight for language model 306 | - **lm.beta** bonus weight for words 307 | 308 | ### Time offsets 309 | 310 | Use the `offsets=true` flag to get positional information of each character in the transcription when using `transcribe.py` script. The offsets are based on the size 311 | of the output tensor, which you need to convert into a format required. 312 | For example, based on default parameters you could multiply the offsets by a scalar (duration of file in seconds / size of output) to get the offsets in seconds. 313 | 314 | ## Pre-trained models 315 | 316 | Pre-trained models can be found under releases [here](https://github.com/SeanNaren/deepspeech.pytorch/releases). 317 | 318 | ## Acknowledgements 319 | 320 | Thanks to [Egor](https://github.com/EgorLakomkin) and [Ryan](https://github.com/ryanleary) for their contributions! 321 | -------------------------------------------------------------------------------- /configs/an4.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_path: data/an4_train_manifest.json 4 | val_path: data/an4_val_manifest.json 5 | batch_size: 8 6 | num_workers: 8 7 | trainer: 8 | max_epochs: 70 9 | accelerator: 'auto' 10 | devices: 1 11 | precision: 16 12 | gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients 13 | strategy: ddp 14 | enable_checkpointing: True 15 | checkpoint: 16 | save_top_k: 1 17 | monitor: "wer" 18 | verbose: True -------------------------------------------------------------------------------- /configs/commonvoice.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_path: data/commonvoice_train_manifest.json 4 | val_path: data/commonvoice_dev_manifest.json 5 | num_workers: 8 6 | augmentation: 7 | spec_augment: True 8 | trainer: 9 | max_epochs: 16 10 | accelerator: 'auto' 11 | devices: 1 12 | precision: 16 13 | gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients 14 | strategy: ddp 15 | enable_checkpointing: True 16 | checkpoint: 17 | save_top_k: 1 18 | monitor: "wer" 19 | verbose: True -------------------------------------------------------------------------------- /configs/librispeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_path: data/libri_train_manifest.json 4 | val_path: data/libri_val_manifest.json 5 | num_workers: 8 6 | augmentation: 7 | spec_augment: True 8 | trainer: 9 | max_epochs: 16 10 | accelerator: 'auto' 11 | devices: 1 12 | precision: 16 13 | gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients 14 | strategy: ddp 15 | enable_checkpointing: True 16 | checkpoint: 17 | save_top_k: 1 18 | monitor: "wer" 19 | verbose: True -------------------------------------------------------------------------------- /configs/tedlium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_path: data/ted_train_manifest.json 4 | val_path: data/ted_val_manifest.json 5 | num_workers: 8 6 | augmentation: 7 | spec_augment: True 8 | trainer: 9 | max_epochs: 16 10 | accelerator: 'auto' 11 | devices: 1 12 | precision: 16 13 | gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients 14 | strategy: ddp 15 | enable_checkpointing: True 16 | checkpoint: 17 | save_top_k: 1 18 | monitor: "wer" 19 | verbose: True -------------------------------------------------------------------------------- /data/an4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tarfile 4 | 5 | import wget 6 | 7 | from deepspeech_pytorch.data.data_opts import add_data_opts 8 | from deepspeech_pytorch.data.utils import create_manifest 9 | 10 | 11 | def download_an4(target_dir: str, 12 | manifest_dir: str, 13 | min_duration: float, 14 | max_duration: float, 15 | num_workers: int): 16 | raw_tar_path = 'an4.tar.gz' 17 | if not os.path.exists(raw_tar_path): 18 | wget.download('https://github.com/SeanNaren/deepspeech.pytorch/releases/download/V3.0/an4.tar.gz') 19 | tar = tarfile.open('an4.tar.gz') 20 | os.makedirs(target_dir, exist_ok=True) 21 | tar.extractall(target_dir) 22 | train_path = target_dir + '/train/' 23 | val_path = target_dir + '/val/' 24 | test_path = target_dir + '/test/' 25 | 26 | print('Creating manifests...') 27 | create_manifest(data_path=train_path, 28 | output_name='an4_train_manifest.json', 29 | manifest_path=manifest_dir, 30 | min_duration=min_duration, 31 | max_duration=max_duration, 32 | num_workers=num_workers) 33 | create_manifest(data_path=val_path, 34 | output_name='an4_val_manifest.json', 35 | manifest_path=manifest_dir, 36 | min_duration=min_duration, 37 | max_duration=max_duration, 38 | num_workers=num_workers) 39 | create_manifest(data_path=test_path, 40 | output_name='an4_test_manifest.json', 41 | manifest_path=manifest_dir, 42 | num_workers=num_workers) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser(description='Processes and downloads an4.') 47 | parser = add_data_opts(parser) 48 | parser.add_argument('--target-dir', default='an4_dataset/', help='Path to save dataset') 49 | args = parser.parse_args() 50 | assert args.sample_rate == 16000, "AN4 only supports sample rate of 16000 currently." 51 | download_an4( 52 | target_dir=args.target_dir, 53 | manifest_dir=args.manifest_dir, 54 | min_duration=args.min_duration, 55 | max_duration=args.max_duration, 56 | num_workers=args.num_workers 57 | ) 58 | -------------------------------------------------------------------------------- /data/common_voice.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import tarfile 5 | from multiprocessing.pool import ThreadPool 6 | 7 | from sox import Transformer 8 | import tqdm 9 | import wget 10 | 11 | from deepspeech_pytorch.data.data_opts import add_data_opts 12 | from deepspeech_pytorch.data.utils import create_manifest 13 | 14 | parser = argparse.ArgumentParser(description='Downloads and processes Mozilla Common Voice dataset.') 15 | parser = add_data_opts(parser) 16 | parser.add_argument("--target-dir", default='CommonVoice_dataset/', type=str, help="Directory to store the dataset.") 17 | parser.add_argument("--tar-path", type=str, help="Path to the Common Voice *.tar file if downloaded (Optional).") 18 | parser.add_argument("--language-dir", default='en', type=str, help="Language dir to process.") 19 | parser.add_argument('--files-to-process', nargs='+', default=['test.tsv', 'dev.tsv', 'train.tsv'], 20 | type=str, help='list of *.csv file names to process') 21 | args = parser.parse_args() 22 | VERSION = 'cv-corpus-5.1-2020-06-22' 23 | COMMON_VOICE_URL = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/" \ 24 | "{}/en.tar.gz".format(VERSION) 25 | 26 | 27 | def convert_to_wav(csv_file, target_dir, num_workers): 28 | """ Read *.csv file description, convert mp3 to wav, process text. 29 | Save results to target_dir. 30 | 31 | Args: 32 | csv_file: str, path to *.csv file with data description, usually start from 'cv-' 33 | target_dir: str, path to dir to save results; wav/ and txt/ dirs will be created 34 | """ 35 | wav_dir = os.path.join(target_dir, 'wav/') 36 | txt_dir = os.path.join(target_dir, 'txt/') 37 | os.makedirs(wav_dir, exist_ok=True) 38 | os.makedirs(txt_dir, exist_ok=True) 39 | audio_clips_path = os.path.dirname(csv_file) + '/clips/' 40 | 41 | def process(x): 42 | file_path, text = x 43 | file_name = os.path.splitext(os.path.basename(file_path))[0] 44 | text = text.strip().upper() 45 | with open(os.path.join(txt_dir, file_name + '.txt'), 'w') as f: 46 | f.write(text) 47 | audio_path = os.path.join(audio_clips_path, file_path) 48 | output_wav_path = os.path.join(wav_dir, file_name + '.wav') 49 | 50 | tfm = Transformer() 51 | tfm.rate(samplerate=args.sample_rate) 52 | tfm.build( 53 | input_filepath=audio_path, 54 | output_filepath=output_wav_path 55 | ) 56 | 57 | print('Converting mp3 to wav for {}.'.format(csv_file)) 58 | with open(csv_file) as csvfile: 59 | reader = csv.DictReader(csvfile, delimiter='\t') 60 | next(reader, None) # skip the headers 61 | data = [(row['path'], row['sentence']) for row in reader] 62 | with ThreadPool(num_workers) as pool: 63 | list(tqdm.tqdm(pool.imap(process, data), total=len(data))) 64 | 65 | 66 | def main(): 67 | target_dir = args.target_dir 68 | language_dir = args.language_dir 69 | 70 | os.makedirs(target_dir, exist_ok=True) 71 | 72 | target_unpacked_dir = os.path.join(target_dir, "CV_unpacked") 73 | 74 | if os.path.exists(target_unpacked_dir): 75 | print('Find existing folder {}'.format(target_unpacked_dir)) 76 | else: 77 | print("Could not find Common Voice, Downloading corpus...") 78 | 79 | filename = wget.download(COMMON_VOICE_URL, target_dir) 80 | target_file = os.path.join(target_dir, os.path.basename(filename)) 81 | 82 | os.makedirs(target_unpacked_dir, exist_ok=True) 83 | print("Unpacking corpus to {} ...".format(target_unpacked_dir)) 84 | tar = tarfile.open(target_file) 85 | tar.extractall(target_unpacked_dir) 86 | tar.close() 87 | 88 | folder_path = os.path.join(target_unpacked_dir, VERSION + '/{}/'.format(language_dir)) 89 | 90 | for csv_file in args.files_to_process: 91 | convert_to_wav( 92 | csv_file=os.path.join(folder_path, csv_file), 93 | target_dir=os.path.join(target_dir, os.path.splitext(csv_file)[0]), 94 | num_workers=args.num_workers 95 | ) 96 | 97 | print('Creating manifests...') 98 | for csv_file in args.files_to_process: 99 | create_manifest( 100 | data_path=os.path.join(target_dir, os.path.splitext(csv_file)[0]), 101 | output_name='commonvoice_' + os.path.splitext(csv_file)[0] + '_manifest.json', 102 | manifest_path=args.manifest_dir, 103 | min_duration=args.min_duration, 104 | max_duration=args.max_duration, 105 | num_workers=args.num_workers 106 | ) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /data/librispeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | import argparse 5 | import subprocess 6 | 7 | from deepspeech_pytorch.data.data_opts import add_data_opts 8 | from tqdm import tqdm 9 | import shutil 10 | 11 | from deepspeech_pytorch.data.utils import create_manifest 12 | 13 | parser = argparse.ArgumentParser(description='Processes and downloads LibriSpeech dataset.') 14 | parser = add_data_opts(parser) 15 | parser.add_argument("--target-dir", default='LibriSpeech_dataset/', type=str, help="Directory to store the dataset.") 16 | parser.add_argument('--files-to-use', default="train-clean-100.tar.gz," 17 | "train-clean-360.tar.gz,train-other-500.tar.gz," 18 | "dev-clean.tar.gz,dev-other.tar.gz," 19 | "test-clean.tar.gz,test-other.tar.gz", type=str, 20 | help='list of file names to download') 21 | args = parser.parse_args() 22 | 23 | LIBRI_SPEECH_URLS = { 24 | "train": ["http://www.openslr.org/resources/12/train-clean-100.tar.gz", 25 | "http://www.openslr.org/resources/12/train-clean-360.tar.gz", 26 | "http://www.openslr.org/resources/12/train-other-500.tar.gz"], 27 | 28 | "val": ["http://www.openslr.org/resources/12/dev-clean.tar.gz", 29 | "http://www.openslr.org/resources/12/dev-other.tar.gz"], 30 | 31 | "test_clean": ["http://www.openslr.org/resources/12/test-clean.tar.gz"], 32 | "test_other": ["http://www.openslr.org/resources/12/test-other.tar.gz"] 33 | } 34 | 35 | 36 | def _preprocess_transcript(phrase): 37 | return phrase.strip().upper() 38 | 39 | 40 | def _process_file(wav_dir, txt_dir, base_filename, root_dir): 41 | full_recording_path = os.path.join(root_dir, base_filename) 42 | assert os.path.exists(full_recording_path) and os.path.exists(root_dir) 43 | wav_recording_path = os.path.join(wav_dir, base_filename.replace(".flac", ".wav")) 44 | subprocess.call(["sox {} -r {} -b 16 -c 1 {}".format(full_recording_path, str(args.sample_rate), 45 | wav_recording_path)], shell=True) 46 | # process transcript 47 | txt_transcript_path = os.path.join(txt_dir, base_filename.replace(".flac", ".txt")) 48 | transcript_file = os.path.join(root_dir, "-".join(base_filename.split('-')[:-1]) + ".trans.txt") 49 | assert os.path.exists(transcript_file), "Transcript file {} does not exist.".format(transcript_file) 50 | transcriptions = open(transcript_file).read().strip().split("\n") 51 | transcriptions = {t.split()[0].split("-")[-1]: " ".join(t.split()[1:]) for t in transcriptions} 52 | with open(txt_transcript_path, "w") as f: 53 | key = base_filename.replace(".flac", "").split("-")[-1] 54 | assert key in transcriptions, "{} is not in the transcriptions".format(key) 55 | f.write(_preprocess_transcript(transcriptions[key])) 56 | f.flush() 57 | 58 | 59 | def main(): 60 | target_dl_dir = args.target_dir 61 | if not os.path.exists(target_dl_dir): 62 | os.makedirs(target_dl_dir) 63 | files_to_dl = args.files_to_use.strip().split(',') 64 | for split_type, lst_libri_urls in LIBRI_SPEECH_URLS.items(): 65 | split_dir = os.path.join(target_dl_dir, split_type) 66 | if not os.path.exists(split_dir): 67 | os.makedirs(split_dir) 68 | split_wav_dir = os.path.join(split_dir, "wav") 69 | if not os.path.exists(split_wav_dir): 70 | os.makedirs(split_wav_dir) 71 | split_txt_dir = os.path.join(split_dir, "txt") 72 | if not os.path.exists(split_txt_dir): 73 | os.makedirs(split_txt_dir) 74 | extracted_dir = os.path.join(split_dir, "LibriSpeech") 75 | if os.path.exists(extracted_dir): 76 | shutil.rmtree(extracted_dir) 77 | for url in lst_libri_urls: 78 | # check if we want to dl this file 79 | dl_flag = False 80 | for f in files_to_dl: 81 | if url.find(f) != -1: 82 | dl_flag = True 83 | if not dl_flag: 84 | print("Skipping url: {}".format(url)) 85 | continue 86 | filename = url.split("/")[-1] 87 | target_filename = os.path.join(split_dir, filename) 88 | if not os.path.exists(target_filename): 89 | wget.download(url, split_dir) 90 | print("Unpacking {}...".format(filename)) 91 | tar = tarfile.open(target_filename) 92 | tar.extractall(split_dir) 93 | tar.close() 94 | os.remove(target_filename) 95 | print("Converting flac files to wav and extracting transcripts...") 96 | assert os.path.exists(extracted_dir), "Archive {} was not properly uncompressed.".format(filename) 97 | for root, subdirs, files in tqdm(os.walk(extracted_dir)): 98 | for f in files: 99 | if f.find(".flac") != -1: 100 | _process_file(wav_dir=split_wav_dir, txt_dir=split_txt_dir, 101 | base_filename=f, root_dir=root) 102 | 103 | print("Finished {}".format(url)) 104 | shutil.rmtree(extracted_dir) 105 | if split_type == 'train': # Prune to min/max duration 106 | create_manifest( 107 | data_path=split_dir, 108 | output_name='libri_' + split_type + '_manifest.json', 109 | manifest_path=args.manifest_dir, 110 | min_duration=args.min_duration, 111 | max_duration=args.max_duration, 112 | num_workers=args.num_workers 113 | ) 114 | else: 115 | create_manifest( 116 | data_path=split_dir, 117 | output_name='libri_' + split_type + '_manifest.json', 118 | manifest_path=args.manifest_dir, 119 | num_workers=args.num_workers 120 | ) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /data/merge_manifests.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import json 4 | import os 5 | from pathlib import Path 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | parser = argparse.ArgumentParser(description="Merges all manifest files in specified folder.") 11 | parser.add_argument("manifests", metavar="m", nargs="+", help="Path to all manifest files you want to merge.") 12 | parser.add_argument("-e", "--extension", default="wav", type=str, help="Audio file extension") 13 | parser.add_argument("--name", default="merged_manifest", type=str, help="Merged dataset name") 14 | parser.add_argument("--out", default="./", type=str, help="Output directory") 15 | args = parser.parse_args() 16 | 17 | 18 | def main(): 19 | new_manifest_path = Path(args.out) / Path(args.name) 20 | new_manifest_path.mkdir(parents=True, exist_ok=True) 21 | (new_manifest_path / args.extension).mkdir(parents=True, exist_ok=True) 22 | (new_manifest_path / 'txt').mkdir(parents=True, exist_ok=True) 23 | 24 | new_manifest = { 25 | 'root_path': new_manifest_path.absolute().as_posix(), 26 | 'samples': [] 27 | } 28 | for manifest in tqdm(args.manifests, desc="Manifests"): 29 | with open(manifest, "r") as manifest_file: 30 | manifest_json = json.load(manifest_file) 31 | 32 | root_path = Path(manifest_json['root_path']) 33 | for sample in tqdm(manifest_json['samples'], desc="Samples"): 34 | try: 35 | old_audio_path = root_path / Path(sample['wav_path']) 36 | new_audio_path = new_manifest_path.absolute() / Path(sample['wav_path']) 37 | os.symlink(old_audio_path, new_audio_path) 38 | old_txt_path = root_path / Path(sample['transcript_path']) 39 | new_txt_path = new_manifest_path.absolute() / Path(sample['transcript_path']) 40 | os.symlink(old_txt_path, new_txt_path) 41 | except FileExistsError: 42 | continue 43 | 44 | new_manifest['samples'] += manifest_json['samples'] 45 | 46 | with open(f"{args.name}_manifest.json", "w") as json_file: 47 | json.dump(new_manifest, json_file) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /data/ted.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | import argparse 5 | import subprocess 6 | import unicodedata 7 | import io 8 | 9 | from deepspeech_pytorch.data.data_opts import add_data_opts 10 | from tqdm import tqdm 11 | 12 | from deepspeech_pytorch.data.utils import create_manifest 13 | 14 | parser = argparse.ArgumentParser(description='Processes and downloads TED-LIUMv2 dataset.') 15 | parser = add_data_opts(parser) 16 | parser.add_argument("--target-dir", default='TEDLIUM_dataset/', type=str, help="Directory to store the dataset.") 17 | parser.add_argument("--tar-path", type=str, help="Path to the TEDLIUM_release tar if downloaded (Optional).") 18 | args = parser.parse_args() 19 | 20 | TED_LIUM_V2_DL_URL = "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz" 21 | 22 | 23 | def get_utterances_from_stm(stm_file): 24 | """ 25 | Return list of entries containing phrase and its start/end timings 26 | :param stm_file: 27 | :return: 28 | """ 29 | res = [] 30 | with io.open(stm_file, "r", encoding='utf-8') as f: 31 | for stm_line in f: 32 | tokens = stm_line.split() 33 | start_time = float(tokens[3]) 34 | end_time = float(tokens[4]) 35 | filename = tokens[0] 36 | transcript = unicodedata.normalize("NFKD", 37 | " ".join(t for t in tokens[6:]).strip()). \ 38 | encode("utf-8", "ignore").decode("utf-8", "ignore") 39 | if transcript != "ignore_time_segment_in_scoring": 40 | res.append({ 41 | "start_time": start_time, "end_time": end_time, 42 | "filename": filename, "transcript": transcript 43 | }) 44 | return res 45 | 46 | 47 | def cut_utterance(src_sph_file, target_wav_file, start_time, end_time, sample_rate=16000): 48 | subprocess.call(["sox {} -r {} -b 16 -c 1 {} trim {} ={}".format(src_sph_file, str(sample_rate), 49 | target_wav_file, start_time, end_time)], 50 | shell=True) 51 | 52 | 53 | def _preprocess_transcript(phrase): 54 | return phrase.strip().upper() 55 | 56 | 57 | def filter_short_utterances(utterance_info, min_len_sec=1.0): 58 | return utterance_info["end_time"] - utterance_info["start_time"] > min_len_sec 59 | 60 | 61 | def prepare_dir(ted_dir): 62 | converted_dir = os.path.join(ted_dir, "converted") 63 | # directories to store converted wav files and their transcriptions 64 | wav_dir = os.path.join(converted_dir, "wav") 65 | if not os.path.exists(wav_dir): 66 | os.makedirs(wav_dir) 67 | txt_dir = os.path.join(converted_dir, "txt") 68 | if not os.path.exists(txt_dir): 69 | os.makedirs(txt_dir) 70 | counter = 0 71 | entries = os.listdir(os.path.join(ted_dir, "sph")) 72 | for sph_file in tqdm(entries, total=len(entries)): 73 | speaker_name = sph_file.split('.sph')[0] 74 | 75 | sph_file_full = os.path.join(ted_dir, "sph", sph_file) 76 | stm_file_full = os.path.join(ted_dir, "stm", "{}.stm".format(speaker_name)) 77 | 78 | assert os.path.exists(sph_file_full) and os.path.exists(stm_file_full) 79 | all_utterances = get_utterances_from_stm(stm_file_full) 80 | 81 | all_utterances = filter(filter_short_utterances, all_utterances) 82 | for utterance_id, utterance in enumerate(all_utterances): 83 | target_wav_file = os.path.join(wav_dir, "{}_{}.wav".format(utterance["filename"], str(utterance_id))) 84 | target_txt_file = os.path.join(txt_dir, "{}_{}.txt".format(utterance["filename"], str(utterance_id))) 85 | cut_utterance(sph_file_full, target_wav_file, utterance["start_time"], utterance["end_time"], 86 | sample_rate=args.sample_rate) 87 | with io.FileIO(target_txt_file, "w") as f: 88 | f.write(_preprocess_transcript(utterance["transcript"]).encode('utf-8')) 89 | counter += 1 90 | 91 | 92 | def main(): 93 | target_dl_dir = args.target_dir 94 | if not os.path.exists(target_dl_dir): 95 | os.makedirs(target_dl_dir) 96 | 97 | target_unpacked_dir = os.path.join(target_dl_dir, "TEDLIUM_release2") 98 | if args.tar_path and os.path.exists(args.tar_path): 99 | target_file = args.tar_path 100 | else: 101 | print("Could not find downloaded TEDLIUM archive, Downloading corpus...") 102 | wget.download(TED_LIUM_V2_DL_URL, target_dl_dir) 103 | target_file = os.path.join(target_dl_dir, "TEDLIUM_release2.tar.gz") 104 | 105 | if not os.path.exists(target_unpacked_dir): 106 | print("Unpacking corpus...") 107 | tar = tarfile.open(target_file) 108 | tar.extractall(target_dl_dir) 109 | tar.close() 110 | else: 111 | print("Found TEDLIUM directory, skipping unpacking of tar files") 112 | 113 | train_ted_dir = os.path.join(target_unpacked_dir, "train") 114 | val_ted_dir = os.path.join(target_unpacked_dir, "dev") 115 | test_ted_dir = os.path.join(target_unpacked_dir, "test") 116 | 117 | prepare_dir(train_ted_dir) 118 | prepare_dir(val_ted_dir) 119 | prepare_dir(test_ted_dir) 120 | print('Creating manifests...') 121 | 122 | create_manifest( 123 | data_path=train_ted_dir, 124 | output_name='ted_train_manifest.json', 125 | manifest_path=args.manifest_dir, 126 | min_duration=args.min_duration, 127 | max_duration=args.max_duration, 128 | num_workers=args.num_workers 129 | ) 130 | create_manifest( 131 | data_path=val_ted_dir, 132 | output_name='ted_val_manifest.json', 133 | manifest_path=args.manifest_dir, 134 | num_workers=args.num_workers 135 | 136 | ) 137 | create_manifest( 138 | data_path=test_ted_dir, 139 | output_name='ted_test_manifest.json', 140 | manifest_path=args.manifest_dir, 141 | num_workers=args.num_workers 142 | ) 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /data/verify_manifest.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("manifests", metavar="m", nargs="+", help="Manifests to verify") 9 | args = parser.parse_args() 10 | 11 | def main(): 12 | for manifest_path in tqdm(args.manifests): 13 | with open(manifest_path, "r") as manifest_file: 14 | manifest_json = json.load(manifest_file) 15 | 16 | root_path = Path(manifest_json['root_path']) 17 | for sample in tqdm(manifest_json['samples']): 18 | assert (root_path / Path(sample['wav_path'])).exists(), f"{sample['wav_path']} does not exist" 19 | assert (root_path / Path(sample['transcript_path'])).exists(), f"{sample['transcript_path']} does not exist" 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | 25 | -------------------------------------------------------------------------------- /data/voxforge.py: -------------------------------------------------------------------------------- 1 | import os 2 | from six.moves import urllib 3 | import argparse 4 | import re 5 | import tempfile 6 | import shutil 7 | import subprocess 8 | import tarfile 9 | import io 10 | from tqdm import tqdm 11 | 12 | from deepspeech_pytorch.data.data_opts import add_data_opts 13 | from deepspeech_pytorch.data.utils import create_manifest 14 | 15 | VOXFORGE_URL_16kHz = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit/' 16 | 17 | parser = argparse.ArgumentParser(description='Processes and downloads VoxForge dataset.') 18 | parser = add_data_opts(parser) 19 | parser.add_argument("--target-dir", default='voxforge_dataset/', type=str, help="Directory to store the dataset.") 20 | args = parser.parse_args() 21 | 22 | 23 | def _get_recordings_dir(sample_dir, recording_name): 24 | wav_dir = os.path.join(sample_dir, recording_name, "wav") 25 | if os.path.exists(wav_dir): 26 | return "wav", wav_dir 27 | flac_dir = os.path.join(sample_dir, recording_name, "flac") 28 | if os.path.exists(flac_dir): 29 | return "flac", flac_dir 30 | raise Exception("wav or flac directory was not found for recording name: {}".format(recording_name)) 31 | 32 | 33 | def prepare_sample(recording_name, url, target_folder): 34 | """ 35 | Downloads and extracts a sample from VoxForge and puts the wav and txt files into :target_folder. 36 | """ 37 | wav_dir = os.path.join(target_folder, "wav") 38 | if not os.path.exists(wav_dir): 39 | os.makedirs(wav_dir) 40 | txt_dir = os.path.join(target_folder, "txt") 41 | if not os.path.exists(txt_dir): 42 | os.makedirs(txt_dir) 43 | # check if sample is processed 44 | filename_set = set(['_'.join(wav_file.split('_')[:-1]) for wav_file in os.listdir(wav_dir)]) 45 | if recording_name in filename_set: 46 | return 47 | 48 | request = urllib.request.Request(url) 49 | response = urllib.request.urlopen(request) 50 | content = response.read() 51 | response.close() 52 | with tempfile.NamedTemporaryFile(suffix=".tgz", mode='wb') as target_tgz: 53 | target_tgz.write(content) 54 | target_tgz.flush() 55 | dirpath = tempfile.mkdtemp() 56 | 57 | tar = tarfile.open(target_tgz.name) 58 | tar.extractall(dirpath) 59 | tar.close() 60 | 61 | recordings_type, recordings_dir = _get_recordings_dir(dirpath, recording_name) 62 | tgz_prompt_file = os.path.join(dirpath, recording_name, "etc", "PROMPTS") 63 | 64 | if os.path.exists(recordings_dir) and os.path.exists(tgz_prompt_file): 65 | transcriptions = open(tgz_prompt_file).read().strip().split("\n") 66 | transcriptions = {t.split()[0]: " ".join(t.split()[1:]) for t in transcriptions} 67 | for wav_file in os.listdir(recordings_dir): 68 | recording_id = wav_file.split('.{}'.format(recordings_type))[0] 69 | transcription_key = recording_name + "/mfc/" + recording_id 70 | if transcription_key not in transcriptions: 71 | continue 72 | utterance = transcriptions[transcription_key] 73 | 74 | target_wav_file = os.path.join(wav_dir, "{}_{}.wav".format(recording_name, recording_id)) 75 | target_txt_file = os.path.join(txt_dir, "{}_{}.txt".format(recording_name, recording_id)) 76 | with io.FileIO(target_txt_file, "w") as file: 77 | file.write(utterance.encode('utf-8')) 78 | original_wav_file = os.path.join(recordings_dir, wav_file) 79 | subprocess.call(["sox {} -r {} -b 16 -c 1 {}".format(original_wav_file, str(args.sample_rate), 80 | target_wav_file)], shell=True) 81 | 82 | shutil.rmtree(dirpath) 83 | 84 | 85 | if __name__ == '__main__': 86 | target_dir = args.target_dir 87 | sample_rate = args.sample_rate 88 | 89 | if not os.path.isdir(target_dir): 90 | os.makedirs(target_dir) 91 | request = urllib.request.Request(VOXFORGE_URL_16kHz) 92 | response = urllib.request.urlopen(request) 93 | content = response.read() 94 | all_files = re.findall("href\=\"(.*\.tgz)\"", content.decode("utf-8")) 95 | for f in tqdm(all_files, total=len(all_files)): 96 | prepare_sample(f.replace(".tgz", ""), VOXFORGE_URL_16kHz + f, target_dir) 97 | print('Creating manifests...') 98 | create_manifest( 99 | data_path=target_dir, 100 | output_name='voxforge_train_manifest.json', 101 | manifest_path=args.manifest_dir, 102 | min_duration=args.min_duration, 103 | max_duration=args.max_duration, 104 | num_workers=args.num_workers 105 | ) 106 | -------------------------------------------------------------------------------- /deepspeech_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/deepspeech.pytorch/b88a631ef6b96553e16bbaa5bd6a621435bed278/deepspeech_pytorch/__init__.py -------------------------------------------------------------------------------- /deepspeech_pytorch/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from pytorch_lightning.callbacks import ModelCheckpoint 5 | 6 | from deepspeech_pytorch.configs.lightning_config import ModelCheckpointConf 7 | 8 | 9 | class CheckpointHandler(ModelCheckpoint): 10 | 11 | def __init__(self, cfg: ModelCheckpointConf): 12 | super().__init__( 13 | dirpath=cfg.dirpath, 14 | filename=cfg.filename, 15 | monitor=cfg.monitor, 16 | verbose=cfg.verbose, 17 | save_last=cfg.save_last, 18 | save_top_k=cfg.save_top_k, 19 | save_weights_only=cfg.save_weights_only, 20 | mode=cfg.mode, 21 | auto_insert_metric_name=cfg.auto_insert_metric_name, 22 | every_n_train_steps=cfg.every_n_train_steps, 23 | train_time_interval=cfg.train_time_interval, 24 | every_n_epochs=cfg.every_n_epochs, 25 | save_on_train_epoch_end=cfg.save_on_train_epoch_end, 26 | ) 27 | 28 | def find_latest_checkpoint(self): 29 | raise NotImplementedError 30 | 31 | 32 | class FileCheckpointHandler(CheckpointHandler): 33 | 34 | def find_latest_checkpoint(self): 35 | """ 36 | Finds the latest checkpoint in a folder based on the timestamp of the file. 37 | If there are no checkpoints, returns None. 38 | :return: The latest checkpoint path, or None if no checkpoints are found. 39 | """ 40 | paths = list(Path(self.dirpath).rglob('*')) 41 | if paths: 42 | paths.sort(key=os.path.getctime) 43 | latest_checkpoint_path = paths[-1] 44 | return latest_checkpoint_path 45 | else: 46 | return None 47 | -------------------------------------------------------------------------------- /deepspeech_pytorch/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/deepspeech.pytorch/b88a631ef6b96553e16bbaa5bd6a621435bed278/deepspeech_pytorch/configs/__init__.py -------------------------------------------------------------------------------- /deepspeech_pytorch/configs/inference_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from deepspeech_pytorch.enums import DecoderType 4 | 5 | 6 | @dataclass 7 | class LMConfig: 8 | decoder_type: DecoderType = DecoderType.greedy 9 | lm_path: str = '' # Path to an (optional) kenlm language model for use with beam search (req\'d with trie) 10 | top_paths: int = 1 # Number of beams to return 11 | alpha: float = 0.0 # Language model weight 12 | beta: float = 0.0 # Language model word bonus (all words) 13 | cutoff_top_n: int = 40 # Cutoff_top_n characters with highest probs in vocabulary will be used in beam search 14 | cutoff_prob: float = 1.0 # Cutoff probability in pruning,default 1.0, no pruning. 15 | beam_width: int = 10 # Beam width to use 16 | lm_workers: int = 4 # Number of LM processes to use 17 | 18 | 19 | @dataclass 20 | class ModelConfig: 21 | precision: int = 32 # Set to 16 to use mixed-precision for inference 22 | cuda: bool = True 23 | model_path: str = '' 24 | 25 | 26 | @dataclass 27 | class InferenceConfig: 28 | lm: LMConfig = LMConfig() 29 | model: ModelConfig = ModelConfig() 30 | 31 | 32 | @dataclass 33 | class TranscribeConfig(InferenceConfig): 34 | audio_path: str = '' # Audio file to predict on 35 | offsets: bool = False # Returns time offset information 36 | chunk_size_seconds: float = -1 # default value to transcribe the whole audio at once 37 | 38 | 39 | @dataclass 40 | class EvalConfig(InferenceConfig): 41 | test_path: str = '' # Path to validation manifest csv or folder 42 | verbose: bool = True # Print out decoded output and error of each sample 43 | save_output: str = '' # Saves output of model from test to this file_path 44 | batch_size: int = 20 # Batch size for testing 45 | num_workers: int = 4 46 | 47 | 48 | @dataclass 49 | class ServerConfig(InferenceConfig): 50 | host: str = '0.0.0.0' 51 | port: int = 8888 52 | -------------------------------------------------------------------------------- /deepspeech_pytorch/configs/lightning_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any 3 | from typing import Optional 4 | 5 | 6 | @dataclass 7 | class ModelCheckpointConf: 8 | _target_: str = "pytorch_lightning.callbacks.ModelCheckpoint" 9 | filepath: Optional[str] = None 10 | monitor: Optional[str] = None 11 | verbose: bool = False 12 | save_last: Optional[bool] = None 13 | save_top_k: Optional[int] = 1 14 | save_weights_only: bool = False 15 | mode: str = "min" 16 | dirpath: Any = None # Union[str, Path, NoneType] 17 | filename: Optional[str] = None 18 | auto_insert_metric_name: bool = True 19 | every_n_train_steps: Optional[int] = None 20 | train_time_interval: Optional[str] = None 21 | every_n_epochs: Optional[int] = None 22 | save_on_train_epoch_end: Optional[bool] = None 23 | 24 | 25 | @dataclass 26 | class TrainerConf: 27 | _target_: str = "pytorch_lightning.trainer.Trainer" 28 | logger: Any = ( 29 | True # Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] 30 | ) 31 | enable_checkpointing: bool = True 32 | default_root_dir: Optional[str] = None 33 | gradient_clip_val: float = 0 34 | callbacks: Any = None 35 | num_nodes: int = 1 36 | num_processes: int = 1 37 | gpus: Any = None # Union[int, str, List[int], NoneType] 38 | auto_select_gpus: bool = False 39 | tpu_cores: Any = None # Union[int, str, List[int], NoneType] 40 | overfit_batches: Any = 0.0 # Union[int, float] 41 | track_grad_norm: Any = -1 # Union[int, float, str] 42 | check_val_every_n_epoch: int = 1 43 | fast_dev_run: Any = False # Union[int, bool] 44 | accumulate_grad_batches: Any = 1 # Union[int, Dict[int, int], List[list]] 45 | max_epochs: int = 1000 46 | min_epochs: int = 1 47 | limit_train_batches: Any = 1.0 # Union[int, float] 48 | limit_val_batches: Any = 1.0 # Union[int, float] 49 | limit_test_batches: Any = 1.0 # Union[int, float] 50 | val_check_interval: Any = 1.0 # Union[int, float] 51 | log_every_n_steps: int = 50 52 | accelerator: Any = None # Union[str, Accelerator, NoneType] 53 | sync_batchnorm: bool = False 54 | precision: int = 32 55 | weights_save_path: Optional[str] = None 56 | num_sanity_val_steps: int = 2 57 | resume_from_checkpoint: Any = None # Union[str, Path, NoneType] 58 | profiler: Any = None # Union[BaseProfiler, bool, str, NoneType] 59 | benchmark: bool = False 60 | deterministic: bool = False 61 | auto_lr_find: Any = False # Union[bool, str] 62 | replace_sampler_ddp: bool = True 63 | detect_anomaly: bool = False 64 | auto_scale_batch_size: Any = False # Union[str, bool] 65 | plugins: Any = None # Union[str, list, NoneType] 66 | amp_backend: str = "native" 67 | amp_level: Any = None 68 | move_metrics_to_cpu: bool = False 69 | gradient_clip_algorithm: Optional[str] = None 70 | devices: Any = None 71 | ipus: Optional[int] = None 72 | enable_progress_bar: bool = True 73 | max_time: Optional[str] = None 74 | limit_predict_batches: float = 1.0 75 | strategy: Optional[str] = None 76 | enable_model_summary: bool = True 77 | reload_dataloaders_every_n_epochs: int = 0 78 | multiple_trainloader_mode: str = "max_size_cycle" 79 | -------------------------------------------------------------------------------- /deepspeech_pytorch/configs/train_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, List 3 | 4 | from omegaconf import MISSING 5 | 6 | from deepspeech_pytorch.configs.lightning_config import TrainerConf, ModelCheckpointConf 7 | from deepspeech_pytorch.enums import SpectrogramWindow, RNNType 8 | 9 | defaults = [ 10 | {"optim": "adam"}, 11 | {"model": "bidirectional"}, 12 | {"checkpoint": "file"} 13 | ] 14 | 15 | 16 | @dataclass 17 | class SpectConfig: 18 | sample_rate: int = 16000 # The sample rate for the data/model features 19 | window_size: float = .02 # Window size for spectrogram generation (seconds) 20 | window_stride: float = .01 # Window stride for spectrogram generation (seconds) 21 | window: SpectrogramWindow = SpectrogramWindow.hamming # Window type for spectrogram generation 22 | 23 | 24 | @dataclass 25 | class AugmentationConfig: 26 | speed_volume_perturb: bool = False # Use random tempo and gain perturbations. 27 | spec_augment: bool = False # Use simple spectral augmentation on mel spectograms. 28 | noise_dir: str = '' # Directory to inject noise into audio. If default, noise Inject not added 29 | noise_prob: float = 0.4 # Probability of noise being added per sample 30 | noise_min: float = 0.0 # Minimum noise level to sample from. (1.0 means all noise, not original signal) 31 | noise_max: float = 0.5 # Maximum noise levels to sample from. Maximum 1.0 32 | 33 | 34 | @dataclass 35 | class DataConfig: 36 | train_path: str = 'data/train_manifest.csv' 37 | val_path: str = 'data/val_manifest.csv' 38 | batch_size: int = 64 # Batch size for training 39 | num_workers: int = 4 # Number of workers used in data-loading 40 | labels_path: str = 'labels.json' # Contains tokens for model output 41 | spect: SpectConfig = SpectConfig() 42 | augmentation: AugmentationConfig = AugmentationConfig() 43 | prepare_data_per_node: bool = True 44 | 45 | 46 | @dataclass 47 | class BiDirectionalConfig: 48 | rnn_type: RNNType = RNNType.lstm # Type of RNN to use in model 49 | hidden_size: int = 1024 # Hidden size of RNN Layer 50 | hidden_layers: int = 5 # Number of RNN layers 51 | 52 | 53 | @dataclass 54 | class UniDirectionalConfig(BiDirectionalConfig): 55 | lookahead_context: int = 20 # The lookahead context for convolution after RNN layers 56 | 57 | 58 | @dataclass 59 | class OptimConfig: 60 | learning_rate: float = 1.5e-4 # Initial Learning Rate 61 | learning_anneal: float = 0.99 # Annealing applied to learning rate after each epoch 62 | weight_decay: float = 1e-5 # Initial Weight Decay 63 | 64 | 65 | @dataclass 66 | class SGDConfig(OptimConfig): 67 | momentum: float = 0.9 68 | 69 | 70 | @dataclass 71 | class AdamConfig(OptimConfig): 72 | eps: float = 1e-8 # Adam eps 73 | betas: tuple = (0.9, 0.999) # Adam betas 74 | 75 | 76 | @dataclass 77 | class DeepSpeechTrainerConf(TrainerConf): 78 | callbacks: Any = MISSING 79 | 80 | 81 | @dataclass 82 | class DeepSpeechConfig: 83 | defaults: List[Any] = field(default_factory=lambda: defaults) 84 | optim: Any = MISSING 85 | model: Any = MISSING 86 | checkpoint: ModelCheckpointConf = MISSING 87 | trainer: DeepSpeechTrainerConf = DeepSpeechTrainerConf() 88 | data: DataConfig = DataConfig() 89 | augmentation: AugmentationConfig = AugmentationConfig() 90 | seed: int = 123456 # Seed for generators 91 | load_auto_checkpoint: bool = False # Automatically load the latest checkpoint from save folder 92 | -------------------------------------------------------------------------------- /deepspeech_pytorch/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/deepspeech.pytorch/b88a631ef6b96553e16bbaa5bd6a621435bed278/deepspeech_pytorch/data/__init__.py -------------------------------------------------------------------------------- /deepspeech_pytorch/data/data_opts.py: -------------------------------------------------------------------------------- 1 | def add_data_opts(parser): 2 | data_opts = parser.add_argument_group("General Data Options") 3 | data_opts.add_argument('--manifest-dir', default='./', type=str, 4 | help='Output directory for manifests') 5 | data_opts.add_argument('--min-duration', default=1, type=int, 6 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 7 | data_opts.add_argument('--max-duration', default=15, type=int, 8 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 9 | parser.add_argument('--num-workers', default=4, type=int, help='Number of workers for processing data.') 10 | parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') 11 | return parser 12 | -------------------------------------------------------------------------------- /deepspeech_pytorch/data/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import json 4 | import os 5 | from multiprocessing import Pool 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import sox 10 | from tqdm import tqdm 11 | 12 | 13 | def create_manifest( 14 | data_path: str, 15 | output_name: str, 16 | manifest_path: str, 17 | num_workers: int, 18 | min_duration: Optional[float] = None, 19 | max_duration: Optional[float] = None, 20 | file_extension: str = "wav"): 21 | data_path = os.path.abspath(data_path) 22 | file_paths = list(Path(data_path).rglob(f"*.{file_extension}")) 23 | file_paths = order_and_prune_files( 24 | file_paths=file_paths, 25 | min_duration=min_duration, 26 | max_duration=max_duration, 27 | num_workers=num_workers 28 | ) 29 | 30 | output_path = Path(manifest_path) / output_name 31 | output_path.parent.mkdir(exist_ok=True, parents=True) 32 | 33 | manifest = { 34 | 'root_path': data_path, 35 | 'samples': [] 36 | } 37 | for wav_path in tqdm(file_paths, total=len(file_paths)): 38 | wav_path = wav_path.relative_to(data_path) 39 | transcript_path = wav_path.parent.with_name("txt") / wav_path.with_suffix(".txt").name 40 | manifest['samples'].append({ 41 | 'wav_path': wav_path.as_posix(), 42 | 'transcript_path': transcript_path.as_posix() 43 | }) 44 | 45 | output_path.write_text(json.dumps(manifest), encoding='utf8') 46 | 47 | 48 | def _duration_file_path(path): 49 | return path, sox.file_info.duration(path) 50 | 51 | 52 | def order_and_prune_files( 53 | file_paths, 54 | min_duration, 55 | max_duration, 56 | num_workers): 57 | print("Gathering durations...") 58 | with Pool(processes=num_workers) as p: 59 | duration_file_paths = list(tqdm(p.imap(_duration_file_path, file_paths), total=len(file_paths))) 60 | print("Sorting manifests...") 61 | if min_duration and max_duration: 62 | print("Pruning manifests between %d and %d seconds" % (min_duration, max_duration)) 63 | duration_file_paths = [(path, duration) for path, duration in duration_file_paths if 64 | min_duration <= duration <= max_duration] 65 | 66 | total_duration = sum([x[1] for x in duration_file_paths]) 67 | print(f"Total duration of split: {total_duration:.4f}s") 68 | return [x[0] for x in duration_file_paths] # Remove durations 69 | -------------------------------------------------------------------------------- /deepspeech_pytorch/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # ---------------------------------------------------------------------------- 3 | # Copyright 2015-2016 Nervana Systems Inc. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ---------------------------------------------------------------------------- 16 | # Modified to support pytorch Tensors 17 | 18 | import torch 19 | from six.moves import xrange 20 | 21 | 22 | class Decoder(object): 23 | """ 24 | Basic decoder class from which all other decoders inherit. Implements several 25 | helper functions. Subclasses should implement the decode() method. 26 | 27 | Arguments: 28 | labels (list): mapping from integers to characters. 29 | blank_index (int, optional): index for the blank '_' character. Defaults to 0. 30 | """ 31 | 32 | def __init__(self, labels, blank_index=0): 33 | self.labels = labels 34 | self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) 35 | self.blank_index = blank_index 36 | space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space 37 | if ' ' in labels: 38 | space_index = labels.index(' ') 39 | self.space_index = space_index 40 | 41 | def decode(self, probs, sizes=None): 42 | """ 43 | Given a matrix of character probabilities, returns the decoder's 44 | best guess of the transcription 45 | 46 | Arguments: 47 | probs: Tensor of character probabilities, where probs[c,t] 48 | is the probability of character c at time t 49 | sizes(optional): Size of each sequence in the mini-batch 50 | Returns: 51 | string: sequence of the model's best guess for the transcription 52 | """ 53 | raise NotImplementedError 54 | 55 | 56 | class BeamCTCDecoder(Decoder): 57 | def __init__(self, 58 | labels, 59 | lm_path=None, 60 | alpha=0, 61 | beta=0, 62 | cutoff_top_n=40, 63 | cutoff_prob=1.0, 64 | beam_width=100, 65 | num_processes=4, 66 | blank_index=0): 67 | super(BeamCTCDecoder, self).__init__(labels) 68 | try: 69 | from ctcdecode import CTCBeamDecoder 70 | except ImportError: 71 | raise ImportError("BeamCTCDecoder requires paddledecoder package.") 72 | labels = list(labels) # Ensure labels are a list before passing to decoder 73 | self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width, 74 | num_processes, blank_index) 75 | 76 | def convert_to_strings(self, out, seq_len): 77 | results = [] 78 | for b, batch in enumerate(out): 79 | utterances = [] 80 | for p, utt in enumerate(batch): 81 | size = seq_len[b][p] 82 | if size > 0: 83 | transcript = ''.join(map(lambda x: self.int_to_char[x.item()], utt[0:size])) 84 | else: 85 | transcript = '' 86 | utterances.append(transcript) 87 | results.append(utterances) 88 | return results 89 | 90 | def convert_tensor(self, offsets, sizes): 91 | results = [] 92 | for b, batch in enumerate(offsets): 93 | utterances = [] 94 | for p, utt in enumerate(batch): 95 | size = sizes[b][p] 96 | if sizes[b][p] > 0: 97 | utterances.append(utt[0:size]) 98 | else: 99 | utterances.append(torch.tensor([], dtype=torch.int)) 100 | results.append(utterances) 101 | return results 102 | 103 | def decode(self, probs, sizes=None): 104 | """ 105 | Decodes probability output using ctcdecode package. 106 | Arguments: 107 | probs: Tensor of character probabilities, where probs[c,t] 108 | is the probability of character c at time t 109 | sizes: Size of each sequence in the mini-batch 110 | Returns: 111 | string: sequences of the model's best guess for the transcription 112 | """ 113 | probs = probs.cpu() 114 | out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes) 115 | 116 | strings = self.convert_to_strings(out, seq_lens) 117 | offsets = self.convert_tensor(offsets, seq_lens) 118 | return strings, offsets 119 | 120 | 121 | class GreedyDecoder(Decoder): 122 | def __init__(self, labels, blank_index=0): 123 | super(GreedyDecoder, self).__init__(labels, blank_index) 124 | 125 | def convert_to_strings(self, 126 | sequences, 127 | sizes=None, 128 | remove_repetitions=False, 129 | return_offsets=False): 130 | """Given a list of numeric sequences, returns the corresponding strings""" 131 | strings = [] 132 | offsets = [] if return_offsets else None 133 | for x in xrange(len(sequences)): 134 | seq_len = sizes[x] if sizes is not None else len(sequences[x]) 135 | string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions) 136 | strings.append([string]) # We only return one path 137 | if return_offsets: 138 | offsets.append([string_offsets]) 139 | if return_offsets: 140 | return strings, offsets 141 | else: 142 | return strings 143 | 144 | def process_string(self, 145 | sequence, 146 | size, 147 | remove_repetitions=False): 148 | string = '' 149 | offsets = [] 150 | for i in range(size): 151 | char = self.int_to_char[sequence[i].item()] 152 | if char != self.int_to_char[self.blank_index]: 153 | # if this char is a repetition and remove_repetitions=true, then skip 154 | if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]: 155 | pass 156 | elif char == self.labels[self.space_index]: 157 | string += ' ' 158 | offsets.append(i) 159 | else: 160 | string = string + char 161 | offsets.append(i) 162 | return string, torch.tensor(offsets, dtype=torch.int) 163 | 164 | def decode(self, probs, sizes=None): 165 | """ 166 | Returns the argmax decoding given the probability matrix. Removes 167 | repeated elements in the sequence, as well as blanks. 168 | 169 | Arguments: 170 | probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim 171 | sizes(optional): Size of each sequence in the mini-batch 172 | Returns: 173 | strings: sequences of the model's best guess for the transcription on inputs 174 | offsets: time step per character predicted 175 | """ 176 | _, max_probs = torch.max(probs, 2) 177 | strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), 178 | sizes, 179 | remove_repetitions=True, 180 | return_offsets=True) 181 | return strings, offsets 182 | -------------------------------------------------------------------------------- /deepspeech_pytorch/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from torch import nn 4 | 5 | 6 | class DecoderType(Enum): 7 | greedy: str = 'greedy' 8 | beam: str = 'beam' 9 | 10 | 11 | class SpectrogramWindow(Enum): 12 | hamming = 'hamming' 13 | hann = 'hann' 14 | blackman = 'blackman' 15 | bartlett = 'bartlett' 16 | 17 | 18 | class RNNType(Enum): 19 | lstm = nn.LSTM 20 | rnn = nn.RNN 21 | gru = nn.GRU 22 | -------------------------------------------------------------------------------- /deepspeech_pytorch/inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | import hydra 5 | import torch 6 | from torch.cuda.amp import autocast 7 | 8 | from deepspeech_pytorch.configs.inference_config import TranscribeConfig 9 | from deepspeech_pytorch.decoder import Decoder 10 | from deepspeech_pytorch.loader.data_loader import ChunkSpectrogramParser 11 | from deepspeech_pytorch.model import DeepSpeech 12 | from deepspeech_pytorch.utils import load_decoder, load_model 13 | 14 | 15 | def decode_results(decoded_output: List, 16 | decoded_offsets: List, 17 | cfg: TranscribeConfig): 18 | results = { 19 | "output": [], 20 | "_meta": { 21 | "acoustic_model": { 22 | "path": cfg.model.model_path 23 | }, 24 | "language_model": { 25 | "path": cfg.lm.lm_path 26 | }, 27 | "decoder": { 28 | "alpha": cfg.lm.alpha, 29 | "beta": cfg.lm.beta, 30 | "type": cfg.lm.decoder_type.value, 31 | } 32 | } 33 | } 34 | 35 | for b in range(len(decoded_output)): 36 | for pi in range(min(cfg.lm.top_paths, len(decoded_output[b]))): 37 | result = {'transcription': decoded_output[b][pi]} 38 | if cfg.offsets: 39 | result['offsets'] = decoded_offsets[b][pi].tolist() 40 | results['output'].append(result) 41 | return results 42 | 43 | 44 | def transcribe(cfg: TranscribeConfig): 45 | device = torch.device("cuda" if cfg.model.cuda else "cpu") 46 | 47 | model = load_model( 48 | device=device, 49 | model_path=cfg.model.model_path 50 | ) 51 | 52 | decoder = load_decoder( 53 | labels=model.labels, 54 | cfg=cfg.lm 55 | ) 56 | 57 | spect_parser = ChunkSpectrogramParser( 58 | audio_conf=model.spect_cfg, 59 | normalize=True 60 | ) 61 | 62 | decoded_output, decoded_offsets = run_transcribe( 63 | audio_path=hydra.utils.to_absolute_path(cfg.audio_path), 64 | spect_parser=spect_parser, 65 | model=model, 66 | decoder=decoder, 67 | device=device, 68 | precision=cfg.model.precision, 69 | chunk_size_seconds=cfg.chunk_size_seconds 70 | ) 71 | results = decode_results( 72 | decoded_output=decoded_output, 73 | decoded_offsets=decoded_offsets, 74 | cfg=cfg 75 | ) 76 | print(json.dumps(results)) 77 | 78 | 79 | def run_transcribe(audio_path: str, 80 | spect_parser: ChunkSpectrogramParser, 81 | model: DeepSpeech, 82 | decoder: Decoder, 83 | device: torch.device, 84 | precision: int, 85 | chunk_size_seconds: float): 86 | hs = None # means that the initial RNN hidden states are set to zeros 87 | all_outs = [] 88 | with torch.no_grad(): 89 | for spect in spect_parser.parse_audio(audio_path, chunk_size_seconds): 90 | spect = spect.contiguous() 91 | spect = spect.view(1, 1, spect.size(0), spect.size(1)) 92 | spect = spect.to(device) 93 | input_sizes = torch.IntTensor([spect.size(3)]).int() 94 | with autocast(enabled=precision == 16): 95 | out, output_sizes, hs = model(spect, input_sizes, hs) 96 | all_outs.append(out.cpu()) 97 | all_outs = torch.cat(all_outs, axis=1) # combine outputs of chunks in one tensor 98 | decoded_output, decoded_offsets = decoder.decode(all_outs) 99 | return decoded_output, decoded_offsets 100 | -------------------------------------------------------------------------------- /deepspeech_pytorch/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/deepspeech.pytorch/b88a631ef6b96553e16bbaa5bd6a621435bed278/deepspeech_pytorch/loader/__init__.py -------------------------------------------------------------------------------- /deepspeech_pytorch/loader/data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from pathlib import Path 5 | from tempfile import NamedTemporaryFile 6 | 7 | import librosa 8 | import numpy as np 9 | import sox 10 | import torch 11 | from torch.utils.data import Dataset, Sampler, DistributedSampler, DataLoader 12 | import torchaudio 13 | 14 | from deepspeech_pytorch.configs.train_config import SpectConfig, AugmentationConfig 15 | from deepspeech_pytorch.loader.spec_augment import spec_augment 16 | 17 | torchaudio.set_audio_backend("sox_io") 18 | 19 | 20 | def load_audio(path): 21 | sound, sample_rate = torchaudio.load(path) 22 | if sound.shape[0] == 1: 23 | sound = sound.squeeze() 24 | else: 25 | sound = sound.mean(axis=0) # multiple channels, average 26 | return sound.numpy() 27 | 28 | 29 | class AudioParser(object): 30 | def __init__(self, 31 | audio_conf: SpectConfig, 32 | normalize: bool = False): 33 | """ 34 | Parses audio file into spectrogram with optional normalization and various augmentations 35 | :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds 36 | :param normalize(default False): Apply standard mean and deviation normalization to audio tensor 37 | """ 38 | self.window_stride = audio_conf.window_stride 39 | self.window_size = audio_conf.window_size 40 | self.sample_rate = audio_conf.sample_rate 41 | self.window = audio_conf.window.value 42 | self.normalize = normalize 43 | 44 | def parse_transcript(self, transcript_path): 45 | """ 46 | :param transcript_path: Path where transcript is stored from the manifest file 47 | :return: Transcript in training/testing format 48 | """ 49 | raise NotImplementedError 50 | 51 | def parse_audio(self, audio_path): 52 | """ 53 | :param audio_path: Path where audio is stored from the manifest file 54 | :return: Audio in training/testing format 55 | """ 56 | raise NotImplementedError 57 | 58 | def get_chunks(self, y, chunk_size_seconds=-1): 59 | """ 60 | :param y: Audio signal as an array of float numbers 61 | :param chunk_size_seconds: Chunk size in seconds 62 | :return: Chunk from the give audio signal of duration `chunk_size_seconds` 63 | """ 64 | total_duration_seconds = math.ceil(len(y) / self.sample_rate) 65 | chunk_size_seconds = total_duration_seconds if chunk_size_seconds <= 0 else chunk_size_seconds 66 | num_of_chunks = math.ceil(total_duration_seconds / chunk_size_seconds) 67 | for i in range(num_of_chunks): 68 | chunk_start = int(i * chunk_size_seconds * self.sample_rate) 69 | chunk_end = chunk_start + int(chunk_size_seconds * self.sample_rate) 70 | chunk = y[chunk_start:chunk_end] 71 | yield chunk 72 | 73 | def compute_spectrogram(self, y): 74 | """ 75 | :param y: Audio signal as an array of float numbers 76 | :return: Spectrogram of the signal 77 | """ 78 | n_fft = int(self.sample_rate * self.window_size) 79 | win_length = n_fft 80 | hop_length = int(self.sample_rate * self.window_stride) 81 | # STFT 82 | D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, 83 | win_length=win_length, window=self.window) 84 | spect, phase = librosa.magphase(D) 85 | # S = log(S+1) 86 | spect = np.log1p(spect) 87 | spect = torch.FloatTensor(spect) 88 | if self.normalize: 89 | mean = spect.mean() 90 | std = spect.std() 91 | spect.add_(-mean) 92 | spect.div_(std) 93 | 94 | return spect 95 | 96 | 97 | class NoiseInjection(object): 98 | def __init__(self, 99 | path=None, 100 | sample_rate=16000, 101 | noise_levels=(0, 0.5)): 102 | """ 103 | Adds noise to an input signal with specific SNR. Higher the noise level, the more noise added. 104 | Modified code from https://github.com/willfrey/audio/blob/master/torchaudio/transforms.py 105 | """ 106 | if not os.path.exists(path): 107 | print("Directory doesn't exist: {}".format(path)) 108 | raise IOError 109 | self.paths = path is not None and librosa.util.find_files(path) 110 | self.sample_rate = sample_rate 111 | self.noise_levels = noise_levels 112 | 113 | def inject_noise(self, data): 114 | noise_path = np.random.choice(self.paths) 115 | noise_level = np.random.uniform(*self.noise_levels) 116 | return self.inject_noise_sample(data, noise_path, noise_level) 117 | 118 | def inject_noise_sample(self, data, noise_path, noise_level): 119 | noise_len = sox.file_info.duration(noise_path) 120 | data_len = len(data) / self.sample_rate 121 | noise_start = np.random.rand() * (noise_len - data_len) 122 | noise_end = noise_start + data_len 123 | noise_dst = audio_with_sox(noise_path, self.sample_rate, noise_start, noise_end) 124 | assert len(data) == len(noise_dst) 125 | noise_energy = np.sqrt(noise_dst.dot(noise_dst) / noise_dst.size) 126 | data_energy = np.sqrt(data.dot(data) / data.size) 127 | data += noise_level * noise_dst * data_energy / noise_energy 128 | return data 129 | 130 | 131 | class SpectrogramParser(AudioParser): 132 | def __init__(self, 133 | audio_conf: SpectConfig, 134 | normalize: bool = False, 135 | augmentation_conf: AugmentationConfig = None): 136 | """ 137 | Parses audio file into spectrogram with optional normalization and various augmentations 138 | :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds 139 | :param normalize(default False): Apply standard mean and deviation normalization to audio tensor 140 | :param augmentation_conf(Optional): Config containing the augmentation parameters 141 | """ 142 | super(SpectrogramParser, self).__init__(audio_conf, normalize) 143 | self.aug_conf = augmentation_conf 144 | if augmentation_conf and augmentation_conf.noise_dir: 145 | self.noise_injector = NoiseInjection(path=augmentation_conf.noise_dir, 146 | sample_rate=self.sample_rate, 147 | noise_levels=augmentation_conf.noise_levels) 148 | else: 149 | self.noise_injector = None 150 | 151 | def parse_audio(self, audio_path): 152 | if self.aug_conf and self.aug_conf.speed_volume_perturb: 153 | y = load_randomly_augmented_audio(audio_path, self.sample_rate) 154 | else: 155 | y = load_audio(audio_path) 156 | if self.noise_injector: 157 | add_noise = np.random.binomial(1, self.aug_conf.noise_prob) 158 | if add_noise: 159 | y = self.noise_injector.inject_noise(y) 160 | 161 | spect = self.compute_spectrogram(y) 162 | if self.aug_conf and self.aug_conf.spec_augment: 163 | spect = spec_augment(spect) 164 | 165 | return spect 166 | 167 | def parse_transcript(self, transcript_path): 168 | raise NotImplementedError 169 | 170 | 171 | class ChunkSpectrogramParser(AudioParser): 172 | def __init__(self, 173 | audio_conf: SpectConfig, 174 | normalize: bool = False): 175 | """ 176 | Parses audio file into spectrogram with optional normalization and various augmentations 177 | :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds 178 | :param normalize(default False): Apply standard mean and deviation normalization to audio tensor 179 | """ 180 | super(ChunkSpectrogramParser, self).__init__(audio_conf, normalize) 181 | 182 | def parse_audio(self, audio_path, chunk_size_seconds=-1): 183 | y = load_audio(audio_path) 184 | for y_chunk in self.get_chunks(y, chunk_size_seconds): 185 | spect = self.compute_spectrogram(y_chunk) 186 | yield spect 187 | 188 | 189 | class SpectrogramDataset(Dataset, SpectrogramParser): 190 | def __init__(self, 191 | audio_conf: SpectConfig, 192 | input_path: str, 193 | labels: list, 194 | normalize: bool = False, 195 | aug_cfg: AugmentationConfig = None): 196 | """ 197 | Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by 198 | a comma. Each new line is a different sample. Example below: 199 | 200 | /path/to/audio.wav,/path/to/audio.txt 201 | ... 202 | You can also pass the directory of dataset. 203 | :param audio_conf: Config containing the sample rate, window and the window length/stride in seconds 204 | :param input_path: Path to input. 205 | :param labels: List containing all the possible characters to map to 206 | :param normalize: Apply standard mean and deviation normalization to audio tensor 207 | :param augmentation_conf(Optional): Config containing the augmentation parameters 208 | """ 209 | self.ids = self._parse_input(input_path) 210 | self.size = len(self.ids) 211 | self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) 212 | super(SpectrogramDataset, self).__init__(audio_conf, normalize, aug_cfg) 213 | 214 | def __getitem__(self, index): 215 | sample = self.ids[index] 216 | audio_path, transcript_path = sample[0], sample[1] 217 | spect = self.parse_audio(audio_path) 218 | transcript = self.parse_transcript(transcript_path) 219 | return spect, transcript 220 | 221 | def _parse_input(self, input_path): 222 | ids = [] 223 | if os.path.isdir(input_path): 224 | for wav_path in Path(input_path).rglob('*.wav'): 225 | transcript_path = str(wav_path).replace('/wav/', '/txt/').replace('.wav', '.txt') 226 | ids.append((wav_path, transcript_path)) 227 | else: 228 | # Assume it is a manifest file 229 | with open(input_path) as f: 230 | manifest = json.load(f) 231 | for sample in manifest['samples']: 232 | wav_path = os.path.join(manifest['root_path'], sample['wav_path']) 233 | transcript_path = os.path.join(manifest['root_path'], sample['transcript_path']) 234 | ids.append((wav_path, transcript_path)) 235 | return ids 236 | 237 | def parse_transcript(self, transcript_path): 238 | with open(transcript_path, 'r', encoding='utf8') as transcript_file: 239 | transcript = transcript_file.read().replace('\n', '') 240 | transcript = list(filter(None, [self.labels_map.get(x) for x in list(transcript)])) 241 | return transcript 242 | 243 | def __len__(self): 244 | return self.size 245 | 246 | 247 | def _collate_fn(batch): 248 | def func(p): 249 | return p[0].size(1) 250 | 251 | batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) 252 | longest_sample = max(batch, key=func)[0] 253 | freq_size = longest_sample.size(0) 254 | minibatch_size = len(batch) 255 | max_seqlength = longest_sample.size(1) 256 | inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) 257 | input_percentages = torch.FloatTensor(minibatch_size) 258 | target_sizes = torch.IntTensor(minibatch_size) 259 | targets = [] 260 | for x in range(minibatch_size): 261 | sample = batch[x] 262 | tensor = sample[0] 263 | target = sample[1] 264 | seq_length = tensor.size(1) 265 | inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) 266 | input_percentages[x] = seq_length / float(max_seqlength) 267 | target_sizes[x] = len(target) 268 | targets.extend(target) 269 | targets = torch.tensor(targets, dtype=torch.long) 270 | return inputs, targets, input_percentages, target_sizes 271 | 272 | 273 | class AudioDataLoader(DataLoader): 274 | def __init__(self, *args, **kwargs): 275 | """ 276 | Creates a data loader for AudioDatasets. 277 | """ 278 | super(AudioDataLoader, self).__init__(*args, **kwargs) 279 | self.collate_fn = _collate_fn 280 | 281 | 282 | class DSRandomSampler(Sampler): 283 | """ 284 | Implementation of a Random Sampler for sampling the dataset. 285 | Added to ensure we reset the start index when an epoch is finished. 286 | This is essential since we support saving/loading state during an epoch. 287 | """ 288 | 289 | def __init__(self, dataset, batch_size=1): 290 | super().__init__(data_source=dataset) 291 | 292 | self.dataset = dataset 293 | self.start_index = 0 294 | self.epoch = 0 295 | self.batch_size = batch_size 296 | ids = list(range(len(self.dataset))) 297 | self.bins = [ids[i:i + self.batch_size] for i in range(0, len(ids), self.batch_size)] 298 | 299 | def __iter__(self): 300 | # deterministically shuffle based on epoch 301 | g = torch.Generator() 302 | g.manual_seed(self.epoch) 303 | indices = ( 304 | torch.randperm(len(self.bins) - self.start_index, generator=g) 305 | .add(self.start_index) 306 | .tolist() 307 | ) 308 | for x in indices: 309 | batch_ids = self.bins[x] 310 | np.random.shuffle(batch_ids) 311 | yield batch_ids 312 | 313 | def __len__(self): 314 | return len(self.bins) - self.start_index 315 | 316 | def set_epoch(self, epoch): 317 | self.epoch = epoch 318 | 319 | 320 | class DSElasticDistributedSampler(DistributedSampler): 321 | """ 322 | Overrides the ElasticDistributedSampler to ensure we reset the start index when an epoch is finished. 323 | This is essential since we support saving/loading state during an epoch. 324 | """ 325 | 326 | def __init__(self, dataset, num_replicas=None, rank=None, batch_size=1): 327 | super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) 328 | self.start_index = 0 329 | self.batch_size = batch_size 330 | ids = list(range(len(dataset))) 331 | self.bins = [ids[i:i + self.batch_size] for i in range(0, len(ids), self.batch_size)] 332 | self.num_samples = int( 333 | math.ceil(float(len(self.bins) - self.start_index) / self.num_replicas) 334 | ) 335 | self.total_size = self.num_samples * self.num_replicas 336 | 337 | def __iter__(self): 338 | # deterministically shuffle based on epoch 339 | g = torch.Generator() 340 | g.manual_seed(self.epoch) 341 | indices = ( 342 | torch.randperm(len(self.bins) - self.start_index, generator=g) 343 | .add(self.start_index) 344 | .tolist() 345 | ) 346 | 347 | # add extra samples to make it evenly divisible 348 | indices += indices[: (self.total_size - len(indices))] 349 | assert len(indices) == self.total_size 350 | 351 | # subsample 352 | indices = indices[self.rank: self.total_size: self.num_replicas] 353 | assert len(indices) == self.num_samples 354 | for x in indices: 355 | batch_ids = self.bins[x] 356 | np.random.shuffle(batch_ids) 357 | yield batch_ids 358 | 359 | def __len__(self): 360 | return self.num_samples 361 | 362 | 363 | def audio_with_sox(path, sample_rate, start_time, end_time): 364 | """ 365 | crop and resample the recording with sox and loads it. 366 | """ 367 | with NamedTemporaryFile(suffix=".wav") as tar_file: 368 | tar_filename = tar_file.name 369 | sox_params = "sox \"{}\" -r {} -c 1 -b 16 -e si {} trim {} ={} >/dev/null 2>&1".format(path, sample_rate, 370 | tar_filename, start_time, 371 | end_time) 372 | os.system(sox_params) 373 | y = load_audio(tar_filename) 374 | return y 375 | 376 | 377 | def augment_audio_with_sox(path, sample_rate, tempo, gain): 378 | """ 379 | Changes tempo and gain of the recording with sox and loads it. 380 | """ 381 | with NamedTemporaryFile(suffix=".wav") as augmented_file: 382 | augmented_filename = augmented_file.name 383 | sox_augment_params = ["tempo", "{:.3f}".format(tempo), "gain", "{:.3f}".format(gain)] 384 | sox_params = "sox \"{}\" -r {} -c 1 -b 16 -e si {} {} >/dev/null 2>&1".format(path, sample_rate, 385 | augmented_filename, 386 | " ".join(sox_augment_params)) 387 | os.system(sox_params) 388 | y = load_audio(augmented_filename) 389 | return y 390 | 391 | 392 | def load_randomly_augmented_audio(path, sample_rate=16000, tempo_range=(0.85, 1.15), 393 | gain_range=(-6, 8)): 394 | """ 395 | Picks tempo and gain uniformly, applies it to the utterance by using sox utility. 396 | Returns the augmented utterance. 397 | """ 398 | low_tempo, high_tempo = tempo_range 399 | tempo_value = np.random.uniform(low=low_tempo, high=high_tempo) 400 | low_gain, high_gain = gain_range 401 | gain_value = np.random.uniform(low=low_gain, high=high_gain) 402 | audio = augment_audio_with_sox(path=path, sample_rate=sample_rate, 403 | tempo=tempo_value, gain=gain_value) 404 | return audio 405 | -------------------------------------------------------------------------------- /deepspeech_pytorch/loader/data_module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from hydra.utils import to_absolute_path 3 | 4 | from deepspeech_pytorch.configs.train_config import DataConfig 5 | from deepspeech_pytorch.loader.data_loader import SpectrogramDataset, DSRandomSampler, AudioDataLoader, \ 6 | DSElasticDistributedSampler 7 | 8 | 9 | class DeepSpeechDataModule(pl.LightningDataModule): 10 | 11 | def __init__(self, 12 | labels: list, 13 | data_cfg: DataConfig, 14 | normalize: bool): 15 | super().__init__() 16 | self.train_path = to_absolute_path(data_cfg.train_path) 17 | self.val_path = to_absolute_path(data_cfg.val_path) 18 | self.labels = labels 19 | self.data_cfg = data_cfg 20 | self.spect_cfg = data_cfg.spect 21 | self.aug_cfg = data_cfg.augmentation 22 | self.normalize = normalize 23 | 24 | @property 25 | def is_distributed(self): 26 | return self.trainer.devices > 1 27 | 28 | def train_dataloader(self): 29 | train_dataset = self._create_dataset(self.train_path) 30 | if self.is_distributed: 31 | train_sampler = DSElasticDistributedSampler( 32 | dataset=train_dataset, 33 | batch_size=self.data_cfg.batch_size 34 | ) 35 | else: 36 | train_sampler = DSRandomSampler( 37 | dataset=train_dataset, 38 | batch_size=self.data_cfg.batch_size 39 | ) 40 | train_loader = AudioDataLoader( 41 | dataset=train_dataset, 42 | num_workers=self.data_cfg.num_workers, 43 | batch_sampler=train_sampler 44 | ) 45 | return train_loader 46 | 47 | def val_dataloader(self): 48 | val_dataset = self._create_dataset(self.val_path) 49 | val_loader = AudioDataLoader( 50 | dataset=val_dataset, 51 | num_workers=self.data_cfg.num_workers, 52 | batch_size=self.data_cfg.batch_size 53 | ) 54 | return val_loader 55 | 56 | def _create_dataset(self, input_path): 57 | dataset = SpectrogramDataset( 58 | audio_conf=self.spect_cfg, 59 | input_path=input_path, 60 | labels=self.labels, 61 | normalize=True, 62 | aug_cfg=self.aug_cfg 63 | ) 64 | return dataset 65 | -------------------------------------------------------------------------------- /deepspeech_pytorch/loader/sparse_image_warp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 RnD at Spoon Radio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # import torch 16 | # import numpy as np 17 | # from torch.autograd import Variable 18 | # import librosa 19 | import random 20 | import numpy as np 21 | # import scipy.signal 22 | import torch 23 | # import torchaudio 24 | # from torchaudio import transforms 25 | # import math 26 | # from torch.utils.data import DataLoader 27 | # from torch.utils.data import Dataset 28 | 29 | 30 | def time_warp(spec, W=5): 31 | spec = spec.view(1, spec.shape[0], spec.shape[1]) 32 | num_rows = spec.shape[1] 33 | spec_len = spec.shape[2] 34 | 35 | y = num_rows // 2 36 | horizontal_line_at_ctr = spec[0][y] 37 | assert len(horizontal_line_at_ctr) == spec_len 38 | 39 | point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)] 40 | assert isinstance(point_to_warp, torch.Tensor) 41 | 42 | # Uniform distribution from (0,W) with chance to be up to W negative 43 | dist_to_warp = random.randrange(-W, W) 44 | src_pts, dest_pts = torch.tensor([[[y, point_to_warp]]]), torch.tensor([[[y, point_to_warp + dist_to_warp]]]) 45 | warped_spectro, dense_flows = SparseImageWarp.sparse_image_warp(spec, src_pts, dest_pts) 46 | return warped_spectro.squeeze(3) 47 | 48 | 49 | def freq_mask(spec, F=15, num_masks=1, replace_with_zero=False): 50 | cloned = spec.clone() 51 | num_mel_channels = cloned.shape[1] 52 | 53 | for i in range(0, num_masks): 54 | f = random.randrange(0, F) 55 | f_zero = random.randrange(0, num_mel_channels - f) 56 | 57 | # avoids randrange error if values are equal and range is empty 58 | if (f_zero == f_zero + f): return cloned 59 | 60 | mask_end = random.randrange(f_zero, f_zero + f) 61 | if (replace_with_zero): 62 | cloned[0][f_zero:mask_end] = 0 63 | else: 64 | cloned[0][f_zero:mask_end] = cloned.mean() 65 | 66 | return cloned 67 | 68 | 69 | def time_mask(spec, T=15, num_masks=1, replace_with_zero=False): 70 | cloned = spec.clone() 71 | len_spectro = cloned.shape[2] 72 | 73 | for i in range(0, num_masks): 74 | t = random.randrange(0, T) 75 | t_zero = random.randrange(0, len_spectro - t) 76 | 77 | # avoids randrange error if values are equal and range is empty 78 | if (t_zero == t_zero + t): return cloned 79 | 80 | mask_end = random.randrange(t_zero, t_zero + t) 81 | if (replace_with_zero): 82 | cloned[0][:, t_zero:mask_end] = 0 83 | else: 84 | cloned[0][:, t_zero:mask_end] = cloned.mean() 85 | return cloned 86 | 87 | 88 | def sparse_image_warp(img_tensor, 89 | source_control_point_locations, 90 | dest_control_point_locations, 91 | interpolation_order=2, 92 | regularization_weight=0.0, 93 | num_boundaries_points=0): 94 | control_point_flows = (dest_control_point_locations - source_control_point_locations) 95 | 96 | batch_size, image_height, image_width = img_tensor.shape 97 | grid_locations = get_grid_locations(image_height, image_width) 98 | flattened_grid_locations = torch.tensor(flatten_grid_locations(grid_locations, image_height, image_width)) 99 | 100 | flattened_flows = interpolate_spline( 101 | dest_control_point_locations, 102 | control_point_flows, 103 | flattened_grid_locations, 104 | interpolation_order, 105 | regularization_weight) 106 | 107 | dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width) 108 | 109 | warped_image = dense_image_warp(img_tensor, dense_flows) 110 | 111 | return warped_image, dense_flows 112 | 113 | 114 | def get_grid_locations(image_height, image_width): 115 | """Wrapper for np.meshgrid.""" 116 | 117 | y_range = np.linspace(0, image_height - 1, image_height) 118 | x_range = np.linspace(0, image_width - 1, image_width) 119 | y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') 120 | return np.stack((y_grid, x_grid), -1) 121 | 122 | 123 | def flatten_grid_locations(grid_locations, image_height, image_width): 124 | return np.reshape(grid_locations, [image_height * image_width, 2]) 125 | 126 | 127 | def create_dense_flows(flattened_flows, batch_size, image_height, image_width): 128 | # possibly .view 129 | return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) 130 | 131 | 132 | def interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0, ): 133 | # First, fit the spline to the observed data. 134 | w, v = solve_interpolation(train_points, train_values, order, regularization_weight) 135 | # Then, evaluate the spline at the query locations. 136 | query_values = apply_interpolation(query_points, train_points, w, v, order) 137 | 138 | return query_values 139 | 140 | 141 | def solve_interpolation(train_points, train_values, order, regularization_weight): 142 | b, n, d = train_points.shape 143 | k = train_values.shape[-1] 144 | 145 | # First, rename variables so that the notation (c, f, w, v, A, B, etc.) 146 | # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. 147 | # To account for python style guidelines we use 148 | # matrix_a for A and matrix_b for B. 149 | 150 | c = train_points 151 | f = train_values.float() 152 | 153 | matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0) # [b, n, n] 154 | # if regularization_weight > 0: 155 | # batch_identity_matrix = array_ops.expand_dims( 156 | # linalg_ops.eye(n, dtype=c.dtype), 0) 157 | # matrix_a += regularization_weight * batch_identity_matrix 158 | 159 | # Append ones to the feature values for the bias term in the linear model. 160 | ones = torch.ones(1, dtype=train_points.dtype).view([-1, 1, 1]) 161 | matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1] 162 | 163 | # [b, n + d + 1, n] 164 | left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1) 165 | 166 | num_b_cols = matrix_b.shape[2] # d + 1 167 | 168 | # In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand. 169 | # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication. 170 | lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10 171 | right_block = torch.cat((matrix_b, lhs_zeros), 172 | 1) # [b, n + d + 1, d + 1] 173 | lhs = torch.cat((left_block, right_block), 174 | 2) # [b, n + d + 1, n + d + 1] 175 | 176 | rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float() 177 | rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k] 178 | 179 | # Then, solve the linear system and unpack the results. 180 | X = torch.linalg.solve(lhs, rhs) 181 | w = X[:, :n, :] 182 | v = X[:, n:, :] 183 | 184 | return w, v 185 | 186 | 187 | def cross_squared_distance_matrix(x, y): 188 | """Pairwise squared distance between two (batch) matrices' rows (2nd dim). 189 | Computes the pairwise distances between rows of x and rows of y 190 | Args: 191 | x: [batch_size, n, d] float `Tensor` 192 | y: [batch_size, m, d] float `Tensor` 193 | Returns: 194 | squared_dists: [batch_size, n, m] float `Tensor`, where 195 | squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 196 | """ 197 | x_norm_squared = torch.sum(torch.mul(x, x)) 198 | y_norm_squared = torch.sum(torch.mul(y, y)) 199 | 200 | x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0, 1)) 201 | 202 | # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 203 | squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared 204 | 205 | return squared_dists.float() 206 | 207 | 208 | def phi(r, order): 209 | """Coordinate-wise nonlinearity used to define the order of the interpolation. 210 | See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. 211 | Args: 212 | r: input op 213 | order: interpolation order 214 | Returns: 215 | phi_k evaluated coordinate-wise on r, for k = r 216 | """ 217 | EPSILON = torch.tensor(1e-10) 218 | # using EPSILON prevents log(0), sqrt0), etc. 219 | # sqrt(0) is well-defined, but its gradient is not 220 | if order == 1: 221 | r = torch.max(r, EPSILON) 222 | r = torch.sqrt(r) 223 | return r 224 | elif order == 2: 225 | return 0.5 * r * torch.log(torch.max(r, EPSILON)) 226 | elif order == 4: 227 | return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON)) 228 | elif order % 2 == 0: 229 | r = torch.max(r, EPSILON) 230 | return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r) 231 | else: 232 | r = torch.max(r, EPSILON) 233 | return torch.pow(r, 0.5 * order) 234 | 235 | 236 | def apply_interpolation(query_points, train_points, w, v, order): 237 | """Apply polyharmonic interpolation model to data. 238 | Given coefficients w and v for the interpolation model, we evaluate 239 | interpolated function values at query_points. 240 | Args: 241 | query_points: `[b, m, d]` x values to evaluate the interpolation at 242 | train_points: `[b, n, d]` x values that act as the interpolation centers 243 | ( the c variables in the wikipedia article) 244 | w: `[b, n, k]` weights on each interpolation center 245 | v: `[b, d, k]` weights on each input dimension 246 | order: order of the interpolation 247 | Returns: 248 | Polyharmonic interpolation evaluated at points defined in query_points. 249 | """ 250 | query_points = query_points.unsqueeze(0) 251 | # First, compute the contribution from the rbf term. 252 | pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float()) 253 | phi_pairwise_dists = phi(pairwise_dists, order) 254 | 255 | rbf_term = torch.matmul(phi_pairwise_dists, w) 256 | 257 | # Then, compute the contribution from the linear term. 258 | # Pad query_points with ones, for the bias term in the linear model. 259 | ones = torch.ones_like(query_points[..., :1]) 260 | query_points_pad = torch.cat(( 261 | query_points, 262 | ones 263 | ), 2).float() 264 | linear_term = torch.matmul(query_points_pad, v) 265 | 266 | return rbf_term + linear_term 267 | 268 | 269 | def dense_image_warp(image, flow): 270 | """Image warping using per-pixel flow vectors. 271 | Apply a non-linear warp to the image, where the warp is specified by a dense 272 | flow field of offset vectors that define the correspondences of pixel values 273 | in the output image back to locations in the source image. Specifically, the 274 | pixel value at output[b, j, i, c] is 275 | images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. 276 | The locations specified by this formula do not necessarily map to an int 277 | index. Therefore, the pixel value is obtained by bilinear 278 | interpolation of the 4 nearest pixels around 279 | (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside 280 | of the image, we use the nearest pixel values at the image boundary. 281 | Args: 282 | image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. 283 | flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. 284 | name: A name for the operation (optional). 285 | Note that image and flow can be of type tf.half, tf.float32, or tf.float64, 286 | and do not necessarily have to be the same type. 287 | Returns: 288 | A 4-D float `Tensor` with shape`[batch, height, width, channels]` 289 | and same type as input image. 290 | Raises: 291 | ValueError: if height < 2 or width < 2 or the inputs have the wrong number 292 | of dimensions. 293 | """ 294 | image = image.unsqueeze(3) # add a single channel dimension to image tensor 295 | batch_size, height, width, channels = image.shape 296 | 297 | # The flow is defined on the image grid. Turn the flow into a list of query 298 | # points in the grid space. 299 | grid_x, grid_y = torch.meshgrid( 300 | torch.arange(width), torch.arange(height)) 301 | 302 | stacked_grid = torch.stack((grid_y, grid_x), dim=2).float() 303 | 304 | batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2) 305 | 306 | query_points_on_grid = batched_grid - flow 307 | query_points_flattened = torch.reshape(query_points_on_grid, 308 | [batch_size, height * width, 2]) 309 | # Compute values at the query points, then reshape the result back to the 310 | # image grid. 311 | interpolated = interpolate_bilinear(image, query_points_flattened) 312 | interpolated = torch.reshape(interpolated, 313 | [batch_size, height, width, channels]) 314 | return interpolated 315 | 316 | 317 | def interpolate_bilinear(grid, 318 | query_points, 319 | name='interpolate_bilinear', 320 | indexing='ij'): 321 | """Similar to Matlab's interp2 function. 322 | Finds values for query points on a grid using bilinear interpolation. 323 | Args: 324 | grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. 325 | query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. 326 | name: a name for the operation (optional). 327 | indexing: whether the query points are specified as row and column (ij), 328 | or Cartesian coordinates (xy). 329 | Returns: 330 | values: a 3-D `Tensor` with shape `[batch, N, channels]` 331 | Raises: 332 | ValueError: if the indexing mode is invalid, or if the shape of the inputs 333 | invalid. 334 | """ 335 | if indexing != 'ij' and indexing != 'xy': 336 | raise ValueError('Indexing mode must be \'ij\' or \'xy\'') 337 | 338 | shape = grid.shape 339 | if len(shape) != 4: 340 | msg = 'Grid must be 4 dimensional. Received size: ' 341 | raise ValueError(msg + str(grid.shape)) 342 | 343 | batch_size, height, width, channels = grid.shape 344 | 345 | shape = [batch_size, height, width, channels] 346 | query_type = query_points.dtype 347 | grid_type = grid.dtype 348 | 349 | num_queries = query_points.shape[1] 350 | 351 | alphas = [] 352 | floors = [] 353 | ceils = [] 354 | index_order = [0, 1] if indexing == 'ij' else [1, 0] 355 | unstacked_query_points = query_points.unbind(2) 356 | 357 | for dim in index_order: 358 | queries = unstacked_query_points[dim] 359 | 360 | size_in_indexing_dimension = shape[dim + 1] 361 | 362 | # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 363 | # is still a valid index into the grid. 364 | max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type) 365 | min_floor = torch.tensor(0.0, dtype=query_type) 366 | maxx = torch.max(min_floor, torch.floor(queries)) 367 | floor = torch.min(maxx, max_floor) 368 | int_floor = floor.long() 369 | floors.append(int_floor) 370 | ceil = int_floor + 1 371 | ceils.append(ceil) 372 | 373 | # alpha has the same type as the grid, as we will directly use alpha 374 | # when taking linear combinations of pixel values from the image. 375 | alpha = queries.clone().detach() - floor.clone().detach() 376 | min_alpha = torch.tensor(0.0, dtype=grid_type) 377 | max_alpha = torch.tensor(1.0, dtype=grid_type) 378 | alpha = torch.min(torch.max(min_alpha, alpha), max_alpha) 379 | 380 | # Expand alpha to [b, n, 1] so we can use broadcasting 381 | # (since the alpha values don't depend on the channel). 382 | alpha = torch.unsqueeze(alpha, 2) 383 | alphas.append(alpha) 384 | 385 | flattened_grid = torch.reshape( 386 | grid, [batch_size * height * width, channels]) 387 | batch_offsets = torch.reshape( 388 | torch.arange(batch_size) * height * width, [batch_size, 1]) 389 | 390 | # This wraps array_ops.gather. We reshape the image data such that the 391 | # batch, y, and x coordinates are pulled into the first dimension. 392 | # Then we gather. Finally, we reshape the output back. It's possible this 393 | # code would be made simpler by using array_ops.gather_nd. 394 | def gather(y_coords, x_coords, name): 395 | linear_coordinates = batch_offsets + y_coords * width + x_coords 396 | gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates) 397 | return torch.reshape(gathered_values, 398 | [batch_size, num_queries, channels]) 399 | 400 | # grab the pixel values in the 4 corners around each query point 401 | top_left = gather(floors[0], floors[1], 'top_left') 402 | top_right = gather(floors[0], ceils[1], 'top_right') 403 | bottom_left = gather(ceils[0], floors[1], 'bottom_left') 404 | bottom_right = gather(ceils[0], ceils[1], 'bottom_right') 405 | 406 | interp_top = alphas[1] * (top_right - top_left) + top_left 407 | interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left 408 | interp = alphas[0] * (interp_bottom - interp_top) + interp_top 409 | 410 | return interp 411 | -------------------------------------------------------------------------------- /deepspeech_pytorch/loader/spec_augment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 RnD at Spoon Radio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """SpecAugment Implementation for Tensorflow. 16 | Related paper : https://arxiv.org/pdf/1904.08779.pdf 17 | In this paper, show summarized parameters by each open datasets in Tabel 1. 18 | ----------------------------------------- 19 | Policy | W | F | m_F | T | p | m_T 20 | ----------------------------------------- 21 | None | 0 | 0 | - | 0 | - | - 22 | ----------------------------------------- 23 | LB | 80 | 27 | 1 | 100 | 1.0 | 1 24 | ----------------------------------------- 25 | LD | 80 | 27 | 2 | 100 | 1.0 | 2 26 | ----------------------------------------- 27 | SM | 40 | 15 | 2 | 70 | 0.2 | 2 28 | ----------------------------------------- 29 | SS | 40 | 27 | 2 | 70 | 0.2 | 2 30 | ----------------------------------------- 31 | LB : LibriSpeech basic 32 | LD : LibriSpeech double 33 | SM : Switchboard mild 34 | SS : Switchboard strong 35 | """ 36 | 37 | import librosa 38 | import librosa.display 39 | import numpy as np 40 | import random 41 | import matplotlib 42 | matplotlib.use('Agg') 43 | import matplotlib.pyplot as plt 44 | from .sparse_image_warp import sparse_image_warp 45 | import torch 46 | 47 | 48 | def time_warp(spec, W=5): 49 | num_rows = spec.shape[1] 50 | spec_len = spec.shape[2] 51 | 52 | y = num_rows // 2 53 | horizontal_line_at_ctr = spec[0][y] 54 | # assert len(horizontal_line_at_ctr) == spec_len 55 | 56 | point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len-W)] 57 | # assert isinstance(point_to_warp, torch.Tensor) 58 | 59 | # Uniform distribution from (0,W) with chance to be up to W negative 60 | dist_to_warp = random.randrange(-W, W) 61 | src_pts = torch.tensor([[[y, point_to_warp]]]) 62 | dest_pts = torch.tensor([[[y, point_to_warp + dist_to_warp]]]) 63 | warped_spectro, dense_flows = sparse_image_warp(spec, src_pts, dest_pts) 64 | 65 | return warped_spectro.squeeze(3) 66 | 67 | 68 | def spec_augment(mel_spectrogram, time_warping_para=40, frequency_masking_para=27, 69 | time_masking_para=70, frequency_mask_num=1, time_mask_num=1): 70 | """Spec augmentation Calculation Function. 71 | 'SpecAugment' have 3 steps for audio data augmentation. 72 | first step is time warping using Tensorflow's image_sparse_warp function. 73 | Second step is frequency masking, last step is time masking. 74 | # Arguments: 75 | mel_spectrogram(numpy array): audio file path of you want to warping and masking. 76 | time_warping_para(float): Augmentation parameter, "time warp parameter W". 77 | If none, default = 40. 78 | frequency_masking_para(float): Augmentation parameter, "frequency mask parameter F" 79 | If none, default = 27. 80 | time_masking_para(float): Augmentation parameter, "time mask parameter T" 81 | If none, default = 70. 82 | frequency_mask_num(float): number of frequency masking lines, "m_F". 83 | If none, default = 1. 84 | time_mask_num(float): number of time masking lines, "m_T". 85 | If none, default = 1. 86 | # Returns 87 | mel_spectrogram(numpy array): warped and masked mel spectrogram. 88 | """ 89 | mel_spectrogram = mel_spectrogram.unsqueeze(0) 90 | 91 | v = mel_spectrogram.shape[1] 92 | tau = mel_spectrogram.shape[2] 93 | 94 | # Step 1 : Time warping 95 | warped_mel_spectrogram = time_warp(mel_spectrogram) 96 | 97 | # Step 2 : Frequency masking 98 | for i in range(frequency_mask_num): 99 | f = np.random.uniform(low=0.0, high=frequency_masking_para) 100 | f = int(f) 101 | if v - f < 0: 102 | continue 103 | f0 = random.randint(0, v-f) 104 | warped_mel_spectrogram[:, f0:f0+f, :] = 0 105 | 106 | # Step 3 : Time masking 107 | for i in range(time_mask_num): 108 | t = np.random.uniform(low=0.0, high=time_masking_para) 109 | t = int(t) 110 | if tau - t < 0: 111 | continue 112 | t0 = random.randint(0, tau-t) 113 | warped_mel_spectrogram[:, :, t0:t0+t] = 0 114 | 115 | return warped_mel_spectrogram.squeeze() 116 | 117 | 118 | def visualization_spectrogram(mel_spectrogram, title): 119 | """visualizing result of SpecAugment 120 | # Arguments: 121 | mel_spectrogram(ndarray): mel_spectrogram to visualize. 122 | title(String): plot figure's title 123 | """ 124 | # Show mel-spectrogram using librosa's specshow. 125 | plt.figure(figsize=(10, 4)) 126 | librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :], ref=np.max), y_axis='mel', fmax=8000, x_axis='time') 127 | # plt.colorbar(format='%+2.0f dB') 128 | plt.title(title) 129 | plt.tight_layout() 130 | plt.show() 131 | -------------------------------------------------------------------------------- /deepspeech_pytorch/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Union 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from omegaconf import OmegaConf 9 | from torch.cuda.amp import autocast 10 | from torch.nn import CTCLoss 11 | 12 | from deepspeech_pytorch.configs.train_config import SpectConfig, BiDirectionalConfig, OptimConfig, AdamConfig, \ 13 | SGDConfig, UniDirectionalConfig 14 | from deepspeech_pytorch.decoder import GreedyDecoder 15 | from deepspeech_pytorch.validation import CharErrorRate, WordErrorRate 16 | 17 | 18 | class SequenceWise(nn.Module): 19 | def __init__(self, module): 20 | """ 21 | Collapses input of dim T*N*H to (T*N)*H, and applies to a module. 22 | Allows handling of variable sequence lengths and minibatch sizes. 23 | :param module: Module to apply input to. 24 | """ 25 | super(SequenceWise, self).__init__() 26 | self.module = module 27 | 28 | def forward(self, x): 29 | t, n = x.size(0), x.size(1) 30 | x = x.view(t * n, -1) 31 | x = self.module(x) 32 | x = x.view(t, n, -1) 33 | return x 34 | 35 | def __repr__(self): 36 | tmpstr = self.__class__.__name__ + ' (\n' 37 | tmpstr += self.module.__repr__() 38 | tmpstr += ')' 39 | return tmpstr 40 | 41 | 42 | class MaskConv(nn.Module): 43 | def __init__(self, seq_module): 44 | """ 45 | Adds padding to the output of the module based on the given lengths. This is to ensure that the 46 | results of the model do not change when batch sizes change during inference. 47 | Input needs to be in the shape of (BxCxDxT) 48 | :param seq_module: The sequential module containing the conv stack. 49 | """ 50 | super(MaskConv, self).__init__() 51 | self.seq_module = seq_module 52 | 53 | def forward(self, x, lengths): 54 | """ 55 | :param x: The input of size BxCxDxT 56 | :param lengths: The actual length of each sequence in the batch 57 | :return: Masked output from the module 58 | """ 59 | for module in self.seq_module: 60 | x = module(x) 61 | mask = torch.BoolTensor(x.size()).fill_(0) 62 | if x.is_cuda: 63 | mask = mask.cuda() 64 | for i, length in enumerate(lengths): 65 | length = length.item() 66 | if (mask[i].size(2) - length) > 0: 67 | mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) 68 | x = x.masked_fill(mask, 0) 69 | return x, lengths 70 | 71 | 72 | class InferenceBatchSoftmax(nn.Module): 73 | def forward(self, input_): 74 | if not self.training: 75 | return F.softmax(input_, dim=-1) 76 | else: 77 | return input_ 78 | 79 | 80 | class BatchRNN(nn.Module): 81 | def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True): 82 | super(BatchRNN, self).__init__() 83 | self.input_size = input_size 84 | self.hidden_size = hidden_size 85 | self.bidirectional = bidirectional 86 | self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None 87 | self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, 88 | bidirectional=bidirectional, bias=True) 89 | self.num_directions = 2 if bidirectional else 1 90 | 91 | def flatten_parameters(self): 92 | self.rnn.flatten_parameters() 93 | 94 | def forward(self, x, output_lengths, h=None): 95 | if self.batch_norm is not None: 96 | x = self.batch_norm(x) 97 | x = nn.utils.rnn.pack_padded_sequence(x, output_lengths) 98 | x, h = self.rnn(x, h) 99 | x, _ = nn.utils.rnn.pad_packed_sequence(x) 100 | if self.bidirectional: 101 | x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum 102 | return x, h 103 | 104 | 105 | class Lookahead(nn.Module): 106 | # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks 107 | # input shape - sequence, batch, feature - TxNxH 108 | # output shape - same as input 109 | def __init__(self, n_features, context): 110 | super(Lookahead, self).__init__() 111 | assert context > 0 112 | self.context = context 113 | self.n_features = n_features 114 | self.pad = (0, self.context - 1) 115 | self.conv = nn.Conv1d( 116 | self.n_features, 117 | self.n_features, 118 | kernel_size=self.context, 119 | stride=1, 120 | groups=self.n_features, 121 | padding=0, 122 | bias=False 123 | ) 124 | 125 | def forward(self, x): 126 | x = x.transpose(0, 1).transpose(1, 2) 127 | x = F.pad(x, pad=self.pad, value=0) 128 | x = self.conv(x) 129 | x = x.transpose(1, 2).transpose(0, 1).contiguous() 130 | return x 131 | 132 | def __repr__(self): 133 | return self.__class__.__name__ + '(' \ 134 | + 'n_features=' + str(self.n_features) \ 135 | + ', context=' + str(self.context) + ')' 136 | 137 | 138 | class DeepSpeech(pl.LightningModule): 139 | def __init__(self, 140 | labels: List, 141 | model_cfg: Union[UniDirectionalConfig, BiDirectionalConfig], 142 | precision: int, 143 | optim_cfg: Union[AdamConfig, SGDConfig], 144 | spect_cfg: SpectConfig 145 | ): 146 | super().__init__() 147 | self.save_hyperparameters() 148 | self.model_cfg = model_cfg 149 | self.precision = precision 150 | self.optim_cfg = optim_cfg 151 | self.spect_cfg = spect_cfg 152 | self.bidirectional = True if OmegaConf.get_type(model_cfg) is BiDirectionalConfig else False 153 | 154 | self.labels = labels 155 | num_classes = len(self.labels) 156 | 157 | self.conv = MaskConv(nn.Sequential( 158 | nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), 159 | nn.BatchNorm2d(32), 160 | nn.Hardtanh(0, 20, inplace=True), 161 | nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), 162 | nn.BatchNorm2d(32), 163 | nn.Hardtanh(0, 20, inplace=True) 164 | )) 165 | # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 166 | rnn_input_size = int(math.floor((self.spect_cfg.sample_rate * self.spect_cfg.window_size) / 2) + 1) 167 | rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) 168 | rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) 169 | rnn_input_size *= 32 170 | 171 | self.rnns = nn.Sequential( 172 | BatchRNN( 173 | input_size=rnn_input_size, 174 | hidden_size=self.model_cfg.hidden_size, 175 | rnn_type=self.model_cfg.rnn_type.value, 176 | bidirectional=self.bidirectional, 177 | batch_norm=False 178 | ), 179 | *( 180 | BatchRNN( 181 | input_size=self.model_cfg.hidden_size, 182 | hidden_size=self.model_cfg.hidden_size, 183 | rnn_type=self.model_cfg.rnn_type.value, 184 | bidirectional=self.bidirectional 185 | ) for x in range(self.model_cfg.hidden_layers - 1) 186 | ) 187 | ) 188 | 189 | self.lookahead = nn.Sequential( 190 | # consider adding batch norm? 191 | Lookahead(self.model_cfg.hidden_size, context=self.model_cfg.lookahead_context), 192 | nn.Hardtanh(0, 20, inplace=True) 193 | ) if not self.bidirectional else None 194 | 195 | fully_connected = nn.Sequential( 196 | nn.BatchNorm1d(self.model_cfg.hidden_size), 197 | nn.Linear(self.model_cfg.hidden_size, num_classes, bias=False) 198 | ) 199 | self.fc = nn.Sequential( 200 | SequenceWise(fully_connected), 201 | ) 202 | self.inference_softmax = InferenceBatchSoftmax() 203 | self.criterion = CTCLoss(blank=self.labels.index('_'), reduction='sum', zero_infinity=True) 204 | self.evaluation_decoder = GreedyDecoder(self.labels) # Decoder used for validation 205 | self.wer = WordErrorRate( 206 | decoder=self.evaluation_decoder, 207 | target_decoder=self.evaluation_decoder 208 | ) 209 | self.cer = CharErrorRate( 210 | decoder=self.evaluation_decoder, 211 | target_decoder=self.evaluation_decoder 212 | ) 213 | 214 | def forward(self, x, lengths, hs=None): 215 | lengths = lengths.cpu().int() 216 | output_lengths = self.get_seq_lens(lengths) 217 | x, _ = self.conv(x, output_lengths) 218 | 219 | sizes = x.size() 220 | x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension 221 | x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH 222 | 223 | # if hs is None, create a list of None values corresponding to the number of rnn layers 224 | if hs is None: 225 | hs = [None] * len(self.rnns) 226 | 227 | new_hs = [] 228 | for i, rnn in enumerate(self.rnns): 229 | x, h = rnn(x, output_lengths, hs[i]) 230 | new_hs.append(h) 231 | 232 | if not self.bidirectional: # no need for lookahead layer in bidirectional 233 | x = self.lookahead(x) 234 | 235 | x = self.fc(x) 236 | x = x.transpose(0, 1) 237 | # identity in training mode, softmax in eval mode 238 | x = self.inference_softmax(x) 239 | return x, output_lengths, new_hs 240 | 241 | def training_step(self, batch, batch_idx): 242 | inputs, targets, input_percentages, target_sizes = batch 243 | input_sizes = input_percentages.mul_(int(inputs.size(3))).int() 244 | out, output_sizes, hs = self(inputs, input_sizes) 245 | out = out.transpose(0, 1) # TxNxH 246 | out = out.log_softmax(-1) 247 | 248 | loss = self.criterion(out, targets, output_sizes, target_sizes) 249 | return loss 250 | 251 | def validation_step(self, batch, batch_idx): 252 | inputs, targets, input_percentages, target_sizes = batch 253 | input_sizes = input_percentages.mul_(int(inputs.size(3))).int() 254 | inputs = inputs.to(self.device) 255 | with autocast(enabled=self.precision == 16): 256 | out, output_sizes, hs = self(inputs, input_sizes) 257 | decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes) 258 | self.wer( 259 | preds=out, 260 | preds_sizes=output_sizes, 261 | targets=targets, 262 | target_sizes=target_sizes 263 | ) 264 | self.cer( 265 | preds=out, 266 | preds_sizes=output_sizes, 267 | targets=targets, 268 | target_sizes=target_sizes 269 | ) 270 | self.log('wer', self.wer.compute(), prog_bar=True, on_epoch=True) 271 | self.log('cer', self.cer.compute(), prog_bar=True, on_epoch=True) 272 | 273 | def configure_optimizers(self): 274 | if OmegaConf.get_type(self.optim_cfg) is SGDConfig: 275 | optimizer = torch.optim.SGD( 276 | params=self.parameters(), 277 | lr=self.optim_cfg.learning_rate, 278 | momentum=self.optim_cfg.momentum, 279 | nesterov=True, 280 | weight_decay=self.optim_cfg.weight_decay 281 | ) 282 | elif OmegaConf.get_type(self.optim_cfg) is AdamConfig: 283 | optimizer = torch.optim.AdamW( 284 | params=self.parameters(), 285 | lr=self.optim_cfg.learning_rate, 286 | betas=self.optim_cfg.betas, 287 | eps=self.optim_cfg.eps, 288 | weight_decay=self.optim_cfg.weight_decay 289 | ) 290 | else: 291 | raise ValueError("Optimizer has not been specified correctly.") 292 | 293 | scheduler = torch.optim.lr_scheduler.ExponentialLR( 294 | optimizer=optimizer, 295 | gamma=self.optim_cfg.learning_anneal 296 | ) 297 | return [optimizer], [scheduler] 298 | 299 | def get_seq_lens(self, input_length): 300 | """ 301 | Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable 302 | containing the size sequences that will be output by the network. 303 | :param input_length: 1D Tensor 304 | :return: 1D Tensor scaled by model 305 | """ 306 | seq_len = input_length 307 | for m in self.conv.modules(): 308 | if type(m) == nn.modules.conv.Conv2d: 309 | seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1) 310 | return seq_len.int() 311 | -------------------------------------------------------------------------------- /deepspeech_pytorch/testing.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | 4 | from deepspeech_pytorch.configs.inference_config import EvalConfig 5 | from deepspeech_pytorch.decoder import GreedyDecoder 6 | from deepspeech_pytorch.loader.data_loader import SpectrogramDataset, AudioDataLoader 7 | from deepspeech_pytorch.utils import load_model, load_decoder 8 | from deepspeech_pytorch.validation import run_evaluation 9 | 10 | 11 | @torch.no_grad() 12 | def evaluate(cfg: EvalConfig): 13 | device = torch.device("cuda" if cfg.model.cuda else "cpu") 14 | 15 | model = load_model( 16 | device=device, 17 | model_path=cfg.model.model_path 18 | ) 19 | 20 | decoder = load_decoder( 21 | labels=model.labels, 22 | cfg=cfg.lm 23 | ) 24 | target_decoder = GreedyDecoder( 25 | labels=model.labels, 26 | blank_index=model.labels.index('_') 27 | ) 28 | test_dataset = SpectrogramDataset( 29 | audio_conf=model.spect_cfg, 30 | input_path=hydra.utils.to_absolute_path(cfg.test_path), 31 | labels=model.labels, 32 | normalize=True 33 | ) 34 | test_loader = AudioDataLoader( 35 | test_dataset, 36 | batch_size=cfg.batch_size, 37 | num_workers=cfg.num_workers 38 | ) 39 | wer, cer = run_evaluation( 40 | test_loader=test_loader, 41 | device=device, 42 | model=model, 43 | decoder=decoder, 44 | target_decoder=target_decoder, 45 | precision=cfg.model.precision 46 | ) 47 | 48 | print('Test Summary \t' 49 | 'Average WER {wer:.3f}\t' 50 | 'Average CER {cer:.3f}\t'.format(wer=wer, cer=cer)) 51 | -------------------------------------------------------------------------------- /deepspeech_pytorch/training.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import hydra 4 | from hydra.utils import to_absolute_path 5 | from pytorch_lightning import seed_everything 6 | 7 | from deepspeech_pytorch.checkpoint import FileCheckpointHandler 8 | from deepspeech_pytorch.configs.train_config import DeepSpeechConfig 9 | from deepspeech_pytorch.loader.data_module import DeepSpeechDataModule 10 | from deepspeech_pytorch.model import DeepSpeech 11 | 12 | 13 | def train(cfg: DeepSpeechConfig): 14 | seed_everything(cfg.seed) 15 | 16 | with open(to_absolute_path(cfg.data.labels_path)) as label_file: 17 | labels = json.load(label_file) 18 | 19 | if cfg.trainer.enable_checkpointing: 20 | checkpoint_callback = FileCheckpointHandler( 21 | cfg=cfg.checkpoint 22 | ) 23 | if cfg.load_auto_checkpoint: 24 | resume_from_checkpoint = checkpoint_callback.find_latest_checkpoint() 25 | if resume_from_checkpoint: 26 | cfg.trainer.resume_from_checkpoint = resume_from_checkpoint 27 | 28 | data_loader = DeepSpeechDataModule( 29 | labels=labels, 30 | data_cfg=cfg.data, 31 | normalize=True, 32 | ) 33 | 34 | model = DeepSpeech( 35 | labels=labels, 36 | model_cfg=cfg.model, 37 | optim_cfg=cfg.optim, 38 | precision=cfg.trainer.precision, 39 | spect_cfg=cfg.data.spect 40 | ) 41 | 42 | trainer = hydra.utils.instantiate( 43 | config=cfg.trainer, 44 | replace_sampler_ddp=False, 45 | callbacks=[checkpoint_callback] if cfg.trainer.enable_checkpointing else None, 46 | ) 47 | trainer.fit(model, data_loader) 48 | -------------------------------------------------------------------------------- /deepspeech_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | 4 | from deepspeech_pytorch.configs.inference_config import LMConfig 5 | from deepspeech_pytorch.decoder import GreedyDecoder 6 | from deepspeech_pytorch.enums import DecoderType 7 | from deepspeech_pytorch.model import DeepSpeech 8 | 9 | 10 | def check_loss(loss, loss_value): 11 | """ 12 | Check that warp-ctc loss is valid and will not break training 13 | :return: Return if loss is valid, and the error in case it is not 14 | """ 15 | loss_valid = True 16 | error = '' 17 | if loss_value == float("inf") or loss_value == float("-inf"): 18 | loss_valid = False 19 | error = "WARNING: received an inf loss" 20 | elif torch.isnan(loss).sum() > 0: 21 | loss_valid = False 22 | error = 'WARNING: received a nan loss, setting loss value to 0' 23 | elif loss_value < 0: 24 | loss_valid = False 25 | error = "WARNING: received a negative loss" 26 | return loss_valid, error 27 | 28 | 29 | def load_model(device, 30 | model_path): 31 | model = DeepSpeech.load_from_checkpoint(hydra.utils.to_absolute_path(model_path)) 32 | model.eval() 33 | model = model.to(device) 34 | return model 35 | 36 | 37 | def load_decoder(labels, cfg: LMConfig): 38 | if cfg.decoder_type == DecoderType.beam: 39 | from deepspeech_pytorch.decoder import BeamCTCDecoder 40 | if cfg.lm_path: 41 | cfg.lm_path = hydra.utils.to_absolute_path(cfg.lm_path) 42 | decoder = BeamCTCDecoder(labels=labels, 43 | lm_path=cfg.lm_path, 44 | alpha=cfg.alpha, 45 | beta=cfg.beta, 46 | cutoff_top_n=cfg.cutoff_top_n, 47 | cutoff_prob=cfg.cutoff_prob, 48 | beam_width=cfg.beam_width, 49 | num_processes=cfg.lm_workers, 50 | blank_index=labels.index('_')) 51 | else: 52 | decoder = GreedyDecoder(labels=labels, 53 | blank_index=labels.index('_')) 54 | return decoder 55 | 56 | 57 | def remove_parallel_wrapper(model): 58 | """ 59 | Return the model or extract the model out of the parallel wrapper 60 | :param model: The training model 61 | :return: The model without parallel wrapper 62 | """ 63 | # Take care of distributed/data-parallel wrapper 64 | model_no_wrapper = model.module if hasattr(model, "module") else model 65 | return model_no_wrapper 66 | -------------------------------------------------------------------------------- /deepspeech_pytorch/validation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from torch.cuda.amp import autocast 5 | from torchmetrics import Metric 6 | from tqdm import tqdm 7 | 8 | from deepspeech_pytorch.decoder import Decoder, GreedyDecoder 9 | 10 | import Levenshtein as Lev 11 | 12 | 13 | class ErrorRate(Metric, ABC): 14 | def __init__(self, 15 | decoder: Decoder, 16 | target_decoder: GreedyDecoder, 17 | save_output: bool = False, 18 | dist_sync_on_step: bool = False): 19 | super().__init__(dist_sync_on_step=dist_sync_on_step) 20 | self.decoder = decoder 21 | self.target_decoder = target_decoder 22 | self.save_output = save_output 23 | 24 | @abstractmethod 25 | def calculate_metric(self, transcript, reference): 26 | raise NotImplementedError 27 | 28 | def update(self, preds: torch.Tensor, 29 | preds_sizes: torch.Tensor, 30 | targets: torch.Tensor, 31 | target_sizes: torch.Tensor): 32 | # unflatten targets 33 | split_targets = [] 34 | offset = 0 35 | for size in target_sizes: 36 | split_targets.append(targets[offset:offset + size]) 37 | offset += size 38 | decoded_output, _ = self.decoder.decode(preds, preds_sizes) 39 | target_strings = self.target_decoder.convert_to_strings(split_targets) 40 | for x in range(len(target_strings)): 41 | transcript, reference = decoded_output[x][0], target_strings[x][0] 42 | self.calculate_metric( 43 | transcript=transcript, 44 | reference=reference 45 | ) 46 | 47 | 48 | class CharErrorRate(ErrorRate): 49 | def __init__(self, 50 | decoder: Decoder, 51 | target_decoder: GreedyDecoder, 52 | save_output: bool = False, 53 | dist_sync_on_step: bool = False): 54 | super().__init__( 55 | decoder=decoder, 56 | target_decoder=target_decoder, 57 | save_output=save_output, 58 | dist_sync_on_step=dist_sync_on_step 59 | ) 60 | self.decoder = decoder 61 | self.target_decoder = target_decoder 62 | self.save_output = save_output 63 | self.add_state("cer", default=torch.tensor(0), dist_reduce_fx="sum") 64 | self.add_state("n_chars", default=torch.tensor(0), dist_reduce_fx="sum") 65 | 66 | def calculate_metric(self, transcript, reference): 67 | cer_inst = self.cer_calc(transcript, reference) 68 | self.cer += cer_inst 69 | self.n_chars += len(reference.replace(' ', '')) 70 | 71 | def compute(self): 72 | cer = float(self.cer) / self.n_chars 73 | return cer.item() * 100 74 | 75 | def cer_calc(self, s1, s2): 76 | """ 77 | Computes the Character Error Rate, defined as the edit distance. 78 | 79 | Arguments: 80 | s1 (string): space-separated sentence 81 | s2 (string): space-separated sentence 82 | """ 83 | s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') 84 | return Lev.distance(s1, s2) 85 | 86 | 87 | class WordErrorRate(ErrorRate): 88 | def __init__(self, 89 | decoder: Decoder, 90 | target_decoder: GreedyDecoder, 91 | save_output: bool = False, 92 | dist_sync_on_step: bool = False): 93 | super().__init__( 94 | decoder=decoder, 95 | target_decoder=target_decoder, 96 | save_output=save_output, 97 | dist_sync_on_step=dist_sync_on_step 98 | ) 99 | self.decoder = decoder 100 | self.target_decoder = target_decoder 101 | self.save_output = save_output 102 | self.add_state("wer", default=torch.tensor(0), dist_reduce_fx="sum") 103 | self.add_state("n_tokens", default=torch.tensor(0), dist_reduce_fx="sum") 104 | 105 | def calculate_metric(self, transcript, reference): 106 | wer_inst = self.wer_calc(transcript, reference) 107 | self.wer += wer_inst 108 | self.n_tokens += len(reference.split()) 109 | 110 | def compute(self): 111 | wer = float(self.wer) / self.n_tokens 112 | return wer.item() * 100 113 | 114 | def wer_calc(self, s1, s2): 115 | """ 116 | Computes the Word Error Rate, defined as the edit distance between the 117 | two provided sentences after tokenizing to words. 118 | Arguments: 119 | s1 (string): space-separated sentence 120 | s2 (string): space-separated sentence 121 | """ 122 | 123 | # build mapping of words to integers 124 | b = set(s1.split() + s2.split()) 125 | word2char = dict(zip(b, range(len(b)))) 126 | 127 | # map the words to a char array (Levenshtein packages only accepts 128 | # strings) 129 | w1 = [chr(word2char[w]) for w in s1.split()] 130 | w2 = [chr(word2char[w]) for w in s2.split()] 131 | 132 | return Lev.distance(''.join(w1), ''.join(w2)) 133 | 134 | 135 | @torch.no_grad() 136 | def run_evaluation(test_loader, 137 | model, 138 | decoder: Decoder, 139 | device: torch.device, 140 | target_decoder: Decoder, 141 | precision: int): 142 | model.eval() 143 | wer = WordErrorRate( 144 | decoder=decoder, 145 | target_decoder=target_decoder 146 | ) 147 | cer = CharErrorRate( 148 | decoder=decoder, 149 | target_decoder=target_decoder 150 | ) 151 | for i, (batch) in tqdm(enumerate(test_loader), total=len(test_loader)): 152 | inputs, targets, input_percentages, target_sizes = batch 153 | input_sizes = input_percentages.mul_(int(inputs.size(3))).int() 154 | inputs = inputs.to(device) 155 | with autocast(enabled=precision == 16): 156 | out, output_sizes, hs = model(inputs, input_sizes) 157 | decoded_output, _ = decoder.decode(out, output_sizes) 158 | wer.update( 159 | preds=out, 160 | preds_sizes=output_sizes, 161 | targets=targets, 162 | target_sizes=target_sizes 163 | ) 164 | cer.update( 165 | preds=out, 166 | preds_sizes=output_sizes, 167 | targets=targets, 168 | target_sizes=target_sizes 169 | ) 170 | return wer.compute(), cer.compute() 171 | -------------------------------------------------------------------------------- /kubernetes/README.md: -------------------------------------------------------------------------------- 1 | # Training deepspeech.pytorch on Kubernetes using TorchElastic 2 | 3 | Below are instructions to train a model using a GKE cluster and take advantage of pre-emptible VMs using [TorchElastic](https://pytorch.org/elastic/master/index.html). 4 | 5 | ``` 6 | gcloud container clusters create torchelastic \ 7 | --machine-type=n1-standard-2 \ 8 | --disk-size=15Gi \ 9 | --zone=us-west1-b \ 10 | --cluster-version=1.15 \ 11 | --num-nodes=3 --min-nodes=0 --max-nodes=3 \ 12 | --enable-autoscaling \ 13 | --scopes=storage-full 14 | 15 | # Add GPU pool 16 | gcloud container node-pools create gpu-pool --cluster torchelastic \ 17 | --accelerator type=nvidia-tesla-v100,count=1\ 18 | --machine-type=n1-standard-4 \ 19 | --disk-size=25Gi \ 20 | --zone=us-west1-b \ 21 | --preemptible \ 22 | --num-nodes=1 --min-nodes=0 --max-nodes=1 \ 23 | --enable-autoscaling \ 24 | --scopes=storage-full 25 | ``` 26 | 27 | We use pre-emptive nodes to reduce costs. The code handles interruptions by saving state to GCS periodically. 28 | 29 | ## Set up ElasticJob 30 | 31 | ``` 32 | git clone https://github.com/pytorch/elastic.git 33 | cd elastic/kubernetes 34 | 35 | kubectl apply -k config/default 36 | ``` 37 | 38 | ### Setup Volume 39 | 40 | First we create a drive to store our data. The drive is fairly small and can be managed under the volumes tab on GCP. Modify the config as needs be. 41 | 42 | ``` 43 | cd deepspeech.pytorch/kubernetes/ 44 | gcloud compute disks create --size 10Gi audio-data --zone us-west1-b 45 | kubectl apply -f data/storage.yaml 46 | kubectl apply -f data/persistent_volume.yaml 47 | ``` 48 | 49 | ### Download Data 50 | 51 | We run a job to download and extract the data onto our drive. Modify the config to match whatever data you'd like to download. 52 | 53 | In our example we download and extract the AN4 data using the deepspeech.pytorch docker image. 54 | 55 | ``` 56 | kubectl apply -f data/transfer_data.yaml 57 | kubectl logs transfer-data --namespace=elastic-job # Monitor logs to determine when the process is complete 58 | kubectl delete -f data/transfer_data.yaml # Delete to free up resources 59 | ``` 60 | 61 | ### Install CRD/CUDA for Training 62 | 63 | ``` 64 | kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded.yaml 65 | kubectl apply -f https://raw.githubusercontent.com/pytorch/elastic/master/kubernetes/config/samples/etcd.yaml 66 | kubectl get svc -n elastic-job 67 | ``` 68 | 69 | ### Training 70 | 71 | #### GCS Model Store 72 | 73 | To store the checkpoint models, we use [Google Cloud Storage](https://cloud.google.com/storage). Create a bucket and make sure to modify `checkpointing.gcs_bucket=deepspeech-1234` to `train.yaml` to point to the bucket that the cluster has access to. 74 | 75 | ``` 76 | kubectl apply -f train.yaml 77 | ``` 78 | -------------------------------------------------------------------------------- /kubernetes/data/persistent_volume.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: audio-data 5 | namespace: elastic-job 6 | spec: 7 | storageClassName: "storage" 8 | capacity: 9 | storage: 10Gi 10 | accessModes: 11 | - ReadWriteOnce 12 | gcePersistentDisk: 13 | pdName: audio-data 14 | fsType: ext4 15 | --- 16 | apiVersion: v1 17 | kind: PersistentVolumeClaim 18 | metadata: 19 | name: audio-data 20 | namespace: elastic-job 21 | spec: 22 | storageClassName: "storage" 23 | accessModes: 24 | - ReadWriteOnce 25 | resources: 26 | requests: 27 | storage: 10Gi -------------------------------------------------------------------------------- /kubernetes/data/storage.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: storage.k8s.io/v1 2 | kind: StorageClass 3 | provisioner: pd.csi.storage.gke.io 4 | metadata: 5 | name: storage 6 | namespace: elastic-job 7 | parameters: 8 | type: pd-ssd 9 | fstype: ext4 10 | replication-type: none -------------------------------------------------------------------------------- /kubernetes/data/transfer_data.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | name: transfer-data 5 | namespace: elastic-job 6 | spec: 7 | containers: 8 | - image: seannaren/deepspeech.pytorch:latest 9 | imagePullPolicy: Always 10 | name: deepspeech 11 | command: ["python"] 12 | args: 13 | - "/workspace/deepspeech.pytorch/data/an4.py" 14 | - "--target-dir=/audio-data/an4_dataset/" 15 | - "--manifest-dir=/audio-data/an4_manifests/" 16 | volumeMounts: 17 | - mountPath: /audio-data/ 18 | name: audio-data 19 | restartPolicy: Never 20 | volumes: 21 | - name: audio-data 22 | persistentVolumeClaim: 23 | claimName: audio-data -------------------------------------------------------------------------------- /kubernetes/train.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: elastic.pytorch.org/v1alpha1 2 | kind: ElasticJob 3 | metadata: 4 | name: deepspeech 5 | namespace: elastic-job 6 | spec: 7 | # Use "etcd-service:2379" if you already apply etcd.yaml 8 | rdzvEndpoint: "etcd-service:2379" 9 | minReplicas: 1 10 | maxReplicas: 1 11 | replicaSpecs: 12 | Worker: 13 | replicas: 1 14 | restartPolicy: ExitCode 15 | template: 16 | apiVersion: v1 17 | kind: Pod 18 | spec: 19 | containers: 20 | - name: deepspeech 21 | image: seannaren/deepspeech.pytorch:latest 22 | imagePullPolicy: Always 23 | command: ["python", "-m", "torchelastic.distributed.launch"] 24 | args: 25 | - "--nproc_per_node=1" 26 | - "/workspace/deepspeech.pytorch/train.py" 27 | - "data.train_path=/audio-data/an4_manifests/an4_train_manifest.csv" 28 | - "data.val_path=/audio-data/an4_manifests/an4_val_manifest.csv" 29 | - "data.labels_path=/workspace/deepspeech.pytorch/labels.json" 30 | - "data.num_workers=8" 31 | - "training.epochs=70" 32 | - "data.batch_size=8" 33 | - "training.multigpu=distributed" 34 | - "model.precision=half" 35 | - "checkpointing=gcs" 36 | - "checkpointing.gcs_bucket=deepspeech-1234" # Swap this to point to the appropriate GCS bucket 37 | - "checkpointing.gcs_save_folder=models/" 38 | - "checkpointing.load_auto_checkpoint=true" 39 | resources: 40 | limits: 41 | nvidia.com/gpu: 1 42 | volumeMounts: 43 | - mountPath: /audio-data/ 44 | name: audio-data 45 | readOnly: true 46 | volumes: 47 | - name: audio-data 48 | persistentVolumeClaim: 49 | claimName: audio-data 50 | readOnly: true 51 | nodeSelector: 52 | cloud.google.com/gke-nodepool: gpu-pool 53 | -------------------------------------------------------------------------------- /labels.json: -------------------------------------------------------------------------------- 1 | [ 2 | "_", 3 | "'", 4 | "A", 5 | "B", 6 | "C", 7 | "D", 8 | "E", 9 | "F", 10 | "G", 11 | "H", 12 | "I", 13 | "J", 14 | "K", 15 | "L", 16 | "M", 17 | "N", 18 | "O", 19 | "P", 20 | "Q", 21 | "R", 22 | "S", 23 | "T", 24 | "U", 25 | "V", 26 | "W", 27 | "X", 28 | "Y", 29 | "Z", 30 | " " 31 | ] -------------------------------------------------------------------------------- /noise_inject.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from scipy.io.wavfile import write 5 | 6 | from deepspeech_pytorch.loader.data_loader import load_audio, NoiseInjection 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--input-path', default='input.wav', help='The input audio to inject noise into') 10 | parser.add_argument('--noise-path', default='noise.wav', help='The noise file to mix in') 11 | parser.add_argument('--output-path', default='output.wav', help='The noise file to mix in') 12 | parser.add_argument('--sample-rate', default=16000, help='Sample rate to save output as') 13 | parser.add_argument('--noise-level', type=float, default=1.0, 14 | help='The Signal to Noise ratio (higher means more noise)') 15 | args = parser.parse_args() 16 | 17 | noise_injector = NoiseInjection() 18 | data = load_audio(args.input_path) 19 | mixed_data = noise_injector.inject_noise_sample(data, args.noise_path, args.noise_level) 20 | mixed_data = torch.tensor(mixed_data, dtype=torch.float).unsqueeze(1) # Add channels dim 21 | write(filename=args.output_path, 22 | data=mixed_data.numpy(), 23 | rate=args.sample_rate) 24 | print('Saved mixed file to %s' % args.output_path) 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | hydra-core 3 | jupyter 4 | librosa 5 | matplotlib 6 | numpy 7 | optuna 8 | pytest 9 | python-levenshtein 10 | pytorch-lightning>=1.7.0 11 | scipy 12 | sklearn 13 | sox 14 | torch 15 | torchaudio 16 | torchelastic 17 | tqdm 18 | wget 19 | git+https://github.com/romesco/hydra-lightning/#subdirectory=hydra-configs-pytorch-lightning 20 | -------------------------------------------------------------------------------- /search_lm_params.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import hydra 4 | from hydra.core.config_store import ConfigStore 5 | import optuna 6 | import torch 7 | 8 | from deepspeech_pytorch.configs.train_config import SpectConfig 9 | from deepspeech_pytorch.decoder import BeamCTCDecoder, GreedyDecoder 10 | from deepspeech_pytorch.loader.data_loader import AudioDataLoader, SpectrogramDataset 11 | from deepspeech_pytorch.utils import load_model 12 | from deepspeech_pytorch.validation import run_evaluation 13 | 14 | 15 | @dataclass 16 | class OptimizerConfig: 17 | model_path: str = '' 18 | test_path: str = '' # Path to test manifest or csv 19 | is_character_based: bool = True # Use CER or WER for finding optimal parameters 20 | lm_path: str = '' 21 | beam_width: int = 10 22 | alpha_from: float = 0.0 23 | alpha_to: float = 3.0 24 | beta_from: float = 0.0 25 | beta_to: float = 1.0 26 | n_trials: int = 500 # Number of trials for optuna 27 | n_jobs: int = 2 # Number of parallel jobs for optuna 28 | precision: int = 16 29 | batch_size: int = 1 # For dataloader 30 | num_workers: int = 1 # For dataloader 31 | spect_cfg: SpectConfig = SpectConfig() 32 | 33 | 34 | cs = ConfigStore.instance() 35 | cs.store(name="config", node=OptimizerConfig) 36 | 37 | 38 | class Objective(object): 39 | def __init__(self, cfg): 40 | self.cfg = cfg 41 | 42 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | self.model = load_model( 44 | self.device, 45 | hydra.utils.to_absolute_path(self.cfg.model_path) 46 | ) 47 | self.ckpt = torch.load( 48 | hydra.utils.to_absolute_path(self.cfg.model_path), 49 | map_location=self.device 50 | ) 51 | self.labels = self.ckpt['hyper_parameters']['labels'] 52 | 53 | self.decoder = BeamCTCDecoder( 54 | labels=self.labels, 55 | lm_path=hydra.utils.to_absolute_path(self.cfg.lm_path), 56 | beam_width=self.cfg.beam_width, 57 | num_processes=self.cfg.num_workers, 58 | blank_index=self.labels.index('_') 59 | ) 60 | self.target_decoder = GreedyDecoder( 61 | labels=self.labels, 62 | blank_index=self.labels.index('_') 63 | ) 64 | 65 | test_dataset = SpectrogramDataset( 66 | audio_conf=self.cfg.spect_cfg, 67 | input_path=hydra.utils.to_absolute_path(cfg.test_path), 68 | labels=self.labels, 69 | normalize=True 70 | ) 71 | self.test_loader = AudioDataLoader( 72 | test_dataset, 73 | batch_size=self.cfg.batch_size, 74 | num_workers=self.cfg.num_workers 75 | ) 76 | 77 | def __call__(self, trial): 78 | alpha = trial.suggest_uniform('alpha', self.cfg.alpha_from, self.cfg.alpha_to) 79 | beta = trial.suggest_uniform('beta', self.cfg.beta_from, self.cfg.beta_to) 80 | self.decoder._decoder.reset_params(alpha, beta) 81 | 82 | wer, cer = run_evaluation( 83 | test_loader=self.test_loader, 84 | device=self.device, 85 | model=self.model, 86 | decoder=self.decoder, 87 | target_decoder=self.target_decoder, 88 | precision=self.cfg.precision 89 | ) 90 | return cer if self.cfg.is_character_based else wer 91 | 92 | 93 | @hydra.main(config_name="config") 94 | def main(cfg: OptimizerConfig) -> None: 95 | study = optuna.create_study() 96 | study.optimize(Objective(cfg), 97 | n_trials=cfg.n_trials, 98 | n_jobs=cfg.n_jobs, 99 | show_progress_bar=True) 100 | print(f"Best Params\n" 101 | f"alpha: {study.best_params['alpha']}\n" 102 | f"beta: {study.best_params['beta']}\n" 103 | f"{'cer' if cfg.is_character_based else 'wer'}: {study.best_value}") 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /select_lm_params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | parser = argparse.ArgumentParser(description='Select the best parameters based on the WER') 9 | parser.add_argument('--input-path', type=str, help='Output json file from search_lm_params') 10 | args = parser.parse_args() 11 | 12 | with open(args.input_path) as f: 13 | results = json.load(f) 14 | 15 | min_results = min(results, key=lambda x: x[2]) # Find the minimum WER (alpha, beta, WER, CER) 16 | print("Alpha: %f \nBeta: %f \nWER: %f\nCER: %f" % tuple(min_results)) 17 | 18 | alpha, beta, *_ = list(zip(*results)) 19 | alpha = np.array(sorted(list(set(alpha)))) 20 | beta = np.array(sorted(list(set(beta)))) 21 | X, Y = np.meshgrid(alpha, beta) 22 | results = {(a, b): (w, c) for a, b, w, c in results} 23 | WER = np.array([[results[(a, b)][0] for a in alpha] for b in beta]) 24 | 25 | fig = plt.figure() 26 | ax = fig.gca(projection='3d') 27 | surf = ax.plot_surface( 28 | X, 29 | Y, 30 | WER, 31 | cmap=matplotlib.cm.rainbow, 32 | linewidth=0, 33 | antialiased=False 34 | ) 35 | ax.set_xlabel('Alpha') 36 | ax.set_ylabel('Beta') 37 | ax.set_zlabel('WER') 38 | ax.set_zlim(5., 101.) 39 | fig.colorbar(surf, shrink=0.5, aspect=5) 40 | plt.show() 41 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from tempfile import NamedTemporaryFile 4 | 5 | import hydra 6 | import torch 7 | from flask import Flask, request, jsonify 8 | from hydra.core.config_store import ConfigStore 9 | 10 | from deepspeech_pytorch.configs.inference_config import ServerConfig 11 | from deepspeech_pytorch.inference import run_transcribe 12 | from deepspeech_pytorch.loader.data_loader import SpectrogramParser 13 | from deepspeech_pytorch.utils import load_model, load_decoder 14 | 15 | app = Flask(__name__) 16 | ALLOWED_EXTENSIONS = set(['.wav', '.mp3', '.ogg', '.webm']) 17 | 18 | cs = ConfigStore.instance() 19 | cs.store(name="config", node=ServerConfig) 20 | 21 | 22 | @app.route('/transcribe', methods=['POST']) 23 | def transcribe_file(): 24 | if request.method == 'POST': 25 | res = {} 26 | if 'file' not in request.files: 27 | res['status'] = "error" 28 | res['message'] = "audio file should be passed for the transcription" 29 | return jsonify(res) 30 | file = request.files['file'] 31 | filename = file.filename 32 | _, file_extension = os.path.splitext(filename) 33 | if file_extension.lower() not in ALLOWED_EXTENSIONS: 34 | res['status'] = "error" 35 | res['message'] = "{} is not supported format.".format(file_extension) 36 | return jsonify(res) 37 | with NamedTemporaryFile(suffix=file_extension) as tmp_saved_audio_file: 38 | file.save(tmp_saved_audio_file.name) 39 | logging.info('Transcribing file...') 40 | transcription, _ = run_transcribe( 41 | audio_path=tmp_saved_audio_file, 42 | spect_parser=spect_parser, 43 | model=model, 44 | decoder=decoder, 45 | device=device, 46 | precision=config.model.precision 47 | ) 48 | logging.info('File transcribed') 49 | res['status'] = "OK" 50 | res['transcription'] = transcription 51 | return jsonify(res) 52 | 53 | 54 | @hydra.main(config_name="config") 55 | def main(cfg: ServerConfig): 56 | global model, spect_parser, decoder, config, device 57 | config = cfg 58 | logging.getLogger().setLevel(logging.DEBUG) 59 | 60 | logging.info('Setting up server...') 61 | device = torch.device("cuda" if cfg.model.cuda else "cpu") 62 | 63 | model = load_model( 64 | device=device, 65 | model_path=cfg.model.model_path 66 | ) 67 | 68 | decoder = load_decoder( 69 | labels=model.labels, 70 | cfg=cfg.lm 71 | ) 72 | 73 | spect_parser = SpectrogramParser( 74 | audio_conf=model.spect_cfg, 75 | normalize=True 76 | ) 77 | 78 | logging.info('Server initialised') 79 | app.run( 80 | host=cfg.host, 81 | port=cfg.port, 82 | debug=True, 83 | use_reloader=False 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='deepspeech_pytorch', 4 | version='0.1', 5 | author='SeanNaren', 6 | packages=find_packages(), 7 | zip_safe=False, 8 | ) 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.config_store import ConfigStore 3 | 4 | from deepspeech_pytorch.configs.inference_config import EvalConfig 5 | from deepspeech_pytorch.testing import evaluate 6 | 7 | cs = ConfigStore.instance() 8 | cs.store(name="config", node=EvalConfig) 9 | 10 | 11 | @hydra.main(config_path='.', config_name="config") 12 | def hydra_main(cfg: EvalConfig): 13 | evaluate(cfg=cfg) 14 | 15 | 16 | if __name__ == '__main__': 17 | hydra_main() 18 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/deepspeech.pytorch/b88a631ef6b96553e16bbaa5bd6a621435bed278/tests/__init__.py -------------------------------------------------------------------------------- /tests/pretrained_smoke_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import wget 5 | 6 | from deepspeech_pytorch.configs.inference_config import LMConfig 7 | from deepspeech_pytorch.enums import DecoderType 8 | from tests.smoke_test import DatasetConfig, DeepSpeechSmokeTest 9 | 10 | pretrained_urls = [ 11 | 'https://github.com/SeanNaren/deepspeech.pytorch/releases/download/V3.0/an4_pretrained_v3.ckpt', 12 | 'https://github.com/SeanNaren/deepspeech.pytorch/releases/download/V3.0/librispeech_pretrained_v3.ckpt', 13 | 'https://github.com/SeanNaren/deepspeech.pytorch/releases/download/V3.0/ted_pretrained_v3.ckpt' 14 | ] 15 | 16 | lm_path = 'http://www.openslr.org/resources/11/3-gram.pruned.3e-7.arpa.gz' 17 | 18 | 19 | class PretrainedSmokeTest(DeepSpeechSmokeTest): 20 | 21 | def test_pretrained_eval_inference(self): 22 | # Disabled GPU due to using TravisCI 23 | cuda, precision = False, 32 24 | train_manifest, val_manifest, test_manifest = self.download_data( 25 | DatasetConfig( 26 | target_dir=self.target_dir, 27 | manifest_dir=self.manifest_dir 28 | ), 29 | folders=False 30 | ) 31 | wget.download(lm_path) 32 | for pretrained_url in pretrained_urls: 33 | print("Running Pre-trained Smoke test for: ", pretrained_url) 34 | wget.download(pretrained_url) 35 | file_path = os.path.basename(pretrained_url) 36 | pretrained_path = os.path.abspath(file_path) 37 | 38 | lm_configs = [ 39 | LMConfig(), # Greedy 40 | LMConfig( 41 | decoder_type=DecoderType.beam 42 | ), # Test Beam Decoder 43 | LMConfig( 44 | decoder_type=DecoderType.beam, 45 | lm_path=os.path.basename(lm_path), 46 | alpha=1, 47 | beta=1 48 | ) # Test Beam Decoder with LM 49 | ] 50 | 51 | for lm_config in lm_configs: 52 | self.eval_model( 53 | model_path=pretrained_path, 54 | test_path=test_manifest, 55 | cuda=cuda, 56 | lm_config=lm_config, 57 | precision=precision 58 | ) 59 | self.inference( 60 | test_path=test_manifest, 61 | model_path=pretrained_path, 62 | cuda=cuda, 63 | lm_config=lm_config, 64 | precision=precision 65 | ) 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /tests/smoke_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import unittest 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | 9 | from pytorch_lightning.utilities import _module_available 10 | 11 | from data.an4 import download_an4 12 | from deepspeech_pytorch.configs.inference_config import EvalConfig, ModelConfig, TranscribeConfig, LMConfig 13 | from deepspeech_pytorch.configs.lightning_config import ModelCheckpointConf 14 | from deepspeech_pytorch.configs.train_config import DeepSpeechConfig, AdamConfig, BiDirectionalConfig, \ 15 | DataConfig, DeepSpeechTrainerConf 16 | from deepspeech_pytorch.enums import DecoderType 17 | from deepspeech_pytorch.inference import transcribe 18 | from deepspeech_pytorch.testing import evaluate 19 | from deepspeech_pytorch.training import train 20 | 21 | 22 | @dataclass 23 | class DatasetConfig: 24 | target_dir: str = '' 25 | manifest_dir: str = '' 26 | min_duration: float = 0 27 | max_duration: float = 15 28 | val_fraction: float = 0.1 29 | sample_rate: int = 16000 30 | num_workers: int = 4 31 | 32 | 33 | class DeepSpeechSmokeTest(unittest.TestCase): 34 | def setUp(self): 35 | self.target_dir = tempfile.mkdtemp() 36 | self.manifest_dir = tempfile.mkdtemp() 37 | self.model_dir = tempfile.mkdtemp() 38 | 39 | def tearDown(self): 40 | shutil.rmtree(self.target_dir) 41 | shutil.rmtree(self.manifest_dir) 42 | shutil.rmtree(self.model_dir) 43 | 44 | def build_train_evaluate_model(self, 45 | limit_train_batches: int, 46 | limit_val_batches: int, 47 | epoch: int, 48 | batch_size: int, 49 | model_config: BiDirectionalConfig, 50 | precision: int, 51 | gpus: int, 52 | folders: bool): 53 | cuda = gpus > 0 54 | 55 | train_path, val_path, test_path = self.download_data( 56 | DatasetConfig( 57 | target_dir=self.target_dir, 58 | manifest_dir=self.manifest_dir 59 | ), 60 | folders=folders 61 | ) 62 | 63 | train_cfg = self.create_training_config( 64 | limit_train_batches=limit_train_batches, 65 | limit_val_batches=limit_val_batches, 66 | max_epochs=epoch, 67 | batch_size=batch_size, 68 | train_path=train_path, 69 | val_path=val_path, 70 | model_config=model_config, 71 | precision=precision, 72 | gpus=gpus 73 | ) 74 | print("Running Training DeepSpeech Model Smoke Test") 75 | train(train_cfg) 76 | 77 | # Expected final model path after training 78 | print(os.listdir(self.model_dir)) 79 | model_path = self.model_dir + '/last.ckpt' 80 | assert os.path.exists(model_path) 81 | 82 | lm_configs = [LMConfig()] 83 | 84 | if _module_available('ctcdecode'): 85 | lm_configs.append( 86 | LMConfig( 87 | decoder_type=DecoderType.beam 88 | ) 89 | ) 90 | print("Running Inference Smoke Tests") 91 | for lm_config in lm_configs: 92 | self.eval_model( 93 | model_path=model_path, 94 | test_path=test_path, 95 | cuda=cuda, 96 | precision=precision, 97 | lm_config=lm_config 98 | ) 99 | 100 | self.inference(test_path=test_path, 101 | model_path=model_path, 102 | cuda=cuda, 103 | precision=precision, 104 | lm_config=lm_config) 105 | 106 | def eval_model(self, 107 | model_path: str, 108 | test_path: str, 109 | cuda: bool, 110 | precision: int, 111 | lm_config: LMConfig): 112 | # Due to using TravisCI with no GPU support we have to disable cuda 113 | eval_cfg = EvalConfig( 114 | model=ModelConfig( 115 | cuda=cuda, 116 | model_path=model_path, 117 | precision=precision 118 | ), 119 | lm=lm_config, 120 | test_path=test_path 121 | ) 122 | evaluate(eval_cfg) 123 | 124 | def inference(self, 125 | test_path: str, 126 | model_path: str, 127 | cuda: bool, 128 | precision: int, 129 | lm_config: LMConfig): 130 | # Select one file from our test manifest to run inference 131 | if os.path.isdir(test_path): 132 | file_path = next(Path(test_path).rglob('*.wav')) 133 | else: 134 | with open(test_path) as f: 135 | # select a file to use for inference test 136 | manifest = json.load(f) 137 | file_name = manifest['samples'][0]['wav_path'] 138 | directory = manifest['root_path'] 139 | file_path = os.path.join(directory, file_name) 140 | 141 | transcribe_cfg = TranscribeConfig( 142 | model=ModelConfig( 143 | cuda=cuda, 144 | model_path=model_path, 145 | precision=precision 146 | ), 147 | lm=lm_config, 148 | audio_path=file_path 149 | ) 150 | transcribe(transcribe_cfg) 151 | 152 | def download_data(self, 153 | cfg: DatasetConfig, 154 | folders: bool): 155 | download_an4( 156 | target_dir=cfg.target_dir, 157 | manifest_dir=cfg.manifest_dir, 158 | min_duration=cfg.min_duration, 159 | max_duration=cfg.max_duration, 160 | num_workers=cfg.num_workers 161 | ) 162 | 163 | # Expected output paths 164 | if folders: 165 | train_path = os.path.join(self.target_dir, 'train/') 166 | val_path = os.path.join(self.target_dir, 'val/') 167 | test_path = os.path.join(self.target_dir, 'test/') 168 | else: 169 | train_path = os.path.join(self.manifest_dir, 'an4_train_manifest.json') 170 | val_path = os.path.join(self.manifest_dir, 'an4_val_manifest.json') 171 | test_path = os.path.join(self.manifest_dir, 'an4_test_manifest.json') 172 | 173 | # Assert manifest paths exists 174 | assert os.path.exists(train_path) 175 | assert os.path.exists(val_path) 176 | assert os.path.exists(test_path) 177 | return train_path, val_path, test_path 178 | 179 | def create_training_config(self, 180 | limit_train_batches: int, 181 | limit_val_batches: int, 182 | max_epochs: int, 183 | batch_size: int, 184 | train_path: str, 185 | val_path: str, 186 | model_config: BiDirectionalConfig, 187 | precision: int, 188 | gpus: int): 189 | return DeepSpeechConfig( 190 | trainer=DeepSpeechTrainerConf( 191 | max_epochs=max_epochs, 192 | precision=precision, 193 | gpus=gpus, 194 | enable_checkpointing=True, 195 | limit_train_batches=limit_train_batches, 196 | limit_val_batches=limit_val_batches 197 | ), 198 | data=DataConfig( 199 | train_path=train_path, 200 | val_path=val_path, 201 | batch_size=batch_size 202 | ), 203 | optim=AdamConfig(), 204 | model=model_config, 205 | checkpoint=ModelCheckpointConf( 206 | dirpath=self.model_dir, 207 | save_last=True, 208 | verbose=True 209 | ) 210 | ) 211 | 212 | 213 | class AN4SmokeTest(DeepSpeechSmokeTest): 214 | 215 | def test_train_eval_inference(self): 216 | # Hardcoded sizes to reduce memory/time, and disabled GPU due to using TravisCI 217 | model_cfg = BiDirectionalConfig( 218 | hidden_size=10, 219 | hidden_layers=1 220 | ) 221 | self.build_train_evaluate_model( 222 | limit_train_batches=1, 223 | limit_val_batches=1, 224 | epoch=1, 225 | batch_size=10, 226 | model_config=model_cfg, 227 | precision=32, 228 | gpus=0, 229 | folders=False 230 | ) 231 | 232 | def test_train_eval_inference_folder(self): 233 | """Test train/eval/inference using folder directories rather than manifest files""" 234 | # Hardcoded sizes to reduce memory/time, and disabled GPU due to using TravisCI 235 | model_cfg = BiDirectionalConfig( 236 | hidden_size=10, 237 | hidden_layers=1 238 | ) 239 | self.build_train_evaluate_model( 240 | limit_train_batches=1, 241 | limit_val_batches=1, 242 | epoch=1, 243 | batch_size=10, 244 | model_config=model_cfg, 245 | precision=32, 246 | gpus=0, 247 | folders=True 248 | ) 249 | 250 | 251 | if __name__ == '__main__': 252 | unittest.main() 253 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.config_store import ConfigStore 3 | 4 | from deepspeech_pytorch.configs.lightning_config import ModelCheckpointConf 5 | from deepspeech_pytorch.configs.train_config import DeepSpeechConfig, AdamConfig, SGDConfig, BiDirectionalConfig, \ 6 | UniDirectionalConfig 7 | from deepspeech_pytorch.training import train 8 | 9 | cs = ConfigStore.instance() 10 | cs.store(name="config", node=DeepSpeechConfig) 11 | cs.store(group="optim", name="sgd", node=SGDConfig) 12 | cs.store(group="optim", name="adam", node=AdamConfig) 13 | cs.store(group="checkpoint", name="file", node=ModelCheckpointConf) 14 | cs.store(group="model", name="bidirectional", node=BiDirectionalConfig) 15 | cs.store(group="model", name="unidirectional", node=UniDirectionalConfig) 16 | 17 | 18 | @hydra.main(config_path='.', config_name="config") 19 | def hydra_main(cfg: DeepSpeechConfig): 20 | train(cfg=cfg) 21 | 22 | 23 | if __name__ == '__main__': 24 | hydra_main() 25 | -------------------------------------------------------------------------------- /transcribe.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.config_store import ConfigStore 3 | 4 | from deepspeech_pytorch.configs.inference_config import TranscribeConfig 5 | from deepspeech_pytorch.inference import transcribe 6 | 7 | cs = ConfigStore.instance() 8 | cs.store(name="config", node=TranscribeConfig) 9 | 10 | 11 | @hydra.main(config_path='.', config_name="config") 12 | def hydra_main(cfg: TranscribeConfig): 13 | transcribe(cfg=cfg) 14 | 15 | 16 | if __name__ == '__main__': 17 | hydra_main() 18 | --------------------------------------------------------------------------------