├── LICENCE ├── README.md ├── data ├── __init__.py ├── an4.py ├── common_voice.py ├── data_loader.py ├── distributed.py ├── librispeech.py ├── merge_manifests.py ├── ted.py ├── utils.py └── voxforge.py ├── decoder.py ├── labels.json ├── model.py ├── multiproc.py ├── noise_inject.py ├── requirements.txt ├── test.py ├── train.py └── transcribe.py /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Vaibhav Gusain & SilverSparro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wav2Letter.pytorch 2 | 3 | Implementation of Wav2Letter using [Baidu Warp-CTC](https://github.com/baidu-research/warp-ctc). 4 | Creates a network based on the [Wav2Letter](https://arxiv.org/abs/1609.03193) architecture, trained with the CTC activation function. 5 | 6 | Currently Tested on pytorch [1.3.1] with cuda10.1 and python3.7. 7 | 8 | Branch selfAttentionExps : contains the code having the attention layer in b/w the final layer and the starting layer. 9 | improves the training time. 10 | 11 | Branch trainableFrontEnd : contains the code in progress to train the model using the raw audio samples only. 12 | 13 | Branch python27 : contains the same code as of master but for python2.7 and pytorch0.4.1 14 | 15 | Current Checkpoint can be downloaded from : https://drive.google.com/file/d/1HH_4TkPUrfcfRSUp2wqgKUu72bfJ8y8t/view?usp=sharing 16 | 17 | NOTE : The model is giving around 37WER with greedy decoder and the performance can be improved by using a beam decoder and a language model 18 | ## Features 19 | 20 | * Train Wav2Letter. 21 | * Language model support using kenlm. 22 | * Noise injection for online training to improve noise robustness. 23 | * Audio augmentation to improve noise robustness. 24 | * Easy start/stop capabilities in the event of crash or hard stop during training. 25 | * Visdom/Tensorboard support for visualizing training graphs. 26 | * Train the model directly on the raw wav form and removed the dependency of creating spectogram. (The old code is shifted to branch 'speechRecognitionSpectogram' ) 27 | 28 | 29 | # Installation 30 | 31 | Several libraries are needed to be installed for training to work. I will assume that everything is being installed in 32 | an Anaconda installation on Ubuntu. 33 | 34 | Install [PyTorch](https://github.com/pytorch/pytorch#installation) if you haven't already. 35 | 36 | Install this fork for Warp-CTC bindings: 37 | ``` 38 | git clone https://github.com/SeanNaren/warp-ctc.git 39 | cd warp-ctc 40 | mkdir build; cd build 41 | cmake .. 42 | make 43 | export CUDA_HOME="/usr/local/cuda" 44 | cd ../pytorch_binding 45 | python setup.py install 46 | ``` 47 | 48 | Install pytorch audio: 49 | ``` 50 | sudo apt-get install sox libsox-dev libsox-fmt-all 51 | git clone https://github.com/pytorch/audio.git 52 | cd audio 53 | pip install cffi 54 | python setup.py install 55 | ``` 56 | 57 | If you want decoding to support beam search with an optional language model, install ctcdecode: 58 | ``` 59 | git clone --recursive https://github.com/parlance/ctcdecode.git 60 | cd ctcdecode 61 | pip install . 62 | ``` 63 | 64 | Finally clone this repo and run this within the repo: 65 | ``` 66 | pip install -r requirements.txt 67 | ``` 68 | 69 | # Usage 70 | 71 | ### Custom Dataset 72 | 73 | To create a custom dataset you must create a CSV file containing the locations of the training data. This has to be in the format of: 74 | 75 | ``` 76 | /path/to/audio.wav,transcription 77 | /path/to/audio2.wav,transcription 78 | ... 79 | ``` 80 | 81 | The first path is to the audio file, and the second is the text containing the transcript on one line. This can then be used as stated below. 82 | 83 | ## Training 84 | 85 | ``` 86 | python train.py --train-manifest data/train_manifest.csv --val-manifest data/val_manifest.csv 87 | ``` 88 | 89 | Use `python train.py --help` for more parameters and options. 90 | 91 | There is also [Visdom](https://github.com/facebookresearch/visdom) support to visualize training. Once a server has been started, to use: 92 | 93 | ``` 94 | python train.py --visdom 95 | ``` 96 | 97 | There is also [Tensorboard](https://github.com/lanpa/tensorboard-pytorch) support to visualize training. Follow the instructions to set up. To use: 98 | 99 | ``` 100 | python train.py --tensorboard --logdir log_dir/ # Make sure the Tensorboard instance is made pointing to this log directory 101 | ``` 102 | 103 | ## MultiGpu support 104 | ``` 105 | python -m multiproc train.py --visdom --cuda # Add your parameters as normal, multiproc will scale to all GPUs automatically 106 | ``` 107 | 108 | For both visualisation tools, you can add your own name to the run by changing the `--id` parameter when training. 109 | 110 | ## Testing 111 | 112 | For testing write all the file path into a csv and run 113 | ``` 114 | python test.py 115 | ``` 116 | PS : for speed improvements try to run test.py with the flag '--fuse-layers'. This option will fuse all the conv-bn operation and increase the model inference speed. 117 | 118 | ### Noise Augmentation/Injection 119 | 120 | There is support for two different types of noise; noise augmentation and noise injection. 121 | 122 | #### Noise Augmentation 123 | 124 | Applies small changes to the tempo and gain when loading audio to increase robustness. To use, use the `--augment` flag when training. 125 | 126 | #### Noise Injection 127 | 128 | Dynamically adds noise into the training data to increase robustness. To use, first fill a directory up with all the noise files you want to sample from. 129 | The dataloader will randomly pick samples from this directory. 130 | 131 | To enable noise injection, use the `--noise-dir /path/to/noise/dir/` to specify where your noise files are. There are a few noise parameters to tweak, such as 132 | `--noise_prob` to determine the probability that noise is added, and the `--noise-min`, `--noise-max` parameters to determine the minimum and maximum noise to add in training. 133 | 134 | Included is a script to inject noise into an audio file to hear what different noise levels/files would sound like. Useful for curating the noise dataset. 135 | 136 | ``` 137 | python noise_inject.py --input-path /path/to/input.wav --noise-path /path/to/noise.wav --output-path /path/to/input_injected.wav --noise-level 0.5 # higher levels means more noise 138 | ``` 139 | 140 | ### Checkpoints 141 | 142 | Training supports saving checkpoints of the model to continue training from should an error occur or early termination. To enable epoch 143 | checkpoints use: 144 | 145 | ``` 146 | python train.py --checkpoint 147 | ``` 148 | 149 | To enable checkpoints every N batches through the epoch as well as epoch saving: 150 | 151 | ``` 152 | python train.py --checkpoint --checkpoint-per-batch N # N is the number of batches to wait till saving a checkpoint at this batch. 153 | ``` 154 | 155 | Note for the batch checkpointing system to work, you cannot change the batch size when loading a checkpointed model from it's original training 156 | run. 157 | 158 | To continue from a checkpointed model that has been saved: 159 | 160 | ``` 161 | python train.py --continue-from models/wav2Letter_checkpoint_epoch_N_iter_N.pth.tar 162 | ``` 163 | 164 | This continues from the same training state as well as recreates the visdom graph to continue from if enabled. 165 | 166 | If you would like to start from a previous checkpoint model but not continue training, add the `--finetune` flag to restart training 167 | from the `--continue-from` weights. 168 | 169 | ### Choosing batch sizes 170 | 171 | Included is a script that can be used to benchmark whether training can occur on your hardware, and the limits on the size of the model/batch 172 | sizes you can use. To use: 173 | 174 | ``` 175 | python benchmark.py --batch-size 32 176 | ``` 177 | 178 | Use the flag `--help` to see other parameters that can be used with the script. 179 | 180 | ### Model details 181 | 182 | Saved models contain the metadata of their training process. To see the metadata run the below command: 183 | 184 | ``` 185 | python model.py --model-path models/wav2Letter.pth.tar 186 | ``` 187 | 188 | To also note, there is no final softmax layer on the model as when trained, warp-ctc does this softmax internally. This will have to also be implemented in complex decoders if anything is built on top of the model, so take this into consideration! 189 | 190 | ## Testing/Inference 191 | 192 | To evaluate a trained model on a test set (has to be in the same format as the training set): 193 | 194 | ``` 195 | python test.py --model-path models/wav2Letter.pth --test-manifest /path/to/test_manifest.csv --cuda 196 | ``` 197 | 198 | ### Alternate Decoders 199 | By default, `test.py` use a `GreedyDecoder` which picks the highest-likelihood output label at each timestep. Repeated and blank symbols are then filtered to give the final output. 200 | 201 | A beam search decoder can optionally be used with the installation of the `ctcdecode` library as described in the Installation section. The `test` and `transcribe` scripts have a `--decoder` argument. To use the beam decoder, add `--decoder beam`. The beam decoder enables additional decoding parameters: 202 | - **beam_width** how many beams to consider at each timestep 203 | - **lm_path** optional binary KenLM language model to use for decoding 204 | - **alpha** weight for language model 205 | - **beta** bonus weight for words 206 | 207 | ### Time offsets 208 | 209 | Use the `--offsets` flag to get positional information of each character in the transcription when using `transcribe.py` script. The offsets are based on the size 210 | of the output tensor, which you need to convert into a format required. 211 | For example, based on default parameters you could multiply the offsets by a scalar (duration of file in seconds / size of output) to get the offsets in seconds. 212 | 213 | ## Acknowledgements 214 | 215 | This work is inspired from the [deepspeech.pytorch](https://github.com/SeanNaren/deepspeech.pytorch) repository of [Sean Naren](https://github.com/SeanNaren). 216 | This work was done as a part of [Silversparro](https://www.silversparro.com) project work regarding speech to text. 217 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader 2 | -------------------------------------------------------------------------------- /data/an4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import io 4 | import shutil 5 | import tarfile 6 | import wget 7 | 8 | from utils import create_manifest 9 | 10 | parser = argparse.ArgumentParser(description='Processes and downloads an4.') 11 | parser.add_argument('--target-dir', default='an4_dataset/', help='Path to save dataset') 12 | parser.add_argument('--min-duration', default=1, type=int, 13 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 14 | parser.add_argument('--max-duration', default=15, type=int, 15 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 16 | args = parser.parse_args() 17 | 18 | 19 | def _format_data(root_path, data_tag, name, wav_folder): 20 | data_path = args.target_dir + data_tag + '/' + name + '/' 21 | new_transcript_path = data_path + '/txt/' 22 | new_wav_path = data_path + '/wav/' 23 | 24 | os.makedirs(new_transcript_path) 25 | os.makedirs(new_wav_path) 26 | 27 | wav_path = root_path + 'wav/' 28 | file_ids = root_path + 'etc/an4_%s.fileids' % data_tag 29 | transcripts = root_path + 'etc/an4_%s.transcription' % data_tag 30 | train_path = wav_path + wav_folder 31 | 32 | _convert_audio_to_wav(train_path) 33 | _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path) 34 | 35 | 36 | def _convert_audio_to_wav(train_path): 37 | with os.popen('find %s -type f -name "*.raw"' % train_path) as pipe: 38 | for line in pipe: 39 | raw_path = line.strip() 40 | new_path = line.replace('.raw', '.wav').strip() 41 | cmd = 'sox -t raw -r %d -b 16 -e signed-integer -B -c 1 \"%s\" \"%s\"' % ( 42 | 16000, raw_path, new_path) 43 | os.system(cmd) 44 | 45 | 46 | def _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path): 47 | with open(file_ids, 'r') as f: 48 | with open(transcripts, 'r') as t: 49 | paths = f.readlines() 50 | transcripts = t.readlines() 51 | for x in range(len(paths)): 52 | path = wav_path + paths[x].strip() + '.wav' 53 | filename = path.split('/')[-1] 54 | extracted_transcript = _process_transcript(transcripts, x) 55 | current_path = os.path.abspath(path) 56 | new_path = new_wav_path + filename 57 | text_path = new_transcript_path + filename.replace('.wav', '.txt') 58 | with io.FileIO(text_path, "w") as file: 59 | file.write(extracted_transcript.encode('utf-8')) 60 | os.rename(current_path, new_path) 61 | 62 | 63 | def _process_transcript(transcripts, x): 64 | extracted_transcript = transcripts[x].split('(')[0].strip("").split('<')[0].strip().upper() 65 | return extracted_transcript 66 | 67 | 68 | def main(): 69 | root_path = 'an4/' 70 | name = 'an4' 71 | wget.download('http://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz') 72 | tar = tarfile.open('an4_raw.bigendian.tar.gz') 73 | tar.extractall() 74 | os.makedirs(args.target_dir) 75 | _format_data(root_path, 'train', name, 'an4_clstk') 76 | _format_data(root_path, 'test', name, 'an4test_clstk') 77 | shutil.rmtree(root_path) 78 | os.remove('an4_raw.bigendian.tar.gz') 79 | train_path = args.target_dir + '/train/' 80 | test_path = args.target_dir + '/test/' 81 | print ('\n', 'Creating manifests...') 82 | create_manifest(train_path, 'an4_train_manifest.csv', args.min_duration, args.max_duration) 83 | create_manifest(test_path, 'an4_val_manifest.csv') 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /data/common_voice.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | import argparse 5 | import csv 6 | from multiprocessing.pool import ThreadPool 7 | import subprocess 8 | from utils import create_manifest 9 | 10 | parser = argparse.ArgumentParser(description='Downloads and processes Mozilla Common Voice dataset.') 11 | parser.add_argument("--target-dir", default='CommonVoice_dataset/', type=str, help="Directory to store the dataset.") 12 | parser.add_argument("--tar-path", type=str, help="Path to the Common Voice *.tar file if downloaded (Optional).") 13 | parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') 14 | parser.add_argument('--min-duration', default=1, type=int, 15 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 16 | parser.add_argument('--max-duration', default=15, type=int, 17 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 18 | parser.add_argument('--files-to-process', default="cv-valid-dev.csv,cv-valid-test.csv,cv-valid-train.csv", 19 | type=str, help='list of *.csv file names to process') 20 | args = parser.parse_args() 21 | COMMON_VOICE_URL = "https://common-voice-data-download.s3.amazonaws.com/cv_corpus_v1.tar.gz" 22 | 23 | def convert_to_wav(csv_file, target_dir): 24 | """ Read *.csv file description, convert mp3 to wav, process text. 25 | Save results to target_dir. 26 | 27 | Args: 28 | csv_file: str, path to *.csv file with data description, usually start from 'cv-' 29 | target_dir: str, path to dir to save results; wav/ and txt/ dirs will be created 30 | """ 31 | wav_dir = os.path.join(target_dir, 'wav/') 32 | txt_dir = os.path.join(target_dir, 'txt/') 33 | os.makedirs(wav_dir, exist_ok=True) 34 | os.makedirs(txt_dir, exist_ok=True) 35 | path_to_data = os.path.dirname(csv_file) 36 | 37 | def process(x): 38 | file_path, text = x 39 | file_name = os.path.splitext(os.path.basename(file_path))[0] 40 | text = text.strip().upper() 41 | with open(os.path.join(txt_dir, file_name + '.txt'), 'w') as f: 42 | f.write(text) 43 | cmd = "sox {} -r {} -b 16 -c 1 {}".format( 44 | os.path.join(path_to_data, file_path), 45 | args.sample_rate, 46 | os.path.join(wav_dir, file_name + '.wav')) 47 | subprocess.call([cmd], shell=True) 48 | 49 | print('Converting mp3 to wav for {}.'.format(csv_file)) 50 | with open(csv_file) as csvfile: 51 | reader = csv.DictReader(csvfile) 52 | data = [(row['filename'], row['text']) for row in reader] 53 | with ThreadPool(10) as pool: 54 | pool.map(process, data) 55 | 56 | def main(): 57 | target_dir = args.target_dir 58 | os.makedirs(target_dir, exist_ok=True) 59 | 60 | target_unpacked_dir = os.path.join(target_dir, "CV_unpacked") 61 | os.makedirs(target_unpacked_dir, exist_ok=True) 62 | 63 | if args.tar_path and os.path.exists(args.tar_path): 64 | print('Find existing file {}'.format(args.tar_path)) 65 | target_file = args.tar_path 66 | else: 67 | print("Could not find downloaded Common Voice archive, Downloading corpus...") 68 | filename = wget.download(COMMON_VOICE_URL, target_dir) 69 | target_file = os.path.join(target_dir, os.path.basename(filename)) 70 | 71 | print("Unpacking corpus to {} ...".format(target_unpacked_dir)) 72 | tar = tarfile.open(target_file) 73 | tar.extractall(target_unpacked_dir) 74 | tar.close() 75 | 76 | for csv_file in args.files_to_process.split(','): 77 | convert_to_wav(os.path.join(target_unpacked_dir, 'cv_corpus_v1/', csv_file), 78 | os.path.join(target_dir, os.path.splitext(csv_file)[0])) 79 | 80 | print('Creating manifests...') 81 | for csv_file in args.files_to_process.split(','): 82 | create_manifest(os.path.join(target_dir, os.path.splitext(csv_file)[0]), 83 | os.path.splitext(csv_file)[0] + '_manifest.csv', 84 | args.min_duration, 85 | args.max_duration) 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import subprocess 4 | from tempfile import NamedTemporaryFile 5 | 6 | from torch.distributed import get_rank 7 | from torch.distributed import get_world_size 8 | from torch.utils.data.sampler import Sampler 9 | 10 | import librosa 11 | import numpy as np 12 | import scipy.signal 13 | import scipy.signal 14 | import torch 15 | from scipy.io.wavfile import read 16 | import math 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data import Dataset 19 | import scipy.io.wavfile as wave 20 | 21 | windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman, 22 | 'bartlett': scipy.signal.bartlett} 23 | 24 | 25 | def load_audio(path): 26 | sample_rate, sound = read(path) 27 | sound = sound.astype('float32') / 32767 # normalize audio 28 | if len(sound.shape) > 1: 29 | if sound.shape[1] == 1: 30 | sound = sound.squeeze() 31 | else: 32 | sound = sound.mean(axis=1) # multiple channels, average 33 | return sound 34 | 35 | 36 | def normalize_tf_data(signal): 37 | return signal / np.max(np.abs(signal)) 38 | 39 | def load_audioW2l2(path): 40 | # sound, _ = torchaudio.load(path) 41 | sample_freq, sound = wave.read(path) 42 | sound = (normalize_tf_data(sound.astype(np.float32)) * 32767.0).astype( 43 | np.int16) 44 | return sound 45 | 46 | def pcen2(E, sr=22050, hop_length=512, t=0.395,eps=0.000001,alpha=0.98,delta=2.0,r=0.5): 47 | 48 | s = 1 - np.exp(- float(hop_length) / (t * sr)) 49 | M = scipy.signal.lfilter([s], [1, s - 1], E) 50 | smooth = (eps + M)**(-alpha) 51 | return (E * smooth + delta)**r - delta**r 52 | # return M 53 | 54 | def split_normalize_with_librosa( 55 | audio, top_db=50, frame_length=1024, hop_length=256, 56 | skip_idx=0): 57 | 58 | edges = librosa.effects.split(audio, 59 | top_db=top_db, frame_length=frame_length, hop_length=hop_length) 60 | 61 | new_audio = np.zeros_like(audio) 62 | for idx, (start, end) in enumerate(edges[skip_idx:]): 63 | segment = audio[start:end] 64 | if start==end: 65 | print ("Warning: splitting in librosa resulted in an empty segment") 66 | continue 67 | new_audio[start:end] = librosa.util.normalize(segment) 68 | 69 | return new_audio 70 | 71 | 72 | class AudioParser(object): 73 | def parse_transcript(self, transcript_path): 74 | """ 75 | :param transcript_path: Path where transcript is stored from the manifest file 76 | :return: Transcript in training/testing format 77 | """ 78 | raise NotImplementedError 79 | 80 | def parse_audio(self, audio_path): 81 | """ 82 | :param audio_path: Path where audio is stored from the manifest file 83 | :return: Audio in training/testing format 84 | """ 85 | raise NotImplementedError 86 | 87 | 88 | class NoiseInjection(object): 89 | def __init__(self, 90 | path=None, 91 | noise_levels=(0, 0.5)): 92 | """ 93 | Adds noise to an input signal with specific SNR. Higher the noise level, the more noise added. 94 | Modified code from https://github.com/willfrey/audio/blob/master/torchaudio/transforms.py 95 | """ 96 | self.paths = path is not None and librosa.util.find_files(path) 97 | self.noise_levels = noise_levels 98 | 99 | def inject_noise(self, data): 100 | noise_path = np.random.choice(self.paths) 101 | noise_level = np.random.uniform(*self.noise_levels) 102 | return self.inject_noise_sample(data, noise_path, noise_level) 103 | 104 | def inject_noise_sample(self, data, noise_path, noise_level): 105 | noise_src = load_audio(noise_path) 106 | noise_offset_fraction = np.random.rand() 107 | noise_dst = np.zeros_like(data) 108 | src_offset = int(len(noise_src) * noise_offset_fraction) 109 | src_left = len(noise_src) - src_offset 110 | dst_offset = 0 111 | dst_left = len(data) 112 | while dst_left > 0: 113 | copy_size = min(dst_left, src_left) 114 | np.copyto(noise_dst[dst_offset:dst_offset + copy_size], 115 | noise_src[src_offset:src_offset + copy_size]) 116 | if src_left > dst_left: 117 | dst_left = 0 118 | else: 119 | dst_left -= copy_size 120 | dst_offset += copy_size 121 | src_left = len(noise_src) 122 | src_offset = 0 123 | data += noise_level * noise_dst 124 | return data 125 | 126 | class SpectrogramParser(AudioParser): 127 | def __init__(self, audio_conf, normalize=False, peak_normalization=False, augment=False): 128 | """ 129 | Parses audio file into spectrogram with optional normalization and various augmentations 130 | :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds 131 | :param normalize(default False): Apply standard mean and deviation normalization to audio tensor 132 | :param augment(default False): Apply random tempo and gain perturbations 133 | """ 134 | super(SpectrogramParser, self).__init__() 135 | self.window_stride = audio_conf['window_stride'] 136 | self.window_size = audio_conf['window_size'] 137 | self.sample_rate = audio_conf['sample_rate'] 138 | self.window = windows.get(audio_conf['window'], windows['hamming']) 139 | self.peak_normalization = peak_normalization 140 | self.normalize = normalize 141 | self.augment = augment 142 | self.noiseInjector = NoiseInjection(audio_conf['noise_dir'], 143 | audio_conf['noise_levels']) if audio_conf.get( 144 | 'noise_dir') is not None else None 145 | self.noise_prob = audio_conf.get('noise_prob') 146 | 147 | def parse_audio(self, audio_path): 148 | if self.augment: 149 | y = load_randomly_augmented_audio(audio_path, self.sample_rate) 150 | else: 151 | y = (load_audio(audio_path) * 32767.0).astype( 152 | np.float32) 153 | # y = load_audio(audio_path).astype(np.float32) 154 | if self.peak_normalization: 155 | y = split_normalize_with_librosa(y) 156 | 157 | if self.noiseInjector: 158 | add_noise = np.random.binomial(1, self.noise_prob) 159 | if add_noise: 160 | y = self.noiseInjector.inject_noise(y) 161 | n_fft = int(self.sample_rate * self.window_size) 162 | win_length = n_fft 163 | hop_length = int(self.sample_rate * self.window_stride) 164 | # STFT 165 | D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, 166 | win_length=win_length, window=self.window) 167 | 168 | spect, phase = librosa.magphase(D) 169 | # S = log(S+1) 170 | pcenResult = pcen2(E=spect,sr=self.sample_rate,hop_length=hop_length) 171 | 172 | spect = np.log1p(spect) 173 | # spect = torch.FloatTensor(spect) 174 | # pcenResult = torch.FloatTensor(pcenResult) 175 | if self.normalize: 176 | mean = spect.mean() 177 | std = spect.std() 178 | # spect.add_(-mean) 179 | # spect.div_(std) 180 | spect = np.add(spect,-mean) 181 | spect = spect/std 182 | meanPcen = pcenResult.mean() 183 | stdPcen = pcenResult.std() 184 | # spect.add_(-mean) 185 | # spect.div_(std) 186 | pcenResult = np.add(pcenResult, -meanPcen) 187 | pcenResult = pcenResult / stdPcen 188 | 189 | return spect ,pcenResult 190 | 191 | def parse_audio_w2l2(self, audio_path): 192 | if self.augment: 193 | y = load_randomly_augmented_audio(audio_path, self.sample_rate) 194 | else: 195 | y = load_audio(audio_path) 196 | if self.peak_normalization: 197 | y = split_normalize_with_librosa(y) 198 | 199 | if self.noiseInjector: 200 | add_noise = np.random.binomial(1, self.noise_prob) 201 | if add_noise: 202 | y = self.noiseInjector.inject_noise(y) 203 | 204 | return y,None 205 | 206 | def parse_transcript(self, transcript_path): 207 | raise NotImplementedError 208 | 209 | 210 | class SpectrogramDataset(Dataset, SpectrogramParser): 211 | def __init__(self, audio_conf, manifest_filepath, labels, normalize=False,peak_normalization=False, augment=False,w2l2=True): 212 | """ 213 | Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by 214 | a comma. Each new line is a different sample. Example below: 215 | 216 | /path/to/audio.wav,/path/to/audio.txt 217 | ... 218 | 219 | :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds 220 | :param manifest_filepath: Path to manifest csv as describe above 221 | :param labels: String containing all the possible characters to map to 222 | :param normalize: Apply standard mean and deviation normalization to audio tensor 223 | :param augment(default False): Apply random tempo and gain perturbations 224 | """ 225 | with open(manifest_filepath) as f: 226 | ids = f.readlines() 227 | ids = [x.strip().split(',') for x in ids] 228 | # ids.sort(key=self.sortAccTosizeOfAudio) 229 | self.ids = ids 230 | self.size = len(ids) 231 | self.w2l2 = w2l2 232 | self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) 233 | super(SpectrogramDataset, self).__init__(audio_conf, normalize, peak_normalization, augment) 234 | 235 | def __getitem__(self, index): 236 | sample = self.ids[index] 237 | audio_path, transcriptLoaded = sample[0], sample[-1] 238 | # transcriptToUse=transcriptLoaded 239 | if self.w2l2: 240 | spect,magnitudeOfAudio = self.parse_audio_w2l2(audio_path) 241 | else: 242 | spect, magnitudeOfAudio = self.parse_audio(audio_path) 243 | transcript = self.parse_transcript(transcriptLoaded) 244 | transcriptToUse = transcript 245 | return spect, transcript, magnitudeOfAudio, audio_path, transcriptToUse 246 | 247 | def parse_transcript(self, transcript_path): 248 | with open(transcript_path, 'r') as transcript_file: 249 | transcript = transcript_file.read().replace('\n', '') 250 | transcript = list(filter(None, [self.labels_map.get(x.lower()) for x in list(transcript)])) 251 | return transcript 252 | 253 | def __len__(self): 254 | return self.size 255 | 256 | def sortAccToLengthOFTranscription(self,elem): 257 | return len(elem[-1].split(' ')) 258 | 259 | def sortAccTosizeOfAudio(self,elem): 260 | return os.stat(elem[0]).st_size 261 | 262 | def _collate_fn(batch): 263 | def func(p): 264 | return p[0].shape[0] 265 | 266 | longest_sample = max(batch, key=func)[0] 267 | # freq_size = longest_sample.shape[0] 268 | minibatch_size = len(batch) 269 | max_seqlength = longest_sample.shape[0] 270 | inputs = torch.zeros(minibatch_size, 1, max_seqlength) 271 | inputsMags = torch.zeros(minibatch_size, 1, max_seqlength) 272 | input_percentages = torch.FloatTensor(minibatch_size) 273 | target_sizes = torch.IntTensor(minibatch_size) 274 | targets = [] 275 | inputFilePathAndTranscription = [] 276 | 277 | for x in range(minibatch_size): 278 | sample = batch[x] 279 | tensor = sample[0] 280 | target = sample[1] 281 | tensorMag = sample[2] 282 | tensorPath = sample[3] 283 | orignalTranscription = sample[4] 284 | seq_length = tensor.shape[0] 285 | tensorShape = tensor.shape 286 | # tensorMagShape = tensorMag.shape 287 | # tensorNew = np.pad(tensor,((0,0),(0,abs(tensorShape[1]-max_seqlength))),'wrap') 288 | tensorNew = np.pad(tensor, (0,abs(tensorShape[0]-max_seqlength)), 'constant', constant_values=(0)) 289 | inputs[x][0].copy_(torch.FloatTensor(tensorNew)) 290 | if tensorMag is not None: 291 | tensorMagNew = np.pad(tensorMag,((0,0),(0,abs(tensorShape[1]-max_seqlength))),'wrap') 292 | inputsMags[x][0].narrow(1, 0, max_seqlength).copy_(torch.FloatTensor(tensorMagNew)) 293 | # inputs[x][0].narrow(1, 0, max_seqlength).copy_(torch.FloatTensor(tensorNew)) 294 | input_percentages[x] = seq_length / float(max_seqlength) 295 | target_sizes[x] = len(target) 296 | targets.extend(target) 297 | sumValueForInput = 0#sum(sum(tensor)) 298 | sumValueForInputMag = 0#sum(sum(tensorMag)) 299 | inputFilePathAndTranscription.append([tensorPath,orignalTranscription,sumValueForInput,sumValueForInputMag,tensorShape[0]]) 300 | numChars = len(targets) 301 | targets = torch.IntTensor(targets) 302 | # inputFilePath = torch.IntTensor(inputFilePath) 303 | return inputs, targets, input_percentages, target_sizes, inputFilePathAndTranscription, inputsMags 304 | 305 | class AudioDataLoader(DataLoader): 306 | def __init__(self, *args, **kwargs): 307 | """ 308 | Creates a data loader for AudioDatasets. 309 | """ 310 | super(AudioDataLoader, self).__init__(*args, **kwargs) 311 | self.collate_fn = _collate_fn 312 | 313 | 314 | class BucketingSampler(Sampler): 315 | def __init__(self, data_source, batch_size=1): 316 | """ 317 | Samples batches assuming they are in order of size to batch similarly sized samples together. 318 | """ 319 | super(BucketingSampler, self).__init__(data_source) 320 | self.data_source = data_source 321 | ids = list(range(0, len(data_source))) 322 | self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)] 323 | 324 | def __iter__(self): 325 | for ids in self.bins: 326 | np.random.shuffle(ids) 327 | yield ids 328 | 329 | def __len__(self): 330 | return len(self.bins) 331 | 332 | def shuffle(self, epoch): 333 | np.random.shuffle(self.bins) 334 | 335 | 336 | class DistributedBucketingSampler(Sampler): 337 | def __init__(self, data_source, batch_size=1, num_replicas=None, rank=None): 338 | """ 339 | Samples batches assuming they are in order of size to batch similarly sized samples together. 340 | """ 341 | super(DistributedBucketingSampler, self).__init__(data_source) 342 | if num_replicas is None: 343 | num_replicas = get_world_size() 344 | if rank is None: 345 | rank = get_rank() 346 | self.data_source = data_source 347 | self.ids = list(range(0, len(data_source))) 348 | self.batch_size = batch_size 349 | self.bins = [self.ids[i:i + batch_size] for i in range(0, len(self.ids), batch_size)] 350 | self.num_replicas = num_replicas 351 | self.rank = rank 352 | self.num_samples = int(math.ceil(len(self.bins) * 1.0 / self.num_replicas)) 353 | self.total_size = self.num_samples * self.num_replicas 354 | 355 | def __iter__(self): 356 | offset = self.rank 357 | # add extra samples to make it evenly divisible 358 | bins = self.bins + self.bins[:(self.total_size - len(self.bins))] 359 | assert len(bins) == self.total_size 360 | samples = bins[offset::self.num_replicas] # Get every Nth bin, starting from rank 361 | return iter(samples) 362 | 363 | def __len__(self): 364 | return self.num_samples 365 | 366 | def shuffle(self, epoch): 367 | # deterministically shuffle based on epoch 368 | g = torch.Generator() 369 | g.manual_seed(epoch) 370 | bin_ids = list(torch.randperm(len(self.bins), generator=g)) 371 | self.bins = [self.bins[i] for i in bin_ids] 372 | 373 | 374 | def get_audio_length(path): 375 | output = subprocess.check_output(['soxi -D \"%s\"' % path.strip()], shell=True) 376 | return float(output) 377 | 378 | 379 | def audio_with_sox(path, sample_rate, start_time, end_time): 380 | """ 381 | crop and resample the recording with sox and loads it. 382 | """ 383 | with NamedTemporaryFile(suffix=".wav") as tar_file: 384 | tar_filename = tar_file.name 385 | sox_params = "sox \"{}\" -r {} -c 1 -b 16 -e si {} trim {} ={} >/dev/null 2>&1".format(path, sample_rate, 386 | tar_filename, start_time, 387 | end_time) 388 | os.system(sox_params) 389 | y = load_audio(tar_filename) 390 | return y 391 | 392 | 393 | def augment_audio_with_sox(path, sample_rate, tempo, gain,w2l2): 394 | """ 395 | Changes tempo and gain of the recording with sox and loads it. 396 | """ 397 | with NamedTemporaryFile(suffix=".wav") as augmented_file: 398 | augmented_filename = augmented_file.name 399 | sox_augment_params = ["tempo", "{:.3f}".format(tempo), "gain", "{:.3f}".format(gain)] 400 | sox_params = "sox \"{}\" -r {} -c 1 -b 16 -e si {} {} >/dev/null 2>&1".format(path, sample_rate, 401 | augmented_filename, 402 | " ".join(sox_augment_params)) 403 | os.system(sox_params) 404 | if w2l2: 405 | y = load_audioW2l2(augmented_filename).astype(np.float32) 406 | else: 407 | y = load_audio(augmented_filename) 408 | return y 409 | 410 | 411 | def load_randomly_augmented_audio(path, sample_rate=16000, tempo_range=(0.85, 1.15), 412 | gain_range=(-6, 8),w2l2=False): 413 | """ 414 | Picks tempo and gain uniformly, applies it to the utterance by using sox utility. 415 | Returns the augmented utterance. 416 | """ 417 | low_tempo, high_tempo = tempo_range 418 | tempo_value = np.random.uniform(low=low_tempo, high=high_tempo) 419 | low_gain, high_gain = gain_range 420 | gain_value = np.random.uniform(low=low_gain, high=high_gain) 421 | audio = augment_audio_with_sox(path=path, sample_rate=sample_rate, 422 | tempo=tempo_value, gain=gain_value,w2l2=w2l2) 423 | return audio 424 | -------------------------------------------------------------------------------- /data/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 3 | import torch.distributed as dist 4 | from torch.nn.modules import Module 5 | 6 | ''' 7 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 8 | launcher included with this example. It assumes that your run is using multiprocess with 1 9 | GPU/process, that the model is on the correct device, and that torch.set_device has been 10 | used to set the device. 11 | 12 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 13 | and will be allreduced at the finish of the backward pass. 14 | ''' 15 | 16 | 17 | class DistributedDataParallel(Module): 18 | def __init__(self, module): 19 | super(DistributedDataParallel, self).__init__() 20 | # self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 21 | self.warn_on_half = False 22 | self.module = module 23 | 24 | for p in self.module.state_dict().values(): 25 | if not torch.is_tensor(p): 26 | continue 27 | # if dist._backend == dist.dist_backend.NCCL: 28 | # assert p.is_cuda, "NCCL backend only supports model parameters to be on GPU." 29 | dist.broadcast(p, 0) 30 | 31 | def allreduce_params(): 32 | if (self.needs_reduction): 33 | self.needs_reduction = False 34 | buckets = {} 35 | for param in self.module.parameters(): 36 | if param.requires_grad and param.grad is not None: 37 | tp = type(param.data) 38 | if tp not in buckets: 39 | buckets[tp] = [] 40 | buckets[tp].append(param) 41 | if self.warn_on_half: 42 | if torch.cuda.HalfTensor in buckets: 43 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 44 | " It is recommended to use the NCCL backend in this case.") 45 | self.warn_on_half = False 46 | 47 | for tp in buckets: 48 | bucket = buckets[tp] 49 | grads = [param.grad.data for param in bucket] 50 | coalesced = _flatten_dense_tensors(grads) 51 | dist.all_reduce(coalesced) 52 | coalesced /= dist.get_world_size() 53 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 54 | buf.copy_(synced) 55 | 56 | for param in list(self.module.parameters()): 57 | def allreduce_hook(*unused): 58 | param._execution_engine.queue_callback(allreduce_params) 59 | 60 | if param.requires_grad: 61 | param.register_hook(allreduce_hook) 62 | 63 | def forward(self, *inputs, **kwargs): 64 | self.needs_reduction = True 65 | return self.module(*inputs, **kwargs) 66 | -------------------------------------------------------------------------------- /data/librispeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | import argparse 5 | import subprocess 6 | from utils import create_manifest 7 | import shutil 8 | 9 | parser = argparse.ArgumentParser(description='Processes and downloads LibriSpeech dataset.') 10 | parser.add_argument("--target-dir", default='LibriSpeech_dataset/', type=str, help="Directory to store the dataset.") 11 | parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') 12 | parser.add_argument('--files-to-use', default="train-clean-100.tar.gz," 13 | "train-clean-360.tar.gz,train-other-500.tar.gz," 14 | "dev-clean.tar.gz,dev-other.tar.gz," 15 | "test-clean.tar.gz,test-other.tar.gz", type=str, 16 | help='list of file names to download') 17 | parser.add_argument('--min-duration', default=1, type=int, 18 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 19 | parser.add_argument('--max-duration', default=15, type=int, 20 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 21 | args = parser.parse_args() 22 | 23 | LIBRI_SPEECH_URLS = { 24 | "train": ["http://www.openslr.org/resources/12/train-clean-100.tar.gz", 25 | "http://www.openslr.org/resources/12/train-clean-360.tar.gz", 26 | "http://www.openslr.org/resources/12/train-other-500.tar.gz"], 27 | 28 | "val": ["http://www.openslr.org/resources/12/dev-clean.tar.gz", 29 | "http://www.openslr.org/resources/12/dev-other.tar.gz"], 30 | 31 | "test_clean": ["http://www.openslr.org/resources/12/test-clean.tar.gz"], 32 | "test_other": ["http://www.openslr.org/resources/12/test-other.tar.gz"] 33 | } 34 | 35 | 36 | def _preprocess_transcript(phrase): 37 | return phrase.strip().upper() 38 | 39 | 40 | def _process_file(wav_dir, txt_dir, base_filename, root_dir): 41 | full_recording_path = os.path.join(root_dir, base_filename) 42 | assert os.path.exists(full_recording_path) and os.path.exists(root_dir) 43 | wav_recording_path = os.path.join(wav_dir, base_filename.replace(".flac", ".wav")) 44 | subprocess.call(["sox {} -r {} -b 16 -c 1 {}".format(full_recording_path, str(args.sample_rate), 45 | wav_recording_path)], shell=True) 46 | # process transcript 47 | txt_transcript_path = os.path.join(txt_dir, base_filename.replace(".flac", ".txt")) 48 | transcript_file = os.path.join(root_dir, "-".join(base_filename.split('-')[:-1]) + ".trans.txt") 49 | assert os.path.exists(transcript_file), "Transcript file {} does not exist.".format(transcript_file) 50 | transcriptions = open(transcript_file).read().strip().split("\n") 51 | transcriptions = {t.split()[0].split("-")[-1]: " ".join(t.split()[1:]) for t in transcriptions} 52 | with open(txt_transcript_path, "w") as f: 53 | key = base_filename.replace(".flac", "").split("-")[-1] 54 | assert key in transcriptions, "{} is not in the transcriptions".format(key) 55 | f.write(_preprocess_transcript(transcriptions[key])) 56 | f.flush() 57 | 58 | 59 | def main(): 60 | target_dl_dir = args.target_dir 61 | if not os.path.exists(target_dl_dir): 62 | os.makedirs(target_dl_dir) 63 | files_to_dl = args.files_to_use.strip().split(',') 64 | for split_type, lst_libri_urls in LIBRI_SPEECH_URLS.items(): 65 | split_dir = os.path.join(target_dl_dir, split_type) 66 | if not os.path.exists(split_dir): 67 | os.makedirs(split_dir) 68 | split_wav_dir = os.path.join(split_dir, "wav") 69 | if not os.path.exists(split_wav_dir): 70 | os.makedirs(split_wav_dir) 71 | split_txt_dir = os.path.join(split_dir, "txt") 72 | if not os.path.exists(split_txt_dir): 73 | os.makedirs(split_txt_dir) 74 | extracted_dir = os.path.join(split_dir, "LibriSpeech") 75 | if os.path.exists(extracted_dir): 76 | shutil.rmtree(extracted_dir) 77 | for url in lst_libri_urls: 78 | # check if we want to dl this file 79 | dl_flag = False 80 | for f in files_to_dl: 81 | if url.find(f) != -1: 82 | dl_flag = True 83 | if not dl_flag: 84 | print("Skipping url: {}".format(url)) 85 | continue 86 | filename = url.split("/")[-1] 87 | target_filename = os.path.join(split_dir, filename) 88 | if not os.path.exists(target_filename): 89 | wget.download(url, split_dir) 90 | print("Unpacking {}...".format(filename)) 91 | tar = tarfile.open(target_filename) 92 | tar.extractall(split_dir) 93 | tar.close() 94 | os.remove(target_filename) 95 | print("Converting flac files to wav and extracting transcripts...") 96 | assert os.path.exists(extracted_dir), "Archive {} was not properly uncompressed.".format(filename) 97 | for root, subdirs, files in os.walk(extracted_dir): 98 | for f in files: 99 | if f.find(".flac") != -1: 100 | _process_file(wav_dir=split_wav_dir, txt_dir=split_txt_dir, 101 | base_filename=f, root_dir=root) 102 | 103 | print("Finished {}".format(url)) 104 | shutil.rmtree(extracted_dir) 105 | if split_type == 'train': # Prune to min/max duration 106 | create_manifest(split_dir, 'libri_' + split_type + '_manifest.csv', args.min_duration, args.max_duration) 107 | else: 108 | create_manifest(split_dir, 'libri_' + split_type + '_manifest.csv') 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /data/merge_manifests.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import io 5 | import os 6 | 7 | from tqdm import tqdm 8 | from utils import order_and_prune_files 9 | 10 | parser = argparse.ArgumentParser(description='Merges all manifest CSV files in specified folder.') 11 | parser.add_argument('--merge-dir', default='manifests/', help='Path to all manifest files you want to merge') 12 | parser.add_argument('--min-duration', default=1, type=int, 13 | help='Prunes any samples shorter than the min duration (given in seconds, default 1)') 14 | parser.add_argument('--max-duration', default=15, type=int, 15 | help='Prunes any samples longer than the max duration (given in seconds, default 15)') 16 | parser.add_argument('--output-path', default='merged_manifest.csv', help='Output path to merged manifest') 17 | 18 | args = parser.parse_args() 19 | 20 | file_paths = [] 21 | for file in os.listdir(args.merge_dir): 22 | if file.endswith(".csv"): 23 | with open(os.path.join(args.merge_dir, file), 'r') as fh: 24 | file_paths += fh.readlines() 25 | file_paths = [file_path.split(',')[0] for file_path in file_paths] 26 | file_paths = order_and_prune_files(file_paths, args.min_duration, args.max_duration) 27 | with io.FileIO(args.output_path, "w") as file: 28 | for wav_path in tqdm(file_paths, total=len(file_paths)): 29 | transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') 30 | sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' 31 | file.write(sample.encode('utf-8')) 32 | -------------------------------------------------------------------------------- /data/ted.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | import argparse 5 | import subprocess 6 | import unicodedata 7 | import io 8 | from utils import create_manifest 9 | from tqdm import tqdm 10 | 11 | parser = argparse.ArgumentParser(description='Processes and downloads TED-LIUMv2 dataset.') 12 | parser.add_argument("--target-dir", default='TEDLIUM_dataset/', type=str, help="Directory to store the dataset.") 13 | parser.add_argument("--tar-path", type=str, help="Path to the TEDLIUM_release tar if downloaded (Optional).") 14 | parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') 15 | parser.add_argument('--min-duration', default=1, type=int, 16 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 17 | parser.add_argument('--max-duration', default=15, type=int, 18 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 19 | args = parser.parse_args() 20 | 21 | TED_LIUM_V2_DL_URL = "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz" 22 | 23 | 24 | def get_utterances_from_stm(stm_file): 25 | """ 26 | Return list of entries containing phrase and its start/end timings 27 | :param stm_file: 28 | :return: 29 | """ 30 | res = [] 31 | with io.open(stm_file, "r", encoding='utf-8') as f: 32 | for stm_line in f: 33 | tokens = stm_line.split() 34 | start_time = float(tokens[3]) 35 | end_time = float(tokens[4]) 36 | filename = tokens[0] 37 | transcript = unicodedata.normalize("NFKD", 38 | " ".join(t for t in tokens[6:]).strip()). \ 39 | encode("utf-8", "ignore").decode("utf-8", "ignore") 40 | if transcript != "ignore_time_segment_in_scoring": 41 | res.append({ 42 | "start_time": start_time, "end_time": end_time, 43 | "filename": filename, "transcript": transcript 44 | }) 45 | return res 46 | 47 | 48 | def cut_utterance(src_sph_file, target_wav_file, start_time, end_time, sample_rate=16000): 49 | subprocess.call(["sox {} -r {} -b 16 -c 1 {} trim {} ={}".format(src_sph_file, str(sample_rate), 50 | target_wav_file, start_time, end_time)], 51 | shell=True) 52 | 53 | 54 | def _preprocess_transcript(phrase): 55 | return phrase.strip().upper() 56 | 57 | 58 | def filter_short_utterances(utterance_info, min_len_sec=1.0): 59 | return utterance_info["end_time"] - utterance_info["start_time"] > min_len_sec 60 | 61 | 62 | def prepare_dir(ted_dir): 63 | converted_dir = os.path.join(ted_dir, "converted") 64 | # directories to store converted wav files and their transcriptions 65 | wav_dir = os.path.join(converted_dir, "wav") 66 | if not os.path.exists(wav_dir): 67 | os.makedirs(wav_dir) 68 | txt_dir = os.path.join(converted_dir, "txt") 69 | if not os.path.exists(txt_dir): 70 | os.makedirs(txt_dir) 71 | counter = 0 72 | entries = os.listdir(os.path.join(ted_dir, "sph")) 73 | for sph_file in tqdm(entries, total=len(entries)): 74 | speaker_name = sph_file.split('.sph')[0] 75 | 76 | sph_file_full = os.path.join(ted_dir, "sph", sph_file) 77 | stm_file_full = os.path.join(ted_dir, "stm", "{}.stm".format(speaker_name)) 78 | 79 | assert os.path.exists(sph_file_full) and os.path.exists(stm_file_full) 80 | all_utterances = get_utterances_from_stm(stm_file_full) 81 | 82 | all_utterances = filter(filter_short_utterances, all_utterances) 83 | for utterance_id, utterance in enumerate(all_utterances): 84 | target_wav_file = os.path.join(wav_dir, "{}_{}.wav".format(utterance["filename"], str(utterance_id))) 85 | target_txt_file = os.path.join(txt_dir, "{}_{}.txt".format(utterance["filename"], str(utterance_id))) 86 | cut_utterance(sph_file_full, target_wav_file, utterance["start_time"], utterance["end_time"], 87 | sample_rate=args.sample_rate) 88 | with io.FileIO(target_txt_file, "w") as f: 89 | f.write(_preprocess_transcript(utterance["transcript"]).encode('utf-8')) 90 | counter += 1 91 | 92 | 93 | def main(): 94 | target_dl_dir = args.target_dir 95 | if not os.path.exists(target_dl_dir): 96 | os.makedirs(target_dl_dir) 97 | 98 | target_unpacked_dir = os.path.join(target_dl_dir, "TEDLIUM_release2") 99 | if args.tar_path and os.path.exists(args.tar_path): 100 | target_file = args.tar_path 101 | else: 102 | print("Could not find downloaded TEDLIUM archive, Downloading corpus...") 103 | wget.download(TED_LIUM_V2_DL_URL, target_dl_dir) 104 | target_file = os.path.join(target_dl_dir, "TEDLIUM_release2.tar.gz") 105 | 106 | if not os.path.exists(target_unpacked_dir): 107 | print("Unpacking corpus...") 108 | tar = tarfile.open(target_file) 109 | tar.extractall(target_dl_dir) 110 | tar.close() 111 | else: 112 | print("Found TEDLIUM directory, skipping unpacking of tar files") 113 | 114 | train_ted_dir = os.path.join(target_unpacked_dir, "train") 115 | val_ted_dir = os.path.join(target_unpacked_dir, "dev") 116 | test_ted_dir = os.path.join(target_unpacked_dir, "test") 117 | 118 | prepare_dir(train_ted_dir) 119 | prepare_dir(val_ted_dir) 120 | prepare_dir(test_ted_dir) 121 | print('Creating manifests...') 122 | 123 | create_manifest(train_ted_dir, 'ted_train_manifest.csv', args.min_duration, args.max_duration) 124 | create_manifest(val_ted_dir, 'ted_val_manifest.csv') 125 | create_manifest(test_ted_dir, 'ted_test_manifest.csv') 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import fnmatch 4 | import io 5 | import os 6 | from tqdm import tqdm 7 | import subprocess 8 | 9 | 10 | def create_manifest(data_path, output_path, min_duration=None, max_duration=None): 11 | file_paths = [os.path.join(dirpath, f) 12 | for dirpath, dirnames, files in os.walk(data_path) 13 | for f in fnmatch.filter(files, '*.wav')] 14 | file_paths = order_and_prune_files(file_paths, min_duration, max_duration) 15 | with io.FileIO(output_path, "w") as file: 16 | for wav_path in tqdm(file_paths, total=len(file_paths)): 17 | transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') 18 | sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' 19 | file.write(sample.encode('utf-8')) 20 | print('\n') 21 | 22 | 23 | def order_and_prune_files(file_paths, min_duration, max_duration): 24 | print("Sorting manifests...") 25 | duration_file_paths = [(path, float(subprocess.check_output( 26 | ['soxi -D \"%s\"' % path.strip()], shell=True))) for path in file_paths] 27 | if min_duration and max_duration: 28 | print("Pruning manifests between %d and %d seconds" % (min_duration, max_duration)) 29 | duration_file_paths = [(path, duration) for path, duration in duration_file_paths if 30 | min_duration <= duration <= max_duration] 31 | 32 | def func(element): 33 | return element[1] 34 | 35 | duration_file_paths.sort(key=func) 36 | return [x[0] for x in duration_file_paths] # Remove durations 37 | -------------------------------------------------------------------------------- /data/voxforge.py: -------------------------------------------------------------------------------- 1 | import os 2 | from six.moves import urllib 3 | import argparse 4 | import re 5 | import tempfile 6 | import shutil 7 | import subprocess 8 | import tarfile 9 | import io 10 | from tqdm import tqdm 11 | 12 | from utils import create_manifest 13 | 14 | VOXFORGE_URL_16kHz = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit/' 15 | 16 | parser = argparse.ArgumentParser(description='Processes and downloads VoxForge dataset.') 17 | parser.add_argument("--target-dir", default='voxforge_dataset/', type=str, help="Directory to store the dataset.") 18 | parser.add_argument('--sample-rate', default=16000, 19 | type=int, help='Sample rate') 20 | parser.add_argument('--min-duration', default=1, type=int, 21 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 22 | parser.add_argument('--max-duration', default=15, type=int, 23 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 24 | args = parser.parse_args() 25 | 26 | 27 | def _get_recordings_dir(sample_dir, recording_name): 28 | wav_dir = os.path.join(sample_dir, recording_name, "wav") 29 | if os.path.exists(wav_dir): 30 | return "wav", wav_dir 31 | flac_dir = os.path.join(sample_dir, recording_name, "flac") 32 | if os.path.exists(flac_dir): 33 | return "flac", flac_dir 34 | raise Exception("wav or flac directory was not found for recording name: {}".format(recording_name)) 35 | 36 | 37 | def prepare_sample(recording_name, url, target_folder): 38 | """ 39 | Downloads and extracts a sample from VoxForge and puts the wav and txt files into :target_folder. 40 | """ 41 | wav_dir = os.path.join(target_folder, "wav") 42 | if not os.path.exists(wav_dir): 43 | os.makedirs(wav_dir) 44 | txt_dir = os.path.join(target_folder, "txt") 45 | if not os.path.exists(txt_dir): 46 | os.makedirs(txt_dir) 47 | # check if sample is processed 48 | filename_set = set(['_'.join(wav_file.split('_')[:-1]) for wav_file in os.listdir(wav_dir)]) 49 | if recording_name in filename_set: 50 | return 51 | 52 | request = urllib.request.Request(url) 53 | response = urllib.request.urlopen(request) 54 | content = response.read() 55 | response.close() 56 | with tempfile.NamedTemporaryFile(suffix=".tgz", mode='wb') as target_tgz: 57 | target_tgz.write(content) 58 | target_tgz.flush() 59 | dirpath = tempfile.mkdtemp() 60 | 61 | tar = tarfile.open(target_tgz.name) 62 | tar.extractall(dirpath) 63 | tar.close() 64 | 65 | recordings_type, recordings_dir = _get_recordings_dir(dirpath, recording_name) 66 | tgz_prompt_file = os.path.join(dirpath, recording_name, "etc", "PROMPTS") 67 | 68 | if os.path.exists(recordings_dir) and os.path.exists(tgz_prompt_file): 69 | transcriptions = open(tgz_prompt_file).read().strip().split("\n") 70 | transcriptions = {t.split()[0]: " ".join(t.split()[1:]) for t in transcriptions} 71 | for wav_file in os.listdir(recordings_dir): 72 | recording_id = wav_file.split('.{}'.format(recordings_type))[0] 73 | transcription_key = recording_name + "/mfc/" + recording_id 74 | if transcription_key not in transcriptions: 75 | continue 76 | utterance = transcriptions[transcription_key] 77 | 78 | target_wav_file = os.path.join(wav_dir, "{}_{}.wav".format(recording_name, recording_id)) 79 | target_txt_file = os.path.join(txt_dir, "{}_{}.txt".format(recording_name, recording_id)) 80 | with io.FileIO(target_txt_file, "w") as file: 81 | file.write(utterance.encode('utf-8')) 82 | original_wav_file = os.path.join(recordings_dir, wav_file) 83 | subprocess.call(["sox {} -r {} -b 16 -c 1 {}".format(original_wav_file, str(args.sample_rate), 84 | target_wav_file)], shell=True) 85 | 86 | shutil.rmtree(dirpath) 87 | 88 | 89 | if __name__ == '__main__': 90 | target_dir = args.target_dir 91 | sample_rate = args.sample_rate 92 | 93 | if not os.path.isdir(target_dir): 94 | os.makedirs(target_dir) 95 | request = urllib.request.Request(VOXFORGE_URL_16kHz) 96 | response = urllib.request.urlopen(request) 97 | content = response.read() 98 | all_files = re.findall("href\=\"(.*\.tgz)\"", content.decode("utf-8")) 99 | for f in tqdm(all_files, total=len(all_files)): 100 | prepare_sample(f.replace(".tgz", ""), VOXFORGE_URL_16kHz + f, target_dir) 101 | print('Creating manifests...') 102 | create_manifest(target_dir, 'voxforge_train_manifest.csv', args.min_duration, args.max_duration) 103 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # ---------------------------------------------------------------------------- 3 | # Copyright 2015-2016 Nervana Systems Inc. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ---------------------------------------------------------------------------- 16 | # Modified to support pytorch Tensors 17 | 18 | import Levenshtein as Lev 19 | import torch 20 | from six.moves import xrange 21 | 22 | 23 | class Decoder(object): 24 | """ 25 | Basic decoder class from which all other decoders inherit. Implements several 26 | helper functions. Subclasses should implement the decode() method. 27 | 28 | Arguments: 29 | labels (string): mapping from integers to characters. 30 | blank_index (int, optional): index for the blank '_' character. Defaults to 0. 31 | space_index (int, optional): index for the space ' ' character. Defaults to 28. 32 | """ 33 | 34 | def __init__(self, labels, blank_index=0): 35 | # e.g. labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#" 36 | self.labels = labels 37 | self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) 38 | self.blank_index = blank_index 39 | space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space 40 | if ' ' in labels: 41 | space_index = labels.index(' ') 42 | self.space_index = space_index 43 | 44 | def wer(self, s1, s2): 45 | """ 46 | Computes the Word Error Rate, defined as the edit distance between the 47 | two provided sentences after tokenizing to words. 48 | Arguments: 49 | s1 (string): space-separated sentence 50 | s2 (string): space-separated sentence 51 | """ 52 | 53 | # build mapping of words to integers 54 | b = set(s1.split() + s2.split()) 55 | word2char = dict(zip(b, range(len(b)))) 56 | 57 | # map the words to a char array (Levenshtein packages only accepts 58 | # strings) 59 | w1 = [chr(word2char[w]) for w in s1.split()] 60 | w2 = [chr(word2char[w]) for w in s2.split()] 61 | 62 | return Lev.distance(''.join(w1), ''.join(w2)) 63 | 64 | def cer(self, s1, s2): 65 | """ 66 | Computes the Character Error Rate, defined as the edit distance. 67 | 68 | Arguments: 69 | s1 (string): space-separated sentence 70 | s2 (string): space-separated sentence 71 | """ 72 | s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') 73 | return Lev.distance(s1, s2) 74 | 75 | def decode(self, probs, sizes=None): 76 | """ 77 | Given a matrix of character probabilities, returns the decoder's 78 | best guess of the transcription 79 | 80 | Arguments: 81 | probs: Tensor of character probabilities, where probs[c,t] 82 | is the probability of character c at time t 83 | sizes(optional): Size of each sequence in the mini-batch 84 | Returns: 85 | string: sequence of the model's best guess for the transcription 86 | """ 87 | raise NotImplementedError 88 | 89 | 90 | class BeamCTCDecoder(Decoder): 91 | def __init__(self, labels, lm_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, 92 | num_processes=4, blank_index=0): 93 | super(BeamCTCDecoder, self).__init__(labels) 94 | try: 95 | from ctcdecode import CTCBeamDecoder 96 | except ImportError: 97 | raise ImportError("BeamCTCDecoder requires paddledecoder package.") 98 | self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width, 99 | num_processes, blank_index) 100 | 101 | def convert_to_strings(self, out, seq_len): 102 | results = [] 103 | for b, batch in enumerate(out): 104 | utterances = [] 105 | for p, utt in enumerate(batch): 106 | size = seq_len[b][p] 107 | if size > 0: 108 | transcript = ''.join(map(lambda x: self.int_to_char[x.item()], utt[0:size])) 109 | else: 110 | transcript = '' 111 | utterances.append(transcript) 112 | results.append(utterances) 113 | return results 114 | 115 | def convert_tensor(self, offsets, sizes): 116 | results = [] 117 | for b, batch in enumerate(offsets): 118 | utterances = [] 119 | for p, utt in enumerate(batch): 120 | size = sizes[b][p] 121 | if sizes[b][p] > 0: 122 | utterances.append(utt[0:size]) 123 | else: 124 | utterances.append(torch.IntTensor()) 125 | results.append(utterances) 126 | return results 127 | 128 | def decode(self, probs, sizes=None): 129 | """ 130 | Decodes probability output using ctcdecode package. 131 | Arguments: 132 | probs: Tensor of character probabilities, where probs[c,t] 133 | is the probability of character c at time t 134 | sizes: Size of each sequence in the mini-batch 135 | Returns: 136 | string: sequences of the model's best guess for the transcription 137 | """ 138 | probs = probs.cpu() 139 | out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes) 140 | 141 | strings = self.convert_to_strings(out, seq_lens) 142 | offsets = self.convert_tensor(offsets, seq_lens) 143 | return strings, offsets 144 | 145 | 146 | class GreedyDecoder(Decoder): 147 | def __init__(self, labels, blank_index=0): 148 | super(GreedyDecoder, self).__init__(labels, blank_index) 149 | 150 | def convert_to_strings(self, sequences, sizes=None, remove_repetitions=False, return_offsets=False): 151 | """Given a list of numeric sequences, returns the corresponding strings""" 152 | strings = [] 153 | offsets = [] if return_offsets else None 154 | for x in xrange(len(sequences)): 155 | seq_len = sizes[x] if sizes is not None else len(sequences[x]) 156 | string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions) 157 | strings.append([string]) # We only return one path 158 | if return_offsets: 159 | offsets.append([string_offsets]) 160 | if return_offsets: 161 | return strings, offsets 162 | else: 163 | return strings 164 | 165 | def process_string(self, sequence, size, remove_repetitions=False): 166 | string = '' 167 | offsets = [] 168 | for i in range(size): 169 | char = self.int_to_char[sequence[i].item()] 170 | if char != self.int_to_char[self.blank_index]: 171 | # if this char is a repetition and remove_repetitions=true, then skip 172 | if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]: 173 | pass 174 | elif char == self.labels[self.space_index]: 175 | string += ' ' 176 | offsets.append(i) 177 | else: 178 | string = string + char 179 | offsets.append(i) 180 | return string, torch.IntTensor(offsets) 181 | 182 | def decode(self, probs, sizes=None): 183 | """ 184 | Returns the argmax decoding given the probability matrix. Removes 185 | repeated elements in the sequence, as well as blanks. 186 | 187 | Arguments: 188 | probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim 189 | sizes(optional): Size of each sequence in the mini-batch 190 | Returns: 191 | strings: sequences of the model's best guess for the transcription on inputs 192 | offsets: time step per character predicted 193 | """ 194 | _, max_probs = torch.max(probs, 2) 195 | strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes, 196 | remove_repetitions=True, return_offsets=True) 197 | return strings, offsets 198 | -------------------------------------------------------------------------------- /labels.json: -------------------------------------------------------------------------------- 1 | [ 2 | "_", 3 | "'", 4 | "a", 5 | "b", 6 | "c", 7 | "d", 8 | "e", 9 | "f", 10 | "g", 11 | "h", 12 | "i", 13 | "j", 14 | "k", 15 | "l", 16 | "m", 17 | "n", 18 | "o", 19 | "p", 20 | "q", 21 | "r", 22 | "s", 23 | "t", 24 | "u", 25 | "v", 26 | "w", 27 | "x", 28 | "y", 29 | "z", 30 | " " 31 | ] -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | import numpy as np 4 | import scipy.signal 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class stft(nn.Module): 12 | def __init__(self, nfft=1024, hop_length=512, window="hanning"): 13 | super(stft, self).__init__() 14 | assert nfft % 2 == 0 15 | 16 | self.hop_length = hop_length 17 | self.n_freq = n_freq = nfft//2 + 1 18 | 19 | self.real_kernels, self.imag_kernels = _get_stft_kernels(nfft, window) 20 | self.real_kernels_size = self.real_kernels.size() 21 | self.conv = nn.Sequential( 22 | nn.Conv2d(1, self.real_kernels_size[0], kernel_size=(self.real_kernels_size[2], self.real_kernels_size[3]), stride=(self.hop_length)), 23 | nn.BatchNorm2d(self.real_kernels_size[0]), 24 | nn.Hardtanh(0, 20, inplace=True) 25 | ) 26 | def forward(self, sample): 27 | sample = sample.unsqueeze(1) 28 | magn = self.conv(sample) 29 | 30 | magn = magn.permute(0, 2, 1, 3) 31 | return magn 32 | 33 | 34 | def _get_stft_kernels(nfft, window): 35 | nfft = int(nfft) 36 | assert nfft % 2 == 0 37 | 38 | def kernel_fn(freq, time): 39 | return np.exp(-1j * (2 * np.pi * time * freq) / float(nfft)) 40 | 41 | kernels = np.fromfunction(kernel_fn, (nfft//2+1, nfft), dtype=np.float64) 42 | 43 | if window == "hanning": 44 | win_cof = scipy.signal.get_window("hanning", nfft)[np.newaxis, :] 45 | else: 46 | win_cof = np.ones((1, nfft), dtype=np.float64) 47 | 48 | kernels = kernels[:, np.newaxis, np.newaxis, :] * win_cof 49 | 50 | real_kernels = nn.Parameter(torch.from_numpy(np.real(kernels)).float()) 51 | imag_kernels = nn.Parameter(torch.from_numpy(np.imag(kernels)).float()) 52 | 53 | return real_kernels, imag_kernels 54 | 55 | class PCEN(nn.Module): 56 | def __init__(self): 57 | super(PCEN,self).__init__() 58 | 59 | ''' 60 | initialising the layer param with the best parametrised values i searched on web (scipy using theese values) 61 | alpha = 0.98 62 | delta=2 63 | r=0.5 64 | ''' 65 | self.log_alpha = Parameter(torch.FloatTensor([0.98])) 66 | self.log_delta = Parameter(torch.FloatTensor([2])) 67 | self.log_r = Parameter(torch.FloatTensor([0.5])) 68 | self.eps = 0.000001 69 | 70 | def forward(self,x,smoother): 71 | # t = x.size(0) 72 | # t = x.size(1) 73 | # t = x.size(2) 74 | # t = x.size(3) 75 | # alpha = self.log_alpha.exp().expand_as(x) 76 | # delta = self.log_delta.exp().expand_as(x) 77 | # r = self.log_r.exp().expand_as(x) 78 | # print 'updated values are alpha={} , delta={} , r={}'.format(self.log_alpha,self.log_delta,self.log_r) 79 | smooth = (self.eps + smoother) ** (-(self.log_alpha)) 80 | # pcen = (x/(self.eps + smoother)**alpha + delta)**r - delta**r 81 | pcen = (x * smooth + self.log_delta)**self.log_r - self.log_delta**self.log_r 82 | return pcen 83 | 84 | class InferenceBatchSoftmax(nn.Module): 85 | def forward(self, input_): 86 | if not self.training: 87 | return F.softmax(input_, dim=-1) 88 | else: 89 | return input_ 90 | 91 | 92 | class Cov1dBlock(nn.Module): 93 | def __init__(self, input_size, output_size, kernal_size, stride, drop_out_prob=-1.0, dilation=1, padding='same',bn=True,activationUse=True): 94 | super(Cov1dBlock, self).__init__() 95 | self.input_size = input_size 96 | self.output_size = output_size 97 | self.kernal_size = kernal_size 98 | self.stride = stride 99 | self.dilation = dilation 100 | self.drop_out_prob = drop_out_prob 101 | self.activationUse = activationUse 102 | self.padding = kernal_size[0] #(kernal_size[0]-stride)//2 if kernal_size[0]!=1 else 103 | '''using the below code for the padding calculation''' 104 | input_rows = input_size 105 | filter_rows = kernal_size[0] 106 | effective_filter_size_rows = (filter_rows - 1) * dilation + 1 107 | out_rows = (input_rows + stride - 1) // stride 108 | self.rows_odd = False 109 | if padding=='same': 110 | self.padding_needed =max(0, (out_rows - 1) * stride + effective_filter_size_rows - 111 | input_rows) 112 | 113 | self.padding_rows = max(0, (out_rows - 1) * stride + 114 | (filter_rows - 1) * dilation + 1 - input_rows) 115 | 116 | self.rows_odd = (self.padding_rows % 2 != 0) 117 | 118 | self.addPaddings = self.padding_rows 119 | elif padding=='half': 120 | self.addPaddings = kernal_size[0] 121 | elif padding == 'invalid': 122 | self.addPaddings = 0 123 | 124 | self.paddingAdded = nn.ReflectionPad1d(self.addPaddings//2) if self.addPaddings >0 else None 125 | self.conv1 = nn.Sequential( 126 | 127 | nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=kernal_size, 128 | stride=stride, padding=(0), dilation=dilation), 129 | # nn.ReLU6() 130 | ) 131 | self.batchNorm = nn.BatchNorm1d(num_features=output_size,momentum=0.90,eps=0.001) if bn else None 132 | # self.activation = nn.Hardtanh(min_val=0,max_val=20) if activationUse else None 133 | # self.activation = nn.ReLU6() if activationUse else None 134 | # self.activation = True if activationUse else False 135 | self.drop_out_layer = nn.Dropout(drop_out_prob) if self.drop_out_prob != -1 else None 136 | # self.activation = nn.ReLU6() if activationUse else None 137 | 138 | # torch.nn.init.xavier_normal(self.conv1._modules['0'].weight) 139 | 140 | def forward(self, xs, hid=None): 141 | if self.paddingAdded is not None: 142 | xs = self.paddingAdded(xs) 143 | fusedLayer = getattr(self,'fusedLayer',None) 144 | if fusedLayer is not None: 145 | output = fusedLayer(xs) 146 | else: 147 | output = self.conv1(xs) 148 | if self.batchNorm is not None: 149 | output = self.batchNorm(output) 150 | if self.activationUse: 151 | output = torch.clamp(input=output,min=0,max=20) 152 | # output = self.activation(output) 153 | if self.training: 154 | if self.drop_out_layer is not None: 155 | output = self.drop_out_layer(output) 156 | 157 | return output 158 | 159 | class WaveToLetter(nn.Module): 160 | def __init__(self,sample_rate,window_size, labels="abc",audio_conf=None,mixed_precision=False): 161 | super(WaveToLetter, self).__init__() 162 | 163 | # model metadata needed for serialization/deserialization 164 | if audio_conf is None: 165 | audio_conf = {} 166 | self._version = '0.0.1' 167 | self._audio_conf = audio_conf or {} 168 | self._labels = labels 169 | self._sample_rate=sample_rate 170 | self._window_size=window_size 171 | self.mixed_precision=mixed_precision 172 | 173 | nfft = (self._sample_rate * self._window_size) 174 | input_size = 1+int((nfft/2)) 175 | hop_length = sample_rate * self._audio_conf.get("window_stride", 0.01) 176 | 177 | # self.pcen = PCEN() 178 | self.frontEnd = stft(hop_length=int(hop_length), nfft=int(nfft)) 179 | conv1 = Cov1dBlock(input_size=input_size,output_size=256,kernal_size=(11,),stride=2,dilation=1,drop_out_prob=0.2,padding='same') 180 | conv2s = [] 181 | conv2s.append(('conv1d_{}'.format(0),conv1)) 182 | inputSize = 256 183 | for idx in range(15): 184 | layergroup = idx//3 185 | if (layergroup) == 0: 186 | convTemp = Cov1dBlock(input_size=inputSize,output_size=256,kernal_size=(11,),stride=1,dilation=1,drop_out_prob=0.2,padding='same') 187 | conv2s.append(('conv1d_{}'.format(idx+1),convTemp)) 188 | inputSize = 256 189 | elif (layergroup) == 1: 190 | convTemp = Cov1dBlock(input_size=inputSize, output_size=384, kernal_size=(13,), stride=1, dilation=1, 191 | drop_out_prob=0.2) 192 | conv2s.append(('conv1d_{}'.format(idx + 1), convTemp)) 193 | inputSize=384 194 | elif (layergroup) ==2: 195 | convTemp = Cov1dBlock(input_size=inputSize, output_size=512, kernal_size=(17,), stride=1, dilation=1, 196 | drop_out_prob=0.2) 197 | conv2s.append(('conv1d_{}'.format(idx + 1), convTemp)) 198 | inputSize = 512 199 | 200 | elif (layergroup) ==3: 201 | convTemp = Cov1dBlock(input_size=inputSize, output_size=640, kernal_size=(21,), stride=1, dilation=1, 202 | drop_out_prob=0.3) 203 | conv2s.append(('conv1d_{}'.format(idx + 1), convTemp)) 204 | inputSize = 640 205 | 206 | elif (layergroup) ==4: 207 | convTemp = Cov1dBlock(input_size=inputSize, output_size=768, kernal_size=(25,), stride=1, dilation=1, 208 | drop_out_prob=0.3) 209 | conv2s.append(('conv1d_{}'.format(idx + 1), convTemp)) 210 | inputSize = 768 211 | 212 | conv1 = Cov1dBlock(input_size=inputSize, output_size=896, kernal_size=(29,), stride=1, dilation=2, drop_out_prob=0.4) 213 | conv2s.append(('conv1d_{}'.format(16), conv1)) 214 | conv1 = Cov1dBlock(input_size=896, output_size=1024, kernal_size=(1,), stride=1, dilation=1, drop_out_prob=0.4) 215 | conv2s.append(('conv1d_{}'.format(17), conv1)) 216 | conv1 = Cov1dBlock(input_size=1024, output_size=len(self._labels), kernal_size=(1,),stride=1,bn=False,activationUse=False) 217 | conv2s.append(('conv1d_{}'.format(18), conv1)) 218 | 219 | self.conv1ds = nn.Sequential(OrderedDict(conv2s)) 220 | self.inference_softmax = InferenceBatchSoftmax() 221 | 222 | def forward(self, x): 223 | x = self.frontEnd(x) 224 | x = x.squeeze(1) 225 | x = self.conv1ds(x) 226 | x = x.transpose(1,2) 227 | x = self.inference_softmax(x) 228 | 229 | return x 230 | 231 | @classmethod 232 | def load_model(cls, path, cuda=False): 233 | package = torch.load(path, map_location=lambda storage, loc: storage) 234 | model = cls(labels=package['labels'], audio_conf=package['audio_conf'],sample_rate=package["sample_rate"] 235 | ,window_size=package["window_size"],mixed_precision=package.get('mixed_precision',False)) 236 | # the blacklist parameters are params that were previous erroneously saved by the model 237 | # care should be taken in future versions that if batch_norm on the first rnn is required 238 | # that it be named something else 239 | blacklist = ['rnns.0.batch_norm.module.weight', 'rnns.0.batch_norm.module.bias', 240 | 'rnns.0.batch_norm.module.running_mean', 'rnns.0.batch_norm.module.running_var'] 241 | 242 | for x in blacklist: 243 | if x in package['state_dict']: 244 | del package['state_dict'][x] 245 | # keyNames = package['state_dict'].keys() 246 | # 247 | # for keyname in keyNames: 248 | # if "num_batches_tracked" in keyname: 249 | # del package['state_dict'][keyname] 250 | model.load_state_dict(package['state_dict']) 251 | # for x in model.rnns: 252 | # x.flatten_parameters() 253 | if cuda: 254 | model = torch.nn.DataParallel(model).cuda() 255 | return model 256 | 257 | @classmethod 258 | def load_model_package(cls, package, cuda=False): 259 | model = cls(labels=package['labels'], audio_conf=package['audio_conf'],sample_rate=package.get("sample_rate",16000) 260 | ,window_size=package.get("window_size",.02),mixed_precision=package.get('mixed_precision',False)) 261 | model.load_state_dict(package['state_dict']) 262 | if cuda: 263 | model = torch.nn.DataParallel(model).cuda() 264 | return model 265 | 266 | @staticmethod 267 | def serialize(model, optimizer=None, epoch=None, iteration=None, loss_results=None, 268 | cer_results=None, wer_results=None, avg_loss=None, meta=None): 269 | model_is_cuda = next(model.parameters()).is_cuda 270 | # model = model.module if model_is_cuda else model 271 | package = { 272 | 'version': model._version, 273 | 'audio_conf': model._audio_conf, 274 | 'labels': model._labels, 275 | 'state_dict': model.state_dict(), 276 | 'mixed_precision': model.mixed_precision, 277 | 'sample_rate': model._sample_rate, 278 | 'window_size': model._window_size 279 | } 280 | if optimizer is not None: 281 | package['optim_dict'] = optimizer.state_dict() 282 | if avg_loss is not None: 283 | package['avg_loss'] = avg_loss 284 | if epoch is not None: 285 | package['epoch'] = epoch + 1 # increment for readability 286 | if iteration is not None: 287 | package['iteration'] = iteration 288 | if loss_results is not None: 289 | package['loss_results'] = loss_results 290 | package['cer_results'] = cer_results 291 | package['wer_results'] = wer_results 292 | if meta is not None: 293 | package['meta'] = meta 294 | return package 295 | 296 | @staticmethod 297 | def get_labels(model): 298 | model_is_cuda = next(model.parameters()).is_cuda 299 | return model.module._labels if model_is_cuda else model._labels 300 | @staticmethod 301 | def get_sample_rate(model): 302 | model_is_cuda = next(model.parameters()).is_cuda 303 | return model.module._sample_rate if model_is_cuda else model._sample_rate 304 | @staticmethod 305 | def get_window_size(model): 306 | model_is_cuda = next(model.parameters()).is_cuda 307 | return model.module._window_size if model_is_cuda else model._window_size 308 | @staticmethod 309 | def setAudioConfKey(model,key,value): 310 | model._audio_conf[key] = value 311 | return model 312 | @staticmethod 313 | def get_param_size(model): 314 | params = 0 315 | for p in model.parameters(): 316 | tmp = 1 317 | for x in p.size(): 318 | tmp *= x 319 | params += tmp 320 | return params 321 | 322 | @staticmethod 323 | def get_audio_conf(model): 324 | model_is_cuda = next(model.parameters()).is_cuda 325 | return model.module._audio_conf if model_is_cuda else model._audio_conf 326 | 327 | @staticmethod 328 | def get_meta(model): 329 | model_is_cuda = next(model.parameters()).is_cuda 330 | m = model.module if model_is_cuda else model 331 | meta = { 332 | "version": m._version 333 | } 334 | return meta 335 | 336 | def fuse_model(self): 337 | for m in self.modules(): 338 | if type(m) == Cov1dBlock: 339 | torch.quantization.fuse_modules(m, [ 'conv1','batchNorm'], inplace=True) 340 | # if type(m) == InvertedResidual: 341 | # for idx in range(len(m.conv)): 342 | # if type(m.conv[idx]) == nn.Conv2d: 343 | # torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True) 344 | 345 | def convertTensorType(self,dtypeToUse=torch.float16): 346 | module = self._modules 347 | layermodules = module['conv1ds'] 348 | # for layermodule in layermodules: 349 | convLayerPrefix = 'conv1d_{}' 350 | 351 | for i in range(0,19): 352 | convLayer = getattr(layermodules,convLayerPrefix.format(i)) 353 | modulesconv = convLayer._modules 354 | if 'batchNorm' in modulesconv: 355 | fusedLayer = self.fuse_conv_and_bn(modulesconv['conv1']._modules['0'],modulesconv['batchNorm']) 356 | setattr(convLayer,'fusedLayer',fusedLayer) 357 | return 358 | 359 | def fuse_conv_and_bn(self,conv, bn): 360 | # 361 | # init 362 | fusedconv = torch.nn.Conv1d( 363 | conv.in_channels, 364 | conv.out_channels, 365 | kernel_size=conv.kernel_size, 366 | stride=conv.stride, 367 | padding=0, 368 | dilation=conv.dilation, 369 | bias=True 370 | ) 371 | # 372 | # prepare filters 373 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 374 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 375 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) 376 | # 377 | # prepare spatial bias 378 | if conv.bias is not None: 379 | b_conv = conv.bias 380 | else: 381 | b_conv = torch.zeros(conv.weight.size(0)) 382 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 383 | fusedconv.bias.copy_(b_conv + b_bn) 384 | # 385 | # we're done 386 | return fusedconv 387 | 388 | if __name__ == '__main__': 389 | pass 390 | -------------------------------------------------------------------------------- /multiproc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import subprocess 4 | 5 | argslist = list(sys.argv)[1:] 6 | world_size = torch.cuda.device_count() 7 | 8 | if '--world-size' in argslist: 9 | argslist[argslist.index('--world-size') + 1] = str(world_size) 10 | else: 11 | argslist.append('--world-size') 12 | argslist.append(str(world_size)) 13 | 14 | workers = [] 15 | 16 | for i in range(world_size): 17 | if '--rank' in argslist: 18 | argslist[argslist.index('--rank') + 1] = str(i) 19 | else: 20 | argslist.append('--rank') 21 | argslist.append(str(i)) 22 | if '--gpu-rank' in argslist: 23 | argslist[argslist.index('--gpu-rank') + 1] = str(i) 24 | else: 25 | argslist.append('--gpu-rank') 26 | argslist.append(str(i)) 27 | stdout = None if i == 0 else open("GPU_" + str(i) + ".log", "w") 28 | print(argslist) 29 | p = subprocess.Popen([str(sys.executable)] + argslist, stdout=stdout, stderr=stdout) 30 | workers.append(p) 31 | 32 | for p in workers: 33 | p.wait() 34 | -------------------------------------------------------------------------------- /noise_inject.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torchaudio 5 | 6 | from data.data_loader import load_audio, NoiseInjection 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--input-path', default='input.wav', help='The input audio to inject noise into') 10 | parser.add_argument('--noise-path', default='noise.wav', help='The noise file to mix in') 11 | parser.add_argument('--output-path', default='output.wav', help='The noise file to mix in') 12 | parser.add_argument('--sample-rate', default=16000, help='Sample rate to save output as') 13 | parser.add_argument('--noise-level', type=float, default=1.0, 14 | help='The Signal to Noise ratio (higher means more noise)') 15 | args = parser.parse_args() 16 | 17 | noise_injector = NoiseInjection() 18 | data = load_audio(args.input_path) 19 | mixed_data = noise_injector.inject_noise_sample(data, args.noise_path, args.noise_level) 20 | mixed_data = torch.FloatTensor(mixed_data).unsqueeze(1) # Add channels dim 21 | torchaudio.save(args.output_path, mixed_data, args.sample_rate) 22 | print('Saved mixed file to %s' % args.output_path) 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-levenshtein 2 | visdom 3 | wget 4 | librosa 5 | tqdm 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from tqdm import tqdm 6 | from decoder import GreedyDecoder 7 | import torch 8 | from data.data_loader import SpectrogramDataset, AudioDataLoader 9 | from model import WaveToLetter 10 | import torch.quantization 11 | # from torch.quantization import QuantStub, DeQuantStub 12 | np.random.seed(123456) 13 | 14 | parser = argparse.ArgumentParser(description='Wav2Letter transcription') 15 | parser.add_argument('--model-path', default='~/models/wave2Letter/wav2Letter_final.pth.tar', 16 | help='Path to model file created by training') 17 | parser.add_argument('--cuda', default=True, action="store_true", help='Use cuda to test model') 18 | parser.add_argument('--test-manifest', metavar='DIR', 19 | help='path to validation manifest csv', default='~/data/validation.csv') 20 | parser.add_argument('--batch-size', default=10, type=int, help='Batch size for training') 21 | parser.add_argument('--fuse-layers', default=False, action="store_true" 22 | , help='if True then combine all the CONV-BN layer to increase the speed of network. W/o Decreasing the accuracy.') 23 | parser.add_argument('--num-workers', default=0, type=int, help='Number of workers used in dataloading') 24 | parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam", "none"], type=str, help="Decoder to use") 25 | parser.add_argument('--verbose', default=True, action="store_true", help="print out decoded output and error of each sample") 26 | no_decoder_args = parser.add_argument_group("No Decoder Options", "Configuration options for when no decoder is " 27 | "specified") 28 | no_decoder_args.add_argument('--output-path', default=None, type=str, help="Where to save raw acoustic output") 29 | beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") 30 | beam_args.add_argument('--top-paths', default=1, type=int, help='number of beams to return') 31 | beam_args.add_argument('--beam-width', default=10, type=int, help='Beam width to use') 32 | beam_args.add_argument('--lm-path', default="", type=str, 33 | help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') 34 | beam_args.add_argument('--alpha', default=0.75, type=float, help='Language model weight') 35 | beam_args.add_argument('--beta', default=1.0, type=float, help='Language model word bonus (all words)') 36 | beam_args.add_argument('--cutoff-top-n', default=40, type=int, 37 | help='Cutoff number in pruning, only top cutoff_top_n characters with highest probs in ' 38 | 'vocabulary will be used in beam search, default 40.') 39 | beam_args.add_argument('--cutoff-prob', default=1.0, type=float, 40 | help='Cutoff probability in pruning,default 1.0, no pruning.') 41 | beam_args.add_argument('--lm-workers', default=4, type=int, help='Number of LM processes to use') 42 | parser.add_argument('--mixPrec', default=False,dest='mixPrec', action='store_true', help='Use mix prec for inference even if it was not avail for training.') 43 | 44 | parser.add_argument('--usePCEN', default=True,dest='usePcen', action='store_true', help='Use cuda to train model') 45 | args = parser.parse_args() 46 | 47 | if __name__ == '__main__': 48 | torch.set_grad_enabled(False) 49 | device = torch.device("cuda" if args.cuda else "cpu") 50 | # device = torch.device("cpu") 51 | model = WaveToLetter.load_model(args.model_path, cuda=args.cuda) 52 | if args.fuse_layers: 53 | model.module.convertTensorType() 54 | model = model.to(device) 55 | model.eval() 56 | avgTime = [] 57 | labels = WaveToLetter.get_labels(model) 58 | audio_conf = WaveToLetter.get_audio_conf(model) 59 | # model.module.fuse_model() 60 | # model.qconfig = torch.quantization.default_qconfig 61 | # torch.quantization.prepare(model, inplace=True) 62 | if args.decoder == "beam": 63 | from decoder import BeamCTCDecoder 64 | 65 | decoder = BeamCTCDecoder(labels, lm_path=args.lm_path, alpha=args.alpha, beta=args.beta, 66 | cutoff_top_n=args.cutoff_top_n, cutoff_prob=args.cutoff_prob, 67 | beam_width=args.beam_width, num_processes=args.lm_workers) 68 | elif args.decoder == "greedy": 69 | decoder = GreedyDecoder(labels, blank_index=labels.index('_')) 70 | else: 71 | decoder = None 72 | target_decoder = GreedyDecoder(labels, blank_index=labels.index('_')) 73 | test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.test_manifest, labels=labels, 74 | normalize=True) 75 | test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, 76 | num_workers=args.num_workers) 77 | total_cer, total_wer = 0, 0 78 | output_data = [] 79 | for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)): 80 | inputs, targets, input_percentages, target_sizes, inputFilePaths, inputsMags = data 81 | 82 | inputs = Variable(inputs, volatile=True) 83 | inputs = inputs.to(device) 84 | 85 | # unflatten targets 86 | split_targets = [] 87 | offset = 0 88 | for size in target_sizes: 89 | split_targets.append(targets[offset:offset + size]) 90 | offset += size 91 | 92 | # if args.cuda: 93 | # inputsMags = inputsMags.cuda() 94 | beforeInferenceTime = time.time() 95 | out = model(inputs) # NxTxH 96 | seq_length = out.size(1) 97 | 98 | afterInferenceTime = time.time() 99 | sizes = input_percentages.mul_(int(seq_length)).int() 100 | 101 | if decoder is None: 102 | # add output to data array, and continue 103 | output_data.append((out.data.cpu().numpy(), sizes.numpy())) 104 | continue 105 | beforeDecoderTime = time.time() 106 | avgTime.append(afterInferenceTime - beforeInferenceTime) 107 | try: 108 | decoded_output, _, = decoder.decode(out.data, sizes) 109 | except Exception as e: 110 | continue 111 | target_strings = target_decoder.convert_to_strings(split_targets) 112 | wer, cer = 0, 0 113 | afterDecoderTime = time.time() 114 | 115 | print ('inferenceTime Total {}, only decodingTime {}, model outputTime {}' 116 | ''.format((afterDecoderTime-beforeInferenceTime), 117 | (afterDecoderTime-beforeDecoderTime),(afterInferenceTime-beforeInferenceTime))) 118 | for x in range(len(target_strings)): 119 | transcript, reference = decoded_output[x][0], target_strings[x][0] 120 | wer_inst = decoder.wer(transcript, reference) / float(len(reference.split())) 121 | cer_inst = decoder.cer(transcript, reference) / float(len(reference)) 122 | wer += wer_inst 123 | cer += cer_inst 124 | if args.verbose: 125 | print("Ref:", reference.lower()) 126 | print("Hyp:", transcript.lower()) 127 | print("WER:", wer_inst, "CER:", cer_inst, "\n") 128 | total_cer += cer 129 | total_wer += wer 130 | temp = (i+1)*args.batch_size 131 | if args.verbose: 132 | print("average_wer: ", total_wer/temp,"average_cer:", total_cer/temp) 133 | 134 | if decoder is not None: 135 | wer = total_wer / len(test_loader.dataset) 136 | cer = total_cer / len(test_loader.dataset) 137 | 138 | print('Test Summary \t' 139 | 'Average WER {wer:.3f}\t' 140 | 'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100)) 141 | else: 142 | np.save(args.output_path, output_data) 143 | avgTimeVal = float(sum(avgTime)) / len(avgTime) 144 | print ('avg time to run inference is {}'.format(avgTimeVal)) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import errno 3 | import json 4 | import os 5 | import random 6 | import time 7 | 8 | import numpy as np 9 | import torch.distributed as dist 10 | import torch.utils.data.distributed 11 | from torch.autograd import Variable 12 | from tqdm import tqdm 13 | from warpctc_pytorch import CTCLoss 14 | from apex import amp 15 | from data.data_loader import AudioDataLoader, SpectrogramDataset, BucketingSampler, DistributedBucketingSampler 16 | from data.distributed import DistributedDataParallel 17 | from decoder import GreedyDecoder 18 | from model import WaveToLetter 19 | import Levenshtein as Lev 20 | 21 | parser = argparse.ArgumentParser(description='Wav2Letter training') 22 | parser.add_argument('--train-manifest', metavar='DIR', 23 | help='path to train manifest csv', default='~/data/train.csv') 24 | parser.add_argument('--val-manifest', metavar='DIR', 25 | help='path to validation manifest csv', default='~/data/validation.csv') 26 | parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') 27 | parser.add_argument('--batch-size', default=16, type=int, help='Batch size for training') 28 | parser.add_argument('--num-workers', default=0, type=int, help='Number of workers used in data-loading') 29 | parser.add_argument('--labels-path', default='labels.json', help='Contains all characters for transcription') 30 | parser.add_argument('--window-size', default=.02, type=float, help='Window size for spectrogram in seconds') 31 | parser.add_argument('--peak-normalization',dest='peak_normalization', default=False, action='store_true', help='Apply peak normalization while training and validation') 32 | parser.add_argument('--window-stride', default=.01, type=float, help='Window stride for spectrogram in seconds') 33 | parser.add_argument('--window', default='hamming', help='Window type for spectrogram generation') 34 | parser.add_argument('--epochs', default=200, type=int, help='Number of training epochs') 35 | parser.add_argument('--cuda', default=True,dest='cuda', action='store_true', help='Use cuda to train model') 36 | parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float, help='initial learning rate') 37 | parser.add_argument('--mixPrec',dest='mixPrec',default=False,action='store_true', help='use mix precision for training') 38 | parser.add_argument('--reg-scale', dest='reg_scale', default=0.9, type=float, help='L2 regularizationScale') 39 | parser.add_argument('--momentum', default=0.90, type=float, help='momentum') 40 | parser.add_argument('--max-norm', default=400, type=int, help='Norm cutoff to prevent explosion of gradients') 41 | parser.add_argument('--learning-anneal', default=1.2, type=float, help='Annealing applied to learning rate every epoch') 42 | parser.add_argument('--silent', dest='silent', action='store_true', help='Turn off progress tracking per iteration') 43 | parser.add_argument('--checkpoint', default=True,dest='checkpoint', action='store_true', help='Enables checkpoint saving of model') 44 | parser.add_argument('--checkpoint-per-batch', default=0, type=int, help='Save checkpoint per batch. 0 means never save') 45 | parser.add_argument('--visdom', dest='visdom', action='store_true', help='Turn on visdom graphing') 46 | parser.add_argument('--tensorboard', default=False,dest='tensorboard', action='store_true', help='Turn on tensorboard graphing') 47 | parser.add_argument('--log-dir', default='visualize/w2lOnMozillaDataAftr118Epch', help='Location of tensorboard log') 48 | parser.add_argument('--log-params', dest='log_params', action='store_true', help='Log parameter values and gradients') 49 | parser.add_argument('--seed', default=1234 ) 50 | parser.add_argument('--id', default='Wav2Letter training', help='Identifier for visdom/tensorboard run') 51 | parser.add_argument('--save-folder', default='~/models/wave2Letter', help='Location to save epoch models') 52 | parser.add_argument('--model-path', default='~/models/wave2Letter/wav2Letter_final.pth.tar', 53 | help='Location to save best validation model') 54 | parser.add_argument('--continue-from', default='', help='Continue from checkpoint model') 55 | parser.add_argument('--finetune', default=False,dest='finetune', action='store_true', 56 | help='Finetune the model from checkpoint "continue_from"') 57 | parser.add_argument('--augment', default=False ,dest='augment', action='store_true', help='Use random tempo and gain perturbations.') 58 | parser.add_argument('--noise-dir', default=None, 59 | help='Directory to inject noise into audio. If default, noise Inject not added') 60 | parser.add_argument('--noise-prob', default=0.9, help='Probability of noise being added per sample') 61 | parser.add_argument('--noise-min', default=0.1, 62 | help='Minimum noise level to sample from. (1.0 means all noise, not original signal)', type=float) 63 | parser.add_argument('--noise-max', default=0.7, 64 | help='Maximum noise levels to sample from. Maximum 1.0', type=float) 65 | parser.add_argument('--no-shuffle', dest='no_shuffle', action='store_true', 66 | help='Turn off shuffling and sample from dataset based on sequence length (smallest to largest)') 67 | parser.add_argument('--no-sortaGrad', dest='no_sorta_grad', action='store_true', 68 | help='Turn off ordering of dataset on sequence length for the first epoch.') 69 | parser.add_argument('--no-bidirectional', dest='bidirectional', action='store_false', default=True, 70 | help='Turn off bi-directional RNNs, introduces lookahead convolution') 71 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:1550', type=str, 72 | help='url used to set up distributed training') 73 | parser.add_argument('--dist-backend', default='gloo', type=str, help='distributed backend') 74 | parser.add_argument('--world-size', default=1, type=int, 75 | help='number of distributed processes') 76 | parser.add_argument('--rank', default=0, type=int, 77 | help='The rank of this process') 78 | parser.add_argument('--gpu-rank', default=None, 79 | help='If using distributed parallel for multi-gpu, sets the GPU for the process') 80 | 81 | torch.manual_seed(123456) 82 | torch.cuda.manual_seed_all(123456) 83 | 84 | 85 | def to_np(x): 86 | return x.data.cpu().numpy() 87 | 88 | 89 | class AverageMeter(object): 90 | """Computes and stores the average and current value""" 91 | 92 | def __init__(self): 93 | self.reset() 94 | 95 | def reset(self): 96 | self.val = 0 97 | self.avg = 0 98 | self.sum = 0 99 | self.count = 0 100 | 101 | def update(self, val, n=1): 102 | self.val = val 103 | self.sum += val * n 104 | self.count += n 105 | self.avg = self.sum / self.count 106 | 107 | 108 | def werCalc(s1, s2): 109 | """ 110 | Computes the Word Error Rate, defined as the edit distance between the 111 | two provided sentences after tokenizing to words. 112 | Arguments: 113 | s1 (string): space-separated sentence 114 | s2 (string): space-separated sentence 115 | """ 116 | 117 | # build mapping of words to integers 118 | b = set(s1.split() + s2.split()) 119 | word2char = dict(zip(b, range(len(b)))) 120 | 121 | # map the words to a char array (Levenshtein packages only accepts 122 | # strings) 123 | w1 = [chr(word2char[w]) for w in s1.split()] 124 | w2 = [chr(word2char[w]) for w in s2.split()] 125 | 126 | return Lev.distance(''.join(w1), ''.join(w2)) 127 | 128 | def cerCalc(s1, s2): 129 | """ 130 | Computes the Character Error Rate, defined as the edit distance. 131 | 132 | Arguments: 133 | s1 (string): space-separated sentence 134 | s2 (string): space-separated sentence 135 | """ 136 | s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') 137 | return Lev.distance(s1, s2) 138 | 139 | def poly_lr_scheduler(init_lr, iter, lr_decay_iter=1, 140 | max_iter=100, power=0.9): 141 | """Polynomial decay of learning rate 142 | :param init_lr is base learning rate 143 | :param iter is a current iteration 144 | :param lr_decay_iter how frequently decay occurs, default is 1 145 | :param max_iter is number of maximum iterations 146 | :param power is a polymomial power 147 | 148 | """ 149 | if iter % lr_decay_iter or iter > max_iter: 150 | return 0 151 | lr = init_lr*(1 - float(iter)/float(max_iter))**power 152 | return lr 153 | 154 | 155 | if __name__ == '__main__': 156 | args = parser.parse_args() 157 | 158 | # Set seeds for determinism 159 | torch.manual_seed(args.seed) 160 | torch.cuda.manual_seed_all(args.seed) 161 | np.random.seed(args.seed) 162 | random.seed(args.seed) 163 | 164 | device = torch.device("cuda" if args.cuda else "cpu") 165 | if args.mixPrec and not args.cuda: 166 | raise ValueError('If using mixed precision training, CUDA must be enabled!') 167 | args.distributed = args.world_size > 1 168 | main_proc = True 169 | device = torch.device("cuda" if args.cuda else "cpu") 170 | if args.distributed: 171 | if args.gpu_rank: 172 | torch.cuda.set_device(int(args.gpu_rank)) 173 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 174 | world_size=args.world_size, rank=args.rank) 175 | main_proc = args.rank == 0 # Only the first proc should save models 176 | save_folder = args.save_folder 177 | os.makedirs(save_folder, exist_ok=True) # Ensure save folder exists 178 | 179 | loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor( 180 | args.epochs) 181 | best_wer = None 182 | optim_state = None 183 | if args.visdom and main_proc: 184 | from visdom import Visdom 185 | 186 | viz = Visdom() 187 | opts = dict(title=args.id, ylabel='', xlabel='Epoch', legend=['Loss', 'WER', 'CER']) 188 | viz_window = None 189 | epochs = torch.arange(1, args.epochs + 1) 190 | if args.tensorboard and main_proc: 191 | try: 192 | os.makedirs(args.log_dir) 193 | except OSError as e: 194 | if e.errno == errno.EEXIST: 195 | print('Tensorboard log directory already exists.') 196 | for file in os.listdir(args.log_dir): 197 | file_path = os.path.join(args.log_dir, file) 198 | try: 199 | if os.path.isfile(file_path): 200 | os.unlink(file_path) 201 | except Exception: 202 | raise 203 | else: 204 | raise 205 | from tensorboardX import SummaryWriter 206 | 207 | tensorboard_writer = SummaryWriter(args.log_dir) 208 | 209 | try: 210 | os.makedirs(save_folder) 211 | except OSError as e: 212 | if e.errno == errno.EEXIST: 213 | print('Model Save directory already exists.') 214 | else: 215 | raise 216 | criterion = CTCLoss() 217 | 218 | avg_loss, start_epoch, start_iter = 0, 0, 0 219 | if args.continue_from: # Starting from previous model 220 | print("Loading checkpoint model %s" % args.continue_from) 221 | package = torch.load(args.continue_from, map_location=lambda storage, loc: storage) 222 | model = WaveToLetter.load_model_package(package) 223 | audio_conf = WaveToLetter.get_audio_conf(model) 224 | labels = WaveToLetter.get_labels(model) 225 | parameters = model.parameters() 226 | optimizer = torch.optim.SGD(parameters, lr=args.lr, 227 | momentum=args.momentum, nesterov=True) 228 | 229 | if args.noise_dir is not None: 230 | model = WaveToLetter.setAudioConfKey(model,'noise_dir',args.noise_dir) 231 | model = WaveToLetter.setAudioConfKey(model,'noise_prob',args.noise_prob) 232 | model = WaveToLetter.setAudioConfKey(model,'noise_max',args.noise_max) 233 | model = WaveToLetter.setAudioConfKey(model,'noise_min',args.noise_min) 234 | 235 | if not args.finetune: # Don't want to restart training 236 | # if args.cuda: 237 | # model.cuda() 238 | optim_state = package['optim_dict'] 239 | # optimizer.load_state_dict(package['optim_dict']) 240 | 241 | # Temporary fix for pytorch #2830 & #1442 while pull request #3658 in not incorporated in a release 242 | # TODO : remove when a new release of pytorch include pull request #3658 243 | # if args.cuda: 244 | # for state in optimizer.state.values(): 245 | # for k, v in state.items(): 246 | # if torch.is_tensor(v): 247 | # state[k] = v.cuda() 248 | 249 | start_epoch = int(package.get('epoch', 1)) - 1 # Index start at 0 for training 250 | start_iter = package.get('iteration', None) 251 | if start_iter is None: 252 | start_epoch += 1 # We saved model after epoch finished, start at the next epoch. 253 | start_iter = 0 254 | else: 255 | start_iter += 1 256 | avg_loss = int(package.get('avg_loss', 0)) 257 | loss_results, cer_results, wer_results = package['loss_results'], package[ 258 | 'cer_results'], package['wer_results'] 259 | if main_proc and args.visdom and \ 260 | package[ 261 | 'loss_results'] is not None and start_epoch > 0: # Add previous scores to visdom graph 262 | x_axis = epochs[0:start_epoch] 263 | y_axis = torch.stack( 264 | (loss_results[0:start_epoch], wer_results[0:start_epoch], cer_results[0:start_epoch]), 265 | dim=1) 266 | viz_window = viz.line( 267 | X=x_axis, 268 | Y=y_axis, 269 | opts=opts, 270 | ) 271 | if main_proc and args.tensorboard and \ 272 | package[ 273 | 'loss_results'] is not None and start_epoch > 0: # Previous scores to tensorboard logs 274 | for i in range(start_epoch): 275 | values = { 276 | 'Avg Train Loss': loss_results[i], 277 | 'Avg WER': wer_results[i], 278 | 'Avg CER': cer_results[i] 279 | } 280 | tensorboard_writer.add_scalars(args.id, values, i + 1) 281 | else: 282 | with open(args.labels_path) as label_file: 283 | labels = str(''.join(json.load(label_file))) 284 | 285 | audio_conf = dict(sample_rate=args.sample_rate, 286 | window_size=args.window_size, 287 | window_stride=args.window_stride, 288 | window=args.window, 289 | noise_dir=args.noise_dir, 290 | noise_prob=args.noise_prob, 291 | noise_levels=(args.noise_min, args.noise_max)) 292 | 293 | model = WaveToLetter(labels=labels, 294 | audio_conf=audio_conf,sample_rate=args.sample_rate,window_size=args.window_size,mixed_precision=args.mixPrec) 295 | # parameters = model.parameters() 296 | # optimizer = torch.optim.SGD(parameters, lr=args.lr, 297 | # momentum=args.momentum, nesterov=True) 298 | 299 | decoder = GreedyDecoder(labels) 300 | 301 | train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels, 302 | normalize=True, peak_normalization=args.peak_normalization, augment=args.augment) 303 | test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels, 304 | normalize=True, peak_normalization=args.peak_normalization, augment=False) 305 | 306 | if not args.distributed: 307 | train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size) 308 | else: 309 | train_sampler = DistributedBucketingSampler(train_dataset, batch_size=args.batch_size, 310 | num_replicas=args.world_size, rank=args.rank) 311 | train_loader = AudioDataLoader(train_dataset, 312 | num_workers=args.num_workers, batch_sampler=train_sampler) 313 | test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, 314 | num_workers=args.num_workers) 315 | 316 | if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad: 317 | print("Shuffling batches for the following epochs") 318 | train_sampler.shuffle(start_epoch) 319 | 320 | model = model.to(device) 321 | parameters = model.parameters() 322 | optimizer = torch.optim.SGD(parameters, lr=args.lr, 323 | momentum=args.momentum, nesterov=True, weight_decay=1e-5) 324 | if args.distributed: 325 | model = DistributedDataParallel(model) 326 | # if args.cuda and not args.distributed: 327 | # model = torch.nn.DataParallel(model).cuda() 328 | # elif args.cuda and args.distributed: 329 | # model.cuda() 330 | # model = DistributedDataParallel(model) 331 | if optim_state is not None: 332 | optimizer.load_state_dict(optim_state) 333 | if args.mixPrec: 334 | model, optimizer = amp.initialize(model, optimizer, opt_level="O2") 335 | print(model) 336 | print("Number of parameters: %d" % WaveToLetter.get_param_size(model)) 337 | 338 | batch_time = AverageMeter() 339 | data_time = AverageMeter() 340 | losses = AverageMeter() 341 | globalStep = 0 342 | for epoch in range(start_epoch, args.epochs): 343 | torch.cuda.empty_cache() 344 | if not args.no_shuffle: 345 | print("Shuffling batches...") 346 | train_sampler.shuffle(epoch) 347 | model.train() 348 | end = time.time() 349 | start_epoch_time = time.time() 350 | for i, (data) in enumerate(train_loader, start=start_iter): 351 | if i == len(train_sampler): 352 | break 353 | inputs, targets, input_percentages, target_sizes, inputFilePaths, inputsMags = data 354 | globalStep +=1 355 | # measure data loading time 356 | data_time.update(time.time() - end) 357 | inputs = Variable(inputs, requires_grad=False) 358 | target_sizes = Variable(target_sizes, requires_grad=False) 359 | targets = Variable(targets, requires_grad=False) 360 | inputs = inputs.to(device) 361 | out = model(inputs) 362 | out = out.transpose(0, 1) # TxNxH 363 | 364 | seq_length = out.size(0) 365 | sizes = Variable(input_percentages.mul_(int(seq_length)).int(), requires_grad=False) 366 | 367 | out = out.cpu() 368 | loss = criterion(out, targets, sizes, target_sizes) 369 | loss = loss / inputs.size(0) # average the loss by minibatch 370 | loss_sum = loss.data.sum() 371 | inf = float("inf") 372 | 373 | if loss_sum == inf or loss_sum == -inf: 374 | print("WARNING: received an inf loss, setting loss value to 0") 375 | loss_value = 0 376 | else: 377 | loss_value = loss.data[0] 378 | 379 | 380 | avg_loss += loss_value 381 | losses.update(loss_value, inputs.size(0)) 382 | # compute gradient 383 | optimizer.zero_grad() 384 | if args.mixPrec: 385 | with amp.scale_loss(loss, optimizer) as scaled_loss: 386 | scaled_loss.backward() 387 | optimizer.clip_master_grads(args.max_norm) 388 | else: 389 | loss.backward() 390 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) 391 | # SGD step 392 | optimizer.step() 393 | # if args.cuda: 394 | # torch.cuda.synchronize() 395 | 396 | # measure elapsed time 397 | batch_time.update(time.time() - end) 398 | end = time.time() 399 | if not args.silent: 400 | print('Epoch: [{0}][{1}/{2}]\t' 401 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 402 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 403 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 404 | (epoch + 1), (i + 1), len(train_sampler), batch_time=batch_time, 405 | data_time=data_time, loss=losses)) 406 | # if losses.val <0.01: 407 | # out = out.transpose(0,1) 408 | # decoded_output, _ = decoder.decode(out.data, sizes) 409 | # for numFile,idxP in enumerate(decoded_output): 410 | # print(idxP),inputFilePaths[numFile][1] 411 | if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0 and main_proc: 412 | file_path = '%s/wav2Letter_checkpoint_epoch_%d_iter_%d.pth.tar' % (save_folder, epoch + 1, i + 1) 413 | print("Saving checkpoint model to %s" % file_path) 414 | torch.save(WaveToLetter.serialize(model, optimizer=optimizer, epoch=epoch, iteration=i, 415 | loss_results=loss_results, 416 | wer_results=wer_results, cer_results=cer_results, avg_loss=avg_loss), 417 | file_path) 418 | del loss 419 | del out 420 | torch.cuda.empty_cache() 421 | 422 | avg_loss /= len(train_sampler) 423 | 424 | epoch_time = time.time() - start_epoch_time 425 | print('Training Summary Epoch: [{0}]\t' 426 | 'Time taken (s): {epoch_time:.0f}\t' 427 | 'Average Loss {loss:.3f}\t'.format( 428 | epoch + 1, epoch_time=epoch_time, loss=avg_loss)) 429 | 430 | start_iter = 0 # Reset start iteration for next epoch 431 | total_cer, total_wer = 0, 0 432 | model.eval() 433 | if (epoch+1)%1==0: 434 | print ("coming into test loop") 435 | for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)): 436 | inputs, targets, input_percentages, target_sizes, inputFilePaths, inputsMags = data 437 | 438 | inputs = Variable(inputs, volatile=True) 439 | inputs = inputs.to(device) 440 | 441 | 442 | # unflatten targets 443 | split_targets = [] 444 | offset = 0 445 | for size in target_sizes: 446 | split_targets.append(targets[offset:offset + size]) 447 | offset += size 448 | 449 | # if args.cuda: 450 | # inputsMags = inputsMags.cuda() 451 | 452 | out = model(inputs) # NxTxH 453 | seq_length = out.size(1) 454 | sizes = input_percentages.mul_(int(seq_length)).int() 455 | 456 | decoded_output, _ = decoder.decode(out.data, sizes) 457 | target_strings = decoder.convert_to_strings(split_targets) 458 | 459 | wer, cer = 0, 0 460 | for x in range(len(target_strings)): 461 | transcript, reference = decoded_output[x][0], target_strings[x][0] 462 | print ('transcript : {}, reference :{} , filePath : {}'.format(transcript,reference,inputFilePaths[x][0])) 463 | # print 'reference : {}'.format(reference) 464 | try: 465 | wer += decoder.wer(transcript, reference) / float(len(reference.split())) 466 | cer += decoder.cer(transcript, reference) / float(len(reference)) 467 | except Exception as e: 468 | print ('encountered exception {}'.format(e)) 469 | total_cer += cer 470 | total_wer += wer 471 | 472 | if args.cuda: 473 | torch.cuda.synchronize() 474 | del out 475 | torch.cuda.empty_cache() 476 | wer = total_wer / len(test_loader.dataset) 477 | cer = total_cer / len(test_loader.dataset) 478 | wer *= 100 479 | cer *= 100 480 | loss_results[epoch] = avg_loss 481 | wer_results[epoch] = wer 482 | cer_results[epoch] = cer 483 | print('Validation Summary Epoch: [{0}]\t' 484 | 'Average WER {wer:.3f}\t' 485 | 'Average CER {cer:.3f}\t'.format( 486 | epoch + 1, wer=wer, cer=cer)) 487 | 488 | if args.visdom and main_proc: 489 | x_axis = epochs[0:epoch + 1] 490 | y_axis = torch.stack((loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1) 491 | if viz_window is None: 492 | viz_window = viz.line( 493 | X=x_axis, 494 | Y=y_axis, 495 | opts=opts, 496 | ) 497 | else: 498 | viz.line( 499 | X=x_axis.unsqueeze(0).expand(y_axis.size(1), x_axis.size(0)).transpose(0, 1), # Visdom fix 500 | Y=y_axis, 501 | win=viz_window, 502 | update='replace', 503 | ) 504 | if args.tensorboard and main_proc: 505 | values = { 506 | 'Avg Train Loss': avg_loss, 507 | 'Avg WER': wer, 508 | 'Avg CER': cer 509 | } 510 | tensorboard_writer.add_scalars(args.id, values, epoch + 1) 511 | if args.log_params: 512 | for tag, value in model.named_parameters(): 513 | tag = tag.replace('.', '/') 514 | tensorboard_writer.add_histogram(tag, to_np(value), epoch + 1) 515 | tensorboard_writer.add_histogram(tag + '/grad', to_np(value.grad), epoch + 1) 516 | if args.checkpoint and main_proc: 517 | file_path = '%s/wav2Letter_%d.pth.tar' % (save_folder, epoch + 1) 518 | torch.save(WaveToLetter.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results, 519 | wer_results=wer_results, cer_results=cer_results), 520 | file_path) 521 | if (epoch + 1) % 1 == 0: 522 | # # anneal lr 523 | 524 | optim_state = optimizer.state_dict() 525 | lrToUse = optim_state['param_groups'][0]['lr'] / args.learning_anneal 526 | optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal 527 | optim_state['param_groups'][0]['lr'] = lrToUse 528 | optimizer.load_state_dict(optim_state) 529 | print('Learning rate annealed to: {lr:.15f}'.format(lr=optim_state['param_groups'][0]['lr'])) 530 | 531 | # if (best_wer is None or best_wer > wer) and main_proc: 532 | if main_proc: 533 | # print("Found better validated model, saving to %s" % args.model_path) 534 | torch.save(WaveToLetter.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results, 535 | wer_results=wer_results, cer_results=cer_results) 536 | , args.model_path) 537 | 538 | best_wer = wer 539 | 540 | avg_loss = 0 541 | 542 | -------------------------------------------------------------------------------- /transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | warnings.simplefilter('ignore') 5 | 6 | from decoder import GreedyDecoder 7 | 8 | from torch.autograd import Variable 9 | 10 | from data.data_loader import SpectrogramParser 11 | from model import WaveToLetter 12 | import os.path 13 | import json 14 | 15 | parser = argparse.ArgumentParser(description='DeepSpeech transcription') 16 | parser.add_argument('--model-path', default='models/wav2letter.pth', 17 | help='Path to model file created by training') 18 | parser.add_argument('--audio-path', default='audio.wav', 19 | help='Audio file to predict on') 20 | parser.add_argument('--cuda', action="store_true", help='Use cuda to test model') 21 | parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") 22 | parser.add_argument('--offsets', dest='offsets', action='store_true', help='Returns time offset information') 23 | beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") 24 | beam_args.add_argument('--top-paths', default=1, type=int, help='number of beams to return') 25 | beam_args.add_argument('--beam-width', default=10, type=int, help='Beam width to use') 26 | beam_args.add_argument('--lm-path', default=None, type=str, 27 | help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') 28 | beam_args.add_argument('--alpha', default=0.8, type=float, help='Language model weight') 29 | beam_args.add_argument('--beta', default=1, type=float, help='Language model word bonus (all words)') 30 | beam_args.add_argument('--cutoff-top-n', default=40, type=int, 31 | help='Cutoff number in pruning, only top cutoff_top_n characters with highest probs in ' 32 | 'vocabulary will be used in beam search, default 40.') 33 | beam_args.add_argument('--cutoff-prob', default=1.0, type=float, 34 | help='Cutoff probability in pruning,default 1.0, no pruning.') 35 | beam_args.add_argument('--lm-workers', default=1, type=int, help='Number of LM processes to use') 36 | args = parser.parse_args() 37 | 38 | 39 | def decode_results(model, decoded_output, decoded_offsets): 40 | results = { 41 | "output": [], 42 | "_meta": { 43 | "acoustic_model": { 44 | "name": os.path.basename(args.model_path) 45 | }, 46 | "language_model": { 47 | "name": os.path.basename(args.lm_path) if args.lm_path else None, 48 | }, 49 | "decoder": { 50 | "lm": args.lm_path is not None, 51 | "alpha": args.alpha if args.lm_path is not None else None, 52 | "beta": args.beta if args.lm_path is not None else None, 53 | "type": args.decoder, 54 | } 55 | } 56 | } 57 | results['_meta']['acoustic_model'].update(WaveToLetter.get_meta(model)) 58 | 59 | for b in range(len(decoded_output)): 60 | for pi in range(min(args.top_paths, len(decoded_output[b]))): 61 | result = {'transcription': decoded_output[b][pi]} 62 | if args.offsets: 63 | result['offsets'] = decoded_offsets[b][pi] 64 | results['output'].append(result) 65 | return results 66 | 67 | 68 | if __name__ == '__main__': 69 | model = WaveToLetter.load_model(args.model_path, cuda=args.cuda) 70 | model.eval() 71 | 72 | labels = WaveToLetter.get_labels(model) 73 | audio_conf = WaveToLetter.get_audio_conf(model) 74 | 75 | if args.decoder == "beam": 76 | from decoder import BeamCTCDecoder 77 | 78 | decoder = BeamCTCDecoder(labels, lm_path=args.lm_path, alpha=args.alpha, beta=args.beta, 79 | cutoff_top_n=args.cutoff_top_n, cutoff_prob=args.cutoff_prob, 80 | beam_width=args.beam_width, num_processes=args.lm_workers) 81 | else: 82 | decoder = GreedyDecoder(labels, blank_index=labels.index('_')) 83 | 84 | parser = SpectrogramParser(audio_conf, normalize=True) 85 | 86 | spect = parser.parse_audio(args.audio_path).contiguous() 87 | spect = spect.view(1, 1, spect.size(0), spect.size(1)) 88 | out = model(Variable(spect, volatile=True)) 89 | out = out.transpose(0, 1) # TxNxH 90 | decoded_output, decoded_offsets = decoder.decode(out.data) 91 | print(json.dumps(decode_results(model, decoded_output, decoded_offsets))) 92 | --------------------------------------------------------------------------------