├── .gitignore ├── README.md ├── create_ngram.py ├── eval.py ├── fix_lm.py ├── requirements.txt └── result.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.arpa 2 | kenlm 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤗 Transformers Wav2Vec2 + PyCTCDecode 2 | 3 | ## **IMPORTANT** This github repo is not actively maintained. Please try to use: https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2ProcessorWithLM instead** 4 | 5 | ## UPDATE 2: 6 | 7 | In-detail blog post which should be much better than this repo is available here: https://huggingface.co/blog/wav2vec2-with-ngram 8 | 9 | ## UPDATE: PyCTCDecode is merged to Transformers! 10 | 11 | ```diff 12 | import torch 13 | from datasets import load_dataset 14 | from transformers import AutoModelForCTC, AutoProcessor 15 | import torchaudio.functional as F 16 | 17 | 18 | model_id = "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm" 19 | 20 | sample = next(iter(load_dataset("common_voice", "es", split="test", streaming=True))) 21 | resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy() 22 | 23 | model = AutoModelForCTC.from_pretrained(model_id) 24 | processor = AutoProcessor.from_pretrained(model_id) 25 | 26 | input_values = processor(resampled_audio, return_tensors="pt").input_values 27 | 28 | with torch.no_grad(): 29 | logits = model(input_values).logits 30 | 31 | -prediction_ids = torch.argmax(logits, dim=-1) 32 | -transcription = processor.batch_decode(prediction_ids) 33 | +transcription = processor.batch_decode(logits.numpy()).text 34 | # => 'bien y qué regalo vas a abrir primero' 35 | ``` 36 | 37 | ## Introduction 38 | 39 | This repo shows how [🤗 **Transformers**](https://github.com/huggingface/transformers) can be used in combination 40 | with [kensho-technologies's **PyCTCDecode**](https://github.com/kensho-technologies/pyctcdecode) & [**KenLM** ngram](https://github.com/kpu/kenlm) 41 | as a simple way to boost word error rate (WER). 42 | 43 | Included is a file to create an ngram with **KenLM** as well as a simple evaluation script to 44 | compare the results of using Wav2Vec2 with **PyCTCDecode** + **KenLM** vs. without using any language model. 45 | 46 | 47 | **Note**: The scripts are written to be used on GPU. If you want to use a CPU instead, 48 | simply remove all `.to("cuda")` occurances in `eval.py`. 49 | 50 | ## Installation 51 | 52 | In a first step, one should install **KenLM**. For Ubuntu, it should be enough to follow the installation steps 53 | described [here](https://github.com/kpu/kenlm/blob/master/BUILDING). The installed `kenlm` folder 54 | should be move into this repo for `./create_ngram.py` to function correctly. Alternatively, one can also 55 | link the `lmplz` binary file to a `lmplz` bash command to directly run `lmplz` instead of `./kenlm/build/bin/lmplz`. 56 | 57 | Next, some Python dependencies should be installed. Assuming PyTorch is installed, it should be sufficient to run 58 | `pip install -r requirements.txt`. 59 | 60 | ## Run evaluation 61 | 62 | 63 | ### Create ngram 64 | 65 | In a first step on should create a ngram. *E.g.* for `polish` the command would be: 66 | 67 | ```bash 68 | ./create_ngram.py --language polish --path_to_ngram polish.arpa 69 | ``` 70 | 71 | After the language model is created, some lines should be converted so it's compatible with 'pyctcdecode'. 72 | 73 | Execute the script to run the conversion: 74 | 75 | ``` 76 | ./fix_lm.py --path_to_ngram polish.arpa --path_to_fixed polish_fixed.arpa 77 | ``` 78 | 79 | Now the generated 'polish_fixed.arpa' ngram can be correctly used with `pyctcdecode` 80 | 81 | 82 | ### Run eval 83 | 84 | Having created the ngram, one can run: 85 | 86 | ```bash 87 | ./eval.py --language polish --path_to_ngram polish.arpa 88 | ``` 89 | 90 | To compare Wav2Vec2 + LM vs. Wav2Vec2 + No LM on polish. 91 | 92 | 93 | ## Results 94 | 95 | Without tuning any hyperparameters, the following results were obtained: 96 | 97 | ``` 98 | Comparison of Wav2Vec2 without Language model vs. Wav2Vec2 with `pyctcdecode` + KenLM 5gram. 99 | Fine-tuned Wav2Vec2 models were used and evaluated on MLS datasets. 100 | Take a closer look at `./eval.py` for comparison 101 | 102 | ==================================================portuguese================================================== 103 | polish - No LM - | WER: 0.3069742867206763 | CER: 0.06054530156286364 | Time: 58.04590034484863 104 | polish - With LM - | WER: 0.2291299753434308 | CER: 0.06211174564528545 | Time: 191.65409898757935 105 | 106 | ==================================================spanish================================================== 107 | portuguese - No LM - | WER: 0.18208286674132138 | CER: 0.05016682956422096 | Time: 114.61633825302124 108 | portuguese - With LM - | WER: 0.1487761958086706 | CER: 0.04489231909945738 | Time: 429.78511357307434 109 | 110 | ==================================================polish================================================== 111 | spanish - No LM - | WER: 0.2581272104769545 | CER: 0.0703088156033147 | Time: 147.8634352684021 112 | spanish - With LM - | WER: 0.14927852292116295 | CER: 0.052034208044195916 | Time: 563.0732748508453 113 | ``` 114 | 115 | It can be seen that the word error rate (WER) is significantly improved when using PyCTCDecode + KenLM. 116 | However, the character error rate (CER) does not improve as much or not at all. 117 | This is expected since using a language model will make sure that words that are predicted are words that exist in the language's vocabulary. 118 | Wav2Vec2 without a LM produces many words that are more or less correct but contain a couple of spelling errors, thus not contributing to a good WER. 119 | Those words are likely to be "corrected" by Wav2Vec2 + LM leading to an improved WER. However a Wav2Vec2 already has a good character error rate as its 120 | vocabulary is composed of characters meaning that a "word-based" language model doesn't really help in this case. 121 | 122 | Overall WER is probably the more important metric though, so it might make a lot of sense to add a LM to Wav2Vec2. 123 | 124 | In terms of speed, adding a LM significantly reduces speed. However, the script is not at all optimized for speed 125 | so using multi-processing and batched inference would significantly speed up both Wav2Vec2 without LM and with LM. 126 | -------------------------------------------------------------------------------- /create_ngram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from datasets import load_dataset 3 | import os 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--language", default="polish", type=str, required=True, help="Language to run comparison on. Choose one of 'polish', 'portuguese', 'spanish' or add more to this script." 9 | ) 10 | parser.add_argument( 11 | "--path_to_ngram", type=str, required=True, help="Path to kenLM ngram" 12 | ) 13 | args = parser.parse_args() 14 | 15 | ds = load_dataset("multilingual_librispeech", f"{args.language}", split="train") 16 | 17 | with open("text.txt", "w") as f: 18 | f.write(" ".join(ds["text"])) 19 | 20 | os.system(f"./kenlm/build/bin/lmplz -o 5 {args.path_to_ngram}") 21 | 22 | ## VERY IMPORTANT!!!: 23 | # After the language model is created, one should open the file. one should add a `` 24 | # The file should have a structure which looks more or less as follows: 25 | 26 | # \data\ 27 | # ngram 1=86586 28 | # ngram 2=546387 29 | # ngram 3=796581 30 | # ngram 4=843999 31 | # ngram 5=850874 32 | 33 | # \1-grams: 34 | # -5.7532206 0 35 | # 0 -0.06677356 36 | # -3.4645514 drugi -0.2088903 37 | # ... 38 | 39 | # Now it is very important also add a token to the n-gram 40 | # so that it can be correctly loaded. You can simple copy the line: 41 | 42 | # 0 -0.06677356 43 | 44 | # and change to . When doing this you should also inclease `ngram` by 1. 45 | # The new ngram should look as follows: 46 | 47 | # \data\ 48 | # ngram 1=86587 49 | # ngram 2=546387 50 | # ngram 3=796581 51 | # ngram 4=843999 52 | # ngram 5=850874 53 | 54 | # \1-grams: 55 | # -5.7532206 0 56 | # 0 -0.06677356 57 | # 0 -0.06677356 58 | # -3.4645514 drugi -0.2088903 59 | # ... 60 | 61 | # Now the ngram can be correctly used with `pyctcdecode` 62 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | from transformers import AutoModelForCTC, Wav2Vec2Processor 4 | import time 5 | import argparse 6 | 7 | from datasets import load_dataset, load_metric 8 | from pyctcdecode import build_ctcdecoder 9 | 10 | 11 | LANG_TO_ID_LOOK_UP = { 12 | "polish": "pl", 13 | "portuguese": "pt", 14 | "spanish": "es", 15 | } 16 | 17 | 18 | def main(args): 19 | language = args.language 20 | lang_id = LANG_TO_ID_LOOK_UP[language] 21 | 22 | if lang_id == "pt": 23 | model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53-portuguese").to("cuda") 24 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-portuguese") 25 | else: 26 | model = AutoModelForCTC.from_pretrained(f"facebook/wav2vec2-base-10k-voxpopuli-ft-{lang_id}").to("cuda") 27 | processor = Wav2Vec2Processor.from_pretrained(f"facebook/wav2vec2-base-10k-voxpopuli-ft-{lang_id}") 28 | 29 | wer = load_metric("wer") 30 | cer = load_metric("cer") 31 | 32 | vocab_dict = processor.tokenizer.get_vocab() 33 | sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])} 34 | 35 | decoder = build_ctcdecoder( 36 | list(sorted_dict.keys()), 37 | args.path_to_ngram, 38 | ) 39 | 40 | # load trained kenlm model 41 | ds = load_dataset("multilingual_librispeech", f"{language}", split="test") 42 | 43 | # Uncomment for dummy run 44 | # ds = ds.select(range(20)) 45 | 46 | def map_to_wer_no_lm(batch): 47 | input_values = processor(batch["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_values.to("cuda") 48 | 49 | with torch.no_grad(): 50 | pred_ids = torch.argmax(model(input_values).logits, -1) 51 | 52 | pred_str = processor.batch_decode(pred_ids) 53 | 54 | batch["pred_str"] = pred_str[0] 55 | batch["ref_str"] = batch["text"] 56 | return batch 57 | 58 | def map_to_wer_with_lm(batch): 59 | input_values = processor(batch["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_values.to("cuda") 60 | 61 | with torch.no_grad(): 62 | logits = model(input_values).logits.cpu().numpy()[0] 63 | 64 | batch["pred_str"] = decoder.decode(logits) 65 | batch["ref_str"] = batch["text"] 66 | 67 | return batch 68 | 69 | start_time_1 = time.time() 70 | result_no_lm = ds.map(map_to_wer_no_lm, remove_columns=ds.column_names) 71 | 72 | wer_result_no_lm = wer.compute(predictions=result_no_lm["pred_str"], references=result_no_lm["ref_str"]) 73 | cer_result_no_lm = cer.compute(predictions=result_no_lm["pred_str"], references=result_no_lm["ref_str"]) 74 | 75 | start_time_2 = time.time() 76 | result_with_lm = ds.map(map_to_wer_with_lm, remove_columns=ds.column_names) 77 | 78 | wer_result_with_lm = wer.compute(predictions=result_with_lm["pred_str"], references=result_with_lm["ref_str"]) 79 | cer_result_with_lm = cer.compute(predictions=result_with_lm["pred_str"], references=result_with_lm["ref_str"]) 80 | 81 | end_time = time.time() 82 | 83 | print(50 * "=" + language + 50 * "=") 84 | print(f"{language} - No LM - | WER: {wer_result_no_lm} | CER: {cer_result_no_lm} | Time: {start_time_2 - start_time_1}") 85 | print(f"{language} - With LM - | WER: {wer_result_with_lm} | CER: {cer_result_with_lm} | Time: {end_time - start_time_2}") 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | 91 | # Required parameters 92 | parser.add_argument( 93 | "--language", default="polish", type=str, required=True, help="Language to run comparison on. Choose one of 'polish', 'portuguese', 'spanish' or add more to this script." 94 | ) 95 | parser.add_argument( 96 | "--path_to_ngram", type=str, required=True, help="Path to kenLM ngram" 97 | ) 98 | args = parser.parse_args() 99 | 100 | main(args) 101 | -------------------------------------------------------------------------------- /fix_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | 4 | def main(args): 5 | ''' 6 | Function searches for lines that needs to be changed to be supported by 7 | PyCTCDecode lib, changes them and writes new KenLM arpa. 8 | ''' 9 | original = open(args.path_to_ngram, 'r').readlines() 10 | fixed = open(args.path_to_fixed, 'w') 11 | 12 | for line in original: 13 | if 'ngram 1=' in line: 14 | base_ngram_1_line = line 15 | text, value = line.split('=') 16 | value = str(float(value.replace('\n', ''))+1) 17 | fixed_ngram_1_line = f"{text}={value}\n" 18 | fixed.write(fixed_ngram_1_line) 19 | elif '\t\t' in line: 20 | base_token_line = line 21 | fixed_token_line = line.replace('\t\t', '\t\t') 22 | fixed.write(base_token_line) 23 | fixed.write(fixed_token_line) 24 | else: 25 | fixed.write(line) 26 | fixed.close() 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | 32 | # Required parameters 33 | parser.add_argument( 34 | "--path_to_ngram", type=str, required=True, help="Path to original KenLM ngram" 35 | ) 36 | parser.add_argument( 37 | "--path_to_fixed", type=str, required=True, help="Path to write fixed KenLM ngram" 38 | ) 39 | args = parser.parse_args() 40 | 41 | main(args) 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.12.0 2 | datasets>=1.15.0 3 | jiwer 4 | librosa 5 | https://github.com/kpu/kenlm/archive/master.zip 6 | pyctcdecode 7 | -------------------------------------------------------------------------------- /result.txt: -------------------------------------------------------------------------------- 1 | Comparison of Wav2Vec2 without Language model vs. Wav2Vec2 with `pyctcdecode` + KenLM 5gram. 2 | Fine-tuned Wav2Vec2 models were used and evaluated on MLS datasets. 3 | Take a closer look at `./eval.py` for comparison 4 | 5 | ==================================================portuguese================================================== 6 | polish - No LM - | WER: 0.3069742867206763 | CER: 0.06054530156286364 | Time: 58.04590034484863 7 | polish - With LM - | WER: 0.2291299753434308 | CER: 0.06211174564528545 | Time: 191.65409898757935 8 | 9 | ==================================================spanish================================================== 10 | portuguese - No LM - | WER: 0.18208286674132138 | CER: 0.05016682956422096 | Time: 114.61633825302124 11 | portuguese - With LM - | WER: 0.1487761958086706 | CER: 0.04489231909945738 | Time: 429.78511357307434 12 | 13 | ==================================================polish================================================== 14 | spanish - No LM - | WER: 0.2581272104769545 | CER: 0.0703088156033147 | Time: 147.8634352684021 15 | spanish - With LM - | WER: 0.14927852292116295 | CER: 0.052034208044195916 | Time: 563.0732748508453 16 | --------------------------------------------------------------------------------