├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── data └── dict.ltr.txt ├── src ├── recognize.hydra.py ├── recognize.py └── requirements.txt └── wav2letter.Dockerfile /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 2 | # Wav2vec 2.0 3 | # @author Loreto Parisi (loretoparisi at gmail dot com) 4 | # Copyright (c) 2020 Loreto Parisi 5 | # 6 | 7 | FROM python:3.8.6-slim-buster 8 | 9 | LABEL maintainer Loreto Parisi loreto@musixmatch.com 10 | 11 | WORKDIR /python 12 | 13 | RUN apt-get update && apt-get install -y --no-install-recommends \ 14 | git build-essential 15 | 16 | # Install fairseq 17 | RUN git clone https://github.com/pytorch/fairseq --depth=1 && cd fairseq && \ 18 | git fetch origin ac11107ed41cb06a758af850373c239309d1c961 && \ 19 | git checkout ac11107ed41cb06a758af850373c239309d1c961 && \ 20 | pip install --editable . 21 | 22 | # Install kenlm 23 | RUN apt install -y --no-install-recommends \ 24 | build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev 25 | RUN git clone https://github.com/kpu/kenlm --depth=1 && \ 26 | cd kenlm && \ 27 | mkdir -p build && \ 28 | cd build && \ 29 | cmake .. -DCMAKE_BUILD_TYPE=Release -DKENLM_MAX_ORDER=20 -DCMAKE_POSITION_INDEPENDENT_CODE=ON && \ 30 | make -j 16 31 | 32 | # Install Additional Dependencies (ATLAS, OpenBLAS, Accelerate, Intel MKL) 33 | RUN apt-get install -y --no-install-recommends \ 34 | libopenblas-dev libfftw3-dev 35 | 36 | # Install wav2letter 37 | RUN pip install packaging 38 | RUN git clone https://github.com/facebookresearch/wav2letter -b v0.2 --depth=1 && \ 39 | cd wav2letter/bindings/python && \ 40 | # for CPU 0 for GPU 1 41 | USE_CUDA=0 \ 42 | # will use Intel MKL for featurization but this may cause dynamic loading conflicts 43 | USE_MKL=0 \ 44 | KENLM_ROOT_DIR=/python/kenlm/ \ 45 | pip install -e . 46 | 47 | # RUN pip install editdistance 48 | RUN pip install soundfile && \ 49 | apt install -y --no-install-recommends libsndfile1 50 | 51 | COPY src/ . 52 | CMD ["bash"] 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Loreto Parisi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wav2vec 2 | wav2vec 2.0 Recognize Implementation. 3 | 4 | ## Disclaimer 5 | [Wave2vec](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec) is part of [fairseq](https://github.com/pytorch/fairseq) 6 | This repository is the result of the issue submitted in the `fairseq` repository [here](https://github.com/pytorch/fairseq/issues/2651). 7 | 8 | ## Resource 9 | Please first download one of the pre-trained models available from `fairseq` (see later). 10 | 11 | ## Pre-trained models 12 | 13 | Model | Finetuning split | Dataset | Model 14 | |---|---|---|--- 15 | Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) 16 | Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) 17 | Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) 18 | Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) 19 | Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) 20 | Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) 21 | Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) 22 | Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) 23 | Wav2Vec 2.0 Large (LV-60) | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox.pt) 24 | Wav2Vec 2.0 Large (LV-60) | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m.pt) 25 | Wav2Vec 2.0 Large (LV-60) | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h.pt) 26 | Wav2Vec 2.0 Large (LV-60) | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h.pt) 27 | 28 | 29 | ## How to install 30 | We make use of `python:3.8.6-slim-buster` as base image in order to let developers to have more flexibility in customize this `Dockerfile`. For a simplifed install please refer to [Alternative Install](#Alternative-Install) section. If you go for this container, please install using the provided `Dockerfile` 31 | ```bash 32 | docker build -t wav2vec -f Dockerfile . 33 | ``` 34 | 35 | ## How to Run 36 | There are two version of `recognize.py`. 37 | - `recognize.py`: For running legacy finetuned model (without Hydra). 38 | - `recognize.hydra.py`: For running new finetuned with newer version of **fairseq**. 39 | 40 | Before running, please copy the downloaded model (e.g. `wav2vec_small_10m.pt`) to the `data/` folder. Please copy there the wav file to test as well, like `data/temp.wav` in the following examples. So the `data/` folder will now look like this 41 | 42 | ``` 43 | . 44 | ├── dict.ltr.txt 45 | ├── temp.wav 46 | └── wav2vec_small_10m.pt 47 | ``` 48 | 49 | We now run the container and the we enter and execute the recognition (`recognize.py` or `recognize.hydra.py`). 50 | ```bash 51 | docker run -d -it --rm -v $PWD/data:/app/data --name w2v wav2vec 52 | docker exec -it w2v bash 53 | python examples/wav2vec/recognize.py --target_dict_path=/app/data/dict.ltr.txt /app/data/wav2vec_small_10m.pt /app/data/temp.wav 54 | ``` 55 | 56 | ## Common issues 57 | ### 1. What if my model are not compatible with **fairseq**? 58 | 59 | At the very least, we have tested with fairseq master branch (> v0.10.1, commit [ac11107](https://github.com/pytorch/fairseq/commit/ac11107ed41cb06a758af850373c239309d1c961)). When you run into issues, like this: 60 | ```txt 61 | omegaconf.errors.ValidationError: Invalid value 'False', expected one of [hard, soft] 62 | full_key: generation.print_alignment 63 | reference_type=GenerationConfig 64 | object_type=GenerationConfig 65 | ``` 66 | It's probably that your model've been finetuned (or trained) with other version of **fairseq**. 67 | You should find yourself which version your model are trained, and edit commit hash in Dockerfile accordingly, **BUT IT MIGHT BREAK src/recognize.py**. 68 | 69 | The workaround is look for what's changed in the parameters inside **fairseq** source code. In the above example, I've managed to find that: 70 | 71 | ***fairseq/dataclass/configs.py (72a25a4 -> 032a404)*** 72 | ```diff 73 | - print_alignment: bool = field( 74 | + print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( 75 | - default=False, 76 | + default=None, 77 | metadata={ 78 | - "help": "if set, uses attention feedback to compute and print alignment to source tokens" 79 | + "help": "if set, uses attention feedback to compute and print alignment to source tokens " 80 | + "(valid options are: hard, soft, otherwise treated as hard alignment)", 81 | + "argparse_const": "hard", 82 | }, 83 | ) 84 | ``` 85 | The problem is fairseq had modified such that `generation.print_alignment` not valid anymore, so I modify `recognize.hydra.py` as below (you might wanna modify the value instead): 86 | ```diff 87 | OmegaConf.set_struct(w2v["cfg"], False) 88 | + del w2v["cfg"].generation["print_alignment"] 89 | cfg = OmegaConf.merge(OmegaConf.structured(Wav2Vec2CheckpointConfig), w2v["cfg"]) 90 | ``` 91 | 92 | ## Alternative install 93 | We provide an alternative Dockerfile named `wav2letter.Dockerfile` that makes use of `wav2letter/wav2letter:cpu-latest` Docker image as `FROM`. 94 | Here are the commands for build, install and run in this case: 95 | 96 | ```bash 97 | docker build -t wav2vec2 -f wav2letter.Dockerfile . 98 | docker run -d -it --rm -v $PWD/data:/root/data --name w2v2 wav2vec2 99 | docker exec -it w2v2 bash 100 | python examples/wav2vec/recognize.py --wav_path /root/data/temp.wav --w2v_path /root/data/wav2vec_small_10m.pt --target_dict_path /root/data/dict.ltr.txt 101 | ``` 102 | 103 | ## Contributors 104 | Thanks to all contributors to this repo. 105 | 106 | - [@sooftware](https://github.com/sooftware) 107 | - [@mychiux413](https://github.com/mychiux413) 108 | - [@osddeitf](https://github.com/osddeitf) 109 | -------------------------------------------------------------------------------- /data/dict.ltr.txt: -------------------------------------------------------------------------------- 1 | | 94802 2 | E 51860 3 | T 38431 4 | A 33152 5 | O 31495 6 | N 28855 7 | I 28794 8 | H 27187 9 | S 26071 10 | R 23546 11 | D 18289 12 | L 16308 13 | U 12400 14 | M 10685 15 | W 10317 16 | C 9844 17 | F 9062 18 | G 8924 19 | Y 8226 20 | P 6890 21 | B 6339 22 | V 3936 23 | K 3456 24 | ' 1023 25 | X 636 26 | J 598 27 | Q 437 28 | Z 213 29 | -------------------------------------------------------------------------------- /src/recognize.hydra.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import soundfile as sf 3 | import torch.nn.functional as F 4 | import itertools as it 5 | from fairseq.data import Dictionary 6 | from fairseq.data.data_utils import post_process 7 | from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecEncoder, Wav2VecCtc, Wav2Vec2AsrConfig 8 | from wav2letter.decoder import CriterionType 9 | from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes 10 | 11 | def parse_args(): 12 | import argparse 13 | parser = argparse.ArgumentParser(description='Wav2vec-2.0 Recognize') 14 | parser.add_argument('w2v_path', type=str, 15 | help='path of pre-trained wav2vec-2.0 model') 16 | parser.add_argument('wav_path', type=str, 17 | help='path of wave file') 18 | parser.add_argument('--target_dict_path', type=str, 19 | default='dict.ltr.txt', 20 | help='path of target dict (dict.ltr.txt)') 21 | return parser.parse_args() 22 | 23 | def base_architecture(args): 24 | args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) 25 | args.dropout_input = getattr(args, "dropout_input", 0) 26 | args.final_dropout = getattr(args, "final_dropout", 0) 27 | args.apply_mask = getattr(args, "apply_mask", False) 28 | args.dropout = getattr(args, "dropout", 0) 29 | args.attention_dropout = getattr(args, "attention_dropout", 0) 30 | args.activation_dropout = getattr(args, "activation_dropout", 0) 31 | 32 | args.mask_length = getattr(args, "mask_length", 10) 33 | args.mask_prob = getattr(args, "mask_prob", 0.5) 34 | args.mask_selection = getattr(args, "mask_selection", "static") 35 | args.mask_other = getattr(args, "mask_other", 0) 36 | args.no_mask_overlap = getattr(args, "no_mask_overlap", False) 37 | args.mask_channel_length = getattr(args, "mask_channel_length", 10) 38 | args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) 39 | args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") 40 | args.mask_channel_other = getattr(args, "mask_channel_other", 0) 41 | args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) 42 | 43 | args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) 44 | args.feature_grad_mult = getattr(args, "feature_grad_mult", 0) 45 | args.layerdrop = getattr(args, "layerdrop", 0.0) 46 | return args 47 | 48 | class W2lDecoder(object): 49 | def __init__(self, tgt_dict): 50 | self.tgt_dict = tgt_dict 51 | self.vocab_size = len(tgt_dict) 52 | self.nbest = 1 53 | 54 | self.criterion_type = CriterionType.CTC 55 | self.blank = ( 56 | tgt_dict.index("") 57 | if "" in tgt_dict.indices 58 | else tgt_dict.bos() 59 | ) 60 | self.asg_transitions = None 61 | 62 | def generate(self, models, sample, **unused): 63 | """Generate a batch of inferences.""" 64 | # model.forward normally channels prev_output_tokens into the decoder 65 | # separately, but SequenceGenerator directly calls model.encoder 66 | encoder_input = { 67 | k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" 68 | } 69 | emissions = self.get_emissions(models, encoder_input) 70 | return self.decode(emissions) 71 | 72 | def get_emissions(self, models, encoder_input): 73 | """Run encoder and normalize emissions""" 74 | # encoder_out = models[0].encoder(**encoder_input) 75 | encoder_out = models[0](**encoder_input) 76 | if self.criterion_type == CriterionType.CTC: 77 | emissions = models[0].get_normalized_probs(encoder_out, log_probs=True) 78 | 79 | return emissions.transpose(0, 1).float().cpu().contiguous() 80 | 81 | def get_tokens(self, idxs): 82 | """Normalize tokens by handling CTC blank, ASG replabels, etc.""" 83 | idxs = (g[0] for g in it.groupby(idxs)) 84 | idxs = filter(lambda x: x != self.blank, idxs) 85 | 86 | return torch.LongTensor(list(idxs)) 87 | 88 | 89 | # from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder 90 | class W2lViterbiDecoder(W2lDecoder): 91 | def __init__(self, tgt_dict): 92 | super().__init__(tgt_dict) 93 | 94 | def decode(self, emissions): 95 | B, T, N = emissions.size() 96 | hypos = list() 97 | 98 | if self.asg_transitions is None: 99 | transitions = torch.FloatTensor(N, N).zero_() 100 | else: 101 | transitions = torch.FloatTensor(self.asg_transitions).view(N, N) 102 | 103 | viterbi_path = torch.IntTensor(B, T) 104 | workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N)) 105 | CpuViterbiPath.compute( 106 | B, 107 | T, 108 | N, 109 | get_data_ptr_as_bytes(emissions), 110 | get_data_ptr_as_bytes(transitions), 111 | get_data_ptr_as_bytes(viterbi_path), 112 | get_data_ptr_as_bytes(workspace), 113 | ) 114 | return [ 115 | [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B) 116 | ] 117 | 118 | 119 | from dataclasses import dataclass 120 | from omegaconf import OmegaConf 121 | from fairseq.dataclass.configs import FairseqConfig 122 | 123 | @dataclass 124 | class Wav2Vec2CheckpointConfig(FairseqConfig): 125 | model: Wav2Vec2AsrConfig = Wav2Vec2AsrConfig() 126 | 127 | class Wav2VecPredictor: 128 | def __init__(self, w2v_path, target_dict_path): 129 | self._target_dict = Dictionary.load(target_dict_path) 130 | self._generator = W2lViterbiDecoder(self._target_dict) 131 | self._model = self._load_model(w2v_path, self._target_dict) 132 | self._model.eval() 133 | 134 | def _get_feature(self, filepath): 135 | def postprocess(feats, sample_rate): 136 | if feats.dim() == 2: 137 | feats = feats.mean(-1) 138 | 139 | assert feats.dim() == 1, feats.dim() 140 | 141 | with torch.no_grad(): 142 | feats = F.layer_norm(feats, feats.shape) 143 | return feats 144 | 145 | wav, sample_rate = sf.read(filepath) 146 | feats = torch.from_numpy(wav).float() 147 | feats = postprocess(feats, sample_rate) 148 | return feats 149 | 150 | def _load_model(self, model_path, target_dict): 151 | w2v = torch.load(model_path) 152 | 153 | # Finetuned with Hydra: w2v["args"] -> w2v["cfg"] + Wav2Vec2AsrConfig 154 | OmegaConf.set_struct(w2v["cfg"], False) 155 | cfg = OmegaConf.merge(OmegaConf.structured(Wav2Vec2CheckpointConfig), w2v["cfg"]) 156 | 157 | # Imitate `Wav2VecCtc.build_model()` Without creating a FairseqTask 158 | model = Wav2VecCtc(cfg.model, Wav2VecEncoder(cfg.model, target_dict)) 159 | 160 | # Load checkpoint's saved weights 161 | model.load_state_dict(w2v["model"], strict=True) 162 | 163 | return model 164 | 165 | def predict(self, wav_path): 166 | sample = dict() 167 | net_input = dict() 168 | 169 | feature = self._get_feature(wav_path) 170 | net_input["source"] = feature.unsqueeze(0) 171 | 172 | padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0) 173 | 174 | net_input["padding_mask"] = padding_mask 175 | sample["net_input"] = net_input 176 | 177 | with torch.no_grad(): 178 | hypo = self._generator.generate([ self._model ], sample, prefix_tokens=None) 179 | 180 | hyp_pieces = self._target_dict.string(hypo[0][0]["tokens"].int().cpu()) 181 | return post_process(hyp_pieces, 'letter') 182 | 183 | if __name__ == '__main__': 184 | args = parse_args() 185 | model = Wav2VecPredictor(args.w2v_path, args.target_dict_path) 186 | print(model.predict(args.wav_path)) 187 | -------------------------------------------------------------------------------- /src/recognize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import soundfile as sf 3 | import torch.nn.functional as F 4 | import itertools as it 5 | from fairseq.data import Dictionary 6 | from fairseq.data.data_utils import post_process 7 | from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecEncoder, Wav2VecCtc 8 | from wav2letter.decoder import CriterionType 9 | from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes 10 | 11 | def parse_args(): 12 | import argparse 13 | parser = argparse.ArgumentParser(description='Wav2vec-2.0 Recognize') 14 | parser.add_argument('w2v_path', type=str, 15 | help='path of pre-trained wav2vec-2.0 model') 16 | parser.add_argument('wav_path', type=str, 17 | help='path of wave file') 18 | parser.add_argument('--target_dict_path', type=str, 19 | default='dict.ltr.txt', 20 | help='path of target dict (dict.ltr.txt)') 21 | return parser.parse_args() 22 | 23 | def base_architecture(args): 24 | args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) 25 | args.dropout_input = getattr(args, "dropout_input", 0) 26 | args.final_dropout = getattr(args, "final_dropout", 0) 27 | args.apply_mask = getattr(args, "apply_mask", False) 28 | args.dropout = getattr(args, "dropout", 0) 29 | args.attention_dropout = getattr(args, "attention_dropout", 0) 30 | args.activation_dropout = getattr(args, "activation_dropout", 0) 31 | 32 | args.mask_length = getattr(args, "mask_length", 10) 33 | args.mask_prob = getattr(args, "mask_prob", 0.5) 34 | args.mask_selection = getattr(args, "mask_selection", "static") 35 | args.mask_other = getattr(args, "mask_other", 0) 36 | args.no_mask_overlap = getattr(args, "no_mask_overlap", False) 37 | args.mask_channel_length = getattr(args, "mask_channel_length", 10) 38 | args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) 39 | args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") 40 | args.mask_channel_other = getattr(args, "mask_channel_other", 0) 41 | args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) 42 | 43 | args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) 44 | args.feature_grad_mult = getattr(args, "feature_grad_mult", 0) 45 | args.layerdrop = getattr(args, "layerdrop", 0.0) 46 | return args 47 | 48 | class W2lDecoder(object): 49 | def __init__(self, tgt_dict): 50 | self.tgt_dict = tgt_dict 51 | self.vocab_size = len(tgt_dict) 52 | self.nbest = 1 53 | 54 | self.criterion_type = CriterionType.CTC 55 | self.blank = ( 56 | tgt_dict.index("") 57 | if "" in tgt_dict.indices 58 | else tgt_dict.bos() 59 | ) 60 | self.asg_transitions = None 61 | 62 | def generate(self, models, sample, **unused): 63 | """Generate a batch of inferences.""" 64 | # model.forward normally channels prev_output_tokens into the decoder 65 | # separately, but SequenceGenerator directly calls model.encoder 66 | encoder_input = { 67 | k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" 68 | } 69 | emissions = self.get_emissions(models, encoder_input) 70 | return self.decode(emissions) 71 | 72 | def get_emissions(self, models, encoder_input): 73 | """Run encoder and normalize emissions""" 74 | # encoder_out = models[0].encoder(**encoder_input) 75 | encoder_out = models[0](**encoder_input) 76 | if self.criterion_type == CriterionType.CTC: 77 | emissions = models[0].get_normalized_probs(encoder_out, log_probs=True) 78 | 79 | return emissions.transpose(0, 1).float().cpu().contiguous() 80 | 81 | def get_tokens(self, idxs): 82 | """Normalize tokens by handling CTC blank, ASG replabels, etc.""" 83 | idxs = (g[0] for g in it.groupby(idxs)) 84 | idxs = filter(lambda x: x != self.blank, idxs) 85 | 86 | return torch.LongTensor(list(idxs)) 87 | 88 | 89 | # from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder 90 | class W2lViterbiDecoder(W2lDecoder): 91 | def __init__(self, tgt_dict): 92 | super().__init__(tgt_dict) 93 | 94 | def decode(self, emissions): 95 | B, T, N = emissions.size() 96 | hypos = list() 97 | 98 | if self.asg_transitions is None: 99 | transitions = torch.FloatTensor(N, N).zero_() 100 | else: 101 | transitions = torch.FloatTensor(self.asg_transitions).view(N, N) 102 | 103 | viterbi_path = torch.IntTensor(B, T) 104 | workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N)) 105 | CpuViterbiPath.compute( 106 | B, 107 | T, 108 | N, 109 | get_data_ptr_as_bytes(emissions), 110 | get_data_ptr_as_bytes(transitions), 111 | get_data_ptr_as_bytes(viterbi_path), 112 | get_data_ptr_as_bytes(workspace), 113 | ) 114 | return [ 115 | [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B) 116 | ] 117 | 118 | class Wav2VecPredictor: 119 | def __init__(self, w2v_path, target_dict_path): 120 | self._target_dict = Dictionary.load(target_dict_path) 121 | self._generator = W2lViterbiDecoder(self._target_dict) 122 | self._model = self._load_model(w2v_path, self._target_dict) 123 | self._model.eval() 124 | 125 | def _get_feature(self, filepath): 126 | def postprocess(feats, sample_rate): 127 | if feats.dim() == 2: 128 | feats = feats.mean(-1) 129 | 130 | assert feats.dim() == 1, feats.dim() 131 | 132 | with torch.no_grad(): 133 | feats = F.layer_norm(feats, feats.shape) 134 | return feats 135 | 136 | wav, sample_rate = sf.read(filepath) 137 | feats = torch.from_numpy(wav).float() 138 | feats = postprocess(feats, sample_rate) 139 | return feats 140 | 141 | def _load_model(self, model_path, target_dict): 142 | w2v = torch.load(model_path) 143 | 144 | # Without create a FairseqTask 145 | args = base_architecture(w2v["args"]) 146 | model = Wav2VecCtc(args, Wav2VecEncoder(args, target_dict)) 147 | model.load_state_dict(w2v["model"], strict=True) 148 | return model 149 | 150 | def predict(self, wav_path): 151 | sample = dict() 152 | net_input = dict() 153 | 154 | feature = self._get_feature(wav_path) 155 | net_input["source"] = feature.unsqueeze(0) 156 | 157 | padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0) 158 | 159 | net_input["padding_mask"] = padding_mask 160 | sample["net_input"] = net_input 161 | 162 | with torch.no_grad(): 163 | hypo = self._generator.generate([ self._model ], sample, prefix_tokens=None) 164 | 165 | hyp_pieces = self._target_dict.string(hypo[0][0]["tokens"].int().cpu()) 166 | return post_process(hyp_pieces, 'letter') 167 | 168 | if __name__ == '__main__': 169 | args = parse_args() 170 | model = Wav2VecPredictor(args.w2v_path, args.target_dict_path) 171 | print(model.predict(args.wav_path)) 172 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | fairseq==0.9.0 -------------------------------------------------------------------------------- /wav2letter.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM wav2letter/wav2letter:cpu-latest 2 | 3 | ENV USE_CUDA=0 4 | ENV KENLM_ROOT_DIR=/root/kenlm 5 | 6 | # will use Intel MKL for featurization but this may cause dynamic loading conflicts. 7 | # ENV USE_MKL=1 8 | 9 | ENV LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2018.5.274/linux/mkl/lib/intel64:$LD_IBRARY_PATH 10 | WORKDIR /root/wav2letter/bindings/python 11 | 12 | #added editdistance package as pip install 13 | RUN TMPDIR=/data/mydir/ pip install --upgrade pip && pip install --cache-dir=/data/vincents/ --build /data/mydir/ editdistance soundfile packaging && pip install -e . 14 | 15 | WORKDIR /root 16 | RUN git clone https://github.com/pytorch/fairseq.git 17 | RUN mkdir data 18 | COPY src/recognize.py /root/fairseq/examples/wav2vec/recognize.py 19 | 20 | WORKDIR /root/fairseq 21 | RUN TMPDIR=/data/mydir/ pip install --cache-dir=/data/mydir/ --editable ./ && python examples/speech_recognition/infer.py --help && python examples/wav2vec/recognize.py --help 22 | 23 | --------------------------------------------------------------------------------