├── README.md ├── run_eval.sh ├── eval_xlsr_wav2vec2.py └── eval_whisper.py /README.md: -------------------------------------------------------------------------------- 1 | # benchmark-asr 2 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | # run script for evaluations 2 | python eval_whisper.py --model_id="openai/whisper-tiny" --dataset="google/fleurs" --config="ar_eg" --device=1 --language="ar" --split="test" -------------------------------------------------------------------------------- /eval_xlsr_wav2vec2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from transformers import pipeline 4 | from datasets import load_dataset, Audio 5 | import evaluate 6 | 7 | wer_metric = evaluate.load("wer") 8 | 9 | 10 | def is_target_text_in_range(ref): 11 | if ref.strip() == "ignore time segment in scoring": 12 | return False 13 | else: 14 | return ref.strip() != "" 15 | 16 | 17 | def get_text(sample): 18 | if "text" in sample: 19 | return sample["text"] 20 | elif "sentence" in sample: 21 | return sample["sentence"] 22 | elif "normalized_text" in sample: 23 | return sample["normalized_text"] 24 | elif "transcript" in sample: 25 | return sample["transcript"] 26 | else: 27 | raise ValueError(f"Sample: {sample.keys()} has no transcript.") 28 | 29 | 30 | def data(dataset): 31 | for i, item in enumerate(dataset): 32 | yield {**item["audio"], "reference": item["norm_text"]} 33 | 34 | 35 | def main(args): 36 | batch_size = args.batch_size 37 | wav2vec2_asr = pipeline( 38 | "automatic-speech-recognition", model=args.model_id, device=args.device 39 | ) 40 | 41 | dataset = load_dataset( 42 | args.dataset, args.config, split=args.split, streaming=True, use_auth_token=True 43 | ) 44 | 45 | # Only uncomment for debugging 46 | dataset = dataset.take(64) 47 | 48 | dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) 49 | dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) 50 | 51 | predictions = [] 52 | references = [] 53 | 54 | # run streamed inference 55 | for out in wav2vec2_asr(data(dataset), batch_size=batch_size): 56 | predictions.append(out["text"]) 57 | references.append(out["reference"][0]) 58 | 59 | wer = wer_metric.compute(references=references, predictions=predictions) 60 | wer = round(100 * wer, 2) 61 | 62 | print("WER:", wer) 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | 68 | parser.add_argument( 69 | "--model_id", 70 | type=str, 71 | required=True, 72 | help="Model identifier. Should be loadable with 🤗 Transformers", 73 | ) 74 | parser.add_argument( 75 | "--dataset", 76 | type=str, 77 | required=True, 78 | help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", 79 | ) 80 | parser.add_argument( 81 | "--config", 82 | type=str, 83 | required=True, 84 | help="Config of the dataset. *E.g.* `'en'` for Common Voice", 85 | ) 86 | parser.add_argument( 87 | "--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`" 88 | ) 89 | 90 | parser.add_argument( 91 | "--device", 92 | type=int, 93 | default=None, 94 | help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", 95 | ) 96 | parser.add_argument( 97 | "--batch_size", 98 | type=int, 99 | default=16, 100 | help="Number of samples to go through each streamed batch.", 101 | ) 102 | args = parser.parse_args() 103 | 104 | main(args) -------------------------------------------------------------------------------- /eval_whisper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from transformers import pipeline 4 | from transformers.models.whisper.english_normalizer import BasicTextNormalizer 5 | from datasets import load_dataset, Audio 6 | import evaluate 7 | 8 | wer_metric = evaluate.load("wer") 9 | 10 | 11 | def is_target_text_in_range(ref): 12 | if ref.strip() == "ignore time segment in scoring": 13 | return False 14 | else: 15 | return ref.strip() != "" 16 | 17 | 18 | def get_text(sample): 19 | if "text" in sample: 20 | return sample["text"] 21 | elif "sentence" in sample: 22 | return sample["sentence"] 23 | elif "normalized_text" in sample: 24 | return sample["normalized_text"] 25 | elif "transcript" in sample: 26 | return sample["transcript"] 27 | elif "transcription" in sample: 28 | return sample["transcription"] 29 | else: 30 | raise ValueError( 31 | f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of " 32 | ".join{sample.keys()}. Ensure a text column name is present in the dataset." 33 | ) 34 | 35 | 36 | whisper_norm = BasicTextNormalizer() 37 | 38 | 39 | def normalise(batch): 40 | batch["norm_text"] = whisper_norm(get_text(batch)) 41 | return batch 42 | 43 | 44 | def data(dataset): 45 | for i, item in enumerate(dataset): 46 | yield {**item["audio"], "reference": item["norm_text"]} 47 | 48 | 49 | def main(args): 50 | batch_size = args.batch_size 51 | whisper_asr = pipeline( 52 | "automatic-speech-recognition", model=args.model_id, device=args.device 53 | ) 54 | 55 | whisper_asr.model.config.forced_decoder_ids = ( 56 | whisper_asr.tokenizer.get_decoder_prompt_ids( 57 | language=args.language, task="transcribe" 58 | ) 59 | ) 60 | 61 | dataset = load_dataset( 62 | args.dataset, 63 | args.config, 64 | split=args.split, 65 | streaming=args.streaming, 66 | use_auth_token=True, 67 | ) 68 | 69 | # Only uncomment for debugging 70 | dataset = dataset.take(args.max_eval_samples) 71 | 72 | dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) 73 | dataset = dataset.map(normalise) 74 | dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) 75 | 76 | predictions = [] 77 | references = [] 78 | 79 | # run streamed inference 80 | for out in whisper_asr(data(dataset), batch_size=batch_size): 81 | predictions.append(whisper_norm(out["text"])) 82 | references.append(out["reference"][0]) 83 | 84 | wer = wer_metric.compute(references=references, predictions=predictions) 85 | wer = round(100 * wer, 2) 86 | 87 | print("WER:", wer) 88 | 89 | if __name__ == "__main__": 90 | parser = argparse.ArgumentParser() 91 | 92 | parser.add_argument( 93 | "--model_id", 94 | type=str, 95 | required=True, 96 | help="Model identifier. Should be loadable with 🤗 Transformers", 97 | ) 98 | parser.add_argument( 99 | "--dataset", 100 | type=str, 101 | default="mozilla-foundation/common_voice_11_0", 102 | help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", 103 | ) 104 | parser.add_argument( 105 | "--config", 106 | type=str, 107 | required=True, 108 | help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice", 109 | ) 110 | parser.add_argument( 111 | "--split", 112 | type=str, 113 | default="test", 114 | help="Split of the dataset. *E.g.* `'test'`", 115 | ) 116 | 117 | parser.add_argument( 118 | "--device", 119 | type=int, 120 | default=-1, 121 | help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", 122 | ) 123 | parser.add_argument( 124 | "--batch_size", 125 | type=int, 126 | default=16, 127 | help="Number of samples to go through each streamed batch.", 128 | ) 129 | parser.add_argument( 130 | "--max_eval_samples", 131 | type=int, 132 | default=None, 133 | help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", 134 | ) 135 | parser.add_argument( 136 | "--streaming", 137 | type=bool, 138 | default=True, 139 | help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", 140 | ) 141 | parser.add_argument( 142 | "--language", 143 | type=str, 144 | required=True, 145 | help="Two letter language code for the transcription language, e.g. use 'en' for English.", 146 | ) 147 | args = parser.parse_args() 148 | 149 | main(args) --------------------------------------------------------------------------------