├── .gitignore ├── README.md ├── base_asr_models.py ├── configuration ├── audio │ └── standard_16k.yaml ├── config.yaml ├── model │ ├── jasper.yaml │ └── wav2letter.yaml └── optimizer │ └── exp_lr_optimizer.yaml ├── data ├── __init__.py ├── augmentations.py ├── data_loader.py ├── label_sets.py ├── language_specific_tools.py └── prepare_librispeech.py ├── decoder.py ├── examples ├── check_requirements.py └── librispeech.sh ├── jasper.py ├── novograd.py ├── requirements.txt ├── train.py ├── unit_tests ├── __init__.py └── decoder_test.py └── wav2letter.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Folders created by scripts 138 | visualize/ 139 | models/ 140 | 141 | # PyCharm 142 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wav2Letter_pytorch 2 | 3 | Implementation of Wav2Letter using PyTorch. 4 | Creates a network based on the [Wav2Letter](https://arxiv.org/abs/1609.03193) architecture, trained with CTC loss. 5 | 6 | ## Features 7 | 8 | * Minimalist code, designed to be a white box - dive into the code! 9 | * Train End-To-End ASR models, including Wav2Letter and Jasper. 10 | * Uses [Hydra](https://hydra.cc/docs/intro) for easy configuration and usage 11 | * Uses [PyTorch Lightning](https://www.pytorchlightning.ai/) for simplify training 12 | * Beam search decoding integrated with kenlm language models. 13 | 14 | # Installation 15 | 16 | Install Python 3.6 or higher. 17 | 18 | Clone the repository (or download it) and install according to the requirements file. 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | # Usage 24 | 25 | ## LibriSpeech Example 26 | Run ```examples/librispeech.sh``` to download and prepare the data, and start training with a single script. 27 | 28 | You can use the ```data/prepare_librispeech.py``` script to prepare other subsets of the Librispeech dataset. 29 | Run it with ```--help``` for more information. 30 | 31 | 32 | 33 | ## Training 34 | Most simple example: 35 | ``` 36 | python train.py data.train_manifest TRAIN.csv data.val_manifest VALID.csv 37 | ``` 38 | Training a Jasper model is as simple as: ```python train.py model.name=jasper ...``` 39 | 40 | To train with multiple GPU's, mixed precision, or many other options, see the [Pytorch-Lightning Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-class-api) documentation. 41 | 42 | Many elements of the model and training are managed via configuration files and command line flags via Hydra. 43 | This includes the audio preprocessing configuration, the optimizer and learning-rate scheduler, and the number and configuration of layers. See the configuration directory for more details. 44 | To see the entire configuration, run ```python train.py [optional overrides] --cfg=job``` 45 | 46 | ## Testing/Inference 47 | WIP! 48 | To evaluate a trained model on a test set (has to be in the same format as the training set): 49 | 50 | ``` 51 | python test.py --model-path models/wav2Letter.pth --test-manifest /path/to/test_manifest.csv --cuda 52 | ``` 53 | 54 | To see the decoded outputs compared to the test data, run with either ```--print-samples``` or ```print-all```. 55 | 56 | You can use a LM during decoding. The LM is expected to be a valid ARPA model, loaded with kenlm. Add ```--lm-path``` to use it. See ```--beam-search-params``` to fine tune your parameters for beam search. 57 | 58 | ### Custom Datasets 59 | To create a custom dataset, create a Pandas Dataframe with the columns ```audio_filepath, text``` and save it using ``` df.to_csv(path) ```. 60 | Alternatively, you can create a .json file - each line contains a json of a sample with at least ```audio_filepath, text``` as fields. 61 | You can add reading specific sections of audio files by adding ```offset``` and ```duration``` fields (in seconds). The values 0 and -1 are the default values, respectively, and cause reading the entire audio file. 62 | If you use a sample rate other than 16K, specify it using ```model.audio_conf.sample_rate=8000``` for example. 63 | 64 | ### Different languages 65 | In addition to English, Hebrew is supported. 66 | 67 | To use, run with ```--labels hebrew```. Note that some terminals and consoles do not display UTF-8 properly. 68 | 69 | 70 | ## Differences from article 71 | 72 | There are some subtle differences between this implementation and the original. 73 | We use CTC loss instead of ASG, which leads to a small difference in labels definition and output interpretation. 74 | We currently use spectrogram features instead of MFCC, which achieved the best results in the original article. 75 | Some of the network hyperparameters are different - convolution kernel sizes, strides, and default sample rate. 76 | 77 | ## Acknowledgements 78 | This work was originally based off [Silversparro's Wav2Letter](https://github.com/silversparro/wav2letter.pytorch). 79 | That work was inspired by the [deepspeech.pytorch](https://github.com/SeanNaren/deepspeech.pytorch) repository of [Sean Naren](https://github.com/SeanNaren). 80 | The prefix-beam-search algorithm is based on [corticph](https://github.com/corticph/prefix-beam-search) with minor edits. 81 | -------------------------------------------------------------------------------- /base_asr_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Dec 14 11:59:45 2020 4 | 5 | @author: Assaf Mushkin 6 | """ 7 | import random 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import pytorch_lightning as ptl 13 | import numpy as np 14 | from hydra.utils import instantiate 15 | 16 | class ConvCTCASR(ptl.LightningModule): 17 | def __init__(self,cfg): 18 | super().__init__() 19 | self._cfg = cfg 20 | self.audio_conf = cfg.audio_conf 21 | self.labels = cfg.labels 22 | self.ctc_decoder = instantiate(cfg.decoder) 23 | self.criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True) 24 | self.print_decoded_prob = cfg.get('print_decoded_prob',0) 25 | self.example_input_array = self.create_example_input_array() 26 | 27 | def create_example_input_array(self): 28 | batch_size = 4 29 | min_length,max_length = 100,200 30 | lengths = torch.randint(min_length,max_length,(4,)) 31 | return (torch.rand(batch_size,self._cfg.input_size,max_length),lengths) 32 | 33 | def compute_output_lengths(self,input_lengths): 34 | ''' 35 | Compute the output lengths given the input lengths. 36 | Override if ratio is not strictly proportional (can happen with unpadded convolutions) 37 | ''' 38 | output_lengths = input_lengths // self.scaling_factor 39 | return output_lengths 40 | 41 | @property 42 | def scaling_factor(self): 43 | ''' 44 | Returns a ratio between input lengths and output lengths. 45 | In convolutional models, depends on kernel size, padding, stride, and dilation. 46 | ''' 47 | raise NotImplementedError() 48 | 49 | def forward(inputs,input_lengths): 50 | raise NotImplementedError() 51 | # returns output, output_lengths 52 | 53 | def add_string_metrics(self, out, output_lengths, texts, prefix): 54 | decoded_texts = self.ctc_decoder.decode(out, output_lengths) 55 | if random.random() < self.print_decoded_prob: 56 | print(f'reference: {texts[0]}') 57 | print(f'decoded : {decoded_texts[0]}') 58 | wer_sum, cer_sum,wer_denom_sum,cer_denom_sum = 0,0,0,0 59 | for expected, predicted in zip(texts, decoded_texts): 60 | cer_value, cer_denom = self.ctc_decoder.cer_ratio(expected, predicted) 61 | wer_value, wer_denom = self.ctc_decoder.wer_ratio(expected, predicted) 62 | cer_sum+= cer_value 63 | cer_denom_sum+=cer_denom 64 | wer_sum+= wer_value 65 | wer_denom_sum+=wer_denom 66 | cer = cer_sum / cer_denom_sum 67 | wer = wer_sum / wer_denom_sum 68 | lengths_ratio = sum(map(len, decoded_texts)) / sum(map(len, texts)) 69 | return {prefix+'_cer':cer, prefix+'_wer':wer, prefix+'_len_ratio':lengths_ratio} 70 | 71 | 72 | #PyTorch Lightning methods 73 | def configure_optimizers(self): 74 | optimizer = instantiate(self._cfg.optimizer, params=self.parameters()) 75 | scheduler = instantiate(self._cfg.scheduler,optimizer=optimizer) 76 | return [optimizer],[scheduler] 77 | 78 | def training_step(self, batch, batch_idx): 79 | inputs, input_lengths, targets, target_lengths, file_paths, texts = batch 80 | out, output_lengths = self.forward(inputs,input_lengths) 81 | loss = self.criterion(out.transpose(0,1), targets, output_lengths, target_lengths) 82 | logs = {'train_loss':loss,'learning_rate':self.optimizers().param_groups[0]['lr']} 83 | logs.update(self.add_string_metrics(out, output_lengths, texts, 'train')) 84 | self.log_dict(logs) 85 | return loss 86 | 87 | def validation_step(self, batch, batch_idx): 88 | inputs, input_lengths, targets, target_lengths, file_paths, texts = batch 89 | out, output_lengths = self.forward(inputs,input_lengths) 90 | loss = self.criterion(out.transpose(0,1), targets, output_lengths, target_lengths) 91 | logs = {'val_loss':loss} 92 | logs.update(self.add_string_metrics(out, output_lengths, texts, 'val')) 93 | self.log_dict(logs) 94 | return loss 95 | -------------------------------------------------------------------------------- /configuration/audio/standard_16k.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | audio_conf: 3 | window: hamming 4 | window_stride: 0.01 5 | window_size: 0.02 6 | sample_rate: 16000 7 | -------------------------------------------------------------------------------- /configuration/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - audio: standard_16k 3 | - optimizer: exp_lr_optimizer 4 | - model: wav2letter 5 | data: 6 | train_manifest: ??? 7 | val_manifest: ??? 8 | batch_size: 4 9 | mel_spec: ${model.input_size} 10 | audio_conf: ${model.audio_conf} 11 | model: 12 | input_size: 64 13 | labels: english_lowercase 14 | decoder: 15 | _target_: decoder.GreedyDecoder 16 | labels: ${model.labels} 17 | trainer: 18 | default_root_dir: . 19 | max_epochs: 5 20 | max_steps: null 21 | gpus: 0 22 | 23 | hydra: 24 | run: 25 | dir: ${trainer.default_root_dir} 26 | job_logging: 27 | root: 28 | handlers: null -------------------------------------------------------------------------------- /configuration/model/jasper.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | name: jasper 3 | mid_layers: 1 4 | jasper_blocks: 5 | - layer_size: 256 6 | kernel_size: 32 7 | stride: 2 8 | residual: False 9 | separable: True 10 | - layer_size: 256 11 | kernel_size: 32 12 | stride: 1 13 | residual: True 14 | separable: True 15 | - layer_size: 256 16 | kernel_size: 32 17 | stride: 1 18 | residual: True 19 | separable: True 20 | - layer_size: 256 21 | kernel_size: 32 22 | stride: 1 23 | residual: True 24 | separable: True 25 | - layer_size: 256 26 | kernel_size: 38 27 | stride: 1 28 | residual: True 29 | separable: True 30 | - layer_size: 256 31 | kernel_size: 38 32 | stride: 1 33 | residual: True 34 | separable: True 35 | - layer_size: 256 36 | kernel_size: 38 37 | stride: 1 38 | residual: True 39 | separable: True 40 | - layer_size: 512 41 | kernel_size: 50 42 | stride: 1 43 | residual: True 44 | separable: True 45 | - layer_size: 512 46 | kernel_size: 50 47 | stride: 1 48 | residual: True 49 | separable: True 50 | - layer_size: 512 51 | kernel_size: 50 52 | stride: 1 53 | residual: True 54 | separable: True 55 | - layer_size: 512 56 | kernel_size: 62 57 | stride: 1 58 | residual: True 59 | separable: True 60 | - layer_size: 512 61 | kernel_size: 62 62 | stride: 1 63 | residual: True 64 | separable: True 65 | - layer_size: 512 66 | kernel_size: 62 67 | stride: 1 68 | residual: True 69 | separable: True 70 | - layer_size: 512 71 | kernel_size: 74 72 | stride: 1 73 | residual: True 74 | separable: True 75 | - layer_size: 1024 76 | kernel_size: 1 77 | stride: 1 78 | residual: False 79 | separable: False -------------------------------------------------------------------------------- /configuration/model/wav2letter.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | name: wav2letter 3 | mid_layers: 1 4 | layers: 5 | - output_size: 256 6 | kernel_size: 11 7 | stride: 2 8 | dilation: 1 9 | dropout: 0.2 10 | - output_size: 256 11 | kernel_size: 11 12 | stride: 1 13 | dilation: 1 14 | dropout: 0.2 15 | - output_size: 256 16 | kernel_size: 11 17 | stride: 1 18 | dilation: 1 19 | dropout: 0.2 20 | - output_size: 256 21 | kernel_size: 11 22 | stride: 1 23 | dilation: 1 24 | dropout: 0.2 25 | - output_size: 384 26 | kernel_size: 13 27 | stride: 1 28 | dilation: 1 29 | dropout: 0.2 30 | - output_size: 384 31 | kernel_size: 13 32 | stride: 1 33 | dilation: 1 34 | dropout: 0.2 35 | - output_size: 384 36 | kernel_size: 13 37 | stride: 1 38 | dilation: 1 39 | dropout: 0.2 40 | - output_size: 512 41 | kernel_size: 17 42 | stride: 1 43 | dilation: 1 44 | dropout: 0.2 45 | - output_size: 512 46 | kernel_size: 17 47 | stride: 1 48 | dilation: 1 49 | dropout: 0.2 50 | - output_size: 512 51 | kernel_size: 17 52 | stride: 1 53 | dilation: 1 54 | dropout: 0.2 55 | - output_size: 640 56 | kernel_size: 21 57 | stride: 1 58 | dilation: 1 59 | dropout: 0.3 60 | - output_size: 640 61 | kernel_size: 21 62 | stride: 1 63 | dilation: 1 64 | dropout: 0.3 65 | - output_size: 640 66 | kernel_size: 21 67 | stride: 1 68 | dilation: 1 69 | dropout: 0.3 70 | - output_size: 768 71 | kernel_size: 25 72 | stride: 1 73 | dilation: 1 74 | dropout: 0.3 75 | - output_size: 768 76 | kernel_size: 25 77 | stride: 1 78 | dilation: 1 79 | dropout: 0.3 80 | - output_size: 768 81 | kernel_size: 25 82 | stride: 1 83 | dilation: 1 84 | dropout: 0.3 85 | - output_size: 896 86 | kernel_size: 29 87 | stride: 1 88 | dilation: 2 89 | dropout: 0.4 90 | - output_size: 896 91 | kernel_size: 29 92 | stride: 1 93 | dilation: 2 94 | dropout: 0.4 95 | - output_size: 896 96 | kernel_size: 29 97 | stride: 1 98 | dilation: 2 99 | dropout: 0.4 100 | - output_size: 1024 101 | kernel_size: 1 102 | stride: 1 103 | dilation: 1 104 | dropout: 0.4 105 | -------------------------------------------------------------------------------- /configuration/optimizer/exp_lr_optimizer.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | optimizer: 3 | _target_: torch.optim.SGD 4 | lr: 1e-5 5 | momentum: 0.9 6 | nesterov: True 7 | weight_decay: 1e-5 8 | scheduler: 9 | _target_: torch.optim.lr_scheduler.ExponentialLR 10 | gamma: 0.999 11 | #scheduler: 12 | # _target_: torch.optim.lr_scheduler.OneCycleLR 13 | # max_lr: ${model.optimizer.lr} 14 | # total_steps: ${trainer.max_steps} -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import data_loader -------------------------------------------------------------------------------- /data/augmentations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Based on nvidia nemo implementation for ASR 4 | ''' 5 | import random 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SpecAugment(nn.Module): 12 | """ 13 | Zeroes out(cuts) random continuous horisontal or 14 | vertical segments of the spectrogram as described in 15 | SpecAugment (https://arxiv.org/abs/1904.08779). 16 | params: 17 | freq_masks - how many frequency segments should be cut 18 | time_masks - how many time segments should be cut 19 | freq_width - maximum number of frequencies to be cut in one segment 20 | time_width - maximum number of time steps to be cut in one segment 21 | """ 22 | 23 | def __init__( 24 | self, freq_masks=1, time_masks=1, freq_width=15, time_width=50, rng=None, 25 | ): 26 | super(SpecAugment, self).__init__() 27 | 28 | self._rng = random.Random() if rng is None else rng 29 | 30 | self.freq_masks = freq_masks 31 | self.time_masks = time_masks 32 | 33 | self.freq_width = freq_width 34 | self.time_width = time_width 35 | 36 | @torch.no_grad() 37 | def forward(self, x): 38 | sh = x.shape 39 | 40 | mask = torch.zeros(x.shape).byte() 41 | 42 | for idx in range(sh[0]): 43 | for i in range(self.freq_masks): 44 | x_left = int(self._rng.uniform(0, sh[1] - self.freq_width)) 45 | 46 | w = int(self._rng.uniform(0, self.freq_width)) 47 | 48 | mask[idx, x_left : x_left + w, :] = 1 49 | 50 | for i in range(self.time_masks): 51 | y_left = int(self._rng.uniform(0, sh[2] - self.time_width)) 52 | 53 | w = int(self._rng.uniform(0, self.time_width)) 54 | 55 | mask[idx, :, y_left : y_left + w] = 1 56 | 57 | x = x.masked_fill(mask.type(torch.bool).to(device=x.device), 0) 58 | 59 | return x 60 | 61 | 62 | class SpecCutout(nn.Module): 63 | """ 64 | Zeroes out(cuts) random rectangles in the spectrogram 65 | as described in (). 66 | params: 67 | rect_masks - how many rectangular masks should be cut 68 | rect_freq - maximum size of cut rectangles along the frequency dimension 69 | rect_time - maximum size of cut rectangles along the time dimension 70 | """ 71 | 72 | def __init__(self, rect_masks=5, rect_time=60, rect_freq=25, rng=None): 73 | super(SpecCutout, self).__init__() 74 | 75 | self._rng = random.Random() if rng is None else rng 76 | 77 | self.rect_masks = rect_masks 78 | self.rect_time = rect_time 79 | self.rect_freq = rect_freq 80 | 81 | @torch.no_grad() 82 | def forward(self, x): 83 | sh = x.shape 84 | 85 | mask = torch.zeros(x.shape).byte() 86 | 87 | for idx in range(sh[0]): 88 | for i in range(self.rect_masks): 89 | rect_x = int(self._rng.uniform(0, sh[1] - self.rect_freq)) 90 | rect_y = int(self._rng.uniform(0, sh[2] - self.rect_time)) 91 | 92 | w_x = int(self._rng.uniform(0, self.rect_time)) 93 | w_y = int(self._rng.uniform(0, self.rect_freq)) 94 | 95 | mask[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 1 96 | 97 | x = x.masked_fill(mask.type(torch.bool).to(device=x.device), 0) 98 | 99 | return x 100 | 101 | class Identity(nn.Module): 102 | """ 103 | Placeholder module. 104 | """ 105 | @torch.no_grad() 106 | def forward(self, x): 107 | return x -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import json 6 | import math 7 | 8 | import librosa 9 | import numpy as np 10 | import scipy.signal 11 | from scipy.io import wavfile 12 | import soundfile as sf 13 | import torch 14 | import torch.nn 15 | from torch.utils.data import Dataset,DataLoader 16 | import pandas as pd 17 | 18 | windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman,'bartlett':scipy.signal.bartlett} 19 | 20 | def load_audio(path,duration=-1,offset=0): 21 | with sf.SoundFile(path, 'r') as f: 22 | dtype = 'float32' 23 | sample_rate = f.samplerate 24 | if offset > 0: 25 | f.seek(int(offset * sample_rate)) 26 | if duration > 0: 27 | samples = f.read(int(duration * sample_rate), dtype=dtype) 28 | else: 29 | samples = f.read(dtype=dtype) 30 | samples = samples.transpose() 31 | return samples 32 | 33 | class SpectrogramExtractor(torch.nn.Module): 34 | def __init__(self, audio_conf, mel_spec=64,use_cuda=False): 35 | super().__init__() 36 | window_size_samples = int(audio_conf.sample_rate * audio_conf.window_size) 37 | window_stride_samples = int(audio_conf.sample_rate * audio_conf.window_stride) 38 | self.n_fft = 2 ** math.ceil(math.log2(window_size_samples)) 39 | filterbanks = torch.tensor( 40 | librosa.filters.mel(audio_conf.sample_rate, 41 | n_fft=self.n_fft, 42 | n_mels=mel_spec, fmin=0, fmax=audio_conf.sample_rate / 2), 43 | dtype=torch.float 44 | ).unsqueeze(0) 45 | self.register_buffer("fb", filterbanks) 46 | torch_windows = { 47 | 'hann': torch.hann_window, 48 | 'hamming': torch.hamming_window, 49 | 'blackman': torch.blackman_window, 50 | 'bartlett': torch.bartlett_window, 51 | 'none': None, 52 | } 53 | window_fn = torch_windows.get(audio_conf.window, None) 54 | window_tensor = window_fn(window_size_samples, periodic=False) if window_fn else None 55 | self.register_buffer("window", window_tensor) 56 | self.stft = lambda x: torch.stft( 57 | x, 58 | n_fft=self.n_fft, 59 | hop_length=window_stride_samples, 60 | win_length=window_size_samples, 61 | center=True, 62 | window=self.window.to(dtype=torch.float), 63 | return_complex=False, 64 | ) 65 | def _get_spect(self, audio): 66 | dithering = 1e-5 67 | preemph = 0.97 68 | x = torch.Tensor(audio) + torch.randn(audio.shape) * dithering # dithering 69 | x = torch.cat((x[0].unsqueeze(0), x[1:] - preemph * x[:-1]), dim=0) # preemphasi 70 | x = self.stft(x.to(device=self.fb.device)) 71 | x = torch.sqrt(x.pow(2).sum(-1)) # get magnitudes 72 | x = x.pow(2) # power magnitude 73 | x = torch.matmul(self.fb.to(x.dtype), x) #apply filterbanks 74 | return x 75 | 76 | 77 | def extract(self,signal): 78 | epsilon = 1e-5 79 | log_zero_guard_value=2 ** -24 80 | spect = self._get_spect(signal) 81 | spect = np.log1p(spect + log_zero_guard_value) 82 | # normlize across time, per feature 83 | mean = spect.mean(axis=2) 84 | std = spect.std(axis=2) 85 | std += epsilon 86 | spect = spect - mean.reshape(1, -1, 1) 87 | spect = spect / std.reshape(1, -1, 1) 88 | return spect.squeeze() 89 | 90 | class SpectrogramDataset(Dataset): 91 | def __init__(self, manifest_filepath, audio_conf, labels, mel_spec=None, use_cuda=False): 92 | ''' 93 | Create a dataset for ASR. Audio conf and labels can be re-used from the model. 94 | Arguments: 95 | manifest_filepath (string): path to the manifest. Each line must be a json containing fields "audio_filepath" and "text". 96 | audio_conf (dict): dict containing sample rate, and window size stride and type. 97 | labels (list): list containing all valid labels in the text. 98 | mel_spec(int or None): if not None, use mel spectrogram with that many channels. 99 | use_cuda(bool): Use torch and torchaudio for stft. Can speed up extraction on GPU. 100 | ''' 101 | super(SpectrogramDataset, self).__init__() 102 | if manifest_filepath.endswith('.csv'): 103 | self.df = pd.read_csv(manifest_filepath,index_col=0) 104 | else: 105 | with open(manifest_filepath) as f: 106 | lines = f.readlines() 107 | self.df = pd.DataFrame(map(json.loads,lines)) 108 | if not 'offset' in self.df.columns: 109 | self.df['offset'] = 0 110 | if not 'duration' in self.df.columns: 111 | self.df['duration'] = -1 112 | self.size = len(self.df) 113 | self.window_stride = audio_conf['window_stride'] 114 | self.window_size = audio_conf['window_size'] 115 | self.sample_rate = audio_conf['sample_rate'] 116 | self.window = windows.get(audio_conf['window'], windows['hamming']) 117 | self.use_cuda = use_cuda 118 | self.mel_spec = mel_spec 119 | self.labels_map = dict([(labels[i],i) for i in range(len(labels))]) 120 | self.validate_sample_rate() 121 | self.extractor = SpectrogramExtractor(audio_conf,mel_spec,use_cuda) 122 | 123 | def __getitem__(self, index): 124 | sample = self.df.iloc[index] 125 | audio_path, transcript = sample.audio_filepath, sample.text 126 | spect = self.parse_audio(audio_path, sample.duration, sample.offset) 127 | target = list(filter(None,[self.labels_map.get(x) for x in list(transcript)])) 128 | return spect, target, audio_path, transcript 129 | 130 | def parse_audio(self,audio_path, duration, offset): 131 | y = load_audio(audio_path, duration, offset) 132 | spect = self.extractor.extract(y) 133 | return spect 134 | 135 | def validate_sample_rate(self): 136 | audio_filepath = self.df.iloc[0].audio_filepath 137 | sound, sr = sf.read(audio_filepath) 138 | assert sr == self.sample_rate, 'Expected sample rate %d but found %d in first file' % (self.sample_rate,sr) 139 | 140 | def __len__(self): 141 | return self.size 142 | 143 | def data_channels(self): 144 | ''' 145 | How many channels are returned in each example. 146 | ''' 147 | return self.mel_spec or int(1+(int(self.sample_rate * self.window_size)/2)) 148 | 149 | def _collator(batch): 150 | inputs, targets, file_paths, texts = zip(*batch) 151 | input_lengths = torch.IntTensor(list(map(lambda input: input.shape[1], inputs))) 152 | target_lengths = torch.IntTensor(list(map(len,targets))) 153 | longest_input = max(input_lengths).item() 154 | longest_target = max(target_lengths).item() 155 | pad_function = lambda x:np.pad(x,((0,0),(0,longest_input-x.shape[1])),mode='constant') 156 | inputs = torch.FloatTensor(list(map(pad_function,inputs))) 157 | targets = torch.IntTensor([np.pad(np.array(t),(0,longest_target-len(t)),mode='constant') for t in targets]) 158 | return inputs, input_lengths, targets, target_lengths, file_paths, texts 159 | 160 | class BatchAudioDataLoader(DataLoader): 161 | def __init__(self, *args, **kwargs): 162 | super(BatchAudioDataLoader, self).__init__(*args,**kwargs) 163 | self.collate_fn = _collator 164 | -------------------------------------------------------------------------------- /data/label_sets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | english_labels = ["'",'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 3 | 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 4 | 'Z'] 5 | english_lowercase_labels = [s.lower() for s in english_labels] 6 | 7 | hebrew_labels = ['א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 8 | 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 'ן', 'ף', 'ץ', 'ם', 'ך'] 9 | 10 | labels_map = {'english':english_labels,'hebrew':hebrew_labels,'english_lowercase':english_lowercase_labels} 11 | for lang in labels_map: 12 | labels = labels_map[lang] 13 | labels.insert(0,'_') # CTC blank label. By default, blank index is 0. 14 | labels.append(' ') -------------------------------------------------------------------------------- /data/language_specific_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | hebrew_replacements = [('צ ','ץ '), 3 | ('פ ','ף '), 4 | ('כ ','ך '), 5 | ('מ ','ם '), 6 | ('נ ','ן ')] 7 | def hebrew_normal_to_final(strings): 8 | return _hebrew_convert(strings,hebrew_replacements) 9 | 10 | def hebrew_final_to_normal(strings): 11 | return _hebrew_convert(strings,{(k[1],k[0])for k in hebrew_replacements}) 12 | 13 | def _hebrew_convert(strings,replacements): 14 | if type(strings) is list: 15 | return [hebrew_normal_to_final(s) for s in strings] 16 | 17 | res = strings + ' ' 18 | for replacement in replacements: 19 | res = res.replace(replacement[0],replacement[1]) 20 | return res[:-1] 21 | 22 | -------------------------------------------------------------------------------- /data/prepare_librispeech.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import os 4 | import os.path 5 | import pandas as pd 6 | import subprocess 7 | import shutil 8 | import glob 9 | import wget 10 | 11 | def download_librispeech_subset(subset_name, download_dir): 12 | if os.path.exists(f'{download_dir}/{subset_name}.tar.gz'): 13 | print(f'{download_dir}/{subset_name} already exists - skipping download') 14 | return 15 | url = f'https://www.openslr.org/resources/12/{subset_name}.tar.gz' 16 | print(f'Downloading from {url} to {download_dir}/{subset_name}.tar.gz') 17 | wget.download(url, out=download_dir) 18 | 19 | def extract_subset(subset_name, download_dir, extracted_dir): 20 | if os.path.exists(f'{extracted_dir}/LibriSpeech/{subset_name}'): 21 | print(f'{extracted_dir}/LibriSpeech/{subset_name} already exists, skipping extraction') 22 | return 23 | os.makedirs(args.extracted_dir,exist_ok=True) 24 | print('Unpacking .tar file') 25 | shutil.unpack_archive(f'{download_dir}/{subset_name}.tar.gz', extracted_dir) 26 | 27 | def read_transcriptions(subset_name,extracted_dir): 28 | all_lines = [] 29 | transcript_glob = os.path.join(args.extracted_dir,'LibriSpeech',subset_name,'*/*/*.txt') 30 | print(transcript_glob) 31 | for transcript_file in glob.glob(transcript_glob): 32 | with open(transcript_file,'r') as f: 33 | lines = f.readlines() 34 | lines = [line.split(' ',1) for line in lines] 35 | for line in lines: 36 | line[0] = os.path.join(os.path.dirname(transcript_file),line[0]+'.flac') 37 | all_lines.extend(lines) 38 | return all_lines 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser('Librispeech data preparation.') 42 | parser.add_argument('--subset',type=str,default='dev-clean',help='Subset of Librispeech to download.') 43 | parser.add_argument('--download_dir',type=str,default='.',help='Directory to download Librispeech to. Will be created if not exists.') 44 | parser.add_argument('--extracted_dir',type=str,default='./extracted', help='Directory to extract Librispeech to. Will be created if not exists.') 45 | parser.add_argument('--manifest_path',type=str,default='df.csv', help='Filename of the manifest to create. This is the path required by the "train.py" script') 46 | parser.add_argument('--use_relative_path',default=True,action='store_false',help='Use relative paths in resulting manifest.') 47 | args = parser.parse_args() 48 | 49 | progress_function = lambda gen : gen 50 | try: 51 | import tqdm 52 | progress_function = tqdm.tqdm 53 | except: 54 | print('tqdm not available, will not show progress') 55 | 56 | download_librispeech_subset(args.subset, args.download_dir) 57 | extract_subset(args.subset, args.download_dir, args.extracted_dir) 58 | all_lines = read_transcriptions(args.subset, args.extracted_dir) 59 | 60 | #os.makedirs(args.,exist_ok=True) 61 | 62 | df = pd.DataFrame(all_lines,columns=['audio_filepath','text']) 63 | if not args.use_relative_path: 64 | df.filepath = df.filepath.apply(os.path.abspath) 65 | df.text = df.text.apply(str.strip) 66 | df.to_csv(args.manifest_path) 67 | print(f'Done - manifest created at {args.manifest_path}') -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*-import torch 2 | import torch 3 | from six.moves import xrange 4 | import Levenshtein as Lev 5 | import re 6 | from collections import defaultdict, Counter 7 | import numpy as np 8 | 9 | import data.label_sets 10 | 11 | class Decoder(object): 12 | """ 13 | Basic decoder class from which all other decoders inherit. Implements several 14 | helper functions. Subclasses should implement the decode() method. 15 | 16 | Arguments: 17 | labels (string): mapping from integers to characters. 18 | blank_index (int, optional): index for the blank '_' character. Defaults to 0. 19 | space_index (int, optional): index for the space ' ' character. Defaults to 28. 20 | """ 21 | 22 | def __init__(self, labels, blank_index=0): 23 | self.labels = data.label_sets.labels_map[labels] if type(labels) is str else labels 24 | self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) 25 | self.blank_index = blank_index 26 | space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space 27 | if ' ' in labels: 28 | space_index = labels.index(' ') 29 | self.space_index = space_index 30 | 31 | def wer(self, s1, s2): 32 | """ 33 | Computes the Word Error Rate, defined as the edit distance between the 34 | two provided sentences after tokenizing to words. 35 | Arguments: 36 | s1 (string): space-separated sentence 37 | s2 (string): space-separated sentence 38 | """ 39 | 40 | # build mapping of words to integers 41 | b = set(s1.split() + s2.split()) 42 | word2char = dict(zip(b, range(len(b)))) 43 | 44 | # map the words to a char array (Levenshtein packages only accepts 45 | # strings) 46 | w1 = [chr(word2char[w]) for w in s1.split()] 47 | w2 = [chr(word2char[w]) for w in s2.split()] 48 | 49 | return Lev.distance(''.join(w1), ''.join(w2)) 50 | 51 | def cer(self, s1, s2): 52 | """ 53 | Computes the Character Error Rate, defined as the edit distance. 54 | 55 | Arguments: 56 | s1 (string): space-separated sentence 57 | s2 (string): space-separated sentence 58 | """ 59 | s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') 60 | return Lev.distance(s1, s2) 61 | 62 | def cer_ratio(self, expected, predicted): 63 | return self.cer(expected,predicted) , len(expected.replace(' ','')) 64 | 65 | def wer_ratio(self, expected, predicted): 66 | return self.wer(expected,predicted) , len(expected.split()) 67 | 68 | 69 | 70 | def decode(self, probs, sizes=None): 71 | """ 72 | Given a matrix of character probabilities, returns the decoder's 73 | best guess of the transcription 74 | 75 | Arguments: 76 | probs: Tensor of character probabilities, where probs[c,t] 77 | is the probability of character c at time t 78 | sizes(optional): Size of each sequence in the mini-batch 79 | Returns: 80 | string: sequence of the model's best guess for the transcription 81 | """ 82 | raise NotImplementedError 83 | 84 | 85 | class GreedyDecoder(Decoder): 86 | def __init__(self, labels, blank_index=0): 87 | super(GreedyDecoder, self).__init__(labels, blank_index) 88 | 89 | def convert_to_strings(self, sequences, sizes=None, remove_repetitions=False, return_offsets=False): 90 | """Given a list of numeric sequences, returns the corresponding strings""" 91 | strings = [] 92 | offsets = [] if return_offsets else None 93 | for x in xrange(len(sequences)): 94 | seq_len = sizes[x] if sizes is not None else len(sequences[x]) 95 | string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions) 96 | strings.append([string]) # We only return one path 97 | if return_offsets: 98 | offsets.append([string_offsets]) 99 | if return_offsets: 100 | return strings, offsets 101 | else: 102 | return strings 103 | 104 | def process_string(self, sequence, size, remove_repetitions=False): 105 | string = '' 106 | offsets = [] 107 | for i in range(size): 108 | char = self.int_to_char[sequence[i].item()] 109 | if char != self.int_to_char[self.blank_index]: 110 | # if this char is a repetition and remove_repetitions=true, then skip 111 | if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]: 112 | pass 113 | elif char == self.labels[self.space_index]: 114 | string += ' ' 115 | offsets.append(i) 116 | else: 117 | string = string + char 118 | offsets.append(i) 119 | return string, torch.IntTensor(offsets) 120 | 121 | def decode(self, probs, sizes=None, return_offsets=False): 122 | """ 123 | Returns the argmax decoding given the probability matrix. Removes 124 | repeated elements in the sequence, as well as blanks. 125 | 126 | Arguments: 127 | probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim 128 | sizes(optional): Size of each sequence in the mini-batch 129 | Returns: 130 | strings: sequences of the model's best guess for the transcription on inputs 131 | offsets: time step per character predicted 132 | """ 133 | if len(probs.shape) == 2: 134 | return self.decode(probs.unsqueeze(0), sizes, return_offsets) 135 | 136 | _, max_probs = torch.max(probs, 2) 137 | strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes, 138 | remove_repetitions=True, return_offsets=True) 139 | strings = [s[0] for s in strings] #This feels a bit hacky. 140 | #if probs.shape[0] == 1: 141 | # strings = strings[0] 142 | # offsets = offsets[0] 143 | if return_offsets: 144 | return strings, offsets 145 | return strings 146 | 147 | def prefix_beam_search(ctc, labels, blank_index=0, lm=None,k=5,alpha=0.3,beta=5,prune=0.001,end_char='>',return_weights=False): 148 | """ 149 | Performs prefix beam search on the output of a CTC network. 150 | Originally from https://github.com/corticph/prefix-beam-search, with minor edits. 151 | Args: 152 | ctc (np.ndarray): The CTC output. Should be a 2D array (timesteps x alphabet_size) 153 | lm (func): Language model function. Should take as input a string and output a probability. 154 | k (int): The beam width. Will keep the 'k' most likely candidates at each timestep. 155 | alpha (float): The language model weight. Should usually be between 0 and 1. 156 | beta (float): The language model compensation term. The higher the 'alpha', the higher the 'beta'. 157 | prune (float): Only extend prefixes with chars with an emission probability higher than 'prune'. 158 | return_weights(bool): return the confidence of the decoded string. 159 | Returns: 160 | string: The decoded CTC output. 161 | """ 162 | assert (ctc.shape[1] == len(labels)), "ctc size:%d, labels: %d" % (ctc.shape[1], len(labels)) 163 | assert ctc.shape[0] > 1, "ctc length: %d was too short" % ctc.shape[0] 164 | assert (ctc >= 0).all(), 'ctc output contains negative numbers' 165 | lm = (lambda l: 1) if lm is None else lm # if no LM is provided, just set to function returning 1 166 | word_count_re = re.compile(r'\w+[\s|>]') 167 | W = lambda l: word_count_re.findall(l) 168 | F = ctc.shape[1] 169 | 170 | ctc = np.vstack((np.zeros(F), ctc)) # just add an imaginative zero'th step (will make indexing more intuitive) 171 | T = ctc.shape[0] 172 | blank_char = labels[blank_index] 173 | 174 | # STEP 1: Initiliazation 175 | O = '' 176 | Pb, Pnb = defaultdict(Counter), defaultdict(Counter) 177 | Pb[0][O] = 1 178 | Pnb[0][O] = 0 179 | A_prev = [O] 180 | # END: STEP 1 181 | 182 | # STEP 2: Iterations and pruning 183 | for t in range(1, T): 184 | pruned_alphabet = [labels[i] for i in np.where(ctc[t] > prune)[0]] 185 | for l in A_prev: 186 | 187 | if len(l) > 0 and l[-1] == end_char: 188 | Pb[t][l] = Pb[t - 1][l] 189 | Pnb[t][l] = Pnb[t - 1][l] 190 | continue 191 | 192 | for c in pruned_alphabet: 193 | c_ix = labels.index(c) 194 | # END: STEP 2 195 | 196 | # STEP 3: “Extending” with a blank 197 | if c == blank_char: 198 | Pb[t][l] += ctc[t][blank_index] * (Pb[t - 1][l] + Pnb[t - 1][l]) 199 | # END: STEP 3 200 | 201 | # STEP 4: Extending with the end character 202 | else: 203 | l_plus = l + c 204 | if len(l) > 0 and c == l[-1]: 205 | Pnb[t][l_plus] += ctc[t][c_ix] * Pb[t - 1][l] 206 | Pnb[t][l] += ctc[t][c_ix] * Pnb[t - 1][l] 207 | # END: STEP 4 208 | 209 | # STEP 5: Extending with any other non-blank character and LM constraints 210 | elif len(l.replace(' ', '')) > 0 and c in (' ', end_char): 211 | lm_prob = lm(l_plus.strip(' '+end_char)) ** alpha 212 | Pnb[t][l_plus] += lm_prob * ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l]) 213 | else: 214 | Pnb[t][l_plus] += ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l]) 215 | # END: STEP 5 216 | 217 | # STEP 6: Make use of discarded prefixes 218 | if l_plus not in A_prev: 219 | Pb[t][l_plus] += ctc[t][blank_index] * (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus]) 220 | Pnb[t][l_plus] += ctc[t][c_ix] * Pnb[t - 1][l_plus] 221 | # END: STEP 6 222 | 223 | # STEP 7: Select most probable prefixes 224 | A_next = Pb[t] + Pnb[t] 225 | sorter = lambda l: A_next[l] * (len(W(l)) + 1) ** beta 226 | A_prev = sorted(A_next, key=sorter, reverse=True)[:k] 227 | # END: STEP 7 228 | if len(A_prev) ==0: 229 | A_prev=[''] 230 | if return_weights: 231 | return A_prev[0],A_next[A_prev[0]] * (len(W(A_prev[0])) + 1) ** beta 232 | return A_prev[0] 233 | #For N-best decode, return A_prev[0:N] - not tested yet. 234 | 235 | class PrefixBeamSearchLMDecoder(Decoder): 236 | def __init__(self,lm_path,labels,blank_index=0,k=5,alpha=0.3,beta=5,prune=1e-3): 237 | """ 238 | Args: 239 | lm_path (str): The path to the kenlm language model. 240 | labels (list(str)): A list of the characters. 241 | blank_index (int): The index of the blank character in the `labels` parameter. 242 | k (int): The beam width. Will keep the 'k' most likely candidates at each timestep. 243 | alpha (float): The language model weight. Should usually be between 0 and 1. 244 | beta (float): The language model compensation term. The higher the 'alpha', the higher the 'beta'. 245 | prune (float): Only extend prefixes with chars with an emission probability higher than 'prune'. 246 | """ 247 | super(PrefixBeamSearchLMDecoder, self).__init__(labels,blank_index) 248 | if lm_path: 249 | import kenlm 250 | self.lm = kenlm.Model(lm_path) 251 | self.lm_weigh = lambda f: 10**(self.lm.score(f)) 252 | else: 253 | self.lm_weigh = lambda s: 1 254 | self.k =k 255 | self.alpha=alpha 256 | self.beta=beta 257 | self.prune=prune 258 | 259 | def decode(self, probs, sizes=None, return_offsets=False): 260 | if return_offsets: 261 | raise NotImplementedError("Prefix beam search does not support offsets (yet).") 262 | if len(probs.shape) == 2: # Single 263 | return prefix_beam_search(probs,self.labels,self.blank_index,self.lm_weigh,self.k,self.alpha,self.beta,self.prune) 264 | elif len(probs.shape) == 3: # Batch 265 | return [self.decode(prob) for prob in probs] 266 | else: 267 | raise RuntimeError('Decoding with wrong shape: %s, expected either [Batch X Frames X Labels] or [Frames X Labels]' % str(probs.shape)) 268 | 269 | 270 | def get_time_per_word(predictions, offsets, ratio=1.0): 271 | """ 272 | Compute the start and end time for each word outputed by the model (and decoder), based on offsets. 273 | 274 | Note that end time per word consider only the first instance of the last character in the word - This might result in slightly earlier end timings when the model predicts repetitions. 275 | Args: 276 | predictions (list(str)): The list of characters predicted. 277 | offsets (list(int)): the list of offsets for each character. 278 | ratio (float, optional): The ratio between output sequence length and input seconds. Can be computed as (sample rate) * (window stride). 279 | """ 280 | word_times = [] 281 | assert len(predictions) == len(offsets) 282 | current_word = '' 283 | start_time = -1 284 | end_time = -1 285 | for letter,offset in zip(predictions,offsets): 286 | if letter == ' ' and current_word: 287 | word_times.append((current_word,start_time,end_time)) 288 | current_word = '' 289 | start_time = -1 290 | end_time = -1 291 | if letter == ' ' and not current_word: 292 | continue # Nothing to do 293 | if current_word: 294 | end_time = offset * ratio 295 | current_word += letter 296 | if not current_word: 297 | start_time = offset * ratio 298 | end_time = offset * ratio 299 | current_word = letter 300 | if current_word: 301 | word_times.append((current_word,start_time,end_time)) 302 | return word_times 303 | 304 | 305 | if __name__ == '__main__': 306 | my_decoder = GreedyDecoder(['_','a','b',' ']) 307 | a = my_decoder.decode( torch.Tensor([[[0.4,0.6,0,0]]])) 308 | space = my_decoder.decode( torch.Tensor([[[0.4,0.1,0,0.5]]])) 309 | aba_and_space = my_decoder.decode(torch.Tensor([[[0.0,0.6,0.3,0.1],[0.0,0.6,0.3,0.1],[0.0,0.3,0.6,0.1],[0.0,0.6,0.3,0.1]], 310 | [[0.4,0.1,0,0.5],[0.4,0.1,0,0.5],[0.4,0.1,0,0.5],[0.4,0.1,0,0.5]]]), 311 | sizes=[4,1]) -------------------------------------------------------------------------------- /examples/check_requirements.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Feb 24 10:56:07 2021 4 | 5 | @author: User 6 | """ 7 | 8 | import numpy 9 | import torch 10 | import pytorch_lightning 11 | import hydra 12 | import soundfile 13 | import glob 14 | import pandas 15 | import tqdm 16 | import soundfile 17 | 18 | print("All python modules requirements available.") -------------------------------------------------------------------------------- /examples/librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "This script downloads the LibriSpeech dev-clean and test-clean subsets, and trains a model for 1 epoch." 3 | echo "Usage: bash librispeech.sh [wav2letter_pytorch directory - optional]" 4 | base_dir=${1:-$(dirname $(dirname "$(readlink -f "$0")"))} 5 | python ${base_dir}/examples/check_requirements.py 6 | python ${base_dir}/data/prepare_librispeech.py --subset dev-clean --manifest_path dev_clean.csv 7 | python ${base_dir}/data/prepare_librispeech.py --subset test-clean --manifest_path test_clean.csv 8 | python ${base_dir}/train.py data.train_manifest=dev_clean.csv data.val_manifest=test_clean.csv trainer.max_epochs=1 9 | echo Training finished successfully! 10 | echo Tensorboard logs were saved to directory "./lightning_logs". Call 'tensorboard --logdir ./lightning_logs' to view the logs generated during training. -------------------------------------------------------------------------------- /jasper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Copied from https://github.com/NVIDIA/DeepLearningExamples. Specifically, 4 | https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper/model.py 5 | 6 | Not original code! 7 | 8 | Minor edits were made. 9 | ''' 10 | 11 | from typing import List, Optional, Tuple 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch import Tensor 16 | import torch.nn.functional as F 17 | import numpy as np 18 | import pickle 19 | 20 | from base_asr_models import ConvCTCASR 21 | 22 | jasper_activations = { 23 | "hardtanh": nn.Hardtanh, 24 | "relu": nn.ReLU, 25 | "selu": nn.SELU, 26 | } 27 | 28 | 29 | def init_weights(m, mode='xavier_uniform'): 30 | if isinstance(m, MaskedConv1d): 31 | init_weights(m.conv, mode) 32 | if isinstance(m, (nn.Conv1d, nn.Linear)): 33 | if mode == 'xavier_uniform': 34 | nn.init.xavier_uniform_(m.weight, gain=1.0) 35 | elif mode == 'xavier_normal': 36 | nn.init.xavier_normal_(m.weight, gain=1.0) 37 | elif mode == 'kaiming_uniform': 38 | nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") 39 | elif mode == 'kaiming_normal': 40 | nn.init.kaiming_normal_(m.weight, nonlinearity="relu") 41 | else: 42 | raise ValueError("Unknown Initialization mode: {0}".format(mode)) 43 | elif isinstance(m, nn.BatchNorm1d): 44 | if m.track_running_stats: 45 | m.running_mean.zero_() 46 | m.running_var.fill_(1) 47 | m.num_batches_tracked.zero_() 48 | if m.affine: 49 | nn.init.ones_(m.weight) 50 | nn.init.zeros_(m.bias) 51 | 52 | 53 | def compute_new_kernel_size(kernel_size, kernel_width): 54 | new_kernel_size = max(int(kernel_size * kernel_width), 1) 55 | # If kernel is even shape, round up to make it odd 56 | if new_kernel_size % 2 == 0: 57 | new_kernel_size += 1 58 | return new_kernel_size 59 | 60 | 61 | def get_same_padding(kernel_size, stride, dilation): 62 | if stride > 1 and dilation > 1: 63 | raise ValueError("Only stride OR dilation may be greater than 1") 64 | if dilation > 1: 65 | return (dilation * kernel_size) // 2 - 1 66 | return kernel_size // 2 67 | 68 | 69 | class MaskedConv1d(nn.Module): 70 | __constants__ = ["use_conv_mask", "real_out_channels", "heads"] 71 | 72 | def __init__( 73 | self, 74 | in_channels, 75 | out_channels, 76 | kernel_size, 77 | stride=1, 78 | padding=0, 79 | dilation=1, 80 | groups=1, 81 | heads=-1, 82 | bias=False, 83 | use_mask=True, 84 | ): 85 | super(MaskedConv1d, self).__init__() 86 | 87 | if not (heads == -1 or groups == in_channels): 88 | raise ValueError("Only use heads for depthwise convolutions") 89 | 90 | self.real_out_channels = out_channels 91 | if heads != -1: 92 | in_channels = heads 93 | out_channels = heads 94 | groups = heads 95 | 96 | self.conv = nn.Conv1d( 97 | in_channels, 98 | out_channels, 99 | kernel_size, 100 | stride=stride, 101 | padding=padding, 102 | dilation=dilation, 103 | groups=groups, 104 | bias=bias, 105 | ) 106 | self.use_mask = use_mask 107 | self.heads = heads 108 | 109 | def get_seq_len(self, lens): 110 | return ( 111 | lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1 112 | ) / self.conv.stride[0] + 1 113 | 114 | def forward(self, x, lens): 115 | if self.use_mask: 116 | lens = lens.to(dtype=torch.long) 117 | max_len = x.size(2) 118 | mask = torch.arange(max_len).to(lens.device).expand(len(lens), max_len) >= lens.unsqueeze(1) 119 | x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0) 120 | # del mask 121 | lens = self.get_seq_len(lens) 122 | 123 | sh = x.shape 124 | if self.heads != -1: 125 | x = x.view(-1, self.heads, sh[-1]) 126 | 127 | out = self.conv(x) 128 | 129 | if self.heads != -1: 130 | out = out.view(sh[0], self.real_out_channels, -1) 131 | 132 | return out, lens 133 | 134 | 135 | class GroupShuffle(nn.Module): 136 | def __init__(self, groups, channels): 137 | super(GroupShuffle, self).__init__() 138 | 139 | self.groups = groups 140 | self.channels_per_group = channels // groups 141 | 142 | def forward(self, x): 143 | sh = x.shape 144 | 145 | x = x.view(-1, self.groups, self.channels_per_group, sh[-1]) 146 | 147 | x = torch.transpose(x, 1, 2).contiguous() 148 | 149 | x = x.view(-1, self.groups * self.channels_per_group, sh[-1]) 150 | 151 | return x 152 | 153 | 154 | class JasperBlock(nn.Module): 155 | __constants__ = ["conv_mask", "separable", "residual_mode", "res", "mconv"] 156 | 157 | def __init__( 158 | self, 159 | inplanes, 160 | planes, 161 | repeat=3, 162 | kernel_size=11, 163 | kernel_size_factor=1, 164 | stride=1, 165 | dilation=1, 166 | padding='same', 167 | dropout=0, 168 | activation=None, 169 | residual=True, 170 | groups=1, 171 | separable=False, 172 | heads=-1, 173 | normalization="batch", 174 | norm_groups=1, 175 | residual_mode='add', 176 | residual_panes=[], 177 | conv_mask=False 178 | ): 179 | super(JasperBlock, self).__init__() 180 | 181 | if padding != "same": 182 | raise ValueError("currently only 'same' padding is supported") 183 | 184 | kernel_size_factor = float(kernel_size_factor) 185 | if type(kernel_size) in (list, tuple): 186 | kernel_size = [compute_new_kernel_size(k, kernel_size_factor) for k in kernel_size] 187 | else: 188 | kernel_size = compute_new_kernel_size(kernel_size, kernel_size_factor) 189 | 190 | padding_val = get_same_padding(kernel_size, stride, dilation) 191 | self.conv_mask = conv_mask 192 | self.separable = separable 193 | self.residual_mode = residual_mode 194 | 195 | inplanes_loop = inplanes 196 | conv = nn.ModuleList() 197 | 198 | for _ in range(repeat - 1): 199 | conv.extend( 200 | self._get_conv_bn_layer( 201 | inplanes_loop, 202 | planes, 203 | kernel_size=kernel_size, 204 | stride=stride, 205 | dilation=dilation, 206 | padding=padding_val, 207 | groups=groups, 208 | heads=heads, 209 | separable=separable, 210 | normalization=normalization, 211 | norm_groups=norm_groups, 212 | ) 213 | ) 214 | 215 | conv.extend(self._get_act_dropout_layer(drop_prob=dropout, activation=activation)) 216 | 217 | 218 | inplanes_loop = planes 219 | 220 | conv.extend( 221 | self._get_conv_bn_layer( 222 | inplanes_loop, 223 | planes, 224 | kernel_size=kernel_size, 225 | stride=stride, 226 | dilation=dilation, 227 | padding=padding_val, 228 | groups=groups, 229 | heads=heads, 230 | separable=separable, 231 | normalization=normalization, 232 | norm_groups=norm_groups, 233 | ) 234 | ) 235 | 236 | self.mconv = conv 237 | 238 | res_panes = residual_panes.copy() 239 | self.dense_residual = residual 240 | 241 | if residual: 242 | res_list = nn.ModuleList() 243 | if len(residual_panes) == 0: 244 | res_panes = [inplanes] 245 | self.dense_residual = False 246 | for ip in res_panes: 247 | res = nn.ModuleList( 248 | self._get_conv_bn_layer( 249 | ip, planes, kernel_size=1, normalization=normalization, norm_groups=norm_groups, 250 | ) 251 | ) 252 | 253 | res_list.append(res) 254 | 255 | self.res = res_list 256 | else: 257 | self.res = None 258 | 259 | self.mout = nn.Sequential(*self._get_act_dropout_layer(drop_prob=dropout, activation=activation)) 260 | 261 | def _get_conv( 262 | self, 263 | in_channels, 264 | out_channels, 265 | kernel_size=11, 266 | stride=1, 267 | dilation=1, 268 | padding=0, 269 | bias=False, 270 | groups=1, 271 | heads=-1, 272 | separable=False, 273 | ): 274 | use_mask = self.conv_mask 275 | if use_mask: 276 | return MaskedConv1d( 277 | in_channels, 278 | out_channels, 279 | kernel_size, 280 | stride=stride, 281 | dilation=dilation, 282 | padding=padding, 283 | bias=bias, 284 | groups=groups, 285 | heads=heads, 286 | use_mask=use_mask, 287 | ) 288 | else: 289 | return nn.Conv1d( 290 | in_channels, 291 | out_channels, 292 | kernel_size, 293 | stride=stride, 294 | dilation=dilation, 295 | padding=padding, 296 | bias=bias, 297 | groups=groups, 298 | ) 299 | 300 | def _get_conv_bn_layer( 301 | self, 302 | in_channels, 303 | out_channels, 304 | kernel_size=11, 305 | stride=1, 306 | dilation=1, 307 | padding=0, 308 | bias=False, 309 | groups=1, 310 | heads=-1, 311 | separable=False, 312 | normalization="batch", 313 | norm_groups=1, 314 | ): 315 | if norm_groups == -1: 316 | norm_groups = out_channels 317 | 318 | if separable: 319 | layers = [ 320 | self._get_conv( 321 | in_channels, 322 | in_channels, 323 | kernel_size, 324 | stride=stride, 325 | dilation=dilation, 326 | padding=padding, 327 | bias=bias, 328 | groups=in_channels, 329 | heads=heads, 330 | ), 331 | self._get_conv( 332 | in_channels, 333 | out_channels, 334 | kernel_size=1, 335 | stride=1, 336 | dilation=1, 337 | padding=0, 338 | bias=bias, 339 | groups=groups, 340 | ), 341 | ] 342 | else: 343 | layers = [ 344 | self._get_conv( 345 | in_channels, 346 | out_channels, 347 | kernel_size, 348 | stride=stride, 349 | dilation=dilation, 350 | padding=padding, 351 | bias=bias, 352 | groups=groups, 353 | ) 354 | ] 355 | 356 | if normalization == "group": 357 | layers.append(nn.GroupNorm(num_groups=norm_groups, num_channels=out_channels)) 358 | elif normalization == "instance": 359 | layers.append(nn.GroupNorm(num_groups=out_channels, num_channels=out_channels)) 360 | elif normalization == "layer": 361 | layers.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) 362 | elif normalization == "batch": 363 | layers.append(nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.1)) 364 | else: 365 | raise ValueError( 366 | f"Normalization method ({normalization}) does not match" f" one of [batch, layer, group, instance]." 367 | ) 368 | 369 | if groups > 1: 370 | layers.append(GroupShuffle(groups, out_channels)) 371 | return layers 372 | 373 | def _get_act_dropout_layer(self, drop_prob=0.2, activation=None): 374 | if activation is None: 375 | activation = nn.Hardtanh(min_val=0.0, max_val=20.0) 376 | layers = [activation, nn.Dropout(p=drop_prob)] 377 | return layers 378 | 379 | def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]): 380 | # type: (Tuple[List[Tensor], Optional[Tensor]]) -> Tuple[List[Tensor], Optional[Tensor]] # nopep8 381 | lens_orig = None 382 | xs = input_[0] 383 | if len(input_) == 2: 384 | xs, lens_orig = input_ 385 | 386 | # compute forward convolutions 387 | out = xs#[-1] 388 | 389 | lens = lens_orig 390 | for i, l in enumerate(self.mconv): 391 | # if we're doing masked convolutions, we need to pass in and 392 | # possibly update the sequence lengths 393 | # if (i % 4) == 0 and self.conv_mask: 394 | if isinstance(l, MaskedConv1d): 395 | out, lens = l(out, lens) 396 | else: 397 | out = l(out) 398 | 399 | # compute the residuals 400 | if self.res is not None: 401 | for i, layer in enumerate(self.res): 402 | res_out = xs#[i] 403 | for j, res_layer in enumerate(layer): 404 | if isinstance(res_layer, MaskedConv1d): 405 | res_out, _ = res_layer(res_out, lens_orig) 406 | else: 407 | res_out = res_layer(res_out) 408 | 409 | if self.residual_mode == 'add': 410 | out = out + res_out 411 | else: 412 | out = torch.max(out, res_out) 413 | 414 | # compute the output 415 | out = self.mout(out) 416 | if self.res is not None and self.dense_residual: 417 | return xs + [out], lens 418 | 419 | return out, lens 420 | 421 | 422 | class Jasper(ConvCTCASR): 423 | def __init__(self,cfg): 424 | super(Jasper,self).__init__(cfg) 425 | self.mid_layers = cfg.mid_layers 426 | if not cfg.input_size: 427 | nfft = (self.audio_conf['sample_rate'] * self.audio_conf['window_size']) 428 | self.input_size = int(1+(nfft/2)) 429 | else: 430 | self.input_size = cfg.input_size 431 | self._build_encoder(cfg) 432 | last_layer_input_size = self.jasper_encoder[-1].mconv[-1].num_features 433 | self.final_layer = nn.Sequential(nn.Conv1d(last_layer_input_size,len(self.labels),kernel_size=1,stride=1)) 434 | self.final_layer.apply(init_weights) 435 | 436 | def _build_encoder(self,cfg): 437 | layer_size = self.input_size 438 | encoder_layers = [] 439 | for l in cfg.jasper_blocks[:cfg.mid_layers]: 440 | layer = JasperBlock(inplanes = layer_size, planes = l.layer_size, 441 | kernel_size = l.kernel_size, 442 | stride = l.get('stride',1), 443 | dilation = l.get('dilation',1), 444 | residual = l.residual, 445 | repeat = l.get('repeat',1), 446 | conv_mask = l.get('conv_mask',True), 447 | separable = l.get('separable',True), 448 | activation = torch.nn.ReLU(), 449 | dropout = l.get('dropout',0)) 450 | encoder_layers.append(layer) 451 | layer_size = l.layer_size 452 | self.jasper_encoder = nn.Sequential(*encoder_layers) 453 | self.jasper_encoder.apply(init_weights) 454 | 455 | @property 456 | def scaling_factor(self): 457 | if not hasattr(self, '_scaling_factor'): 458 | self._scaling_factor = int(np.prod([block.mconv[0].conv.stride[0] for block in self.jasper_encoder])) 459 | return self._scaling_factor 460 | 461 | 462 | def forward(self,xs,input_lengths): 463 | ''' 464 | [Batches X channels X length], lengths 465 | ''' 466 | encoder_res = self.jasper_encoder((xs,input_lengths)) 467 | output_lengths = encoder_res[1].to(dtype=int) 468 | jasper_res = self.final_layer(encoder_res[0]) 469 | jasper_res = jasper_res.transpose(2,1) # For consistency with other models. 470 | if self.training: 471 | jasper_res = F.log_softmax(jasper_res,dim=-1) 472 | else: 473 | jasper_res = F.softmax(jasper_res,dim=-1) 474 | assert not (jasper_res != jasper_res).any() # is there any NAN in result? 475 | return jasper_res, output_lengths # [Batches X Labels X Time (padded to max)], [Batches] 476 | 477 | -------------------------------------------------------------------------------- /novograd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Copied from https://github.com/NVIDIA/DeepLearningExamples. Specifically, 4 | https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper/optimizers.py 5 | 6 | Not original code! 7 | ''' 8 | 9 | import torch 10 | from torch.optim import Optimizer 11 | import math 12 | class Novograd(Optimizer): 13 | """ 14 | Implements Novograd algorithm. 15 | Args: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.95, 0)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 24 | grad_averaging: gradient averaging 25 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 26 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 27 | (default: False) 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8, 31 | weight_decay=0, grad_averaging=False, amsgrad=False): 32 | if not 0.0 <= lr: 33 | raise ValueError("Invalid learning rate: {}".format(lr)) 34 | if not 0.0 <= eps: 35 | raise ValueError("Invalid epsilon value: {}".format(eps)) 36 | if not 0.0 <= betas[0] < 1.0: 37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 40 | defaults = dict(lr=lr, betas=betas, eps=eps, 41 | weight_decay=weight_decay, 42 | grad_averaging=grad_averaging, 43 | amsgrad=amsgrad) 44 | 45 | super(Novograd, self).__init__(params, defaults) 46 | 47 | def __setstate__(self, state): 48 | super(Novograd, self).__setstate__(state) 49 | for group in self.param_groups: 50 | group.setdefault('amsgrad', False) 51 | 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | Arguments: 55 | closure (callable, optional): A closure that reevaluates the model 56 | and returns the loss. 57 | """ 58 | loss = None 59 | if closure is not None: 60 | loss = closure() 61 | 62 | for group in self.param_groups: 63 | for p in group['params']: 64 | if p.grad is None: 65 | continue 66 | grad = p.grad.data 67 | if grad.is_sparse: 68 | raise RuntimeError('Sparse gradients are not supported.') 69 | amsgrad = group['amsgrad'] 70 | 71 | state = self.state[p] 72 | 73 | # State initialization 74 | if len(state) == 0: 75 | state['step'] = 0 76 | # Exponential moving average of gradient values 77 | state['exp_avg'] = torch.zeros_like(p.data) 78 | # Exponential moving average of squared gradient values 79 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 80 | if amsgrad: 81 | # Maintains max of all exp. moving avg. of sq. grad. values 82 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 83 | 84 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 85 | if amsgrad: 86 | max_exp_avg_sq = state['max_exp_avg_sq'] 87 | beta1, beta2 = group['betas'] 88 | 89 | state['step'] += 1 90 | 91 | norm = torch.sum(torch.pow(grad, 2)) 92 | 93 | if exp_avg_sq == 0: 94 | exp_avg_sq.copy_(norm) 95 | else: 96 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 97 | 98 | if amsgrad: 99 | # Maintains the maximum of all 2nd moment running avg. till now 100 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 101 | # Use the max. for normalizing running avg. of gradient 102 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 103 | else: 104 | denom = exp_avg_sq.sqrt().add_(group['eps']) 105 | 106 | grad.div_(denom) 107 | if group['weight_decay'] != 0: 108 | grad.add_(group['weight_decay'], p.data) 109 | if group['grad_averaging']: 110 | grad.mul_(1 - beta1) 111 | exp_avg.mul_(beta1).add_(grad) 112 | 113 | p.data.add_(-group['lr'], exp_avg) 114 | 115 | return loss -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | librosa 4 | tqdm 5 | pytorch 6 | six 7 | scipy 8 | glob2 9 | pytorch-lightning 10 | #The following libraries require Microsoft Visual Build C++ Tools on windows. Should check requirements on Linux. 11 | python-levenshtein 12 | # Install this version of kenlm. Seems to be easier to install on windows. 13 | https://github.com/kpu/kenlm/archive/master.zip -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | 6 | import pytorch_lightning 7 | import hydra 8 | from omegaconf import DictConfig, OmegaConf 9 | from hydra.utils import instantiate 10 | 11 | from data import label_sets 12 | from wav2letter import Wav2Letter 13 | from jasper import Jasper 14 | from data.data_loader import SpectrogramDataset, BatchAudioDataLoader 15 | 16 | name_to_model = { 17 | "jasper":Jasper, 18 | "wav2letter":Wav2Letter 19 | } 20 | 21 | def get_data_loaders(labels, cfg): 22 | train_dataset = SpectrogramDataset(cfg.train_manifest, cfg.audio_conf, labels,mel_spec=cfg.mel_spec) 23 | train_batch_loader = BatchAudioDataLoader(train_dataset, batch_size=cfg.batch_size) 24 | eval_dataset = SpectrogramDataset(cfg.val_manifest, cfg.audio_conf, labels,mel_spec=cfg.mel_spec) 25 | val_batch_loader = BatchAudioDataLoader(eval_dataset,batch_size=cfg.batch_size) 26 | return train_batch_loader, val_batch_loader 27 | 28 | @hydra.main(config_path='configuration', config_name='config') 29 | def main(cfg: DictConfig): 30 | if type(cfg.model.labels) is str: 31 | cfg.model.labels = label_sets.labels_map[cfg.model.labels] 32 | train_loader, val_loader = get_data_loaders(cfg.model.labels,cfg.data) 33 | model = name_to_model[cfg.model.name](cfg.model) 34 | trainer = pytorch_lightning.Trainer(**cfg.trainer) 35 | 36 | 37 | trainer.fit(model, train_loader, val_loader) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /unit_tests/decoder_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import pytest 4 | import numpy as np 5 | 6 | from decoder import PrefixBeamSearchLMDecoder, prefix_beam_search, GreedyDecoder 7 | from data.label_sets import english_labels 8 | 9 | def greedy_decode(samples, labels, blank_index=0,sizes=None): 10 | greedy_decoder = GreedyDecoder(labels, blank_index=blank_index) 11 | res = greedy_decoder.decode(torch.FloatTensor(samples).unsqueeze(0), sizes=sizes) 12 | return res[0] 13 | 14 | def test_sanity(): 15 | sample = np.zeros((10,len(english_labels))) 16 | sample[0,2] = 0.5 17 | sample[1,20]=0.5 18 | sample[2,19]=0.5 19 | sample[3:,0]=0.5 20 | res = prefix_beam_search(sample,english_labels) 21 | assert res == 'ASR' 22 | 23 | def test_inconsistent_sizes(): 24 | sample = np.zeros((10,len(english_labels) - 1)) 25 | with pytest.raises(AssertionError) as exc_info: 26 | _ = prefix_beam_search(sample,english_labels) 27 | assert exc_info is not None 28 | 29 | 30 | def test_beam_is_not_greedy(): 31 | ''' 32 | Example from https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-51889a3d85a7 33 | Shows that beam search can find a path that greedy decoding can not. 34 | ''' 35 | labels = ['_','A','B',' '] 36 | samples = np.array([[0.8,0.2,0,0],[0.6,0.4,0,0]]) 37 | res = prefix_beam_search(samples,labels,blank_index=0,return_weights=True) 38 | assert res == ('A',0.52) 39 | 40 | greedy_decoder = GreedyDecoder(labels, blank_index=0) 41 | greedy_res = greedy_decoder.decode(torch.FloatTensor(samples).unsqueeze(0), sizes=None) 42 | assert greedy_res == [''] 43 | 44 | def test_beam_width_changes(): 45 | def the_lm(s): 46 | if s == 'A': 47 | return 0.5 48 | return 1 49 | 50 | labels = ['_','A',' '] 51 | samples = np.array([[0.8,0.2,0], 52 | [0.7,0.3,0], 53 | [0.6,0.4,0], 54 | [0.0,0.0,1]]) 55 | res = prefix_beam_search(samples,labels,lm=the_lm,return_weights=False,k=25,alpha=1,beta=0) 56 | res2 = prefix_beam_search(samples,labels,lm=the_lm,return_weights=False,k=1,alpha=1,beta=0) 57 | 58 | assert res == ' ' 59 | assert res2 == 'A ' 60 | 61 | def test_class_wrapper(): 62 | 63 | sample = np.zeros((10,len(english_labels))) 64 | sample[0,2] = 0.5 65 | sample[1,20]=0.5 66 | sample[2,19]=0.5 67 | sample[3:,0]=0.5 68 | decoder = PrefixBeamSearchLMDecoder('',english_labels) 69 | res = decoder.decode(sample) 70 | assert res == 'ASR' 71 | 72 | def test_pbs_batch_dimensions(): 73 | sample = torch.zeros((10,len(english_labels))) 74 | sample[0,2] = 0.5 75 | sample[1,20]=0.5 76 | sample[2,19]=0.5 77 | sample[3:,0]=0.5 78 | sample = sample.unsqueeze(0) 79 | decoder = PrefixBeamSearchLMDecoder('',english_labels) 80 | res = decoder.decode(sample,english_labels) 81 | assert res == ['ASR'] -------------------------------------------------------------------------------- /wav2letter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import OrderedDict 3 | import librosa 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pytorch_lightning as ptl 8 | import numpy as np 9 | 10 | from base_asr_models import ConvCTCASR 11 | 12 | class Conv1dBlock(nn.Module): 13 | def __init__(self,input_channels,output_channels,kernel_size,stride,drop_out_prob=-1.0,dilation=1,bn=True,activation_use=True): 14 | super(Conv1dBlock, self).__init__() 15 | self.input_channels = input_channels 16 | self.output_channels = output_channels 17 | self.kernel_size = kernel_size 18 | self.stride = stride 19 | self.drop_out_prob = drop_out_prob 20 | self.dilation = dilation 21 | self.activation_use = activation_use 22 | self.padding = kernel_size[0] 23 | '''Padding Calculation''' 24 | input_rows = input_channels 25 | filter_rows = kernel_size[0] 26 | out_rows = (input_rows + stride - 1) // stride 27 | self.padding_rows = max(0, (out_rows -1) * stride + (filter_rows -1) * dilation + 1 - input_rows) 28 | if self.padding_rows > 0: 29 | if self.padding_rows % 2 == 0: 30 | self.paddingAdded = nn.ReflectionPad1d(self.padding_rows // 2) 31 | else: 32 | self.paddingAdded = nn.ReflectionPad1d((self.padding_rows //2,(self.padding_rows +1)//2)) 33 | else: 34 | self.paddingAdded = nn.Identity() 35 | self.conv1 = nn.Conv1d(in_channels=input_channels,out_channels=output_channels, 36 | kernel_size=kernel_size,stride=stride,padding=0,dilation=dilation) 37 | self.batch_norm = nn.BatchNorm1d(num_features=output_channels,momentum=0.9,eps=0.001) if bn else nn.Identity() 38 | self.drop_out = nn.Dropout(drop_out_prob) if self.drop_out_prob != -1 else nn.Identity() 39 | 40 | def forward(self,xs): 41 | xs = self.paddingAdded(xs) 42 | output = self.conv1(xs) 43 | output = self.batch_norm(output) 44 | output = self.drop_out(output) 45 | if self.activation_use: 46 | output = torch.clamp(input=output,min=0,max=20) 47 | return output 48 | 49 | class Wav2Letter(ConvCTCASR): 50 | def __init__(self, cfg): 51 | super(Wav2Letter,self).__init__(cfg) 52 | self.mid_layers = cfg.mid_layers 53 | if not cfg.input_size: 54 | nfft = (self.audio_conf['sample_rate'] * self.audio_conf['window_size']) 55 | self.input_size = int(1+(nfft/2)) 56 | else: 57 | self.input_size = cfg.input_size 58 | 59 | layers = cfg.layers[: self.mid_layers] 60 | layer_size = self.input_size 61 | conv_blocks = [] 62 | for idx in range(len(layers)): 63 | layer_params = layers[idx] # TODO: can we use **layer_params here? 64 | layer = Conv1dBlock(input_channels=layer_size,output_channels=layer_params.output_size, 65 | kernel_size=(layer_params.kernel_size,),stride=layer_params.stride, 66 | dilation=layer_params.dilation,drop_out_prob=layer_params.dropout) 67 | layer_size = layer_params.output_size 68 | conv_blocks.append(('conv1d_{}'.format(idx),layer)) 69 | last_layer = Conv1dBlock(input_channels=layer_size, output_channels=len(self.labels), kernel_size=(1,), stride=1,bn=False,activation_use=False) 70 | conv_blocks.append(('conv1d_{}'.format(len(layers)),last_layer)) 71 | self.conv1ds = nn.Sequential(OrderedDict(conv_blocks)) 72 | 73 | 74 | @property 75 | def scaling_factor(self): 76 | if not hasattr(self,'_scaling_factor'): 77 | strides = [] 78 | for module in self.conv1ds.children(): 79 | strides.append(module.conv1.stride[0]) 80 | self._scaling_factor = int(np.prod(strides)) 81 | return self._scaling_factor 82 | 83 | 84 | def forward(self, x, input_lengths=None): 85 | x = self.conv1ds(x) 86 | x = x.transpose(1,2) 87 | x = F.log_softmax(x,dim=-1) 88 | if input_lengths is not None: 89 | output_lengths = self.compute_output_lengths(input_lengths) 90 | else: 91 | output_lengths = None 92 | return x, output_lengths 93 | 94 | 95 | 96 | 97 | --------------------------------------------------------------------------------