├── .gitignore ├── README.md ├── asr ├── __init__.py ├── language_model │ ├── __init__.py │ └── language_model.py ├── utils │ ├── __init__.py │ ├── beam_search_decoder.py │ └── utils.py └── wav2vec2 │ ├── __init__.py │ ├── decoder │ ├── __init__.py │ └── ctc_decoder.py │ ├── inference.py │ └── vocab.py ├── asr_inference_live.py ├── asr_inference_offline.py ├── asr_inference_recording.py ├── data ├── lm_training_corpus │ └── corpus.txt ├── models │ └── lm │ │ ├── twitter │ │ ├── bigram.pkl │ │ └── unigram.pkl │ │ └── wikipedia │ │ ├── bigram.pkl │ │ └── unigram.pkl └── samples │ ├── Achievements_of_the_Democratic_Party_(Homer_S._Cummings).ogg │ ├── rec.wav │ ├── rec2.wav │ └── shortened.wav ├── notebooks ├── Training_Simple_Lanugage_Model.ipynb ├── wav2vec2_asr_pretrained_inference.ipynb ├── wav2vec2_experiment_language_model.ipynb ├── wav2vec2_finetuning_version_1.ipynb ├── wav2vec2_finetuning_version_2_with_data_augmentations.ipynb └── wav2vec2large_experiment_language_model.ipynb ├── requirements.txt └── train_language_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar 2 | *.zip 3 | /data/models/asr/* 4 | /output/* 5 | /.vscode/* 6 | # Created by https://www.gitignore.io/api/python 7 | # Edit at https://www.gitignore.io/?templates=python 8 | 9 | ### Python ### 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # celery beat schedule file 102 | celerybeat-schedule 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # End of https://www.gitignore.io/api/python 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Sound Recognition using Wav2Vec2 2 | 3 | This repository uses wav2vec2 model from hugging face transformers to create an ASR system which takes input speech signal as input and outputs transcriptions asynchronously. 4 | 5 | I have also written a [post](https://www.tarunbisht.com/deep%20learning/2021/06/17/speech-recognition-using-wav2vec-model/) explaining wave2vec2 in some detail with some further learning directions. 6 | 7 | ## Installation 8 | 9 | ### Installing via pip 10 | - Download and Install python 11 | - Create a virtual environment using `python -m venv env_name` 12 | - enable created environment `env_path\Scripts\activate` 13 | - Install PyTorch `pip install torch==1.8.0+cu102 torchaudio===0.8.0 -f https://download.pytorch.org/whl/torch_stable.html` 14 | - Install required dependencies `pip install -r requirements.txt` 15 | 16 | ### Installing via conda 17 | - Download and install miniconda 18 | - Create a new virutal environment using `conda create --name env_name python==3.8` 19 | - enable create environment `conda activate env_name` 20 | - Install PyTorch `conda install pytorch torchaudio cudatoolkit=11.1 -c pytorch` 21 | - Install required dependencies `pip install -r requirements.txt` 22 | 23 | ## Inferencing 24 | ### transcribing an audio file 25 | - run `python asr_inference_offline.py` with parameters: 26 | - `--model` or `-m`: path to saved wavenetctc local model if not passed it will be downloaded (Defaults to None) 27 | - `--pipeline` or `-t` : path to saved wav2vec local pipeline path if not passed then it will be downloaded (Defaults to None) 28 | - `--output` or `-out` : path to output file to save transcriptions. (not required) 29 | - `--device` or `-d` : device to use for inferencing (choices=["cpu", "cuda"] and Defaults to cpu) 30 | - `--lm` or `l` : path to folder in which trained language model is saved with unigram and bigram files. This language model will be used by beam search algorithm to weight scores of beams (Defaults to None) 31 | - `--beam_width` or `-bw` : beam width to use for beam search decoder during inferencing (Defaults to 1). If `beam_width <= 1` then max decoding will be used to decode ctc inputs, else beam search decoding will be used. 32 | - example 33 | - `python asr_inference_offline.py --recording data/samples/rec.wav -out output/transcription.txt` 34 | - `python asr_inference_offline.py --recording data/samples/rec.wav --device cuda` 35 | ### transcribing a streaming audio 36 | - run `python asr_inference_recording.py` with parameters: 37 | - `--recording` or `-rec` : path to audio recording 38 | - `--model` or `-m`: path to saved wavenetctc local model if not passed it will be downloaded (Defaults to None) 39 | - `--pipeline` or `-t` : path to saved wav2vec local pipeline path if not passed then it will be downloaded (Defaults to None) 40 | - `--blocksize` or `-bs` : size of each audio block to be passed to model (Defaults to 16000) 41 | - `--overlap` or `-ov` : overlapping between each loaded block (Defaults to 0) 42 | - `--output` or `-out` : path to output file to save transcriptions. (not required) 43 | - `--device` or `-d` : device to use for inferencing (choices=["cpu", "cuda"] and Defaults to cpu) 44 | - `--lm` or `l` : path to folder in which trained language model is saved with unigram and bigram files. This language model will be used by beam search algorithm to weight scores of beams (Defaults to None) 45 | - `--beam_width` or `-bw` : beam width to use for beam search decoder during inferencing (Defaults to 1). If `beam_width <= 1` then max decoding will be used to decode ctc inputs, else beam search decoding will be used. 46 | - example 47 | - `python asr_inference_recording.py --recording data/samples/rec.wav -bs 16000 -out output/transcription.txt` 48 | - `python asr_inference_recording.py --recording data/samples/rec.wav -bs 16000 -ov 1600 -out output/transcription.txt` 49 | - `python asr_inference_recording.py --recording data/samples/rec.wav -bs 16000 -ov 1600 -out output/transcription.txt --device gpu` 50 | 51 | ### live recording and transcribing 52 | - run `python asr_inference_live.py` with parameters: 53 | - `--model` or `-m`: path to saved wavenetctc local model if not passed it will be downloaded (Defaults to None) 54 | - `--pipeline` or `-t` : path to saved wav2vec local pipeline path if not passed then it will be downloaded (Defaults to None) 55 | - `--blocksize` or `-bs` : size of each audio block to be passed to model (Defaults to 16000) 56 | - `--output` or `-out` : path to output file to save transcriptions. (not required) 57 | - `--device` or `-d` : device to use for inferencing (choices=["cpu", "cuda"] and Defaults to cpu) 58 | - `--lm` or `l` : path to folder in which trained language model is saved with unigram and bigram files. This language model will be used by beam search algorithm to weight scores of beams (Defaults to None) 59 | - `--beam_width` or `-bw` : beam width to use for beam search decoder during inferencing (Defaults to 1). If `beam_width <= 1` then max decoding will be used to decode ctc inputs, else beam search decoding will be used. 60 | - example 61 | - `python asr_inference_live.py -bs 16000 -out output/transcription.txt` 62 | - `python asr_inference_live.py` 63 | - `python asr_inference_live.py --device cuda` 64 | 65 | ## Training Language Model 66 | - run `python asr_inference_live.py` with parameters: 67 | - `--corpus` or `-c` : path to corpus text file. 68 | - `--save` or `-s` : folder path to save model files. 69 | 70 | ## Notebooks 71 | All notebooks resides in notebook folder these are handy when using google colab or similar platforms. All these notebooks are tested in google colab. 72 | - `wav2vec2_asr_pretrained_inference` : Basic inference notebook 73 | - `wav2vec2_experiment_language_model` : kenlm language model with beam search 74 | - `wav2vec2large_experiment_language_model` : kenlm language model with beam search for larger model 75 | - `wav2vec2_finetuning_version_1` : finetuning notebook without augmentation 76 | - `wav2vec2_finetuning_version_2_with_data_augmentations` : finetuning notebook with augmentation 77 | - `Training_Simple_Lanugage_Model` : training language model notebook version with wikipedia data 78 | 79 | ## Comparisions 80 | ### GPU inference vs CPU inference 81 | For 4min 10sec recorder audio total time taken 82 | 1. GPU (Nvidia GeForce 940MX) : 18.29sec 83 | 2. CPU : 116.85sec 84 | 85 | ## To do list 86 | - Environment Setup ✔ 87 | - Inferencing with CPU ✔ 88 | - Inferencing with GPU ✔ 89 | - Asyncio Compatible ✔ 90 | - Training and Finetuning Notebooks ✔ 91 | - Training and Finetuning Scripts 92 | - Converting model to TensorFlow with ONNX for inference using TensorFlow 93 | 94 | ## Tested Platforms 95 | - native windows 10 ✔ 96 | - windows-10 wsl2 cpu ✔ 97 | - windows-10 wsl2 gpu ✔ 98 | - Linux ✔ 99 | 100 | ## References 101 | - [Hugging Face Wav2Vec2](https://huggingface.co/transformers/master/model_doc/wav2vec2.html) 102 | - [CTC decoder adapted from githubharald/CTCDecoder](https://github.com/githubharald/CTCDecoder) 103 | -------------------------------------------------------------------------------- /asr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/asr/__init__.py -------------------------------------------------------------------------------- /asr/language_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model import LanguageModel -------------------------------------------------------------------------------- /asr/language_model/language_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | class LanguageModel: 6 | """Simple character-level language model""" 7 | 8 | def __init__(self, chars: list) -> None: 9 | self._unigram = {c: 0 for c in chars} 10 | self._bigram = {c: {d: 0 for d in chars} for c in chars} 11 | self.chars = chars 12 | 13 | def get_char_unigram(self, c: str) -> float: 14 | """Probability of character c.""" 15 | return self._unigram[c] 16 | 17 | def get_char_bigram(self, c: str, d: str) -> float: 18 | """Probability that character c is followed by character d.""" 19 | return self._bigram[c][d] 20 | 21 | def train(self, txt: str, normalize=False): 22 | """Create language model from text corpus.""" 23 | # compute unigrams 24 | for c in txt: 25 | # ignore unknown chars 26 | if c not in self._unigram: 27 | continue 28 | self._unigram[c] += 1 29 | 30 | # compute bigrams 31 | for i in range(len(txt) - 1): 32 | c = txt[i] 33 | d = txt[i + 1] 34 | 35 | # ignore unknown chars 36 | if c not in self._bigram or d not in self._bigram[c]: 37 | continue 38 | 39 | self._bigram[c][d] += 1 40 | if normalize: 41 | self.normalize() 42 | 43 | def normalize(self): 44 | # normalize 45 | sum_unigram = sum(self._unigram.values()) 46 | for c in self.chars: 47 | self._unigram[c] /= sum_unigram 48 | 49 | for c in self.chars: 50 | sum_bigram = sum(self._bigram[c].values()) 51 | if sum_bigram == 0: 52 | continue 53 | for d in self.chars: 54 | self._bigram[c][d] /= sum_bigram 55 | 56 | def save(self, path): 57 | with open(os.path.join(path, "unigram.pkl"), 'wb') as pkl: 58 | pickle.dump(self._unigram, pkl) 59 | with open(os.path.join(path, "bigram.pkl"), 'wb') as pkl: 60 | pickle.dump(self._bigram, pkl) 61 | 62 | def load(self, path): 63 | with open(os.path.join(path, "unigram.pkl"), 'rb') as pkl: 64 | self._unigram = pickle.load(pkl) 65 | with open(os.path.join(path, "bigram.pkl"), 'rb') as pkl: 66 | self._bigram = pickle.load(pkl) -------------------------------------------------------------------------------- /asr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import (MicrophoneStreaming, AudioStreaming, 2 | AudioReader, MicrophoneCaptureFailed) 3 | from .beam_search_decoder import BeamSearchDecoder -------------------------------------------------------------------------------- /asr/utils/beam_search_decoder.py: -------------------------------------------------------------------------------- 1 | '''adapted from https://github.com/githubharald/CTCDecoder''' 2 | from collections import defaultdict 3 | from dataclasses import dataclass 4 | from typing import Optional, List, Tuple 5 | import numpy as np 6 | from asr.language_model import LanguageModel 7 | 8 | 9 | def log(x: float) -> float: 10 | with np.errstate(divide='ignore'): 11 | return np.log(x) 12 | 13 | 14 | @dataclass 15 | class BeamEntry: 16 | """Information about one single beam at specific time-step.""" 17 | pr_total: float = log(0) # blank and non-blank 18 | pr_non_blank: float = log(0) # non-blank 19 | pr_blank: float = log(0) # blank 20 | pr_text: float = log(1) # LM score 21 | lm_applied: bool = False # flag if LM was already applied to this beam 22 | labeling: tuple = () # beam-labeling 23 | 24 | 25 | class BeamList: 26 | """Information about all beams at specific time-step.""" 27 | 28 | def __init__(self) -> None: 29 | self.entries = defaultdict(BeamEntry) 30 | 31 | def normalize(self) -> None: 32 | """Length-normalise LM score.""" 33 | for k in self.entries.keys(): 34 | labeling_len = len(self.entries[k].labeling) 35 | self.entries[k].pr_text = (1.0 / (labeling_len if labeling_len else 1.0)) * self.entries[k].pr_text 36 | 37 | def sort_labelings(self) -> List[Tuple[int]]: 38 | """Return beam-labelings, sorted by probability.""" 39 | beams = self.entries.values() 40 | sorted_beams = sorted(beams, reverse=True, key=lambda x: x.pr_total + x.pr_text) 41 | return [x.labeling for x in sorted_beams] 42 | 43 | 44 | class BeamSearchDecoder: 45 | def __init__(self, vocab: list, blank_idx: int, beam_width: int = 5, num_sentences: int = 1, lm: Optional[LanguageModel] = None): 46 | self.vocab = vocab 47 | self.blank_idx = blank_idx 48 | self.beam_width = beam_width 49 | self.num_sentences = num_sentences 50 | self.lm = lm 51 | 52 | def __apply_lm(self, parent_beam: BeamEntry, child_beam: BeamEntry) -> None: 53 | """Calculate LM score of child beam by taking score from 54 | parent beam and bigram probability of last two chars.""" 55 | if not self.lm or child_beam.lm_applied: 56 | return 57 | # take bigram if beam length at least 2 58 | if len(child_beam.labeling) > 1: 59 | c = self.vocab[child_beam.labeling[-2]] 60 | d = self.vocab[child_beam.labeling[-1]] 61 | ngram_prob = self.lm.get_char_bigram(c, d) 62 | # otherwise take unigram 63 | else: 64 | c = self.vocab[child_beam.labeling[-1]] 65 | ngram_prob = self.lm.get_char_unigram(c) 66 | 67 | lm_factor = 0.01 # influence of language model 68 | # probability of char sequence 69 | child_beam.pr_text = parent_beam.pr_text + lm_factor * log(ngram_prob) 70 | child_beam.lm_applied = True # only apply LM once per beam entry 71 | 72 | def __call__(self, logits: np.array) -> list: 73 | """Beam search decoder. 74 | 75 | See the paper of Hwang et al. and the paper of Graves et al. 76 | 77 | Args: 78 | logits: Output of neural network of shape TxC. 79 | 80 | Returns: 81 | The decoded text. 82 | """ 83 | 84 | max_T, max_C = logits.shape 85 | 86 | # initialise beam state 87 | last = BeamList() 88 | labeling = () 89 | last.entries[labeling] = BeamEntry() 90 | last.entries[labeling].pr_blank = log(1) 91 | last.entries[labeling].pr_total = log(1) 92 | 93 | # go over all time-steps 94 | for t in range(max_T): 95 | curr = BeamList() 96 | 97 | # get beam-labelings of best beams 98 | best_labelings = last.sort_labelings()[:self.beam_width] 99 | 100 | # go over best beams 101 | for labeling in best_labelings: 102 | # probability of paths ending with a non-blank 103 | pr_non_blank = log(0) 104 | # in case of non-empty beam 105 | if labeling: 106 | # probability of paths with repeated last char at the end 107 | pr_non_blank = last.entries[labeling].pr_non_blank + log(logits[t, labeling[-1]]) 108 | 109 | # probability of paths ending with a blank 110 | pr_blank = last.entries[labeling].pr_total + log(logits[t, self.blank_idx]) 111 | 112 | # fill in data for current beam 113 | curr.entries[labeling].labeling = labeling 114 | curr.entries[labeling].pr_non_blank = np.logaddexp(curr.entries[labeling].pr_non_blank, pr_non_blank) 115 | curr.entries[labeling].pr_blank = np.logaddexp(curr.entries[labeling].pr_blank, pr_blank) 116 | curr.entries[labeling].pr_total = np.logaddexp(curr.entries[labeling].pr_total, 117 | np.logaddexp(pr_blank, pr_non_blank)) 118 | curr.entries[labeling].pr_text = last.entries[labeling].pr_text 119 | curr.entries[labeling].lm_applied = True # LM already applied at previous time-step for this beam-labeling 120 | 121 | # extend current beam-labeling 122 | for c in range(max_C - 1): 123 | # add new char to current beam-labeling 124 | new_labeling = labeling + (c,) 125 | 126 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank 127 | if labeling and labeling[-1] == c: 128 | pr_non_blank = last.entries[labeling].pr_blank + log(logits[t, c]) 129 | else: 130 | pr_non_blank = last.entries[labeling].pr_total + log(logits[t, c]) 131 | 132 | # fill in data 133 | curr.entries[new_labeling].labeling = new_labeling 134 | curr.entries[new_labeling].pr_non_blank = np.logaddexp(curr.entries[new_labeling].pr_non_blank, 135 | pr_non_blank) 136 | curr.entries[new_labeling].pr_total = np.logaddexp(curr.entries[new_labeling].pr_total, pr_non_blank) 137 | 138 | # apply LM 139 | self.__apply_lm(curr.entries[labeling], curr.entries[new_labeling]) 140 | 141 | # set new beam state 142 | last = curr 143 | 144 | # normalise LM scores according to beam-labeling-length 145 | last.normalize() 146 | 147 | # sort by probability and get most probable labelings 148 | best_labeling = last.sort_labelings()[:self.num_sentences] 149 | return best_labeling 150 | -------------------------------------------------------------------------------- /asr/utils/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import asyncio 3 | import sounddevice as sd 4 | import numpy as np 5 | import soundfile as sf 6 | from scipy.signal import resample 7 | 8 | import torch 9 | from torchaudio.transforms import Resample 10 | 11 | 12 | class MicrophoneCaptureFailed(Exception): 13 | pass 14 | 15 | 16 | class MicrophoneStreaming: 17 | def __init__(self, sr=16000, blocksize=1024, channels=1, device=None, loop=None, dtype="float32"): 18 | self._sr = sr 19 | self._channels = channels 20 | self._device = device 21 | self._buffer = asyncio.Queue() 22 | self._buffersize = blocksize 23 | self._dtype = dtype 24 | self._loop = loop 25 | 26 | def __callback(self, indata, frame_count, time_info, status): 27 | self._loop.call_soon_threadsafe(self._buffer.put_nowait, (indata.copy(), status)) 28 | 29 | async def record_to_file(self, filename, duration=None): 30 | with sf.SoundFile(filename, mode='x', samplerate=self._sr, channels=self._channels) as f: 31 | t = time.time() 32 | rec = duration if duration is not None else 10 33 | async for block, status in self.generator(): 34 | f.write(block) 35 | rec = duration+0 if duration is not None else duration+1 36 | if(time.time() - t) > rec: 37 | break 38 | 39 | async def generator(self, future: asyncio.Future = None): 40 | if self._loop is None: 41 | self._loop = asyncio.get_running_loop() 42 | stream = sd.InputStream( 43 | samplerate=self._sr, 44 | device=self._device, 45 | channels=self._channels, 46 | callback=self.__callback, 47 | dtype=self._dtype, 48 | blocksize=self._buffersize) 49 | with stream: 50 | if not stream.active: 51 | # if it was not called start() or exception was raised 52 | # in the audio callback 53 | if future: 54 | # if the future is waiting for the start or any failure 55 | # set the exception 56 | future.set_exception(f"Could not open the {self._device} capture device") 57 | 58 | # coroutine also will be notified 59 | raise MicrophoneCaptureFailed 60 | else: 61 | if future: 62 | # if the future is waiting for the start or any failure 63 | # set True meaning that the microphone was successfully opened 64 | future.set_result(True) 65 | 66 | while stream.active: 67 | indata, status = await self._buffer.get() 68 | yield indata.squeeze(), status 69 | 70 | 71 | class AudioStreaming: 72 | def __init__(self, audio_path, blocksize, sr=16000, overlap=0, padding=None, dtype="float32"): 73 | assert blocksize >= 0, "blocksize cannot be 0 or negative" 74 | self._sr = sr 75 | self._orig_sr = sf.info(audio_path).samplerate 76 | self._sf_blocks = sf.blocks(audio_path, 77 | blocksize=blocksize, 78 | overlap=overlap, 79 | fill_value=padding, 80 | dtype=dtype) 81 | 82 | async def generator(self, future: asyncio.Future=None): 83 | for block in self._sf_blocks: 84 | chunk = await self.__resample_file(block, self._orig_sr, self._sr) 85 | yield chunk, self._orig_sr 86 | 87 | async def __resample_file(self, array, original_sr, target_sr): 88 | resampling_transform = Resample(orig_freq=original_sr, 89 | new_freq=target_sr) 90 | 91 | sample = resampling_transform(torch.Tensor([array])).squeeze() 92 | return sample 93 | 94 | 95 | class AudioReader: 96 | def __init__(self, audio_path, sr=16000, dtype="float32"): 97 | self._sr = sr 98 | self._dtype = dtype 99 | self._audio_path = audio_path 100 | 101 | def read(self): 102 | data, sr = sf.read(self._audio_path, dtype=self._dtype) 103 | data = self.__resample_file(data, sr, self._sr) 104 | return data, sr 105 | 106 | def __resample_file(self, array, original_sr, target_sr): 107 | return resample(array, num=int(len(array)*target_sr/original_sr)) -------------------------------------------------------------------------------- /asr/wav2vec2/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import Wav2Vec2ASR -------------------------------------------------------------------------------- /asr/wav2vec2/decoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/asr/wav2vec2/decoder/__init__.py -------------------------------------------------------------------------------- /asr/wav2vec2/decoder/ctc_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from asr.language_model import LanguageModel 4 | from asr.utils import BeamSearchDecoder 5 | from asr.wav2vec2.vocab import vocab_list 6 | 7 | 8 | class CTCDecoder: 9 | def __init__(self, blank_idx: int = 0, beam_width: int = 100, lm_path: str = None): 10 | """constructor 11 | 12 | Args: 13 | blank_idx (int, optional): index of ctc blank token. Defaults to 0. 14 | beam_width (int, optional): beam width to search larget the value gives more accurate decoding costing computation. Defaults to 100. 15 | lm_path (str, optional): path to langugage model folder with unigram and bigrams. Defaults to None. 16 | """ 17 | lm = None 18 | if beam_width <= 1: 19 | self.mode = "greedy" 20 | else: 21 | self.mode = "beam" 22 | if lm_path is not None: 23 | self.mode = "beam_lm" 24 | lm = LanguageModel(chars=vocab_list[1:]) 25 | lm.load(lm_path) 26 | self._beam_search = BeamSearchDecoder(vocab_list[1:], 27 | blank_idx, 28 | beam_width, 29 | lm=lm) 30 | 31 | def __call__(self, logits: torch.tensor): 32 | return self.decode(logits) 33 | 34 | def decode(self, logits: torch.tensor): 35 | """decode logits using greedy method or beam search if beam width <= 1 then greedy else beam search. 36 | 37 | Args: 38 | logits (torch.tensor): logits from model outputs 39 | 40 | Returns: 41 | np.array: ctc decoded output 42 | """ 43 | out_proba = torch.nn.functional.softmax(logits, dim=-1)[0] 44 | if self.mode == "greedy": 45 | out = self._greedy_path(out_proba).cpu().numpy() 46 | elif self.mode == "beam": 47 | out = self._beam_search(out_proba.cpu().numpy())[0] 48 | elif self.mode == "beam_lm": 49 | out = self._beam_search(out_proba.cpu().numpy())[0] 50 | else: 51 | out = None 52 | raise ValueError( 53 | "Mode not defined mode choices [greedy, beam and beam_lm]") 54 | return out 55 | 56 | def _greedy_path(self, probs: torch.tensor) -> torch.tensor: 57 | """max decoding ctc output by taking maximum probabilities from each timestep 58 | 59 | Args: 60 | probs (torch.tensor): softmax logits from model 61 | 62 | Returns: 63 | torch.tensor: max decoded outputs 64 | """ 65 | return torch.argmax(probs, axis=1) 66 | -------------------------------------------------------------------------------- /asr/wav2vec2/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import asyncio 3 | import functools 4 | import transformers 5 | import numpy as np 6 | from asr.wav2vec2.decoder.ctc_decoder import CTCDecoder 7 | 8 | 9 | class Wav2Vec2ASR: 10 | """ 11 | Wav2Vec2 class wrapper for speech recognition 12 | """ 13 | 14 | def __init__(self, sr: int = 16000, device: str = "cpu", 15 | processor_path: str = None, model_path: str = None, 16 | pretrained_model_name: str = "facebook/wav2vec2-base-960h", 17 | beam_width: int = 5, lm_path: str = None): 18 | """Wave2Vec2 class constructor 19 | 20 | Args: 21 | sr (int, optional): sample rate of audio passing as input 22 | device (str, optional): device to load model and inputs choices are 'cpu' and 'cuda. Defaults to "cpu". 23 | processor_path (str, optional): path to saved local processor files. Defaults to None. 24 | model_path (str, optional): path to saved local model. Defaults to None 25 | pretrained_model_name (str, optional): pretrained model name as per hugging face pretrained models to load. Defaults to "facebook/wav2vec2-base-960h". 26 | beam_width (int, optional): width of beam search more the number better the results but increase computation. Defaults to 5. 27 | lm_path (str, optional): path to saved language model. Defaults to None. 28 | """ 29 | self.sr = sr 30 | self.device = torch.device(device) 31 | self.processor_path = processor_path 32 | self.model_path = model_path 33 | self.pretrained_model_name = pretrained_model_name 34 | self.decoder = CTCDecoder(blank_idx=0, 35 | beam_width=beam_width, 36 | lm_path=lm_path) 37 | 38 | def load(self): 39 | """load models and processors 40 | """ 41 | processor = (transformers.Wav2Vec2Processor.from_pretrained(self.pretrained_model_name) 42 | if self.processor_path is None else torch.load(self.processor_path)) 43 | model = (transformers.Wav2Vec2ForCTC.from_pretrained(self.pretrained_model_name) 44 | if self.model_path is None else torch.load(self.model_path)) 45 | model.eval() 46 | model.to(self.device) 47 | self.model = model 48 | self.processor = processor 49 | 50 | def _transcribe(self, inputs: torch.tensor) -> str: 51 | """transcribe input speech and return resulting transcription 52 | 53 | Args: 54 | inputs (torch.tensor): single raw speech torch tensor (timestep,1) 55 | 56 | Returns: 57 | str: transcription of raw speech signal 58 | """ 59 | inputs = self.processor(inputs, sampling_rate=self.sr, 60 | padding="longest", 61 | return_tensors='pt').input_values.to(self.device) 62 | with torch.no_grad(): 63 | logits = self.model(inputs).logits 64 | outs = self.decoder(logits) 65 | return self.processor.decode(outs) 66 | 67 | async def capture_and_transcribe(self, 68 | stream_obj, 69 | started_future: asyncio.Future = None, 70 | loop=None): 71 | """capture streaming audio and transcribe 72 | 73 | Args: 74 | stream_obj (asr.utils.MicrophoneStreaming or asr.utils.AudioStreaming): streaming object with generator that yields audio blocks 75 | started_future (asyncio.Future, optional): asyncio future. Defaults to None. 76 | loop (optional): asyncio event loop which we can get using asyncio.get_running_loop(). Defaults to None. 77 | 78 | Yields: 79 | [generator object]: returns generator that yield outputs from streaming audio 80 | """ 81 | if loop is None: 82 | loop = asyncio.get_running_loop() 83 | async for block, status in stream_obj.generator(started_future): 84 | process_func = functools.partial(self._transcribe, inputs=block) 85 | transcriptions = await loop.run_in_executor(None, process_func) 86 | yield transcriptions 87 | 88 | async def transcribe(self, inputs: torch.tensor, loop=None): 89 | """transcribe and audio signal use for offline audio transcription 90 | 91 | Args: 92 | inputs (torch.tensor): raw speech signal as pytorch tensor (timestep,1) 93 | loop (optional): asyncio event loop which we can get using asyncio.get_running_loop(). Defaults to None. 94 | 95 | Returns: 96 | [corountine object]: coroutine object which we get await and get results asynchronously 97 | """ 98 | if loop is None: 99 | loop = asyncio.get_running_loop() 100 | process_func = functools.partial(self._transcribe, inputs=inputs) 101 | return await loop.run_in_executor(None, process_func) 102 | -------------------------------------------------------------------------------- /asr/wav2vec2/vocab.py: -------------------------------------------------------------------------------- 1 | vocab_dict = {"": 0, 2 | "": 1, 3 | "": 2, 4 | "": 3, 5 | "|": 4, 6 | "E": 5, 7 | "T": 6, 8 | "A": 7, 9 | "O": 8, 10 | "N": 9, 11 | "I": 10, 12 | "H": 11, 13 | "S": 12, 14 | "R": 13, 15 | "D": 14, 16 | "L": 15, 17 | "U": 16, 18 | "M": 17, 19 | "W": 18, 20 | "C": 19, 21 | "F": 20, 22 | "G": 21, 23 | "Y": 22, 24 | "P": 23, 25 | "B": 24, 26 | "V": 25, 27 | "K": 26, 28 | "'": 27, 29 | "X": 28, 30 | "J": 29, 31 | "Q": 30, 32 | "Z": 31 33 | } 34 | 35 | vocab_list = [key for key, value in vocab_dict.items()] 36 | -------------------------------------------------------------------------------- /asr_inference_live.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import functools 4 | import sys 5 | from asr.utils import MicrophoneStreaming 6 | from asr.wav2vec2 import Wav2Vec2ASR 7 | 8 | parser = argparse.ArgumentParser(description="ASR with live audio") 9 | parser.add_argument("--model", "-m", default=None, required=False, 10 | help="Trained Model local path") 11 | parser.add_argument("--processor", "-t", default=None, required=False, 12 | help="Local asr processor path") 13 | parser.add_argument("--blocksize", "-bs", default=16000, type=int, required=False, 14 | help="Size of each audio block to be passed to model") 15 | parser.add_argument("--output", "-out", required=False, 16 | help="path to save resultant transcriptions") 17 | parser.add_argument("--device", "-d", default='cpu', nargs='?', choices=['cuda', 'cpu'], required=False, 18 | help="device to use for inferencing") 19 | parser.add_argument("--beam_width", "-bw", default=1, type=int, required=False, 20 | help="beam width to use for beam search decoder during inferencing") 21 | parser.add_argument("--lm", "-l", default=None, required=False, 22 | help="Trained lm folder path with unigram and bigram files") 23 | parser.add_argument("--pretrained_model_name", "-pwmn", default="facebook/wav2vec2-base-960h", 24 | type=str, required=False, help="Pretrained wav2vec2 model name") 25 | 26 | args = parser.parse_args() 27 | 28 | asr = Wav2Vec2ASR(device=args.device, 29 | processor_path=args.processor, 30 | model_path=args.model, 31 | pretrained_model_name=args.pretrained_model_name, 32 | beam_width=args.beam_width, 33 | lm_path=args.lm) 34 | 35 | print("Loading Models ...") 36 | asr.load() 37 | print("Models Loaded ...") 38 | 39 | 40 | def write_to_file(output_file, transcriptions): 41 | output_file.write(transcriptions) 42 | 43 | 44 | def print_transcription(transcription): 45 | print(transcription, end=" ") 46 | sys.stdout.flush() 47 | 48 | 49 | async def main(output_file=None): 50 | loop = asyncio.get_running_loop() 51 | stream = MicrophoneStreaming(blocksize=args.blocksize, loop=loop) 52 | async for transcription in asr.capture_and_transcribe(stream, loop=loop): 53 | if not transcription == "": 54 | print_func = functools.partial( 55 | print_transcription, transcription=transcription) 56 | await loop.run_in_executor(None, print_func) 57 | if output_file is not None: 58 | write_func = functools.partial(write_to_file, output_file=output_file, 59 | transcriptions=transcriptions) 60 | await loop.run_in_executor(None, write_func) 61 | 62 | if __name__ == "__main__": 63 | print("Start Transcribing...") 64 | try: 65 | if args.output: 66 | with open(args.output, "w") as f: 67 | asyncio.run(main(f)) 68 | else: 69 | asyncio.run(main()) 70 | except KeyboardInterrupt: 71 | print("Exited") 72 | -------------------------------------------------------------------------------- /asr_inference_offline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import functools 4 | from asr.utils import AudioReader 5 | from asr.wav2vec2 import Wav2Vec2ASR 6 | 7 | parser = argparse.ArgumentParser( 8 | description="ASR with recorded audio (offline)") 9 | parser.add_argument("--recording", "-rec", required=True, 10 | help="path to recording file") 11 | parser.add_argument("--model", "-m", default=None, required=False, 12 | help="path to local saved model") 13 | parser.add_argument("--processor", "-t", default=None, required=False, 14 | help="path to local saved processor") 15 | parser.add_argument("--output", "-out", required=False, 16 | help="path to save resultant transcriptions") 17 | parser.add_argument("--lm", "-l", default=None, required=False, 18 | help="Trained lm folder path with unigram and bigram files") 19 | parser.add_argument("--device", "-d", default='cpu', nargs='?', choices=['cuda', 'cpu'], required=False, 20 | help="device to use for inferencing") 21 | parser.add_argument("--beam_width", "-bw", default=1, type=int, required=False, 22 | help="beam width to use for beam search decoder during inferencing") 23 | parser.add_argument("--pretrained_model_name", "-pwmn", default="facebook/wav2vec2-base-960h", 24 | type=str, required=False, help="Pretrained wav2vec2 model name") 25 | 26 | args = parser.parse_args() 27 | 28 | asr = Wav2Vec2ASR(device=args.device, 29 | processor_path=args.processor, 30 | model_path=args.model, 31 | pretrained_model_name=args.pretrained_model_name, 32 | beam_width=args.beam_width, 33 | lm_path=args.lm) 34 | 35 | print("Loading Models ...") 36 | asr.load() 37 | print("Models Loaded ...") 38 | 39 | 40 | async def main(): 41 | loop = asyncio.get_running_loop() 42 | reader = AudioReader(audio_path=args.recording, 43 | sr=16000, 44 | dtype="float32") 45 | inputs, sr = reader.read() 46 | transcriptions = await asr.transcribe(inputs, loop=loop) 47 | print(transcriptions) 48 | if args.output: 49 | with open(args.output, "w") as f: 50 | f.write(transcriptions) 51 | 52 | 53 | if __name__ == "__main__": 54 | print("Start Transcribing...") 55 | try: 56 | asyncio.run(main()) 57 | except KeyboardInterrupt: 58 | print("Exited") 59 | -------------------------------------------------------------------------------- /asr_inference_recording.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import functools 4 | import sys 5 | from asr.utils import AudioStreaming 6 | from asr.wav2vec2 import Wav2Vec2ASR 7 | 8 | parser = argparse.ArgumentParser(description="ASR with live audio") 9 | parser.add_argument("--recording", "-rec", required=True, 10 | help="path to recording file") 11 | parser.add_argument("--model", "-m", default=None, required=False, 12 | help="Trained Model local path") 13 | parser.add_argument("--processor", "-t", default=None, required=False, 14 | help="Local asr processor path") 15 | parser.add_argument("--blocksize", "-bs", default=16000, type=int, required=False, 16 | help="Size of each audio block to be passed to model") 17 | parser.add_argument("--overlap", "-ov", default=0, type=int, required=False, 18 | help="Overlapping amount in audio blocks") 19 | parser.add_argument("--output", "-out", required=False, 20 | help="path to save resultant transcriptions") 21 | parser.add_argument("--device", "-d", default='cpu', nargs='?', choices=['cuda', 'cpu'], required=False, 22 | help="device to use for inferencing") 23 | parser.add_argument("--beam_width", "-bw", default=1, type=int, required=False, 24 | help="beam width to use for beam search decoder during inferencing") 25 | parser.add_argument("--lm", "-l", default=None, required=False, 26 | help="Trained lm folder path with unigram and bigram files") 27 | parser.add_argument("--pretrained_model_name", "-pwmn", default="facebook/wav2vec2-base-960h", 28 | type=str, required=False, help="Pretrained wav2vec2 model name") 29 | 30 | args = parser.parse_args() 31 | 32 | asr = Wav2Vec2ASR(device=args.device, 33 | processor_path=args.processor, 34 | model_path=args.model, 35 | pretrained_model_name=args.pretrained_model_name, 36 | beam_width=args.beam_width, 37 | lm_path=args.lm) 38 | 39 | print("Loading Models ...") 40 | asr.load() 41 | print("Models Loaded ...") 42 | 43 | 44 | def write_to_file(output_file, transcriptions): 45 | output_file.write(transcriptions) 46 | 47 | 48 | def print_transcription(transcription): 49 | print(transcription, end=" ") 50 | sys.stdout.flush() 51 | 52 | 53 | async def main(output_file=None): 54 | loop = asyncio.get_running_loop() 55 | stream = AudioStreaming(audio_path=args.recording, blocksize=args.blocksize, 56 | overlap=args.overlap) 57 | async for transcription in asr.capture_and_transcribe(stream, loop=loop): 58 | if not transcription == "": 59 | print_func = functools.partial( 60 | print_transcription, transcription=transcription) 61 | await loop.run_in_executor(None, print_func) 62 | if output_file is not None: 63 | write_func = functools.partial(write_to_file, output_file=output_file, 64 | transcriptions=transcriptions) 65 | await loop.run_in_executor(None, write_func) 66 | 67 | if __name__ == "__main__": 68 | print("Start Transcribing...") 69 | try: 70 | if args.output: 71 | with open(args.output, "w") as f: 72 | asyncio.run(main(f)) 73 | else: 74 | asyncio.run(main()) 75 | except KeyboardInterrupt: 76 | print("Exited") 77 | -------------------------------------------------------------------------------- /data/models/lm/twitter/bigram.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/models/lm/twitter/bigram.pkl -------------------------------------------------------------------------------- /data/models/lm/twitter/unigram.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/models/lm/twitter/unigram.pkl -------------------------------------------------------------------------------- /data/models/lm/wikipedia/bigram.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/models/lm/wikipedia/bigram.pkl -------------------------------------------------------------------------------- /data/models/lm/wikipedia/unigram.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/models/lm/wikipedia/unigram.pkl -------------------------------------------------------------------------------- /data/samples/Achievements_of_the_Democratic_Party_(Homer_S._Cummings).ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/samples/Achievements_of_the_Democratic_Party_(Homer_S._Cummings).ogg -------------------------------------------------------------------------------- /data/samples/rec.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/samples/rec.wav -------------------------------------------------------------------------------- /data/samples/rec2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/samples/rec2.wav -------------------------------------------------------------------------------- /data/samples/shortened.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tarun-bisht/wav2vec2-asr/17308ac128f9762b30220f1e580f699bc98be2e7/data/samples/shortened.wav -------------------------------------------------------------------------------- /notebooks/Training_Simple_Lanugage_Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Training Simple Lanugage Model.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "widgets": { 18 | "application/vnd.jupyter.widget-state+json": { 19 | "0677f310a8104790b2f4545d1bece7b0": { 20 | "model_module": "@jupyter-widgets/controls", 21 | "model_name": "HBoxModel", 22 | "model_module_version": "1.5.0", 23 | "state": { 24 | "_view_name": "HBoxView", 25 | "_dom_classes": [], 26 | "_model_name": "HBoxModel", 27 | "_view_module": "@jupyter-widgets/controls", 28 | "_model_module_version": "1.5.0", 29 | "_view_count": null, 30 | "_view_module_version": "1.5.0", 31 | "box_style": "", 32 | "layout": "IPY_MODEL_5cb0a81006e14c4eba6da4cb4b9bf14b", 33 | "_model_module": "@jupyter-widgets/controls", 34 | "children": [ 35 | "IPY_MODEL_0cd3f576c0c2445f9817c5b56a87b900", 36 | "IPY_MODEL_4f066cb8912e40108de0c8d0d5667b09", 37 | "IPY_MODEL_35e18b46304a40e696ea98c7e05d6c94" 38 | ] 39 | } 40 | }, 41 | "5cb0a81006e14c4eba6da4cb4b9bf14b": { 42 | "model_module": "@jupyter-widgets/base", 43 | "model_name": "LayoutModel", 44 | "model_module_version": "1.2.0", 45 | "state": { 46 | "_view_name": "LayoutView", 47 | "grid_template_rows": null, 48 | "right": null, 49 | "justify_content": null, 50 | "_view_module": "@jupyter-widgets/base", 51 | "overflow": null, 52 | "_model_module_version": "1.2.0", 53 | "_view_count": null, 54 | "flex_flow": null, 55 | "width": null, 56 | "min_width": null, 57 | "border": null, 58 | "align_items": null, 59 | "bottom": null, 60 | "_model_module": "@jupyter-widgets/base", 61 | "top": null, 62 | "grid_column": null, 63 | "overflow_y": null, 64 | "overflow_x": null, 65 | "grid_auto_flow": null, 66 | "grid_area": null, 67 | "grid_template_columns": null, 68 | "flex": null, 69 | "_model_name": "LayoutModel", 70 | "justify_items": null, 71 | "grid_row": null, 72 | "max_height": null, 73 | "align_content": null, 74 | "visibility": null, 75 | "align_self": null, 76 | "height": null, 77 | "min_height": null, 78 | "padding": null, 79 | "grid_auto_rows": null, 80 | "grid_gap": null, 81 | "max_width": null, 82 | "order": null, 83 | "_view_module_version": "1.2.0", 84 | "grid_template_areas": null, 85 | "object_position": null, 86 | "object_fit": null, 87 | "grid_auto_columns": null, 88 | "margin": null, 89 | "display": null, 90 | "left": null 91 | } 92 | }, 93 | "0cd3f576c0c2445f9817c5b56a87b900": { 94 | "model_module": "@jupyter-widgets/controls", 95 | "model_name": "HTMLModel", 96 | "model_module_version": "1.5.0", 97 | "state": { 98 | "_view_name": "HTMLView", 99 | "style": "IPY_MODEL_54b399d1ad7747088cc4a73837b87c60", 100 | "_dom_classes": [], 101 | "description": "", 102 | "_model_name": "HTMLModel", 103 | "placeholder": "​", 104 | "_view_module": "@jupyter-widgets/controls", 105 | "_model_module_version": "1.5.0", 106 | "value": " 0%", 107 | "_view_count": null, 108 | "_view_module_version": "1.5.0", 109 | "description_tooltip": null, 110 | "_model_module": "@jupyter-widgets/controls", 111 | "layout": "IPY_MODEL_e2d00fc27634452383943b40e131f36d" 112 | } 113 | }, 114 | "4f066cb8912e40108de0c8d0d5667b09": { 115 | "model_module": "@jupyter-widgets/controls", 116 | "model_name": "FloatProgressModel", 117 | "model_module_version": "1.5.0", 118 | "state": { 119 | "_view_name": "ProgressView", 120 | "style": "IPY_MODEL_b8da8be70fb342b0ae900853af0bb7d1", 121 | "_dom_classes": [], 122 | "description": "", 123 | "_model_name": "FloatProgressModel", 124 | "bar_style": "danger", 125 | "max": 5824596, 126 | "_view_module": "@jupyter-widgets/controls", 127 | "_model_module_version": "1.5.0", 128 | "value": 999, 129 | "_view_count": null, 130 | "_view_module_version": "1.5.0", 131 | "orientation": "horizontal", 132 | "min": 0, 133 | "description_tooltip": null, 134 | "_model_module": "@jupyter-widgets/controls", 135 | "layout": "IPY_MODEL_3321512e63f9446f8d2d8d57d7586607" 136 | } 137 | }, 138 | "35e18b46304a40e696ea98c7e05d6c94": { 139 | "model_module": "@jupyter-widgets/controls", 140 | "model_name": "HTMLModel", 141 | "model_module_version": "1.5.0", 142 | "state": { 143 | "_view_name": "HTMLView", 144 | "style": "IPY_MODEL_ec1edc03d12241ba96e467f120e46ac0", 145 | "_dom_classes": [], 146 | "description": "", 147 | "_model_name": "HTMLModel", 148 | "placeholder": "​", 149 | "_view_module": "@jupyter-widgets/controls", 150 | "_model_module_version": "1.5.0", 151 | "value": " 999/5824596 [00:01<36:25, 2665.15it/s]", 152 | "_view_count": null, 153 | "_view_module_version": "1.5.0", 154 | "description_tooltip": null, 155 | "_model_module": "@jupyter-widgets/controls", 156 | "layout": "IPY_MODEL_030880aa5a054d788231490e0f4a7e78" 157 | } 158 | }, 159 | "54b399d1ad7747088cc4a73837b87c60": { 160 | "model_module": "@jupyter-widgets/controls", 161 | "model_name": "DescriptionStyleModel", 162 | "model_module_version": "1.5.0", 163 | "state": { 164 | "_view_name": "StyleView", 165 | "_model_name": "DescriptionStyleModel", 166 | "description_width": "", 167 | "_view_module": "@jupyter-widgets/base", 168 | "_model_module_version": "1.5.0", 169 | "_view_count": null, 170 | "_view_module_version": "1.2.0", 171 | "_model_module": "@jupyter-widgets/controls" 172 | } 173 | }, 174 | "e2d00fc27634452383943b40e131f36d": { 175 | "model_module": "@jupyter-widgets/base", 176 | "model_name": "LayoutModel", 177 | "model_module_version": "1.2.0", 178 | "state": { 179 | "_view_name": "LayoutView", 180 | "grid_template_rows": null, 181 | "right": null, 182 | "justify_content": null, 183 | "_view_module": "@jupyter-widgets/base", 184 | "overflow": null, 185 | "_model_module_version": "1.2.0", 186 | "_view_count": null, 187 | "flex_flow": null, 188 | "width": null, 189 | "min_width": null, 190 | "border": null, 191 | "align_items": null, 192 | "bottom": null, 193 | "_model_module": "@jupyter-widgets/base", 194 | "top": null, 195 | "grid_column": null, 196 | "overflow_y": null, 197 | "overflow_x": null, 198 | "grid_auto_flow": null, 199 | "grid_area": null, 200 | "grid_template_columns": null, 201 | "flex": null, 202 | "_model_name": "LayoutModel", 203 | "justify_items": null, 204 | "grid_row": null, 205 | "max_height": null, 206 | "align_content": null, 207 | "visibility": null, 208 | "align_self": null, 209 | "height": null, 210 | "min_height": null, 211 | "padding": null, 212 | "grid_auto_rows": null, 213 | "grid_gap": null, 214 | "max_width": null, 215 | "order": null, 216 | "_view_module_version": "1.2.0", 217 | "grid_template_areas": null, 218 | "object_position": null, 219 | "object_fit": null, 220 | "grid_auto_columns": null, 221 | "margin": null, 222 | "display": null, 223 | "left": null 224 | } 225 | }, 226 | "b8da8be70fb342b0ae900853af0bb7d1": { 227 | "model_module": "@jupyter-widgets/controls", 228 | "model_name": "ProgressStyleModel", 229 | "model_module_version": "1.5.0", 230 | "state": { 231 | "_view_name": "StyleView", 232 | "_model_name": "ProgressStyleModel", 233 | "description_width": "", 234 | "_view_module": "@jupyter-widgets/base", 235 | "_model_module_version": "1.5.0", 236 | "_view_count": null, 237 | "_view_module_version": "1.2.0", 238 | "bar_color": null, 239 | "_model_module": "@jupyter-widgets/controls" 240 | } 241 | }, 242 | "3321512e63f9446f8d2d8d57d7586607": { 243 | "model_module": "@jupyter-widgets/base", 244 | "model_name": "LayoutModel", 245 | "model_module_version": "1.2.0", 246 | "state": { 247 | "_view_name": "LayoutView", 248 | "grid_template_rows": null, 249 | "right": null, 250 | "justify_content": null, 251 | "_view_module": "@jupyter-widgets/base", 252 | "overflow": null, 253 | "_model_module_version": "1.2.0", 254 | "_view_count": null, 255 | "flex_flow": null, 256 | "width": null, 257 | "min_width": null, 258 | "border": null, 259 | "align_items": null, 260 | "bottom": null, 261 | "_model_module": "@jupyter-widgets/base", 262 | "top": null, 263 | "grid_column": null, 264 | "overflow_y": null, 265 | "overflow_x": null, 266 | "grid_auto_flow": null, 267 | "grid_area": null, 268 | "grid_template_columns": null, 269 | "flex": null, 270 | "_model_name": "LayoutModel", 271 | "justify_items": null, 272 | "grid_row": null, 273 | "max_height": null, 274 | "align_content": null, 275 | "visibility": null, 276 | "align_self": null, 277 | "height": null, 278 | "min_height": null, 279 | "padding": null, 280 | "grid_auto_rows": null, 281 | "grid_gap": null, 282 | "max_width": null, 283 | "order": null, 284 | "_view_module_version": "1.2.0", 285 | "grid_template_areas": null, 286 | "object_position": null, 287 | "object_fit": null, 288 | "grid_auto_columns": null, 289 | "margin": null, 290 | "display": null, 291 | "left": null 292 | } 293 | }, 294 | "ec1edc03d12241ba96e467f120e46ac0": { 295 | "model_module": "@jupyter-widgets/controls", 296 | "model_name": "DescriptionStyleModel", 297 | "model_module_version": "1.5.0", 298 | "state": { 299 | "_view_name": "StyleView", 300 | "_model_name": "DescriptionStyleModel", 301 | "description_width": "", 302 | "_view_module": "@jupyter-widgets/base", 303 | "_model_module_version": "1.5.0", 304 | "_view_count": null, 305 | "_view_module_version": "1.2.0", 306 | "_model_module": "@jupyter-widgets/controls" 307 | } 308 | }, 309 | "030880aa5a054d788231490e0f4a7e78": { 310 | "model_module": "@jupyter-widgets/base", 311 | "model_name": "LayoutModel", 312 | "model_module_version": "1.2.0", 313 | "state": { 314 | "_view_name": "LayoutView", 315 | "grid_template_rows": null, 316 | "right": null, 317 | "justify_content": null, 318 | "_view_module": "@jupyter-widgets/base", 319 | "overflow": null, 320 | "_model_module_version": "1.2.0", 321 | "_view_count": null, 322 | "flex_flow": null, 323 | "width": null, 324 | "min_width": null, 325 | "border": null, 326 | "align_items": null, 327 | "bottom": null, 328 | "_model_module": "@jupyter-widgets/base", 329 | "top": null, 330 | "grid_column": null, 331 | "overflow_y": null, 332 | "overflow_x": null, 333 | "grid_auto_flow": null, 334 | "grid_area": null, 335 | "grid_template_columns": null, 336 | "flex": null, 337 | "_model_name": "LayoutModel", 338 | "justify_items": null, 339 | "grid_row": null, 340 | "max_height": null, 341 | "align_content": null, 342 | "visibility": null, 343 | "align_self": null, 344 | "height": null, 345 | "min_height": null, 346 | "padding": null, 347 | "grid_auto_rows": null, 348 | "grid_gap": null, 349 | "max_width": null, 350 | "order": null, 351 | "_view_module_version": "1.2.0", 352 | "grid_template_areas": null, 353 | "object_position": null, 354 | "object_fit": null, 355 | "grid_auto_columns": null, 356 | "margin": null, 357 | "display": null, 358 | "left": null 359 | } 360 | } 361 | } 362 | } 363 | }, 364 | "cells": [ 365 | { 366 | "cell_type": "markdown", 367 | "metadata": { 368 | "id": "F3MEfwY-PRn1" 369 | }, 370 | "source": [ 371 | "## Simple Language Model Training with wikipedia data" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "metadata": { 377 | "id": "FkkJxWk89_kM" 378 | }, 379 | "source": [ 380 | "import os\n", 381 | "import pickle\n", 382 | "import tensorflow as tf\n", 383 | "import tensorflow_datasets as tfds\n", 384 | "import numpy as np\n", 385 | "from tqdm.auto import tqdm" 386 | ], 387 | "execution_count": 21, 388 | "outputs": [] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "metadata": { 393 | "id": "wJmMm4doOjxM" 394 | }, 395 | "source": [ 396 | "class LanguageModel:\n", 397 | " \"\"\"Simple character-level language model\"\"\"\n", 398 | "\n", 399 | " def __init__(self, chars: list) -> None:\n", 400 | " self._unigram = {c: 0 for c in chars}\n", 401 | " self._bigram = {c: {d: 0 for d in chars} for c in chars}\n", 402 | " self.chars = chars\n", 403 | "\n", 404 | " def get_char_unigram(self, c: str) -> float:\n", 405 | " \"\"\"Probability of character c.\"\"\"\n", 406 | " return self._unigram[c]\n", 407 | "\n", 408 | " def get_char_bigram(self, c: str, d: str) -> float:\n", 409 | " \"\"\"Probability that character c is followed by character d.\"\"\"\n", 410 | " return self._bigram[c][d]\n", 411 | "\n", 412 | " def train(self, txt: str):\n", 413 | " \"\"\"Create language model from text corpus.\"\"\"\n", 414 | " # compute unigrams\n", 415 | " for c in txt:\n", 416 | " # ignore unknown chars\n", 417 | " if c not in self._unigram:\n", 418 | " continue\n", 419 | " self._unigram[c] += 1\n", 420 | "\n", 421 | " # compute bigrams\n", 422 | " for i in range(len(txt) - 1):\n", 423 | " c = txt[i]\n", 424 | " d = txt[i + 1]\n", 425 | "\n", 426 | " # ignore unknown chars\n", 427 | " if c not in self._bigram or d not in self._bigram[c]:\n", 428 | " continue\n", 429 | "\n", 430 | " self._bigram[c][d] += 1\n", 431 | "\n", 432 | " def normalize(self):\n", 433 | " # normalize\n", 434 | " sum_unigram = sum(self._unigram.values())\n", 435 | " for c in self.chars:\n", 436 | " self._unigram[c] /= sum_unigram\n", 437 | "\n", 438 | " for c in self.chars:\n", 439 | " sum_bigram = sum(self._bigram[c].values())\n", 440 | " if sum_bigram == 0:\n", 441 | " continue\n", 442 | " for d in self.chars:\n", 443 | " self._bigram[c][d] /= sum_bigram\n", 444 | "\n", 445 | " def save(self, path):\n", 446 | " with open(os.path.join(path, \"unigram.pkl\"), 'wb') as pkl:\n", 447 | " pickle.dump(self._unigram, pkl)\n", 448 | " with open(os.path.join(path, \"bigram.pkl\"), 'wb') as pkl:\n", 449 | " pickle.dump(self._bigram, pkl)\n", 450 | "\n", 451 | " def load(self, path):\n", 452 | " with open(os.path.join(path, \"unigram.pkl\"), 'rb') as pkl:\n", 453 | " self._unigram = pickle.load(pkl)\n", 454 | " with open(os.path.join(path, \"bigram.pkl\"), 'rb') as pkl:\n", 455 | " self._bigram = pickle.load(pkl)" 456 | ], 457 | "execution_count": 22, 458 | "outputs": [] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "metadata": { 463 | "id": "BdAZjTyu2EgL" 464 | }, 465 | "source": [ 466 | "# Loading the wikipedia dataset.\n", 467 | "DATASET_NAME = 'wikipedia/20190301.en'\n", 468 | "# DATASET_NAME = 'wikipedia/20190301.uk'\n", 469 | "\n", 470 | "dataset, dataset_info = tfds.load(\n", 471 | " name=DATASET_NAME,\n", 472 | " data_dir='tmp',\n", 473 | " with_info=True,\n", 474 | " split=tfds.Split.TRAIN)" 475 | ], 476 | "execution_count": 23, 477 | "outputs": [] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "metadata": { 482 | "colab": { 483 | "base_uri": "https://localhost:8080/" 484 | }, 485 | "id": "bMhatKag2Nw9", 486 | "outputId": "412f4dce-a3cf-4779-a0aa-e3963c625cba" 487 | }, 488 | "source": [ 489 | "print(dataset)" 490 | ], 491 | "execution_count": 24, 492 | "outputs": [ 493 | { 494 | "output_type": "stream", 495 | "name": "stdout", 496 | "text": [ 497 | "\n" 498 | ] 499 | } 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "metadata": { 505 | "colab": { 506 | "base_uri": "https://localhost:8080/" 507 | }, 508 | "id": "CUoeROKZ6bw7", 509 | "outputId": "f57b710d-f87a-4b51-9346-5a76453a8646" 510 | }, 511 | "source": [ 512 | "TRAIN_NUM_EXAMPLES = dataset_info.splits['train'].num_examples\n", 513 | "print('Total number of articles: ', TRAIN_NUM_EXAMPLES)" 514 | ], 515 | "execution_count": 25, 516 | "outputs": [ 517 | { 518 | "output_type": "stream", 519 | "name": "stdout", 520 | "text": [ 521 | "Total number of articles: 5824596\n" 522 | ] 523 | } 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "metadata": { 529 | "id": "Vrs6RqLE_of_" 530 | }, 531 | "source": [ 532 | "vocab_dict = {\"\": 0,\n", 533 | " \"\": 1,\n", 534 | " \"\": 2,\n", 535 | " \"\": 3,\n", 536 | " \"|\": 4,\n", 537 | " \"E\": 5,\n", 538 | " \"T\": 6,\n", 539 | " \"A\": 7,\n", 540 | " \"O\": 8,\n", 541 | " \"N\": 9,\n", 542 | " \"I\": 10,\n", 543 | " \"H\": 11,\n", 544 | " \"S\": 12,\n", 545 | " \"R\": 13,\n", 546 | " \"D\": 14,\n", 547 | " \"L\": 15,\n", 548 | " \"U\": 16,\n", 549 | " \"M\": 17,\n", 550 | " \"W\": 18,\n", 551 | " \"C\": 19,\n", 552 | " \"F\": 20,\n", 553 | " \"G\": 21,\n", 554 | " \"Y\": 22,\n", 555 | " \"P\": 23,\n", 556 | " \"B\": 24,\n", 557 | " \"V\": 25,\n", 558 | " \"K\": 26,\n", 559 | " \"'\": 27,\n", 560 | " \"X\": 28,\n", 561 | " \"J\": 29,\n", 562 | " \"Q\": 30,\n", 563 | " \"Z\": 31\n", 564 | " }\n", 565 | "\n", 566 | "vocab_list = [key for key, value in vocab_dict.items()]" 567 | ], 568 | "execution_count": 26, 569 | "outputs": [] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "metadata": { 574 | "id": "Bi3qjUKjAYR-" 575 | }, 576 | "source": [ 577 | "def change_digit_to_word(x):\n", 578 | " x = x.replace(\"0\", \"zero \")\n", 579 | " x = x.replace(\"1\", \"one \")\n", 580 | " x = x.replace(\"2\", \"two \")\n", 581 | " x = x.replace(\"3\", \"three \")\n", 582 | " x = x.replace(\"4\", \"four \")\n", 583 | " x = x.replace(\"5\", \"five \")\n", 584 | " x = x.replace(\"6\", \"six \")\n", 585 | " x = x.replace(\"7\", \"seven \")\n", 586 | " x = x.replace(\"8\", \"eight \")\n", 587 | " x = x.replace(\"9\", \"nine \")\n", 588 | " x = x.replace(\" \", \" \")\n", 589 | " x = x.strip()\n", 590 | " return x" 591 | ], 592 | "execution_count": 27, 593 | "outputs": [] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "metadata": { 598 | "id": "mE4QEWz8_ORR" 599 | }, 600 | "source": [ 601 | "lm = LanguageModel(chars=vocab_list[1:])" 602 | ], 603 | "execution_count": 31, 604 | "outputs": [] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "metadata": { 609 | "colab": { 610 | "base_uri": "https://localhost:8080/", 611 | "height": 49, 612 | "referenced_widgets": [ 613 | "0677f310a8104790b2f4545d1bece7b0", 614 | "5cb0a81006e14c4eba6da4cb4b9bf14b", 615 | "0cd3f576c0c2445f9817c5b56a87b900", 616 | "4f066cb8912e40108de0c8d0d5667b09", 617 | "35e18b46304a40e696ea98c7e05d6c94", 618 | "54b399d1ad7747088cc4a73837b87c60", 619 | "e2d00fc27634452383943b40e131f36d", 620 | "b8da8be70fb342b0ae900853af0bb7d1", 621 | "3321512e63f9446f8d2d8d57d7586607", 622 | "ec1edc03d12241ba96e467f120e46ac0", 623 | "030880aa5a054d788231490e0f4a7e78" 624 | ] 625 | }, 626 | "id": "ap0wYRTa4w6L", 627 | "outputId": "1ba7cd8d-d70b-489b-9dc7-a143668c1fe0" 628 | }, 629 | "source": [ 630 | "sample_per_corpus = 1000\n", 631 | "corpus = \"\"\n", 632 | "step = 0\n", 633 | "for example in tqdm(dataset):\n", 634 | " corpus += example['title'].numpy().decode('utf-8')\n", 635 | " corpus += \" \"\n", 636 | " corpus += example['text'].numpy().decode('utf-8')\n", 637 | " step += 1\n", 638 | " if step == sample_per_corpus:\n", 639 | " lm.train(corpus)\n", 640 | " step = 0\n", 641 | " corpus = \"\"\n", 642 | "lm.normalize()" 643 | ], 644 | "execution_count": 32, 645 | "outputs": [ 646 | { 647 | "output_type": "display_data", 648 | "data": { 649 | "application/vnd.jupyter.widget-view+json": { 650 | "model_id": "0677f310a8104790b2f4545d1bece7b0", 651 | "version_minor": 0, 652 | "version_major": 2 653 | }, 654 | "text/plain": [ 655 | " 0%| | 0/5824596 [00:00\n", 1556 | "var my_div = document.createElement(\"DIV\");\n", 1557 | "var my_p = document.createElement(\"P\");\n", 1558 | "var my_btn = document.createElement(\"BUTTON\");\n", 1559 | "var t = document.createTextNode(\"Press to start recording\");\n", 1560 | "\n", 1561 | "my_btn.appendChild(t);\n", 1562 | "//my_p.appendChild(my_btn);\n", 1563 | "my_div.appendChild(my_btn);\n", 1564 | "document.body.appendChild(my_div);\n", 1565 | "\n", 1566 | "var base64data = 0;\n", 1567 | "var reader;\n", 1568 | "var recorder, gumStream;\n", 1569 | "var recordButton = my_btn;\n", 1570 | "\n", 1571 | "var handleSuccess = function(stream) {\n", 1572 | " gumStream = stream;\n", 1573 | " var options = {\n", 1574 | " //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n", 1575 | " mimeType : 'audio/webm;codecs=opus'\n", 1576 | " //mimeType : 'audio/webm;codecs=pcm'\n", 1577 | " }; \n", 1578 | " //recorder = new MediaRecorder(stream, options);\n", 1579 | " recorder = new MediaRecorder(stream);\n", 1580 | " recorder.ondataavailable = function(e) { \n", 1581 | " var url = URL.createObjectURL(e.data);\n", 1582 | " var preview = document.createElement('audio');\n", 1583 | " preview.controls = true;\n", 1584 | " preview.src = url;\n", 1585 | " document.body.appendChild(preview);\n", 1586 | "\n", 1587 | " reader = new FileReader();\n", 1588 | " reader.readAsDataURL(e.data); \n", 1589 | " reader.onloadend = function() {\n", 1590 | " base64data = reader.result;\n", 1591 | " //console.log(\"Inside FileReader:\" + base64data);\n", 1592 | " }\n", 1593 | " };\n", 1594 | " recorder.start();\n", 1595 | " };\n", 1596 | "\n", 1597 | "recordButton.innerText = \"Recording... press to stop\";\n", 1598 | "\n", 1599 | "navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n", 1600 | "\n", 1601 | "\n", 1602 | "function toggleRecording() {\n", 1603 | " if (recorder && recorder.state == \"recording\") {\n", 1604 | " recorder.stop();\n", 1605 | " gumStream.getAudioTracks()[0].stop();\n", 1606 | " recordButton.innerText = \"Saving the recording... pls wait!\"\n", 1607 | " }\n", 1608 | "}\n", 1609 | "\n", 1610 | "// https://stackoverflow.com/a/951057\n", 1611 | "function sleep(ms) {\n", 1612 | " return new Promise(resolve => setTimeout(resolve, ms));\n", 1613 | "}\n", 1614 | "\n", 1615 | "var data = new Promise(resolve=>{\n", 1616 | "//recordButton.addEventListener(\"click\", toggleRecording);\n", 1617 | "recordButton.onclick = ()=>{\n", 1618 | "toggleRecording()\n", 1619 | "\n", 1620 | "sleep(2000).then(() => {\n", 1621 | " // wait 2000ms for the data to be available...\n", 1622 | " // ideally this should use something like await...\n", 1623 | " //console.log(\"Inside data:\" + base64data)\n", 1624 | " resolve(base64data.toString())\n", 1625 | "\n", 1626 | "});\n", 1627 | "\n", 1628 | "}\n", 1629 | "});\n", 1630 | " \n", 1631 | "\n", 1632 | "\"\"\"\n", 1633 | "\n", 1634 | "def get_audio(sr):\n", 1635 | " display(HTML(AUDIO_HTML))\n", 1636 | " data = eval_js(\"data\")\n", 1637 | " binary = b64decode(data.split(',')[1])\n", 1638 | " \n", 1639 | " process = (ffmpeg\n", 1640 | " .input('pipe:0')\n", 1641 | " .output('pipe:1', format='wav')\n", 1642 | " .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)\n", 1643 | " )\n", 1644 | " output, err = process.communicate(input=binary)\n", 1645 | " \n", 1646 | " riff_chunk_size = len(output) - 8\n", 1647 | " # Break up the chunk size into four bytes, held in b.\n", 1648 | " q = riff_chunk_size\n", 1649 | " b = []\n", 1650 | " for i in range(4):\n", 1651 | " q, r = divmod(q, 256)\n", 1652 | " b.append(r)\n", 1653 | "\n", 1654 | " # Replace bytes 4:8 in proc.stdout with the actual size of the RIFF chunk.\n", 1655 | " riff = output[:4] + bytes(b) + output[8:]\n", 1656 | "\n", 1657 | " speech, rate = librosa.load(io.BytesIO(riff),sr=16000)\n", 1658 | " return speech, sr" 1659 | ], 1660 | "execution_count": null, 1661 | "outputs": [] 1662 | }, 1663 | { 1664 | "cell_type": "markdown", 1665 | "metadata": { 1666 | "id": "rLr68Rei5p8g" 1667 | }, 1668 | "source": [ 1669 | "# Recording and loading audio" 1670 | ] 1671 | }, 1672 | { 1673 | "cell_type": "code", 1674 | "metadata": { 1675 | "colab": { 1676 | "base_uri": "https://localhost:8080/", 1677 | "height": 96 1678 | }, 1679 | "id": "KBIeAAWAwB7A", 1680 | "outputId": "bf5c3fe4-b62c-4ab1-d8b7-8d6fb83c95f3" 1681 | }, 1682 | "source": [ 1683 | "#load any audio file of your choice\n", 1684 | "speech, rate = get_audio(sr=16000)" 1685 | ], 1686 | "execution_count": null, 1687 | "outputs": [ 1688 | { 1689 | "output_type": "display_data", 1690 | "data": { 1691 | "text/html": [ 1692 | "\n", 1693 | "\n" 1770 | ], 1771 | "text/plain": [ 1772 | "" 1773 | ] 1774 | }, 1775 | "metadata": { 1776 | "tags": [] 1777 | } 1778 | } 1779 | ] 1780 | }, 1781 | { 1782 | "cell_type": "markdown", 1783 | "metadata": { 1784 | "id": "0Zya_f855yY_" 1785 | }, 1786 | "source": [ 1787 | "# Inferencing\n", 1788 | "- tokenizing(encoding) speech data and return pytorch tensor\n", 1789 | "- pass encodings to model" 1790 | ] 1791 | }, 1792 | { 1793 | "cell_type": "code", 1794 | "metadata": { 1795 | "id": "tA8ODZzNzp1_" 1796 | }, 1797 | "source": [ 1798 | "input_values = tokenizer(speech, return_tensors = 'pt').input_values\n", 1799 | "#logits (non-normalized predictions)\n", 1800 | "logits = model(input_values).logits" 1801 | ], 1802 | "execution_count": null, 1803 | "outputs": [] 1804 | }, 1805 | { 1806 | "cell_type": "markdown", 1807 | "metadata": { 1808 | "id": "Bw7t8DuE7iNg" 1809 | }, 1810 | "source": [ 1811 | "decoding transcript" 1812 | ] 1813 | }, 1814 | { 1815 | "cell_type": "code", 1816 | "metadata": { 1817 | "colab": { 1818 | "base_uri": "https://localhost:8080/" 1819 | }, 1820 | "id": "E_5sijeizdPy", 1821 | "outputId": "33064179-153f-4c76-fd0f-684ffac66d6b" 1822 | }, 1823 | "source": [ 1824 | "predicted_ids = torch.argmax(logits, dim =-1)\n", 1825 | "#decode the audio to generate text\n", 1826 | "transcriptions = tokenizer.decode(predicted_ids[0])\n", 1827 | "print(transcriptions)" 1828 | ], 1829 | "execution_count": null, 1830 | "outputs": [ 1831 | { 1832 | "output_type": "stream", 1833 | "text": [ 1834 | "A WO\n" 1835 | ], 1836 | "name": "stdout" 1837 | } 1838 | ] 1839 | }, 1840 | { 1841 | "cell_type": "code", 1842 | "metadata": { 1843 | "id": "soSVdBHtB1Ab" 1844 | }, 1845 | "source": [ 1846 | "" 1847 | ], 1848 | "execution_count": null, 1849 | "outputs": [] 1850 | } 1851 | ] 1852 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sounddevice==0.4.2 2 | SoundFile==0.10.3.post1 3 | transformers==4.9.2 -------------------------------------------------------------------------------- /train_language_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from asr.language_model import LanguageModel 4 | from asr.wav2vec2.vocab import vocab_list 5 | from tqdm import tqdm 6 | 7 | corpus_path = os.path.join("data", "lm_training_corpus", "corpus.txt") 8 | save_path = os.path.join("data", "models", "lm") 9 | parser = argparse.ArgumentParser(description="Train character language model from text corpus") 10 | parser.add_argument("--corpus", "-c", default=corpus_path, type=str, help="path to text corpus for training") 11 | parser.add_argument("--save", "-s", default=save_path, type=str, help="path to save trained model") 12 | args = parser.parse_args() 13 | 14 | 15 | def change_digit_to_word(x): 16 | x = x.replace("0", "zero ") 17 | x = x.replace("1", "one ") 18 | x = x.replace("2", "two ") 19 | x = x.replace("3", "three ") 20 | x = x.replace("4", "four ") 21 | x = x.replace("5", "five ") 22 | x = x.replace("6", "six ") 23 | x = x.replace("7", "seven ") 24 | x = x.replace("8", "eight ") 25 | x = x.replace("9", "nine ") 26 | x = x.replace(" ", " ") 27 | x = x.strip() 28 | return x 29 | 30 | 31 | # excluding pad token for language model 32 | lm = LanguageModel(chars=vocab_list[1:]) 33 | 34 | lines = sum(1 for i in open(args.corpus, "r")) 35 | with open(args.corpus, "r") as txt: 36 | for line in tqdm(txt, total=lines): 37 | line = change_digit_to_word(line) 38 | lm.train(line) 39 | lm.normalize() 40 | lm.save(args.save) --------------------------------------------------------------------------------