├── data ├── __init__.py ├── visspeech.py ├── ascend.py ├── mustcv1.py ├── seame.py ├── libritrans.py └── covost2.py ├── whisper ├── __main__.py ├── assets │ ├── multilingual │ │ ├── added_tokens.json │ │ ├── special_tokens_map.json │ │ └── tokenizer_config.json │ ├── mel_filters.npz │ └── gpt2 │ │ ├── special_tokens_map.json │ │ └── tokenizer_config.json ├── normalizers │ ├── __init__.py │ ├── basic.py │ └── english.py ├── utils.py ├── audio.py ├── __init__.py ├── model.py ├── tokenizer.py ├── transcribe.py └── decoding.py ├── .gitignore ├── requirements.txt ├── setup.py ├── scripts ├── libritrans.sh ├── mustcv1.sh ├── covost2.sh ├── visspeech.sh ├── ascend.sh └── seame.sh ├── README.md ├── config.py ├── csasr_st.py └── avsr.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | 4 | cli() 5 | -------------------------------------------------------------------------------- /whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonppy/PromptingWhisper/HEAD/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer 2 | from .english import EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | build/ 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | *.log 12 | *rtx* 13 | *.pdf 14 | *.mkv 15 | *.mp4 16 | *a40* 17 | *durip* 18 | testing*.sh 19 | *.ipynb -------------------------------------------------------------------------------- /whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch==1.13.1 3 | torchaudio==0.13.1 4 | torchvision==0.14.1 5 | tqdm 6 | more-itertools 7 | transformers>=4.19.0 8 | ffmpeg-python==0.2.0 9 | opencc 10 | jieba 11 | editdistance 12 | pandas 13 | inflect 14 | sacrebleu 15 | profanityfilter 16 | ftfy 17 | regex 18 | av -------------------------------------------------------------------------------- /whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="whisper", 8 | py_modules=["whisper"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | entry_points = { 20 | 'console_scripts': ['whisper=whisper.transcribe:cli'], 21 | }, 22 | include_package_data=True, 23 | extras_require={'dev': ['pytest']}, 24 | ) 25 | -------------------------------------------------------------------------------- /scripts/libritrans.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate pw 4 | export CUDA_VISIBLE_DEVICES=2 5 | 6 | 7 | dataset="libritrans" 8 | model="tiny" 9 | dataset_dir="path/to/libritrans" 10 | core_metric="bleu" 11 | split="dev" 12 | language="fr" 13 | logit_mask=${language} 14 | vocab_cap=0.5 15 | # for unconstraint gen (no vocab constaint), don't pass logit_mask 16 | 17 | echo "currently testing ${model}" 18 | exp_name="${language}_${model}_${split}" 19 | python ../csasr_st.py \ 20 | --language ${language} \ 21 | --logit_mask ${logit_mask} \ 22 | --vocab_cap ${vocab_cap} \ 23 | --data_split ${split} \ 24 | --model ${model} \ 25 | --dataset ${dataset} \ 26 | --dataset_dir ${dataset_dir} \ 27 | --core_metric ${core_metric} \ 28 | --beam_size 5 \ 29 | --topk 1000 \ 30 | --task transcribe 31 | # >> "./logs/${dataset}/${exp_name}.log" 2>&1 -------------------------------------------------------------------------------- /scripts/mustcv1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate pw 4 | export CUDA_VISIBLE_DEVICES=2 5 | 6 | 7 | dataset="mustcv1" 8 | model="tiny" 9 | dataset_dir="path/to/mustcv1" 10 | core_metric="bleu" 11 | split="dev" 12 | language="ru" 13 | logit_mask=${language} 14 | vocab_cap=0.5 15 | # for unconstraint gen (no vocab constaint), don't pass logit_mask 16 | # for language ru, we directly constrain vocab using script, so vocab_cap won't be needed 17 | echo "currently testing ${model}" 18 | exp_name="${language}_${model}_${split}" 19 | python ../csasr_st.py \ 20 | --language ${language} \ 21 | --logit_mask ${logit_mask} \ 22 | --vocab_cap ${vocab_cap} \ 23 | --data_split ${split} \ 24 | --model ${model} \ 25 | --dataset ${dataset} \ 26 | --dataset_dir ${dataset_dir} \ 27 | --core_metric ${core_metric} \ 28 | --beam_size 5 \ 29 | --topk 1000 \ 30 | --task transcribe 31 | # >> "./logs/${dataset}/${exp_name}.log" 2>&1 -------------------------------------------------------------------------------- /scripts/covost2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate pw 4 | export CUDA_VISIBLE_DEVICES=2 5 | 6 | 7 | dataset="covost2" 8 | model="tiny" 9 | dataset_dir="path/to/covost2" 10 | core_metric="bleu" 11 | split="dev" 12 | language="de" 13 | logit_mask=${language} 14 | vocab_cap=0.5 15 | # for unconstraint gen (no vocab constaint), don't pass logit_mask 16 | # for language zh and ar, we directly constrain vocab using script, so vocab_cap won't be needed 17 | echo "currently testing ${model}" 18 | exp_name="${language}_${model}_${split}" 19 | python ../csasr_st.py \ 20 | --language ${language} \ 21 | --logit_mask ${logit_mask} \ 22 | --vocab_cap ${vocab_cap} \ 23 | --data_split ${split} \ 24 | --model ${model} \ 25 | --dataset ${dataset} \ 26 | --dataset_dir ${dataset_dir} \ 27 | --core_metric ${core_metric} \ 28 | --beam_size 5 \ 29 | --topk 1000 \ 30 | --task transcribe 31 | # >> "./logs/${dataset}/${exp_name}.log" 2>&1 -------------------------------------------------------------------------------- /scripts/visspeech.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate pw 4 | export CUDA_VISIBLE_DEVICES=2 5 | 6 | 7 | dataset="visspeech" 8 | model="medium.en" 9 | dataset_dir="path/to/visspeech/data" 10 | core_metric="wer" 11 | pk="0" 12 | ok="50" 13 | num_img=3 14 | socratic="1" # "1" mean also input visual prompt utilizing CLIP, "0" mean audio only 15 | 16 | mkdir -p logs/${dataset} 17 | echo "currently testing ${model} pk ${pk} ok ${ok}" 18 | exp_name="${model}_placesk${pk}_objectk${ok}" 19 | python ../avsr.py \ 20 | --place_topk $pk \ 21 | --obj_topk $ok \ 22 | --socratic $socratic \ 23 | --language "en" \ 24 | --num_img ${num_img} \ 25 | --model ${model} \ 26 | --dataset ${dataset} \ 27 | --dataset_dir ${dataset_dir} \ 28 | --core_metric ${core_metric} \ 29 | --batch_size 32 \ 30 | --beam_size 5 \ 31 | --topk 600 \ 32 | --task transcribe \ 33 | --object_txt_fn 'path/to/place_and_object/dictionary_and_semantic_hierarchy.txt' \ 34 | --place_txt_fn 'path/to/place_and_object/categories_places365.txt' \ 35 | --object_pkl_fn "path/to/place_and_object/tencent_336.pkl" 36 | --place_pkl_fn "path/to/place_and_object/places365_336.pkl" >> "./logs/${dataset}/${exp_name}.log" 2>&1 37 | -------------------------------------------------------------------------------- /scripts/ascend.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate pw 4 | export CUDA_VISIBLE_DEVICES=2 5 | 6 | 7 | dataset="ascend" 8 | model="tiny" 9 | dataset_dir="path/to/ascend/ASCEND" 10 | split="dev" 11 | core_metric="mer" 12 | single_lang_threshold=0.9 13 | concat_lang_token=1 14 | code_switching="zh-en" 15 | # need both concat_lang_token to be 1 and code_switching to be "zh-en" to enable lang concat in the prompt 16 | # only turn code-switching to be "zh-en" will do normal whisper LID to select language token for the prompt 17 | # if code-switching is "0", you should pass in a language token e.g. "zh", and we will therefore use this for all utterances 18 | mkdir -p ./logs/${dataset} 19 | 20 | echo "currently testing ${model}" 21 | exp_name="${model}_${split}" 22 | python ../csasr_st.py \ 23 | --data_split ${split} \ 24 | --single_lang_threshold ${single_lang_threshold} \ 25 | --concat_lang_token ${concat_lang_token} \ 26 | --code_switching ${code_switching} \ 27 | --model ${model} \ 28 | --dataset ${dataset} \ 29 | --dataset_dir ${dataset_dir} \ 30 | --core_metric ${core_metric} \ 31 | --beam_size 5 \ 32 | --topk 1000 \ 33 | --task transcribe 34 | # >> "./logs/${dataset}/${exp_name}.log" 2>&1 -------------------------------------------------------------------------------- /scripts/seame.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate pw 4 | export CUDA_VISIBLE_DEVICES=2 5 | 6 | 7 | dataset="seame" 8 | model="tiny" 9 | dataset_dir="path/to/seame/seame/data" 10 | core_metric="mer" 11 | split="valid" # for seame, it should be valid, devsge or devman, the later two are usually treated as test set in the literature 12 | single_lang_threshold=1 13 | concat_lang_token=1 14 | code_switching="zh-en" 15 | # need both concat_lang_token to be 1 and code_switching to be "zh-en" to enable lang concat in the prompt 16 | # only turn code-switching to be "zh-en" will do normal whisper LID to select language token for the prompt 17 | # if code-switching is "0", you should pass in a language token e.g. "zh", and we will therefore use this for all utterances 18 | mkdir -p ./logs/${dataset} 19 | 20 | echo "currently testing ${model}" 21 | exp_name="${model}_${split}" 22 | python ../csasr_st.py \ 23 | --data_split ${split} \ 24 | --single_lang_threshold ${single_lang_threshold} \ 25 | --concat_lang_token ${concat_lang_token} \ 26 | --code_switching ${code_switching} \ 27 | --model ${model} \ 28 | --dataset ${dataset} \ 29 | --dataset_dir ${dataset_dir} \ 30 | --core_metric ${core_metric} \ 31 | --beam_size 5 \ 32 | --topk 1000 \ 33 | --task transcribe 34 | # >> "./logs/${dataset}/${exp_name}.log" 2>&1 -------------------------------------------------------------------------------- /whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) 52 | ) 53 | 54 | 55 | class BasicTextNormalizer: 56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols 58 | self.split_letters = split_letters 59 | 60 | def __call__(self, s: str): 61 | s = s.lower() 62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 64 | s = self.clean(s).lower() 65 | 66 | if self.split_letters: 67 | s = " ".join(regex.findall(r"\X", s, regex.U)) 68 | 69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /whisper/utils.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | from typing import Iterator, TextIO 3 | 4 | 5 | def exact_div(x, y): 6 | assert x % y == 0 7 | return x // y 8 | 9 | 10 | def str2bool(string): 11 | str2val = {"True": True, "False": False} 12 | if string in str2val: 13 | return str2val[string] 14 | else: 15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 16 | 17 | 18 | def optional_int(string): 19 | return None if string == "None" else int(string) 20 | 21 | 22 | def optional_float(string): 23 | return None if string == "None" else float(string) 24 | 25 | 26 | def compression_ratio(text) -> float: 27 | return len(text) / len(zlib.compress(text.encode("utf-8"))) 28 | 29 | 30 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): 31 | assert seconds >= 0, "non-negative timestamp expected" 32 | milliseconds = round(seconds * 1000.0) 33 | 34 | hours = milliseconds // 3_600_000 35 | milliseconds -= hours * 3_600_000 36 | 37 | minutes = milliseconds // 60_000 38 | milliseconds -= minutes * 60_000 39 | 40 | seconds = milliseconds // 1_000 41 | milliseconds -= seconds * 1_000 42 | 43 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 44 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 45 | 46 | 47 | def write_txt(transcript: Iterator[dict], file: TextIO): 48 | for segment in transcript: 49 | print(segment['text'].strip(), file=file, flush=True) 50 | 51 | 52 | def write_vtt(transcript: Iterator[dict], file: TextIO): 53 | print("WEBVTT\n", file=file) 54 | for segment in transcript: 55 | print( 56 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 57 | f"{segment['text'].strip().replace('-->', '->')}\n", 58 | file=file, 59 | flush=True, 60 | ) 61 | 62 | 63 | def write_srt(transcript: Iterator[dict], file: TextIO): 64 | """ 65 | Write a transcript to a file in SRT format. 66 | 67 | Example usage: 68 | from pathlib import Path 69 | from whisper.utils import write_srt 70 | 71 | result = transcribe(model, audio_path, temperature=temperature, **args) 72 | 73 | # save SRT 74 | audio_basename = Path(audio_path).stem 75 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 76 | write_srt(result["segments"], file=srt) 77 | """ 78 | for i, segment in enumerate(transcript, start=1): 79 | # write srt lines 80 | print( 81 | f"{i}\n" 82 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 83 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 84 | f"{segment['text'].strip().replace('-->', '->')}\n", 85 | file=file, 86 | flush=True, 87 | ) 88 | -------------------------------------------------------------------------------- /whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 20 | 21 | 22 | def load_audio(file: str, sr: int = SAMPLE_RATE): 23 | """ 24 | Open an audio file and read as mono waveform, resampling as necessary 25 | 26 | Parameters 27 | ---------- 28 | file: str 29 | The audio file to open 30 | 31 | sr: int 32 | The sample rate to resample the audio if necessary 33 | 34 | Returns 35 | ------- 36 | A NumPy array containing the audio waveform, in float32 dtype. 37 | """ 38 | try: 39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 41 | out, _ = ( 42 | ffmpeg.input(file, threads=0) 43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 45 | ) 46 | except ffmpeg.Error as e: 47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 48 | 49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 50 | 51 | 52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 53 | """ 54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 55 | """ 56 | if torch.is_tensor(array): 57 | if array.shape[axis] > length: 58 | # array = array.index_select(dim=axis, index=torch.arange(length)) # this will give error that not all tensor on same device, cuda:0 and cpu 59 | array = array[...,:length] 60 | 61 | if array.shape[axis] < length: 62 | pad_widths = [(0, 0)] * array.ndim 63 | pad_widths[axis] = (0, length - array.shape[axis]) 64 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 65 | else: 66 | if array.shape[axis] > length: 67 | array = array.take(indices=range(length), axis=axis) 68 | # array = array[...,:length] 69 | 70 | if array.shape[axis] < length: 71 | pad_widths = [(0, 0)] * array.ndim 72 | pad_widths[axis] = (0, length - array.shape[axis]) 73 | array = np.pad(array, pad_widths) 74 | 75 | return array 76 | 77 | 78 | @lru_cache(maxsize=None) 79 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 80 | """ 81 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 82 | Allows decoupling librosa dependency; saved using: 83 | 84 | np.savez_compressed( 85 | "mel_filters.npz", 86 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 87 | ) 88 | """ 89 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 90 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: 91 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 92 | 93 | 94 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): 95 | """ 96 | Compute the log-Mel spectrogram of 97 | 98 | Parameters 99 | ---------- 100 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 101 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 102 | 103 | n_mels: int 104 | The number of Mel-frequency filters, only 80 is supported 105 | 106 | Returns 107 | ------- 108 | torch.Tensor, shape = (80, n_frames) 109 | A Tensor that contains the Mel spectrogram 110 | """ 111 | if not torch.is_tensor(audio): 112 | if isinstance(audio, str): 113 | audio = load_audio(audio) 114 | audio = torch.from_numpy(audio) 115 | 116 | window = torch.hann_window(N_FFT).to(audio.device) 117 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 118 | magnitudes = stft[:, :-1].abs() ** 2 119 | 120 | filters = mel_filters(audio.device, n_mels) 121 | mel_spec = filters @ magnitudes 122 | 123 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 124 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 125 | log_spec = (log_spec + 4.0) / 4.0 126 | return log_spec 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Prompting the Hidden Talent of Web-Scale Speech Models for Zero-Shot Task Generalization 3 | This is the official codebase for paper [Prompting the Hidden Talent of Web-Scale Speech Models 4 | for Zero-Shot Task Generalization](https://arxiv.org/abs/2305.11095). 5 | 6 | ``` 7 | @inproceedings{peng2023whisper, 8 | title={Prompting the Hidden Talent of Web-Scale Speech Models for Zero-Shot Task Generalization}, 9 | author={Peng, Puyuan and Yan, Brian and Watanabe, Shinji and Harwath, David}, 10 | booktitle={Interspeech}, 11 | year={2023} 12 | } 13 | ``` 14 | 15 | # Table of Contents 16 | 1. [Environment](#1-environment) 17 | 2. [Audio Visual Speech Recognition](#2-audio-visual-speech-recognition) 18 | 3. [Code Switched Speech Recognition and Speech Translation](#3-code-switched-speech-recognition) 19 | 4. [Speech Translation](#4-speech-translation) 20 | 21 | 22 | # 1. Environment 23 | It is recommended to create a new conda environment for this project with `conda create -n pw python=3.9.16` 24 | 25 | ```bash 26 | conda activate pw 27 | pip install torch==1.13.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 28 | pip install transformers ffmpeg-python OpenCC jieba editdistance pandas inflect sacrebleu more-itertools 29 | 30 | # for avsr only 31 | pip install torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 32 | pip install profanityfilter 33 | pip install ftfy regex tqdm 34 | pip install git+https://github.com/openai/CLIP.git 35 | pip install av 36 | ``` 37 | 38 | In PromptingWhisper directory, run `pip install -e ./` 39 | 40 | # 2. Audio Visual Speech Recognition 41 | We tested whisper with different prompts on [VisSpeech](https://arxiv.org/abs/2206.07684) and [How2](https://arxiv.org/abs/1811.00347). Both datasets are collections of YouTube Videos. Since How2 was proposed a few years ago, a lot of videos are no longer available, and we didn't attempt to recover them. We randomly selected a 2000 subset of How2 and use it for hyperparameter tunning, and VisSpeech is the main dataset that we studied. 42 | 43 | The script for running AVSR on VisSpeech is provided at `./script/visspeech.sh`. To run the script, please download the VisSpeech [metafile](https://gabeur.github.io/data/VisSpeech.zip) and videos. Put them in `/path/to/visspeech`. In addition, we make use of [Places365 categories](https://github.com/CSAILVision/places365/blob/master/categories_places365.txt) and [Tencent ML-images categories](https://github.com/Tencent/tencent-ml-images/blob/master/data/dictionary_and_semantic_hierarchy.txt). Please also use corresponding link to download the txt file. Change the path to data and txt files accordingly in `./script/visspeech.sh`, and 44 | 45 | ```bash 46 | cd scripts 47 | bash visspeech.sh 48 | ``` 49 | 50 | NOTE: we observe that if your downloaded videos are of a lower quality, CLIP could perform worse on retrieving visual prompts, which leads to higher WER. Therefore we recommend downloading the videos in as high resolution as possible. Our video downloading setting (for [yt-dlp](https://github.com/yt-dlp/yt-dlp)) is `bestvideo[height<=720]+bestaudio/best[height<=720]` and in `.mkv` format. We use [David Xu's code](https://github.com/DavidXu9000/yt-dl) for downloading 51 | 52 | # 3. Code Switched Speech Recognition 53 | For code-switched speech recognition (CS-ASR) we use [ASCEND](https://arxiv.org/abs/2112.06223) and [SEAME](https://www.isca-speech.org/archive/pdfs/interspeech_2010/lyu10_interspeech.pdf). ASECEND can be obtained following the [official codebase](https://github.com/HLTCHKUST/ASCEND), and SEAME can be obtained through LDC [here](https://catalog.ldc.upenn.edu/LDC2015S04). 54 | 55 | For ASCEND, put the downloaded dataset at `/path/to/ascend/ASCEND`, and run `ascend.sh` in `scripts` folder with the corresponding path changed. Make sure to checkout the instructions in `ascend.sh` on how to enable the `concat` prompt. 56 | 57 | For SEAME, we followed [this ESPnet receipe](https://github.com/espnet/espnet/tree/master/egs2/seame/asr1) to prepare the dataset, and put the data at `/path/to/seame/seame/data`. Run `seame.sh` with corresponding path changed. Also make sure to checkout the instructions in `seame.sh` on how to enable the `concat` prompt. 58 | 59 | # 4. Speech Translation 60 | We prompt Whisper for En->X translation on three datasets, [COVOST2](https://github.com/facebookresearch/covost) (Arabic, Mandain Chinese, German, Catalan), [MuST-C V1](https://ict.fbk.eu/must-c/) (German, Russian), and [Libri-Trans](https://github.com/alicank/Translation-Augmented-LibriSpeech-Corpus) (French). The data preparation for the three datasets should be relatively simple, just following the instructions in the link (please let me know if you encounter any difficulties). Run `covost2.sh`, `mustcv1.sh` and `libritrans.sh` in the `scripts` folder and please also check the instruction in those .sh files for vocabulary constraint generation 61 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def MyParser(): 5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | 7 | parser.add_argument("--seed", type=int, default=1) 8 | parser.add_argument("--num_workers", type=int, default=8) 9 | parser.add_argument("--data_split", type=str, default="dev", help="val or dev all means development set or validation set, test means test set. hyperparameter tuning should be done on validation set. for seame it should be valid or devsge or devman, the later two are treated as test set in the literature") 10 | parser.add_argument("--batch_size", type=int, default=64, help="this is just for dataloader, the model forward is still with batch size == 1") 11 | parser.add_argument("--sample_rate", type=int, default=16000, help='target sample rate needs to be 16000 (fixed by whisper), if audio native sample is not this, will resample') 12 | parser.add_argument("--audio_max_length", type=int, default=480000, help="30sec * 16000, input needs to be of length 30 sec. (needs mask if not, don't know what result would be), don't have to be anymore") 13 | parser.add_argument("--text_max_length", type=int, default=120, help='this is not used') 14 | parser.add_argument("--padding_idx", type=int, default=-100, help="this is not used") 15 | 16 | parser.add_argument("--model", type=str, choices=['tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large', 'largev2']) 17 | parser.add_argument("--whisper_root", type=str, default="/saltpool0/scratch/pyp/whisper/pretrained_models") 18 | parser.add_argument("--dataset", type=str, help="e.g. 'ascend', 'seame', 'covost2'") 19 | parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file") 20 | parser.add_argument("--core_metric", type=str, choices=['cer', 'wer', 'mer', 'bleu']) 21 | parser.add_argument("--task", type=str, choices=['transcribe', 'translate'], help="note that this is the task token of Whisper, not zero-shot tasks that we studied") 22 | parser.add_argument("--topk", type=int, default=100, help="print the top k worst pred and ref") 23 | parser.add_argument("--beam_size", type=int, default=None, help="if None, use greedy decoding") 24 | parser.add_argument("--block_ngrams", nargs="+", type=int, default=[], help="block repeated ngrams, if [], no blocking") 25 | parser.add_argument("--language", type=str, help="en, ar, zh, ru, etc. in the case of ST, the language token indicate ") 26 | parser.add_argument("--code_switching", type=str, default='0', help='0 means no code switching, for mandarin english cs speech, put zh-en. if concat_lang_token is specified as 1, we will insert both and in the prompt. If put en-zh, will insert and i.e. different order. We found zh-en to work better on both ascend and seame') 27 | parser.add_argument("--single_lang_threshold", type=float, default=0.8, help="if the probability of language detector result is bigger than equal to this number, use single language token even if we have specified --concat_lang_token") 28 | parser.add_argument("--concat_lang_token", type=int, default=0, help="if true, will use both two language tokens for the code switching input") 29 | parser.add_argument("--logit_mask", type=str, default="0", help="if not None, mask out the output logit to contraint the output vocabulary, currently might only support zh") 30 | parser.add_argument("--vocab_cap", type=float, default=0.7, help="for speech translation for now, only allow to generate tokens that has top vocab_cap frequency in the training set") 31 | 32 | # AVSR specific 33 | parser.add_argument("--socratic", type=str, default="0", help="whether use clip to detect place and object, and input them in the prompt of the decoder, 0 means no, 1 means yes. the name socratic comes from https://arxiv.org/abs/2204.00598") 34 | parser.add_argument("--num_img", type=int, default=3, help="number of images we sample from the video, which are later used for CLIP places and objects detection") 35 | parser.add_argument("--place_topk", type=int, default=0, help="we find it to be unhelpful") 36 | parser.add_argument("--obj_topk", type=int, default=50, help="a surprisingly large amount of obj can be very helpful") 37 | parser.add_argument("--object_txt_fn", type=str, default='/data/scratch/pyp/exp_pyp/whisper/place_and_object/dictionary_and_semantic_hierarchy.txt', help="this is downloaded") 38 | parser.add_argument("--place_txt_fn", type=str, default='/data/scratch/pyp/exp_pyp/whisper/place_and_object/categories_places365.txt', help="this is downloaded") 39 | parser.add_argument("--object_pkl_fn", type=str, default="/data/scratch/pyp/exp_pyp/whisper/place_and_object/tencent_336.pkl", help="CLIP embedding of tencent objects text, if not exist, running avsr.py will automatically run CLIP embedding on the downloaded txt, and store the results") 40 | parser.add_argument("--place_pkl_fn", type=str, default="/data/scratch/pyp/exp_pyp/whisper/place_and_object/places365_336.pkl", help="CLIP embedding of places365 text, if not exist, running avsr.py will automatically run CLIP embedding on the downloaded txt, and store the results") 41 | 42 | 43 | 44 | return parser -------------------------------------------------------------------------------- /whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import Whisper, ModelDimensions 14 | from .transcribe import transcribe 15 | from .normalizers import BasicTextNormalizer, EnglishTextNormalizer 16 | 17 | 18 | _MODELS = { 19 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 20 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 21 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 22 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 23 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 24 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 25 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 26 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 27 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt", 28 | "largev2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 29 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt" 30 | } 31 | 32 | 33 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 34 | os.makedirs(root, exist_ok=True) 35 | 36 | expected_sha256 = url.split("/")[-2] 37 | download_target = os.path.join(root, os.path.basename(url)) 38 | 39 | if os.path.exists(download_target) and not os.path.isfile(download_target): 40 | raise RuntimeError(f"{download_target} exists and is not a regular file") 41 | 42 | if os.path.isfile(download_target): 43 | model_bytes = open(download_target, "rb").read() 44 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 45 | return model_bytes if in_memory else download_target 46 | else: 47 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 48 | 49 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 50 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 51 | while True: 52 | buffer = source.read(8192) 53 | if not buffer: 54 | break 55 | 56 | output.write(buffer) 57 | loop.update(len(buffer)) 58 | 59 | model_bytes = open(download_target, "rb").read() 60 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 61 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") 62 | 63 | return model_bytes if in_memory else download_target 64 | 65 | 66 | def available_models() -> List[str]: 67 | """Returns the names of available models""" 68 | return list(_MODELS.keys()) 69 | 70 | 71 | def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: 72 | """ 73 | Load a Whisper ASR model 74 | 75 | Parameters 76 | ---------- 77 | name : str 78 | one of the official model names listed by `whisper.available_models()`, or 79 | path to a model checkpoint containing the model dimensions and the model state_dict. 80 | device : Union[str, torch.device] 81 | the PyTorch device to put the model into 82 | download_root: str 83 | path to download the model files; by default, it uses "~/.cache/whisper" 84 | in_memory: bool 85 | whether to preload the model weights into host memory 86 | 87 | Returns 88 | ------- 89 | model : Whisper 90 | The Whisper ASR model instance 91 | """ 92 | 93 | if device is None: 94 | device = "cuda" if torch.cuda.is_available() else "cpu" 95 | if download_root is None: 96 | download_root = os.getenv( 97 | "XDG_CACHE_HOME", 98 | os.path.join(os.path.expanduser("~"), ".cache", "whisper") 99 | ) 100 | print("model weights is downloaded to ", download_root) 101 | if name in _MODELS: 102 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 103 | elif os.path.isfile(name): 104 | checkpoint_file = open(name, "rb").read() if in_memory else name 105 | else: 106 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 107 | with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: 108 | checkpoint = torch.load(fp, map_location=device) 109 | del checkpoint_file 110 | 111 | dims = ModelDimensions(**checkpoint["dims"]) 112 | model = Whisper(dims) 113 | model.load_state_dict(checkpoint["model_state_dict"]) 114 | 115 | return model.to(device) 116 | -------------------------------------------------------------------------------- /data/visspeech.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import whisper 7 | import torchaudio.transforms as at 8 | 9 | import csv 10 | import editdistance 11 | import av 12 | 13 | 14 | class calc_metrics: 15 | def __init__(self): 16 | pass 17 | def __call__(self, refs, preds): 18 | """ 19 | refs are output from dataloader, so uses the collate fn, that already contains the normalization 20 | preds are the output of whisper tokenizer, which doesn't have dataset specific normalization 21 | 22 | they should both in list (list of list) 23 | """ 24 | distance = 0 25 | tokens = 0 26 | wer_list = [] 27 | processed_preds = [] 28 | processed_refs = [] 29 | exclude = [",", "?", ".", "!", ";"] 30 | for ref, pred in zip(refs, preds): 31 | pred = pred.lower() 32 | pred = ''.join(ch for ch in pred if ch not in exclude) 33 | processed_preds.append(pred) 34 | processed_refs.append(ref) # do not process ref 35 | cur_dist =editdistance.distance(pred.split(" "), ref.split(" ")) 36 | cur_tokens = len(ref.split(" ")) 37 | wer_list.append(cur_dist/cur_tokens) 38 | distance += cur_dist 39 | tokens += cur_tokens 40 | 41 | return {"wer":distance/tokens}, (wer_list, processed_preds, processed_refs) 42 | 43 | 44 | 45 | 46 | def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor: 47 | with av.open(wave_path, metadata_errors="ignore") as container: 48 | decode = container.decode(audio=0) 49 | first_frame = next(decode) 50 | cur_sample_rate = first_frame.sample_rate 51 | aframes_list = [first_frame.to_ndarray()] 52 | for frame in decode: 53 | aframes_list.append(frame.to_ndarray()) 54 | aframes = np.concatenate(aframes_list, 1) 55 | wav = torch.as_tensor(aframes).mean(dim=0) 56 | if cur_sample_rate != sample_rate: 57 | wav = at.Resample(cur_sample_rate, sample_rate, dtype=wav.dtype)(wav) 58 | if wav.mean() == 0: 59 | print(wave_path, "empty!") 60 | return wav 61 | 62 | def load_img(fn, num_img): 63 | if fn.endswith(".mkv"): 64 | img_fn = fn.replace(".mkv", f"-{num_img}.pt") 65 | elif fn.endswith(".mp4"): 66 | img_fn = fn.replace(".mp4", f"-{num_img}.pt") 67 | else: 68 | raise RuntimeError(f"video_fn extension not supported: {fn}") 69 | if os.path.isfile(img_fn): 70 | ret_frames = torch.load(img_fn, map_location="cpu") 71 | else: 72 | with av.open(fn, metadata_errors="ignore") as container: 73 | all_frames = [frame.to_image() for frame in container.decode(video=0)] 74 | mul = len(all_frames) // num_img 75 | ret_frames = [torch.from_numpy(np.array(f.convert("RGB"), dtype=np.float32)) for f in all_frames[::mul][:num_img]] 76 | ret_frames = torch.stack(ret_frames, dim=0) 77 | ret_frames = ret_frames.permute(0, 3, 1, 2) / 255.0 78 | torch.save(ret_frames, img_fn) 79 | return ret_frames 80 | 81 | class VisSpeechDataset(torch.utils.data.Dataset): 82 | def __init__(self, args, split, sample_rate): 83 | super().__init__() 84 | self.split = split 85 | self.args = args 86 | self.sample_rate = sample_rate 87 | self.data = [] 88 | with open(Path(args.dataset_dir)/"VisSpeech.csv", "r") as file: 89 | csv_file = csv.reader(file) 90 | header = next(csv_file) 91 | missing = [] 92 | for i, item in enumerate(csv_file): 93 | key,yt_id,start_time,end_time,text = item 94 | fn = Path(args.dataset_dir)/f"{key}.mkv" 95 | if fn.is_file(): 96 | self.data.append([fn, text]) 97 | else: 98 | fn = Path(str(fn).replace(".mkv", ".mp4")) 99 | assert fn.is_file(), f"{fn} doesn't exist!" 100 | self.data.append([fn, text]) 101 | 102 | print(f"expacting {i+1} files, and get {len(self.data)} files") 103 | print(f"missing: {missing}") 104 | 105 | 106 | def __len__(self): 107 | return len(self.data) 108 | 109 | def __getitem__(self, id): 110 | audio_path, raw_text = self.data[id] 111 | 112 | # audio 113 | audio = load_wave(str(audio_path), sample_rate=self.sample_rate) 114 | audio = whisper.pad_or_trim(audio) 115 | mel = whisper.log_mel_spectrogram(audio) 116 | 117 | if self.args.socratic == "1": 118 | imgs = load_img(str(audio_path), num_img=self.args.num_img) 119 | else: 120 | imgs = None 121 | return { 122 | "audio_path": audio_path, 123 | "input_mel": mel, 124 | "imgs": imgs, 125 | "raw_text": raw_text 126 | } 127 | 128 | def collate(self, batch): 129 | audio_paths, input_mels, imgs, raw_text = [], [], [], [] 130 | for f in batch: 131 | audio_paths.append(f['audio_path']) 132 | input_mels.append(f["input_mel"]) 133 | imgs.append(f['imgs']) 134 | raw_text.append(f['raw_text']) 135 | 136 | 137 | input_mels = torch.stack(input_mels, dim=0) 138 | 139 | collated_batch = {} 140 | collated_batch["input_mels"] = input_mels 141 | collated_batch["audio_paths"] = audio_paths 142 | collated_batch["imgs"] = imgs 143 | collated_batch["raw_text"] = raw_text 144 | 145 | return collated_batch 146 | 147 | 148 | def get_dataloader(args): 149 | dataset = VisSpeechDataset(args, "test", args.sample_rate) # there is only one split, only test 150 | print("dataset size: ", len(dataset)) 151 | loader = torch.utils.data.DataLoader(dataset, 152 | batch_size=args.batch_size, drop_last=False, shuffle=False, 153 | num_workers=args.num_workers, 154 | collate_fn=dataset.collate, persistent_workers=True 155 | ) 156 | 157 | return loader -------------------------------------------------------------------------------- /data/ascend.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | import torch 5 | import whisper 6 | import torchaudio 7 | import torchaudio.transforms as at 8 | 9 | import csv 10 | import opencc 11 | from itertools import chain 12 | 13 | 14 | ##======== from eval.py and utils.py of https://github.com/HLTCHKUST/ASCEND ========## 15 | import re 16 | import jieba 17 | import editdistance 18 | ##### 19 | # Common Functions 20 | ##### 21 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞", 22 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", 23 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", 24 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", 25 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 26 | 27 | ##### 28 | # Metric Helper Functions 29 | ##### 30 | def tokenize_for_mer(text): 31 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, jieba.lcut(text))) 32 | tokens = [[tok] if tok.isascii() else list(tok) for tok in tokens] 33 | return list(chain(*tokens)) 34 | 35 | def tokenize_for_cer(text): 36 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, list(text))) 37 | return tokens 38 | 39 | # below is added to data processing pipeline, but actually the data doesn't contains the CHARS_TO_IGNORE (chekced in ascend_example.ipynb), so only need to do this for whisper prediction 40 | 41 | chars_to_ignore_re = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" 42 | def remove_special_characters(text): 43 | if chars_to_ignore_re is not None: 44 | return re.sub(chars_to_ignore_re, "", text).lower() 45 | else: 46 | return text.lower() 47 | 48 | class calc_metrics: 49 | # this follow the official evaluation code https://github.com/HLTCHKUST/ASCEND/blob/main/eval.py 50 | def __init__(self): 51 | self.converter = opencc.OpenCC('t2s.json') 52 | def __call__(self, refs, preds): 53 | """ 54 | refs are output from dataloader, so uses the collate fn, that already contains the normalization 55 | preds are the output of whisper tokenizer, which doesn't have dataset specific normalization 56 | 57 | they should both in list (list of list) 58 | """ 59 | mixed_distance = 0 60 | mixed_tokens = 0 61 | char_distance = 0 62 | char_tokens = 0 63 | mer_list = [] 64 | processed_preds = [] 65 | processed_refs = [] 66 | for ref, pred in zip(refs, preds): 67 | pred = remove_special_characters(self.converter.convert(pred)) 68 | ref = remove_special_characters(ref) 69 | 70 | processed_preds.append(pred) 71 | processed_refs.append(ref) 72 | 73 | m_pred = tokenize_for_mer(pred) 74 | m_ref = tokenize_for_mer(ref) 75 | cur_dist = editdistance.distance(m_pred, m_ref) 76 | cur_tokens = len(m_ref) 77 | mer_list.append(cur_dist/cur_tokens) 78 | mixed_distance += cur_dist 79 | mixed_tokens += cur_tokens 80 | 81 | c_pred = tokenize_for_cer(pred) 82 | c_ref = tokenize_for_cer(ref) 83 | char_distance += editdistance.distance(c_pred, c_ref) 84 | char_tokens += len(c_ref) 85 | 86 | return {"cer":char_distance/char_tokens, "mer": mixed_distance/mixed_tokens}, (mer_list, processed_preds, processed_refs) 87 | 88 | def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor: 89 | waveform, sr = torchaudio.load(wave_path, normalize=True) # normalization is not required, but since spectrogram is extracted, whether or not normalizing doesn't make a difference 90 | if sample_rate != sr: 91 | waveform = at.Resample(sr, sample_rate)(waveform) 92 | return waveform 93 | 94 | class ASCENDDataset(torch.utils.data.Dataset): 95 | def __init__(self, args, split, sample_rate): 96 | super().__init__() 97 | self.args = args 98 | self.sample_rate = sample_rate 99 | self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task="transcribe") 100 | self.tokenizer_en = whisper.tokenizer.get_tokenizer(True, language="en", task="transcribe") 101 | self.data = [] 102 | with open(Path(args.dataset_dir)/f"{split}_metadata.csv", "r") as f: 103 | file = csv.reader(f) 104 | header = next(file) 105 | self.data = [line[:4] for line in file] # path, text, duration, language 106 | print(f"pad audio to {self.args.audio_max_length/16000} seconds") 107 | 108 | 109 | def __len__(self): 110 | return len(self.data) 111 | 112 | def __getitem__(self, id): 113 | cur_path, raw_text, duration, language = self.data[id] 114 | audio_path = Path(self.args.dataset_dir)/cur_path 115 | 116 | # audio 117 | audio = load_wave(audio_path, sample_rate=self.sample_rate) 118 | audio = whisper.pad_or_trim(audio.flatten(), length=self.args.audio_max_length) 119 | mel = whisper.log_mel_spectrogram(audio) 120 | return { 121 | "audio_path": audio_path, 122 | "input_mel": mel, 123 | "raw_text": raw_text 124 | } 125 | def collate(self, batch): 126 | audio_paths, input_mels, raw_text = [], [], [] 127 | for f in batch: 128 | raw_text.append(f['raw_text']) 129 | audio_paths.append(f['audio_path']) 130 | input_mels.append(f["input_mel"]) 131 | 132 | input_mels = torch.stack(input_mels, dim=0) 133 | 134 | collated_batch = {} 135 | collated_batch["input_mels"] = input_mels 136 | collated_batch["audio_paths"] = audio_paths 137 | collated_batch["raw_text"] = raw_text 138 | 139 | return collated_batch 140 | 141 | 142 | def get_dataloader(args): 143 | tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True, language="zh", task=args.task) 144 | dataset = ASCENDDataset(args, "validation" if args.data_split in ['dev', 'val'] else "test", args.sample_rate) 145 | print("dataset size: ", len(dataset)) 146 | loader = torch.utils.data.DataLoader(dataset, 147 | batch_size=args.batch_size, drop_last=False, shuffle=False, 148 | num_workers=args.num_workers, 149 | collate_fn=dataset.collate, persistent_workers=True 150 | ) 151 | 152 | return tokenizer, loader -------------------------------------------------------------------------------- /data/mustcv1.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | import yaml 6 | import torch 7 | import whisper 8 | import torchaudio 9 | import torchaudio.transforms as at 10 | import os 11 | from itertools import chain 12 | 13 | import re 14 | import jieba 15 | from sacrebleu import BLEU 16 | ##### 17 | # Common Functions 18 | ##### 19 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞","؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")","{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。","、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 20 | zh2en = {",": ",", "。": ".", "?":"?", "!":"!", ";": ";", "‘": "'", ":": ":", "’":"'", "(":"(", ")":")", "【": "[", "】": "]", "~":"~"} 21 | en2zh = {} 22 | for key in zh2en: 23 | en2zh[zh2en[key]] = key 24 | ##### 25 | # Metric Helper Functions 26 | ##### 27 | def tokenize_for_mer(text): 28 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, jieba.lcut(text))) 29 | tokens = [[tok] if tok.isascii() else list(tok) for tok in tokens] 30 | return list(chain(*tokens)) 31 | 32 | def tokenize_for_cer(text): 33 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, list(text))) 34 | return tokens 35 | 36 | # below is added to data processing pipeline, but actually the data doesn't contains the CHARS_TO_IGNORE (chekced in ascend_example.ipynb), so only need to do this for whisper prediction 37 | 38 | chars_to_ignore_re = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" 39 | def remove_special_characters(text): 40 | if chars_to_ignore_re is not None: 41 | return re.sub(chars_to_ignore_re, "", text).lower() 42 | else: 43 | return text.lower() 44 | 45 | 46 | def replace(item): 47 | return item if item not in en2zh else en2zh[item] 48 | class calc_metrics: 49 | def __init__(self): 50 | pass 51 | def __call__(self, refs, preds): 52 | """ 53 | refs are output from dataloader, so uses the collate fn, that already contains the normalization 54 | preds are the output of whisper tokenizer, which doesn't have dataset specific normalization 55 | 56 | they should both in list (list of list) 57 | """ 58 | ref4bleu = [[]] 59 | pred4bleu = [] 60 | bleu_fn = BLEU() 61 | sentence_blue = [] 62 | sentence_blue_fn = BLEU(effective_order=True) 63 | for ref, pred in zip(refs, preds): 64 | if len(ref) > 0: 65 | pred4bleu.append(pred) 66 | ref4bleu[0].append(ref) 67 | sentence_blue.append(sentence_blue_fn.sentence_score(pred, [ref]).score) 68 | 69 | bleu = bleu_fn.corpus_score(pred4bleu, ref4bleu) 70 | return {"bleu": bleu}, (sentence_blue, pred4bleu, ref4bleu[0]) 71 | 72 | def load_wave(wave_path, sample_rate:int=16000, start:float=-1., end:float=-1.) -> torch.Tensor: 73 | if start == -1.: 74 | waveform, sr = torchaudio.load(wave_path, normalize=True) 75 | else: 76 | metadata = torchaudio.info(wave_path) 77 | sr = metadata.sample_rate 78 | start_frame, end_frame = int(round(sr*start)), int(round(sr*end)) 79 | waveform, sr = torchaudio.load(filepath=wave_path, frame_offset=max(0,start_frame-1), num_frames=end_frame-start_frame, normalize=True) 80 | assert (waveform.shape[-1]/sr - (end-start))*(waveform.shape[-1]/sr - (end-start)) < 64, f"loaded waveform should have duration: {(end-start)}s, but it has duration {waveform.shape[-1]/sr}s" 81 | if sample_rate != sr: 82 | waveform = at.Resample(sr, sample_rate)(waveform) 83 | return waveform 84 | 85 | class MuSTCV1Dataset(torch.utils.data.Dataset): 86 | def __init__(self, args, split, sample_rate): 87 | super().__init__() 88 | self.args = args 89 | self.sample_rate = sample_rate 90 | self.tokenizer = whisper.tokenizer.get_tokenizer(True, language=args.language, task="transcribe") 91 | self.data = [] 92 | fn_dir = os.path.join(args.dataset_dir, f"en-{args.language}", "data", "tst-COMMON") 93 | all_wav_fn = os.path.join(fn_dir, "txt", "tst-COMMON.yaml") 94 | all_trans_fn = os.path.join(fn_dir, "txt", f"tst-COMMON.{args.language}") 95 | with open(all_trans_fn, "r") as f, open(all_wav_fn, "r") as g: 96 | all_trans = [l.strip() for l in f.readlines()] 97 | all_wav = yaml.load(g, Loader = yaml.FullLoader) 98 | for trans, wavitem in zip(all_trans, all_wav): 99 | start = float(wavitem['offset']) 100 | end = start + float(wavitem['duration']) 101 | wav_fn = os.path.join(fn_dir, "wav", wavitem['wav']) 102 | self.data.append([wav_fn, start, end, trans]) 103 | print(f"pad audio to {self.args.audio_max_length/16000} seconds") 104 | 105 | 106 | def __len__(self): 107 | return len(self.data) 108 | 109 | def __getitem__(self, id): 110 | cur_path, start, end, raw_text = self.data[id] 111 | audio_path = cur_path 112 | 113 | # audio 114 | audio = load_wave(audio_path, sample_rate=self.sample_rate, start=start, end=end) 115 | audio = whisper.pad_or_trim(audio.flatten(), length=self.args.audio_max_length) 116 | mel = whisper.log_mel_spectrogram(audio) 117 | 118 | return { 119 | "audio_path": audio_path, 120 | "input_mel": mel, 121 | "raw_text": raw_text 122 | } 123 | def collate(self, batch): 124 | audio_paths, input_mels, raw_text = [], [], [] 125 | for f in batch: 126 | raw_text.append(f['raw_text']) 127 | audio_paths.append(f['audio_path']) 128 | input_mels.append(f["input_mel"]) 129 | 130 | input_mels = torch.stack(input_mels, dim=0) 131 | collated_batch = {} 132 | 133 | collated_batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in collated_batch.items()} 134 | collated_batch["input_mels"] = input_mels 135 | collated_batch["audio_paths"] = audio_paths 136 | collated_batch["raw_text"] = raw_text 137 | 138 | return collated_batch 139 | 140 | 141 | def get_dataloader(args): 142 | tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True, language=args.language, task=args.task) 143 | dataset = MuSTCV1Dataset(args, "dev" if args.data_split in ['dev', 'val'] else "test", args.sample_rate) # split doesn't make a difference, will always on tst-COMMON, as we are not tuning any hyperparam on this dataset 144 | print("dataset size: ", len(dataset)) 145 | loader = torch.utils.data.DataLoader(dataset, 146 | batch_size=args.batch_size, 147 | drop_last=False, shuffle=False, num_workers=args.num_workers, 148 | collate_fn=dataset.collate, persistent_workers=True 149 | ) 150 | return tokenizer, loader -------------------------------------------------------------------------------- /data/seame.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | 6 | import torch 7 | import whisper 8 | import torchaudio 9 | import torchaudio.transforms as at 10 | 11 | 12 | import opencc 13 | 14 | 15 | ##======== from eval.py and utils.py of https://github.com/HLTCHKUST/ASCEND ========## 16 | import re 17 | import editdistance 18 | import inflect # convert numbers to words 19 | ##### 20 | # Common Functions 21 | ##### 22 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞", 23 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", 24 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", 25 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", 26 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 27 | 28 | import regex 29 | def tokenize_for_mer(text): 30 | reg_range = r"[\u4e00-\ufaff]|[0-9]+|[a-zA-Z]+\'*[a-z]*" 31 | matches = re.findall(reg_range, text, re.UNICODE) 32 | p = inflect.engine() 33 | res = [] 34 | for item in matches: 35 | try: 36 | temp = p.number_to_words(item) if (item.isnumeric() and len(regex.findall(r'\p{Han}+', item)) == 0) else item 37 | except: 38 | temp = item 39 | res.append(temp) 40 | return res 41 | 42 | def tokenize_for_cer(text): 43 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, list(text))) 44 | return tokens 45 | 46 | 47 | chars_to_ignore_re = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" 48 | def remove_special_characters(text): 49 | if chars_to_ignore_re is not None: 50 | return re.sub(chars_to_ignore_re, "", text).lower() 51 | else: 52 | return text.lower() 53 | 54 | class calc_metrics: 55 | def __init__(self): 56 | self.converter = opencc.OpenCC('t2s.json') 57 | # pass 58 | def __call__(self, refs, preds): 59 | """ 60 | refs are output from dataloader, so uses the collate fn, that already contains the normalization 61 | preds are the output of whisper tokenizer, which doesn't have dataset specific normalization 62 | 63 | they should both in list (list of list) 64 | """ 65 | mixed_distance = 0 66 | mixed_tokens = 0 67 | char_distance = 0 68 | char_tokens = 0 69 | mer_list = [] 70 | processed_preds = [] 71 | processed_refs = [] 72 | for ref, pred in zip(refs, preds): 73 | pred = remove_special_characters(self.converter.convert(pred)) 74 | 75 | 76 | m_pred = tokenize_for_mer(pred) 77 | processed_preds.append(" ".join(m_pred)) 78 | processed_refs.append(ref) 79 | # m_ref = tokenize_for_mer(ref) 80 | m_ref = ref.split(" ") 81 | cur_dist = editdistance.distance(m_pred, m_ref) 82 | cur_tokens = len(m_ref) 83 | mer_list.append(cur_dist/cur_tokens) 84 | mixed_distance += cur_dist 85 | mixed_tokens += cur_tokens 86 | 87 | 88 | return {"mer": mixed_distance/mixed_tokens}, (mer_list, processed_preds, processed_refs) 89 | 90 | def load_wave(wave_path, sample_rate:int=16000, start:float=-1., end:float=-1.) -> torch.Tensor: 91 | if start == -1.: 92 | waveform, sr = torchaudio.load(wave_path, normalize=True) 93 | else: 94 | metadata = torchaudio.info(wave_path) 95 | sr = metadata.sample_rate 96 | start_frame, end_frame = int(round(sr*start)), int(round(sr*end)) 97 | waveform, sr = torchaudio.load(filepath=wave_path, frame_offset=max(0,start_frame-1), num_frames=end_frame-start_frame, normalize=True) 98 | assert (waveform.shape[-1]/sr - (end-start))*(waveform.shape[-1]/sr - (end-start)) < 64, f"loaded waveform should have duration: {(end-start)}s, but it has duration {waveform.shape[-1]/sr}s" 99 | if sample_rate != sr: 100 | waveform = at.Resample(sr, sample_rate)(waveform) 101 | return waveform 102 | 103 | 104 | class SEAMEDataset(torch.utils.data.Dataset): 105 | def __init__(self, args, split, sample_rate): 106 | super().__init__() 107 | self.split = split 108 | self.args = args 109 | self.sample_rate = sample_rate 110 | self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task="transcribe") 111 | self.data = [] 112 | assert self.split in ['valid', 'devsge', 'devman'], self.split 113 | with open(Path(args.dataset_dir)/"espnet_prep_data"/f"{split}"/"segments", "r") as segs, open(Path(args.dataset_dir)/"espnet_prep_data"/f"{split}"/"text.clean", "r") as trans: 114 | for seg, tran in zip(segs.readlines(),trans.readlines()): 115 | seg_name_a, wav_name, start_time, end_time = seg.strip().split() 116 | temp = tran.strip().split(" ") 117 | seg_name_b, text = temp[0], " ".join(temp[1:]) 118 | assert seg_name_a == seg_name_b, f"wav order in segments file and txt file doesn't match" 119 | wav_name = wav_name.upper() + ".flac" 120 | if self.split == "train": 121 | audio_len = float(end_time) - float(start_time) 122 | text_len = len(self.tokenizer.encode(text)) 123 | if audio_len*16000 > self.args.audio_max_length or text_len > self.args.text_max_length: 124 | continue 125 | self.data.append([Path(args.dataset_dir)/"all_audio"/wav_name, float(start_time), float(end_time), text]) 126 | 127 | 128 | 129 | def __len__(self): 130 | return len(self.data) 131 | 132 | def __getitem__(self, id): 133 | audio_path, start_time, end_time, raw_text = self.data[id] 134 | audio = load_wave(audio_path, sample_rate=self.sample_rate, start=start_time, end=end_time) 135 | audio = whisper.pad_or_trim(audio.flatten()) 136 | mel = whisper.log_mel_spectrogram(audio) 137 | 138 | return { 139 | "raw_text": raw_text, 140 | "audio_path": audio_path, 141 | "input_mel": mel 142 | } 143 | 144 | def collate(self, batch): 145 | raw_text, audio_paths, input_mels = [], [], [] 146 | for f in batch: 147 | raw_text.append(f['raw_text']) 148 | audio_paths.append(f['audio_path']) 149 | input_mels.append(f["input_mel"]) 150 | 151 | input_mels = torch.stack(input_mels, dim=0) 152 | 153 | collated_batch = {} 154 | collated_batch["input_mels"] = input_mels 155 | collated_batch["audio_paths"] = audio_paths 156 | collated_batch['raw_text'] = raw_text 157 | 158 | return collated_batch 159 | 160 | 161 | def get_dataloader(args): 162 | tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True, language="zh", task=args.task) 163 | dataset = SEAMEDataset(args, args.data_split, args.sample_rate) 164 | print("dataset size: ", len(dataset)) 165 | loader = torch.utils.data.DataLoader(dataset, 166 | batch_size=args.batch_size, 167 | drop_last=False, shuffle=False, num_workers=args.num_workers, 168 | collate_fn=dataset.collate, persistent_workers=True 169 | ) 170 | 171 | return tokenizer, loader -------------------------------------------------------------------------------- /csasr_st.py: -------------------------------------------------------------------------------- 1 | import re, random 2 | import torch 3 | 4 | from collections import Counter 5 | import csv, os 6 | import pandas as pd 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import regex 11 | 12 | from config import MyParser 13 | import whisper 14 | punc = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞","؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 15 | 16 | 17 | if __name__ == "__main__": 18 | torch.cuda.empty_cache() 19 | args = MyParser().parse_args() 20 | print(args) 21 | 22 | # seed everything 23 | random.seed(args.seed) 24 | np.random.seed(args.seed) 25 | torch.manual_seed(args.seed) 26 | torch.cuda.manual_seed_all(args.seed) 27 | 28 | 29 | # both dataset and post-processing are dataset specific, so all done in ${dataset}.py 30 | if args.dataset == "ascend": 31 | from data.ascend import get_dataloader, calc_metrics 32 | elif args.dataset == "seame": 33 | from data.seame import get_dataloader, calc_metrics 34 | elif args.dataset == "covost2": 35 | from data.covost2 import get_dataloader 36 | if args.language == "zh": 37 | from data.covost2 import calc_metrics_zh as calc_metrics 38 | else: 39 | from data.covost2 import calc_metrics_ar as calc_metrics 40 | elif args.dataset == "libritrans": 41 | from data.libritrans import get_dataloader, calc_metrics 42 | elif args.dataset == "mustcv1": 43 | from data.mustcv1 import get_dataloader, calc_metrics 44 | ################################### 45 | 46 | tokenizer, data_loader = get_dataloader(args) 47 | model = whisper.load_model(args.model) 48 | 49 | model.eval() 50 | model.cuda() 51 | 52 | 53 | if args.logit_mask != "0": 54 | def construct(lang, path): 55 | local_tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True, language=lang, task="transcribe") 56 | counter = Counter() 57 | if args.dataset == "covost2": 58 | data = pd.read_csv(path, sep="\t", header=0, encoding="utf-8", escapechar="\\", quoting=csv.QUOTE_NONE, na_filter=False) 59 | elif args.dataset == "libritrans" or args.dataset == "mustcv1": 60 | with open(path, "r") as ff: 61 | all_trans = [l for l in ff.readlines()] 62 | data = {'translation': all_trans} 63 | for text in data['translation']: 64 | tokens = local_tokenizer.encode(text.strip()) 65 | counter.update(tokens) 66 | del data 67 | return counter 68 | 69 | if args.language == "zh": 70 | lang_in = "zh-CN" 71 | else: 72 | lang_in = args.language 73 | if args.dataset == "covost2": 74 | path=f"{args.dataset_dir}/metadata/covost_v2.en_{lang_in}.train.tsv" 75 | elif args.dataset == "libritrans": 76 | path=f"{args.dataset_dir}/train/train.fr" 77 | elif args.dataset == "mustcv1": 78 | path = f"{args.dataset_dir}/en-{lang_in}/data/train/txt/train.{lang_in}" 79 | if not os.path.isfile(path): 80 | path = path.replace("/data/scratch/", "/data3/scratch/") # handle rtx path 81 | if not os.path.isfile(path): 82 | path = path.replace("/data3/scratch/", "/scratch/cluster/") 83 | 84 | 85 | # construct vocab 86 | counter = construct(args.language, path) 87 | 88 | # only allow the most frequent tokens 89 | n_vocab = model.dims.n_vocab 90 | cap_p = getattr(args, "vocab_cap", 0.7) 91 | cap_n = round(len(counter)*cap_p) 92 | constraint_ind = [item[0] for item in counter.most_common(cap_n)] 93 | special_inds = list(tokenizer.tokenizer.get_added_vocab().values()) 94 | constraint_ind += special_inds # add the indices of the special tokens 95 | 96 | # redo constraint for zh ar and ru as we can constrain the output script 97 | if args.language == "zh" or args.language == "ar" or args.language == "ru": 98 | lang2range = {"zh": r"[\u4e00-\ufaff]", "ar": r"[\u0600-\u06ff]"} 99 | constraint_ind = [] 100 | 101 | 102 | for i in range(n_vocab): 103 | decoding_res = tokenizer.decode(i) 104 | if args.language == "ru": 105 | constraint = regex.findall(r'\p{Cyrillic}+', decoding_res) 106 | else: 107 | constraint_reg_range = lang2range[args.language] 108 | constraint = re.findall(constraint_reg_range, decoding_res, re.UNICODE) 109 | if len(decoding_res) > 0 and len(constraint) > 0: 110 | constraint_ind.append(i) 111 | constraint_ind += list(tokenizer.tokenizer.get_added_vocab().values()) # add the indices of the special tokens 112 | 113 | # # control whether outputting punctuations 114 | punc2ind = {} 115 | for p in punc: 116 | punc2ind[p] = tokenizer.encode(p) 117 | pind = np.unique(list(punc2ind.values())).tolist() 118 | for p in pind: 119 | constraint_ind += p 120 | 121 | constraint_ind = np.unique(constraint_ind).tolist() 122 | 123 | logit_mask = torch.ones((1, n_vocab)) * -1000000. 124 | logit_mask[:, constraint_ind] = 0.0 125 | print(f"allowed vocab: {args.language} scripts") 126 | print(f"total vocab size: {n_vocab}, allowed vocab size: {len(constraint_ind)}") 127 | else: 128 | logit_mask = None 129 | 130 | 131 | refs = [] 132 | preds = [] 133 | single_preds = [] 134 | prompts = [] 135 | 136 | for i, b in enumerate(tqdm(data_loader)): 137 | input_mels = b["input_mels"].half().cuda() 138 | raw_texts = b['raw_text'] 139 | with torch.no_grad(): 140 | 141 | # for input_mel, label in zip(input_mels, labels): 142 | for input_mel, raw_text in zip(input_mels, raw_texts): 143 | if args.code_switching != "0": 144 | main_lang, second_lang = args.code_switching.split("-") 145 | _, probs = whisper.detect_language(model, input_mel) 146 | max_lang = max(probs, key=probs.get) 147 | prob = probs[max_lang] 148 | 149 | if max_lang == main_lang: 150 | lang = main_lang 151 | elif max_lang == second_lang: 152 | lang = second_lang 153 | else: # Whisper language identification is not working well, assigning main_lang as the language 154 | lang = main_lang 155 | options = whisper.DecodingOptions(task=args.task, language=lang, without_timestamps=True, beam_size=args.beam_size, block_ngrams=args.block_ngrams, concat_lang_token=args.code_switching if (args.concat_lang_token != 0 and prob < args.single_lang_threshold) else "0", logit_mask=logit_mask) 156 | else: 157 | options = whisper.DecodingOptions(task=args.task, language=args.language, without_timestamps=True, beam_size=args.beam_size, block_ngrams=args.block_ngrams, concat_lang_token="0", logit_mask=logit_mask) 158 | with torch.no_grad(): 159 | results = whisper.decode(model, input_mel, options) 160 | preds.append(results.text) 161 | ref = raw_text 162 | refs.append(ref) 163 | 164 | 165 | inference_metrics, (wer_list, processed_preds, processed_refs) = calc_metrics()(refs, preds) 166 | print("results:", inference_metrics) 167 | print("results:", inference_metrics) 168 | # in the case of speech translation, the metric is actually BLUE score 169 | if args.topk > 0: 170 | import numpy as np 171 | inds = np.argsort(wer_list)[::-1] 172 | for ind in inds[:args.topk]: 173 | print("-"*10) 174 | print("wer/mer: ", wer_list[ind]) 175 | print("ref: ", processed_refs[ind]) 176 | print("pred: ", processed_preds[ind]) 177 | # print("prompt: ", prompts[ind]) 178 | else: 179 | for j, (k, v) in enumerate(zip(processed_refs, processed_preds)): 180 | if j % 100 == 0: 181 | print("-"*10) 182 | print("ref: ", k) 183 | print("pred: ", v) 184 | 185 | print("results:", inference_metrics) 186 | print("results:", inference_metrics) 187 | -------------------------------------------------------------------------------- /data/libritrans.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | 6 | import torch 7 | import whisper 8 | import torchaudio 9 | import torchaudio.transforms as at 10 | import os 11 | from itertools import chain 12 | 13 | 14 | import re 15 | import jieba 16 | 17 | from sacrebleu import BLEU 18 | ##### 19 | # Common Functions 20 | ##### 21 | 22 | LANGUAGES = { 23 | "en": "english", 24 | "zh": "chinese", 25 | "de": "german", 26 | "es": "spanish", 27 | "ru": "russian", 28 | "ko": "korean", 29 | "fr": "french", 30 | "ja": "japanese", 31 | "pt": "portuguese", 32 | "tr": "turkish", 33 | "pl": "polish", 34 | "ca": "catalan", 35 | "nl": "dutch", 36 | "ar": "arabic", 37 | "sv": "swedish", 38 | "it": "italian", 39 | "id": "indonesian", 40 | "hi": "hindi", 41 | "fi": "finnish", 42 | "vi": "vietnamese", 43 | "iw": "hebrew", 44 | "uk": "ukrainian", 45 | "el": "greek", 46 | "ms": "malay", 47 | "cs": "czech", 48 | "ro": "romanian", 49 | "da": "danish", 50 | "hu": "hungarian", 51 | "ta": "tamil", 52 | "no": "norwegian", 53 | "th": "thai", 54 | "ur": "urdu", 55 | "hr": "croatian", 56 | "bg": "bulgarian", 57 | "lt": "lithuanian", 58 | "la": "latin", 59 | "mi": "maori", 60 | "ml": "malayalam", 61 | "cy": "welsh", 62 | "sk": "slovak", 63 | "te": "telugu", 64 | "fa": "persian", 65 | "lv": "latvian", 66 | "bn": "bengali", 67 | "sr": "serbian", 68 | "az": "azerbaijani", 69 | "sl": "slovenian", 70 | "kn": "kannada", 71 | "et": "estonian", 72 | "mk": "macedonian", 73 | "br": "breton", 74 | "eu": "basque", 75 | "is": "icelandic", 76 | "hy": "armenian", 77 | "ne": "nepali", 78 | "mn": "mongolian", 79 | "bs": "bosnian", 80 | "kk": "kazakh", 81 | "sq": "albanian", 82 | "sw": "swahili", 83 | "gl": "galician", 84 | "mr": "marathi", 85 | "pa": "punjabi", 86 | "si": "sinhala", 87 | "km": "khmer", 88 | "sn": "shona", 89 | "yo": "yoruba", 90 | "so": "somali", 91 | "af": "afrikaans", 92 | "oc": "occitan", 93 | "ka": "georgian", 94 | "be": "belarusian", 95 | "tg": "tajik", 96 | "sd": "sindhi", 97 | "gu": "gujarati", 98 | "am": "amharic", 99 | "yi": "yiddish", 100 | "lo": "lao", 101 | "uz": "uzbek", 102 | "fo": "faroese", 103 | "ht": "haitian creole", 104 | "ps": "pashto", 105 | "tk": "turkmen", 106 | "nn": "nynorsk", 107 | "mt": "maltese", 108 | "sa": "sanskrit", 109 | "lb": "luxembourgish", 110 | "my": "myanmar", 111 | "bo": "tibetan", 112 | "tl": "tagalog", 113 | "mg": "malagasy", 114 | "as": "assamese", 115 | "tt": "tatar", 116 | "haw": "hawaiian", 117 | "ln": "lingala", 118 | "ha": "hausa", 119 | "ba": "bashkir", 120 | "jw": "javanese", 121 | "su": "sundanese", 122 | } 123 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞","؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")","{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。","、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 124 | zh2en = {",": ",", "。": ".", "?":"?", "!":"!", ";": ";", "‘": "'", ":": ":", "’":"'", "(":"(", ")":")", "【": "[", "】": "]", "~":"~"} 125 | en2zh = {} 126 | for key in zh2en: 127 | en2zh[zh2en[key]] = key 128 | ##### 129 | # Metric Helper Functions 130 | ##### 131 | def tokenize_for_mer(text): 132 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, jieba.lcut(text))) 133 | tokens = [[tok] if tok.isascii() else list(tok) for tok in tokens] 134 | return list(chain(*tokens)) 135 | 136 | def tokenize_for_cer(text): 137 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, list(text))) 138 | return tokens 139 | 140 | 141 | chars_to_ignore_re = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" 142 | def remove_special_characters(text): 143 | if chars_to_ignore_re is not None: 144 | return re.sub(chars_to_ignore_re, "", text).lower() 145 | else: 146 | return text.lower() 147 | 148 | 149 | def replace(item): 150 | return item if item not in en2zh else en2zh[item] 151 | 152 | class calc_metrics: 153 | def __init__(self): 154 | # self.converter = opencc.OpenCC('t2s.json') 155 | pass 156 | def __call__(self, refs, preds): 157 | """ 158 | refs are output from dataloader, so uses the collate fn, that already contains the normalization 159 | preds are the output of whisper tokenizer, which doesn't have dataset specific normalization 160 | 161 | they should both in list (list of list) 162 | """ 163 | 164 | ref4bleu = [[]] 165 | pred4bleu = [] 166 | bleu_fn = BLEU() 167 | sentence_blue = [] 168 | sentence_blue_fn = BLEU(effective_order=True) 169 | for ref, pred in zip(refs, preds): 170 | if len(ref) > 0: 171 | ref4bleu[0].append(ref) 172 | pred4bleu.append(pred) 173 | sentence_blue.append(sentence_blue_fn.sentence_score(pred, [ref]).score) 174 | 175 | bleu = bleu_fn.corpus_score(pred4bleu, ref4bleu) 176 | 177 | 178 | return {"bleu": bleu}, (sentence_blue, pred4bleu, ref4bleu[0]) 179 | 180 | 181 | def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor: 182 | waveform, sr = torchaudio.load(wave_path, normalize=True) 183 | if sample_rate != sr: 184 | waveform = at.Resample(sr, sample_rate)(waveform) 185 | return waveform 186 | 187 | class LibriTransDataset(torch.utils.data.Dataset): 188 | def __init__(self, args, split, sample_rate): 189 | super().__init__() 190 | self.args = args 191 | self.sample_rate = sample_rate 192 | self.tokenizer = whisper.tokenizer.get_tokenizer(True, language=args.language, task="transcribe") 193 | self.data = [] 194 | assert args.language in LANGUAGES, f"language {args.language} is not supported by whisper" 195 | print("running on libri-trans language:", LANGUAGES[args.language]) 196 | assert split in ["train", "dev", "test"], f"split {split} not in {['train', 'dev', 'test']}" 197 | lang = "zh-CN" if "zh" in args.language else args.language 198 | assert args.language == "fr", f"language needs to be fr, but it's {args.language}" 199 | for real_split in ['test', 'dev']: 200 | path = os.path.join(args.dataset_dir,real_split) 201 | with open(os.path.join(path, "alignments.meta"), "r") as f, open(os.path.join(path, f"{real_split}.fr"), "r") as g: 202 | all_flines = [l.strip().split("\t") for l in f.readlines()] 203 | all_flines = all_flines[1:] 204 | all_glines = [l.strip() for l in g.readlines()] 205 | assert len(all_flines) == len(all_glines), f"wav files length should equal to translation file length, but they are of length: {len(all_flines)}, and {len(all_glines)}" 206 | for fline, gline in zip(all_flines, all_glines): 207 | wav_fn = os.path.join(path, "audiofiles", fline[4] + ".wav") 208 | trans = gline 209 | self.data.append([wav_fn, None, trans]) 210 | print(f"pad audio to {self.args.audio_max_length/16000} seconds") 211 | 212 | 213 | def __len__(self): 214 | return len(self.data) 215 | 216 | def __getitem__(self, id): 217 | cur_path, raw_en, raw_text = self.data[id] 218 | audio_path = cur_path 219 | 220 | # audio 221 | audio = load_wave(audio_path, sample_rate=self.sample_rate) 222 | audio = whisper.pad_or_trim(audio.flatten(), length=self.args.audio_max_length) 223 | mel = whisper.log_mel_spectrogram(audio) 224 | return { 225 | "audio_path": audio_path, 226 | "input_mel": mel, 227 | "raw_text": raw_text, 228 | "raw_en": raw_en 229 | } 230 | def collate(self, batch): 231 | audio_paths, input_mels, raw_text, raw_en = [], [], [], [] 232 | for f in batch: 233 | raw_text.append(f['raw_text']) 234 | audio_paths.append(f['audio_path']) 235 | input_mels.append(f["input_mel"]) 236 | raw_en.append(f['raw_en']) 237 | 238 | input_mels = torch.stack(input_mels, dim=0) 239 | collated_batch = {} 240 | collated_batch["input_mels"] = input_mels 241 | collated_batch["audio_paths"] = audio_paths 242 | collated_batch["raw_text"] = raw_text 243 | collated_batch["raw_en"] = raw_en 244 | 245 | return collated_batch 246 | 247 | 248 | def get_dataloader(args): 249 | tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True, language=args.language, task=args.task) 250 | dataset = LibriTransDataset(args, "dev" if args.data_split in ['dev', 'val'] else "test", args.sample_rate) # split doesn't make a difference, will use deev+test, as we are not tuning any hyperparams on this dataset 251 | print("dataset size: ", len(dataset)) 252 | loader = torch.utils.data.DataLoader(dataset, 253 | batch_size=args.batch_size, 254 | drop_last=False, shuffle=False, num_workers=args.num_workers, 255 | collate_fn=dataset.collate, persistent_workers=True 256 | ) 257 | 258 | return tokenizer, loader -------------------------------------------------------------------------------- /data/covost2.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | import whisper 6 | import torchaudio 7 | import torchaudio.transforms as at 8 | import os 9 | import csv 10 | import opencc 11 | from itertools import chain 12 | 13 | 14 | import re 15 | import jieba 16 | import pandas as pd 17 | 18 | from sacrebleu import BLEU 19 | ##### 20 | # Common Functions 21 | ##### 22 | 23 | LANGUAGES = { 24 | "en": "english", 25 | "zh": "chinese", 26 | "de": "german", 27 | "es": "spanish", 28 | "ru": "russian", 29 | "ko": "korean", 30 | "fr": "french", 31 | "ja": "japanese", 32 | "pt": "portuguese", 33 | "tr": "turkish", 34 | "pl": "polish", 35 | "ca": "catalan", 36 | "nl": "dutch", 37 | "ar": "arabic", 38 | "sv": "swedish", 39 | "it": "italian", 40 | "id": "indonesian", 41 | "hi": "hindi", 42 | "fi": "finnish", 43 | "vi": "vietnamese", 44 | "iw": "hebrew", 45 | "uk": "ukrainian", 46 | "el": "greek", 47 | "ms": "malay", 48 | "cs": "czech", 49 | "ro": "romanian", 50 | "da": "danish", 51 | "hu": "hungarian", 52 | "ta": "tamil", 53 | "no": "norwegian", 54 | "th": "thai", 55 | "ur": "urdu", 56 | "hr": "croatian", 57 | "bg": "bulgarian", 58 | "lt": "lithuanian", 59 | "la": "latin", 60 | "mi": "maori", 61 | "ml": "malayalam", 62 | "cy": "welsh", 63 | "sk": "slovak", 64 | "te": "telugu", 65 | "fa": "persian", 66 | "lv": "latvian", 67 | "bn": "bengali", 68 | "sr": "serbian", 69 | "az": "azerbaijani", 70 | "sl": "slovenian", 71 | "kn": "kannada", 72 | "et": "estonian", 73 | "mk": "macedonian", 74 | "br": "breton", 75 | "eu": "basque", 76 | "is": "icelandic", 77 | "hy": "armenian", 78 | "ne": "nepali", 79 | "mn": "mongolian", 80 | "bs": "bosnian", 81 | "kk": "kazakh", 82 | "sq": "albanian", 83 | "sw": "swahili", 84 | "gl": "galician", 85 | "mr": "marathi", 86 | "pa": "punjabi", 87 | "si": "sinhala", 88 | "km": "khmer", 89 | "sn": "shona", 90 | "yo": "yoruba", 91 | "so": "somali", 92 | "af": "afrikaans", 93 | "oc": "occitan", 94 | "ka": "georgian", 95 | "be": "belarusian", 96 | "tg": "tajik", 97 | "sd": "sindhi", 98 | "gu": "gujarati", 99 | "am": "amharic", 100 | "yi": "yiddish", 101 | "lo": "lao", 102 | "uz": "uzbek", 103 | "fo": "faroese", 104 | "ht": "haitian creole", 105 | "ps": "pashto", 106 | "tk": "turkmen", 107 | "nn": "nynorsk", 108 | "mt": "maltese", 109 | "sa": "sanskrit", 110 | "lb": "luxembourgish", 111 | "my": "myanmar", 112 | "bo": "tibetan", 113 | "tl": "tagalog", 114 | "mg": "malagasy", 115 | "as": "assamese", 116 | "tt": "tatar", 117 | "haw": "hawaiian", 118 | "ln": "lingala", 119 | "ha": "hausa", 120 | "ba": "bashkir", 121 | "jw": "javanese", 122 | "su": "sundanese", 123 | } 124 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞","؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")","{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。","、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 125 | zh2en = {",": ",", "。": ".", "?":"?", "!":"!", ";": ";", "‘": "'", ":": ":", "’":"'", "(":"(", ")":")", "【": "[", "】": "]", "~":"~"} 126 | en2zh = {} 127 | for key in zh2en: 128 | en2zh[zh2en[key]] = key 129 | ##### 130 | # Metric Helper Functions 131 | ##### 132 | def tokenize_for_mer(text): 133 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, jieba.lcut(text))) 134 | tokens = [[tok] if tok.isascii() else list(tok) for tok in tokens] 135 | return list(chain(*tokens)) 136 | 137 | def tokenize_for_cer(text): 138 | tokens = list(filter(lambda tok: len(tok.strip()) > 0, list(text))) 139 | return tokens 140 | 141 | # below is added to data processing pipeline, but actually the data doesn't contains the CHARS_TO_IGNORE (chekced in ascend_example.ipynb), so only need to do this for whisper prediction 142 | 143 | chars_to_ignore_re = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" 144 | def remove_special_characters(text): 145 | if chars_to_ignore_re is not None: 146 | return re.sub(chars_to_ignore_re, "", text).lower() 147 | else: 148 | return text.lower() 149 | 150 | 151 | def replace(item): 152 | return item if item not in en2zh else en2zh[item] 153 | 154 | class calc_metrics_ar: 155 | def __init__(self): 156 | # self.converter = opencc.OpenCC('t2s.json') 157 | pass 158 | def __call__(self, refs, preds): 159 | """ 160 | refs are output from dataloader, so uses the collate fn, that already contains the normalization 161 | preds are the output of whisper tokenizer, which doesn't have dataset specific normalization 162 | 163 | they should both in list (list of list) 164 | """ 165 | 166 | ref4bleu = [[]] 167 | pred4bleu = [] 168 | bleu_fn = BLEU() 169 | sentence_blue = [] 170 | sentence_blue_fn = BLEU(effective_order=True) 171 | for ref, pred in zip(refs, preds): 172 | if len(ref) > 0: 173 | ref4bleu[0].append(ref) 174 | pred4bleu.append(pred) 175 | sentence_blue.append(sentence_blue_fn.sentence_score(pred, [ref]).score) 176 | 177 | bleu = bleu_fn.corpus_score(pred4bleu, ref4bleu) 178 | 179 | 180 | return {"bleu": bleu}, (sentence_blue, pred4bleu, ref4bleu[0]) 181 | 182 | 183 | class calc_metrics_zh: 184 | def __init__(self): 185 | self.converter = opencc.OpenCC('t2s.json') 186 | # pass 187 | def __call__(self, refs, preds): 188 | """ 189 | refs are output from dataloader, so uses the collate fn, that already contains the normalization 190 | preds are the output of whisper tokenizer, which doesn't have dataset specific normalization 191 | 192 | they should both in list (list of list) 193 | """ 194 | ref4bleu = [[]] 195 | pred4bleu = [] 196 | bleu_fn = BLEU(tokenize='zh') 197 | sentence_blue = [] 198 | sentence_blue_fn = BLEU(tokenize='zh', effective_order=True) 199 | for ref, pred in zip(refs, preds): 200 | if len(ref) > 0: 201 | pred = self.converter.convert(pred) 202 | pred4bleu.append(pred) 203 | ref4bleu[0].append(ref) 204 | sentence_blue.append(sentence_blue_fn.sentence_score(pred, [ref]).score) 205 | 206 | bleu = bleu_fn.corpus_score(pred4bleu, ref4bleu) 207 | return {"cer":0, "bleu": bleu, "meteor": 0}, (sentence_blue, pred4bleu, ref4bleu[0]) 208 | 209 | 210 | 211 | def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor: 212 | waveform, sr = torchaudio.load(wave_path, normalize=True) 213 | if sample_rate != sr: 214 | waveform = at.Resample(sr, sample_rate)(waveform) 215 | return waveform 216 | 217 | class COVOST2Dataset(torch.utils.data.Dataset): 218 | def __init__(self, args, split, sample_rate): 219 | super().__init__() 220 | self.args = args 221 | self.sample_rate = sample_rate 222 | self.tokenizer = whisper.tokenizer.get_tokenizer(True, language=args.language, task="transcribe") 223 | self.data = [] 224 | assert args.language in LANGUAGES, f"language {args.language} is not supported by whisper" 225 | print("running on covost2 language:", LANGUAGES[args.language]) 226 | assert split in ["train", "dev", "test"], f"split {split} not in {['train', 'dev', 'test']}" 227 | lang = "zh-CN" if "zh" in args.language else args.language 228 | path = os.path.join(args.dataset_dir, "metadata", f"covost_v2.en_{lang}.{split}.tsv") 229 | data = pd.read_csv(path, sep="\t", header=0, encoding="utf-8", escapechar="\\", quoting=csv.QUOTE_NONE, na_filter=False) 230 | for audio_fn, eng, trans in zip(data['path'], data['sentence'], data['translation']): 231 | self.data.append([os.path.join(args.dataset_dir, "clips", audio_fn), eng, trans]) 232 | del data 233 | 234 | if split == "dev": 235 | self.data = self.data[:5000] # to speed up development cycle 236 | print(f"pad audio to {self.args.audio_max_length/16000} seconds") 237 | 238 | 239 | def __len__(self): 240 | return len(self.data) 241 | 242 | def __getitem__(self, id): 243 | cur_path, raw_en, raw_text = self.data[id] 244 | audio_path = cur_path 245 | 246 | # audio 247 | audio = load_wave(audio_path, sample_rate=self.sample_rate) 248 | audio = whisper.pad_or_trim(audio.flatten(), length=self.args.audio_max_length) 249 | mel = whisper.log_mel_spectrogram(audio) 250 | 251 | return { 252 | "audio_path": audio_path, 253 | "input_mel": mel, 254 | "raw_text": raw_text, 255 | "raw_en": raw_en 256 | } 257 | def collate(self, batch): 258 | audio_paths, input_mels, raw_text, raw_en = [], [], [], [] 259 | for f in batch: 260 | raw_text.append(f['raw_text']) 261 | audio_paths.append(f['audio_path']) 262 | input_mels.append(f["input_mel"]) 263 | raw_en.append(f['raw_en']) 264 | 265 | input_mels = torch.stack(input_mels, dim=0) 266 | 267 | collated_batch = {} 268 | collated_batch["input_mels"] = input_mels 269 | collated_batch["audio_paths"] = audio_paths 270 | collated_batch["raw_text"] = raw_text 271 | collated_batch["raw_en"] = raw_en 272 | 273 | return collated_batch 274 | 275 | 276 | def get_dataloader(args): 277 | tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True, language=args.language, task=args.task) 278 | dataset = COVOST2Dataset(args, "dev" if args.data_split in ['dev', 'val'] else "test", args.sample_rate) 279 | print("dataset size: ", len(dataset)) 280 | loader = torch.utils.data.DataLoader(dataset, 281 | batch_size=args.batch_size, drop_last=False, shuffle=False, 282 | num_workers=args.num_workers, 283 | collate_fn=dataset.collate, persistent_workers=True 284 | ) 285 | 286 | return tokenizer, loader -------------------------------------------------------------------------------- /whisper/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Dict 4 | from typing import Iterable, Optional 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | from torch import nn 11 | 12 | from .transcribe import transcribe as transcribe_function 13 | from .decoding import detect_language as detect_language_function, decode as decode_function 14 | 15 | 16 | @dataclass 17 | class ModelDimensions: 18 | n_mels: int 19 | n_audio_ctx: int 20 | n_audio_state: int 21 | n_audio_head: int 22 | n_audio_layer: int 23 | n_vocab: int 24 | n_text_ctx: int 25 | n_text_state: int 26 | n_text_head: int 27 | n_text_layer: int 28 | 29 | 30 | class LayerNorm(nn.LayerNorm): 31 | def forward(self, x: Tensor) -> Tensor: 32 | return super().forward(x.float()).type(x.dtype) 33 | 34 | 35 | class Linear(nn.Linear): 36 | def forward(self, x: Tensor) -> Tensor: 37 | return F.linear( 38 | x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) 39 | ) 40 | 41 | 42 | class Conv1d(nn.Conv1d): 43 | def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: 44 | return super()._conv_forward( 45 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 46 | ) 47 | 48 | 49 | def sinusoids(length, channels, max_timescale=10000): 50 | """Returns sinusoids for positional embedding""" 51 | assert channels % 2 == 0 52 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 53 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 54 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 55 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 56 | 57 | 58 | class MultiHeadAttention(nn.Module): 59 | def __init__(self, n_state: int, n_head: int): 60 | super().__init__() 61 | self.n_head = n_head 62 | self.query = Linear(n_state, n_state) 63 | self.key = Linear(n_state, n_state, bias=False) 64 | self.value = Linear(n_state, n_state) 65 | self.out = Linear(n_state, n_state) 66 | 67 | def forward( 68 | self, 69 | x: Tensor, 70 | xa: Optional[Tensor] = None, 71 | mask: Optional[Tensor] = None, 72 | kv_cache: Optional[dict] = None, 73 | ): 74 | q = self.query(x) 75 | 76 | if kv_cache is None or xa is None: 77 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 78 | # otherwise, perform key/value projections for self- or cross-attention as usual. 79 | k = self.key(x if xa is None else xa) 80 | v = self.value(x if xa is None else xa) 81 | else: 82 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 83 | k = kv_cache.get(self.key, self.key(xa)) 84 | v = kv_cache.get(self.value, self.value(xa)) 85 | 86 | wv = self.qkv_attention(q, k, v, mask) 87 | return self.out(wv) 88 | 89 | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): 90 | n_batch, n_ctx, n_state = q.shape 91 | scale = (n_state // self.n_head) ** -0.25 92 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 93 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 94 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 95 | 96 | qk = q @ k 97 | if mask is not None: 98 | qk = qk + mask[:n_ctx, :n_ctx] 99 | 100 | w = F.softmax(qk.float(), dim=-1).to(q.dtype) 101 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 102 | 103 | 104 | class ResidualAttentionBlock(nn.Module): 105 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 106 | super().__init__() 107 | 108 | self.attn = MultiHeadAttention(n_state, n_head) 109 | self.attn_ln = LayerNorm(n_state) 110 | 111 | self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None 112 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 113 | 114 | n_mlp = n_state * 4 115 | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) 116 | self.mlp_ln = LayerNorm(n_state) 117 | 118 | def forward( 119 | self, 120 | x: Tensor, 121 | xa: Optional[Tensor] = None, 122 | mask: Optional[Tensor] = None, 123 | kv_cache: Optional[dict] = None, 124 | ): 125 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) 126 | if self.cross_attn: 127 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) 128 | x = x + self.mlp(self.mlp_ln(x)) 129 | return x 130 | 131 | 132 | class AudioEncoder(nn.Module): 133 | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 134 | super().__init__() 135 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 136 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 137 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 138 | 139 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 140 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 141 | ) 142 | self.ln_post = LayerNorm(n_state) 143 | 144 | def forward(self, x: Tensor): 145 | """ 146 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 147 | the mel spectrogram of the audio 148 | """ 149 | x = F.gelu(self.conv1(x)) 150 | x = F.gelu(self.conv2(x)) 151 | x = x.permute(0, 2, 1) 152 | 153 | ############### original ############### 154 | # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 155 | # x = (x + self.positional_embedding).to(x.dtype) 156 | ############### original ############### 157 | ############### change to ############### 158 | assert x.shape[1] <= self.positional_embedding.shape[0] and x.shape[2] == self.positional_embedding.shape[1], f"audio shape and positional embedding shape doesn't match: audio shape: {x.shape}, positional_embedding shape: {self.positional_embedding.shape}" 159 | x = (x + self.positional_embedding.unsqueeze(0)[:, :x.shape[1]]).to(x.dtype) 160 | ############### change to ############### 161 | for i, block in enumerate(self.blocks): 162 | x = block(x) 163 | 164 | x = self.ln_post(x) 165 | return x 166 | 167 | class TextDecoder(nn.Module): 168 | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 169 | super().__init__() 170 | 171 | self.token_embedding = nn.Embedding(n_vocab, n_state) 172 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 173 | 174 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 175 | [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] 176 | ) 177 | self.ln = LayerNorm(n_state) 178 | 179 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 180 | self.register_buffer("mask", mask, persistent=False) 181 | 182 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): 183 | """ 184 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 185 | the text tokens 186 | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) 187 | the encoded audio features to be attended on 188 | """ 189 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 190 | temp = self.token_embedding(x) 191 | input_length = x.shape[-1] 192 | x = temp + self.positional_embedding[offset : offset + input_length] 193 | x = x.to(xa.dtype) 194 | 195 | for block in self.blocks: 196 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 197 | 198 | x = self.ln(x) 199 | logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() 200 | 201 | return logits 202 | 203 | 204 | class Whisper(nn.Module): 205 | def __init__(self, dims: ModelDimensions): 206 | super().__init__() 207 | self.dims = dims 208 | self.encoder = AudioEncoder( 209 | self.dims.n_mels, 210 | self.dims.n_audio_ctx, 211 | self.dims.n_audio_state, 212 | self.dims.n_audio_head, 213 | self.dims.n_audio_layer, 214 | ) 215 | self.decoder = TextDecoder( 216 | self.dims.n_vocab, 217 | self.dims.n_text_ctx, 218 | self.dims.n_text_state, 219 | self.dims.n_text_head, 220 | self.dims.n_text_layer, 221 | ) 222 | 223 | def embed_audio(self, mel: torch.Tensor): 224 | return self.encoder.forward(mel) 225 | 226 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 227 | return self.decoder.forward(tokens, audio_features) 228 | 229 | def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: 230 | return self.decoder(tokens, self.encoder(mel)) 231 | 232 | @property 233 | def device(self): 234 | return next(self.parameters()).device 235 | 236 | @property 237 | def is_multilingual(self): 238 | return self.dims.n_vocab == 51865 239 | 240 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 241 | """ 242 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 243 | tensors calculated for the previous positions. This method returns a dictionary that stores 244 | all caches, and the necessary hooks for the key and value projection modules that save the 245 | intermediate tensors to be reused during later calculations. 246 | 247 | Returns 248 | ------- 249 | cache : Dict[nn.Module, torch.Tensor] 250 | A dictionary object mapping the key/value projection modules to its cache 251 | hooks : List[RemovableHandle] 252 | List of PyTorch RemovableHandle objects to stop the hooks to be called 253 | """ 254 | cache = {**cache} if cache is not None else {} 255 | hooks = [] 256 | 257 | def save_to_cache(module, _, output): 258 | if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: 259 | cache[module] = output # save as-is, for the first token or cross attention 260 | else: 261 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 262 | return cache[module] 263 | 264 | def install_hooks(layer: nn.Module): 265 | if isinstance(layer, MultiHeadAttention): 266 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 267 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 268 | 269 | self.decoder.apply(install_hooks) 270 | return cache, hooks 271 | 272 | detect_language = detect_language_function 273 | transcribe = transcribe_function 274 | decode = decode_function -------------------------------------------------------------------------------- /whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from functools import lru_cache 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from transformers import GPT2TokenizerFast 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "iw": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | } 111 | 112 | # language code lookup by name, with a few language aliases 113 | TO_LANGUAGE_CODE = { 114 | **{language: code for code, language in LANGUAGES.items()}, 115 | "burmese": "my", 116 | "valencian": "ca", 117 | "flemish": "nl", 118 | "haitian": "ht", 119 | "letzeburgesch": "lb", 120 | "pushto": "ps", 121 | "panjabi": "pa", 122 | "moldavian": "ro", 123 | "moldovan": "ro", 124 | "sinhalese": "si", 125 | "castilian": "es", 126 | } 127 | 128 | 129 | @dataclass(frozen=True) 130 | class Tokenizer: 131 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" 132 | 133 | tokenizer: "GPT2TokenizerFast" 134 | language: Optional[str] 135 | sot_sequence: Tuple[int] 136 | 137 | def encode(self, text, **kwargs): 138 | return self.tokenizer.encode(text, **kwargs) 139 | 140 | def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): 141 | return self.tokenizer.decode(token_ids, **kwargs) 142 | 143 | def decode_with_timestamps(self, tokens) -> str: 144 | """ 145 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 146 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 147 | """ 148 | outputs = [[]] 149 | for token in tokens: 150 | if token >= self.timestamp_begin: 151 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" 152 | outputs.append(timestamp) 153 | outputs.append([]) 154 | else: 155 | outputs[-1].append(token) 156 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] 157 | return "".join(outputs) 158 | 159 | @property 160 | @lru_cache() 161 | def eot(self) -> int: 162 | return self.tokenizer.eos_token_id 163 | 164 | @property 165 | @lru_cache() 166 | def sot(self) -> int: 167 | return self._get_single_token_id("<|startoftranscript|>") 168 | 169 | @property 170 | @lru_cache() 171 | def sot_lm(self) -> int: 172 | return self._get_single_token_id("<|startoflm|>") 173 | 174 | @property 175 | @lru_cache() 176 | def sot_prev(self) -> int: 177 | return self._get_single_token_id("<|startofprev|>") 178 | 179 | @property 180 | @lru_cache() 181 | def no_speech(self) -> int: 182 | return self._get_single_token_id("<|nospeech|>") 183 | 184 | @property 185 | @lru_cache() 186 | def no_timestamps(self) -> int: 187 | return self._get_single_token_id("<|notimestamps|>") 188 | 189 | @property 190 | @lru_cache() 191 | def timestamp_begin(self) -> int: 192 | return self.tokenizer.all_special_ids[-1] + 1 193 | 194 | @property 195 | @lru_cache() 196 | def language_token(self) -> int: 197 | """Returns the token id corresponding to the value of the `language` field""" 198 | if self.language is None: 199 | raise ValueError(f"This tokenizer does not have language token configured") 200 | 201 | additional_tokens = dict( 202 | zip( 203 | self.tokenizer.additional_special_tokens, 204 | self.tokenizer.additional_special_tokens_ids, 205 | ) 206 | ) 207 | candidate = f"<|{self.language}|>" 208 | if candidate in additional_tokens: 209 | return additional_tokens[candidate] 210 | 211 | raise KeyError(f"Language {self.language} not found in tokenizer.") 212 | 213 | @property 214 | @lru_cache() 215 | def all_language_tokens(self) -> Tuple[int]: 216 | result = [] 217 | for token, token_id in zip( 218 | self.tokenizer.additional_special_tokens, 219 | self.tokenizer.additional_special_tokens_ids, 220 | ): 221 | if token.strip("<|>") in LANGUAGES: 222 | result.append(token_id) 223 | return tuple(result) 224 | 225 | @property 226 | @lru_cache() 227 | def all_language_codes(self) -> Tuple[str]: 228 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) 229 | 230 | @property 231 | @lru_cache() 232 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 233 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 234 | 235 | @property 236 | @lru_cache() 237 | def non_speech_tokens(self) -> Tuple[int]: 238 | """ 239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 241 | 242 | - ♪♪♪ 243 | - ( SPEAKING FOREIGN LANGUAGE ) 244 | - [DAVID] Hey there, 245 | 246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 247 | """ 248 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") 249 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 250 | 251 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 252 | # In case they're multiple tokens, suppress the first token, which is safe because: 253 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 254 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 255 | miscellaneous = set("♩♪♫♬♭♮♯") 256 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 257 | 258 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 259 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} 260 | for symbol in symbols + list(miscellaneous): 261 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: 262 | if len(tokens) == 1 or symbol in miscellaneous: 263 | result.add(tokens[0]) 264 | 265 | return tuple(sorted(result)) 266 | 267 | def _get_single_token_id(self, text) -> int: 268 | tokens = self.tokenizer.encode(text) 269 | assert len(tokens) == 1, f"{text} is not encoded as a single token" 270 | return tokens[0] 271 | 272 | 273 | @lru_cache(maxsize=None) 274 | def build_tokenizer(name: str = "gpt2"): 275 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 276 | path = os.path.join(os.path.dirname(__file__), "assets", name) 277 | tokenizer = GPT2TokenizerFast.from_pretrained(path) 278 | 279 | specials = [ 280 | "<|startoftranscript|>", 281 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 282 | "<|translate|>", 283 | "<|transcribe|>", 284 | "<|startoflm|>", 285 | "<|startofprev|>", 286 | "<|nospeech|>", 287 | "<|notimestamps|>", 288 | ] 289 | 290 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 291 | return tokenizer 292 | 293 | 294 | @lru_cache(maxsize=None) 295 | def get_tokenizer( 296 | multilingual: bool, 297 | *, 298 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 299 | language: Optional[str] = None, 300 | concat_lang_token: Optional[str] = "0", # customized if 0, do nothing 301 | ) -> Tokenizer: 302 | if language is not None: 303 | language = language.lower() 304 | if language not in LANGUAGES: 305 | if language in TO_LANGUAGE_CODE: 306 | language = TO_LANGUAGE_CODE[language] 307 | else: 308 | raise ValueError(f"Unsupported language: {language}") 309 | 310 | if multilingual: 311 | tokenizer_name = "multilingual" 312 | task = task or "transcribe" 313 | language = language or "en" 314 | else: 315 | tokenizer_name = "gpt2" 316 | task = None 317 | language = None 318 | 319 | tokenizer = build_tokenizer(name=tokenizer_name) 320 | all_special_ids: List[int] = tokenizer.all_special_ids 321 | sot: int = all_special_ids[1] 322 | translate: int = all_special_ids[-6] 323 | transcribe: int = all_special_ids[-5] 324 | 325 | langs = tuple(LANGUAGES.keys()) 326 | sot_sequence = [sot] 327 | if concat_lang_token == "0": 328 | if language is not None: 329 | sot_sequence.append(sot + 1 + langs.index(language)) 330 | else: # use two language tokens 331 | main_lang, second_lang = concat_lang_token.split("-") 332 | sot_sequence.append(sot + 1 + langs.index(main_lang)) 333 | sot_sequence.append(sot + 1 + langs.index(second_lang)) 334 | 335 | if task is not None: 336 | sot_sequence.append(transcribe if task == "transcribe" else translate) 337 | 338 | return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) 339 | -------------------------------------------------------------------------------- /avsr.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import pickle 4 | import torch 5 | import random 6 | from tqdm import tqdm 7 | from profanityfilter import ProfanityFilter 8 | 9 | import numpy as np 10 | from config import MyParser 11 | import whisper 12 | import clip 13 | 14 | from torchvision.transforms import Compose, Resize, CenterCrop, Normalize 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | from PIL import Image 21 | BICUBIC = Image.BICUBIC 22 | 23 | 24 | if __name__ == "__main__": 25 | torch.cuda.empty_cache() 26 | args = MyParser().parse_args() 27 | print(args) 28 | # seed everything 29 | random.seed(args.seed) 30 | np.random.seed(args.seed) 31 | torch.manual_seed(args.seed) 32 | torch.cuda.manual_seed_all(args.seed) 33 | 34 | 35 | if args.dataset == "visspeech": 36 | from data.visspeech import get_dataloader, calc_metrics 37 | else: 38 | raise NotImplementedError(f"we don't support dataset {args.dataset} yet") 39 | 40 | ###################### CLIP textual feature embedding ###################### 41 | ###################### CLIP textual feature embedding ###################### 42 | ###################### CLIP textual feature embedding ###################### 43 | 44 | # clip_version = "ViT-L/14" #@param ["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px"] {type:"string"} 45 | clip_version = "ViT-L/14@336px" 46 | 47 | clip_feat_dim = {'RN50': 1024, 'RN101': 512, 'RN50x4': 640, 'RN50x16': 768, 'RN50x64': 1024, 'ViT-B/32': 512, 'ViT-B/16': 512, 'ViT-L/14': 768, "ViT-L/14@336px": 768}[clip_version] 48 | clip_img_res = {'ViT-L/14': 224, "ViT-L/14@336px": 336}[clip_version] 49 | 50 | if args.socratic == "1": 51 | clip_model, _ = clip.load(clip_version) # clip.available_models() 52 | preprocess = Compose([ 53 | Resize(clip_img_res, interpolation=BICUBIC), 54 | CenterCrop(clip_img_res), 55 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 56 | ]) 57 | clip_model.cuda().eval() 58 | 59 | def num_params(model): 60 | return np.sum([int(np.prod(p.shape)) for p in model.parameters()]) 61 | print("clip_Model parameters (total):", num_params(clip_model)) 62 | print("clip_Model parameters (image encoder):", num_params(clip_model.visual)) 63 | print("clip_Model parameters (text encoder):", num_params(clip_model.token_embedding) + num_params(clip_model.transformer)) 64 | print("Input image resolution:", clip_model.visual.input_resolution) 65 | print("Context length:", clip_model.context_length) 66 | print("Vocab size:", clip_model.vocab_size) 67 | img_size = clip_model.visual.input_resolution 68 | 69 | def get_text_feats(in_text, batch_size=64): 70 | text_tokens = clip.tokenize(in_text).cuda() 71 | text_id = 0 72 | text_feats = np.zeros((len(in_text), clip_feat_dim), dtype=np.float32) 73 | while text_id < len(text_tokens): # Batched inference. 74 | batch_size = min(len(in_text) - text_id, batch_size) 75 | text_batch = text_tokens[text_id:text_id+batch_size] 76 | with torch.no_grad(): 77 | batch_feats = clip_model.encode_text(text_batch).float() 78 | batch_feats /= batch_feats.norm(dim=-1, keepdim=True) 79 | batch_feats = np.float32(batch_feats.cpu()) 80 | text_feats[text_id:text_id+batch_size, :] = batch_feats 81 | text_id += batch_size 82 | return text_feats 83 | 84 | def get_img_feats(img): 85 | assert len(img.shape) == 4 86 | img_in = preprocess(img) 87 | with torch.no_grad(): 88 | img_feats = clip_model.encode_image(img_in.cuda()).float() 89 | img_feats /= img_feats.norm(dim=-1, keepdim=True) 90 | img_feats = np.float32(img_feats.cpu()) 91 | return img_feats 92 | 93 | def get_nn_text(raw_texts, text_feats, img_feats, topk): 94 | assert len(img_feats.shape) == 2 and img_feats.shape[0] == args.num_img, f"img_feats shape: {img_feats.shape}" 95 | scores = [] 96 | texts = [] 97 | for img_feat in img_feats: 98 | cur_scores = text_feats @ img_feat[None,...].T 99 | cur_scores = cur_scores.squeeze() 100 | scores.append(cur_scores) 101 | texts += raw_texts 102 | scores = np.concatenate(scores) 103 | high_to_low_ids = np.argsort(scores).squeeze()[::-1] 104 | selected_texts = [] 105 | selected_scores = [] 106 | for id in high_to_low_ids: 107 | if texts[id] in selected_texts: 108 | continue 109 | if len(selected_texts) >= topk: 110 | break 111 | selected_texts.append(texts[id]) 112 | selected_scores.append(scores[id]) 113 | return selected_texts, selected_scores 114 | 115 | 116 | if args.socratic == "1": 117 | place_fn = args.place_pkl_fn 118 | object_fn = args.object_pkl_fn 119 | if os.path.isfile(place_fn): 120 | print("load place texts and feats from ", place_fn) 121 | with open(place_fn, "rb") as f: 122 | place_f = pickle.load(f) 123 | place_texts = place_f['place_texts'] 124 | place_feats = place_f['place_feats'] 125 | print("length of place texts: ", len(place_texts)) 126 | else: 127 | print("embed places365 text") 128 | # Load scene categories from Places365. 129 | place_categories = np.loadtxt(args.place_txt_fn, dtype=str) 130 | place_texts = [] 131 | for place in place_categories: 132 | try: 133 | place = place.split('/')[2:] 134 | if len(place) > 1: 135 | place = place[1] + ' ' + place[0] 136 | else: 137 | place = place[0] 138 | place = place.replace('_', ' ') 139 | place_texts.append(place) 140 | except: 141 | pass 142 | place_feats = get_text_feats([f'Photo of a {p}.' for p in place_texts]) 143 | print("length of place texts: ", len(place_texts)) 144 | with open(place_fn, "wb") as f: 145 | pickle.dump({"place_texts": place_texts, "place_feats": place_feats}, f) 146 | 147 | # Load object categories from Tencent ML Images. 148 | if os.path.isfile(object_fn): 149 | print("load tencent ml image texts and feats from ", object_fn) 150 | with open(object_fn, "rb") as f: 151 | object_f = pickle.load(f) 152 | object_texts = object_f['object_texts'] 153 | object_feats = object_f['object_feats'] 154 | print("num of object texts: ", len(object_texts)) 155 | else: 156 | print("embed tencent ml image text") 157 | with open(args.object_txt_fn) as fid: 158 | object_categories = fid.readlines() 159 | object_texts = [] 160 | pf = ProfanityFilter() 161 | for object_text in object_categories[1:]: 162 | object_text = object_text.strip() 163 | object_text = object_text.split('\t')[3] 164 | safe_list = '' 165 | for variant in object_text.split(','): 166 | text = variant.strip() 167 | if pf.is_clean(text): 168 | safe_list += f'{text}, ' 169 | safe_list = safe_list[:-2] 170 | if len(safe_list) > 0: 171 | object_texts.append(safe_list) 172 | 173 | object_texts = [o for o in list(set(object_texts)) if o not in place_texts] # Remove redundant categories. 174 | object_feats = get_text_feats([f'Photo of a {o}.' for o in object_texts]) 175 | print("length of object texts: ", len(object_texts)) 176 | with open(object_fn, "wb") as f: 177 | pickle.dump({"object_texts": object_texts, "object_feats": object_feats}, f) 178 | ###################### CLIP textual feature embedding ###################### 179 | ###################### CLIP textual feature embedding ###################### 180 | ###################### CLIP textual feature embedding ###################### 181 | 182 | 183 | ################################### 184 | 185 | loader = get_dataloader(args) 186 | 187 | model = whisper.load_model(args.model) 188 | model.eval() 189 | model.cuda() 190 | 191 | refs = [] 192 | preds = [] 193 | all_prompts = [] 194 | for i, b in enumerate(tqdm(loader)): 195 | input_mels = b["input_mels"].half().cuda() 196 | raw_texts = b["raw_text"] 197 | imgs = b['imgs'] 198 | with torch.no_grad(): 199 | for input_mel, raw_text, img in zip(input_mels, raw_texts, imgs): 200 | if args.socratic == "1": 201 | img = img.cuda() 202 | img_feats = get_img_feats(img) 203 | place_list = '' 204 | if args.place_topk > 0: 205 | sorted_places, places_scores = get_nn_text(place_texts, place_feats, img_feats, args.place_topk) 206 | sorted_places = sorted_places[::-1] 207 | 208 | for i in range(len(sorted_places)): 209 | place_list += f'{sorted_places[i]}, ' 210 | object_list = '' 211 | if args.obj_topk > 0: 212 | sorted_obj_texts, obj_scores = get_nn_text(object_texts, object_feats, img_feats, args.obj_topk) 213 | sorted_obj_texts = sorted_obj_texts[::-1] 214 | 215 | for i in range(len(sorted_obj_texts)): 216 | object_list += f'{sorted_obj_texts[i].split(",")[0]}, ' 217 | object_list = object_list[:-2] + ". " 218 | prompt = place_list + object_list 219 | if len(prompt) == 0: 220 | prompt = None 221 | else: 222 | prompt = None 223 | all_prompts.append(prompt) 224 | 225 | options = whisper.DecodingOptions(task=args.task, language=args.language, without_timestamps=True, beam_size=args.beam_size, block_ngrams=args.block_ngrams, prompt=prompt) 226 | results = whisper.decode(model, input_mel, options) 227 | preds.append(results.text) 228 | refs.append(raw_text) 229 | 230 | 231 | 232 | inference_metrics, (wer_list, processed_preds, processed_refs) = calc_metrics()(refs, preds) 233 | print("results:", inference_metrics) 234 | print("results:", inference_metrics) 235 | if args.topk > 0: 236 | import numpy as np 237 | inds = np.argsort(wer_list)[::-1] 238 | for ind in inds[:args.topk]: 239 | print("-"*10) 240 | print("wer/mer: ", wer_list[ind]) 241 | print("ref: ", processed_refs[ind]) 242 | print("pred: ", processed_preds[ind]) 243 | print("prompt: ", all_prompts[ind]) 244 | else: 245 | for j, (k, v) in enumerate(zip(processed_refs, processed_preds)): 246 | if j % 100 == 0: 247 | print("-"*10) 248 | print("ref: ", k) 249 | print("pred: ", v) 250 | 251 | print("results:", inference_metrics) 252 | print("results:", inference_metrics) 253 | 254 | 255 | -------------------------------------------------------------------------------- /whisper/transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from typing import List, Optional, Tuple, Union, TYPE_CHECKING 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | 10 | from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram 11 | from .decoding import DecodingOptions, DecodingResult 12 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 13 | from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt 14 | 15 | if TYPE_CHECKING: 16 | from .model import Whisper 17 | 18 | 19 | def transcribe( 20 | model: "Whisper", 21 | audio: Union[str, np.ndarray, torch.Tensor], 22 | *, 23 | verbose: Optional[bool] = None, 24 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), 25 | compression_ratio_threshold: Optional[float] = 2.4, 26 | logprob_threshold: Optional[float] = -1.0, 27 | no_speech_threshold: Optional[float] = 0.6, 28 | condition_on_previous_text: bool = True, 29 | **decode_options, 30 | ): 31 | """ 32 | Transcribe an audio file using Whisper 33 | 34 | Parameters 35 | ---------- 36 | model: Whisper 37 | The Whisper model instance 38 | 39 | audio: Union[str, np.ndarray, torch.Tensor] 40 | The path to the audio file to open, or the audio waveform 41 | 42 | verbose: bool 43 | Whether to display the text being decoded to the console. If True, displays all the details, 44 | If False, displays minimal details. If None, does not display anything 45 | 46 | temperature: Union[float, Tuple[float, ...]] 47 | Temperature for sampling. It can be a tuple of temperatures, which will be successfully used 48 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 49 | 50 | compression_ratio_threshold: float 51 | If the gzip compression ratio is above this value, treat as failed 52 | 53 | logprob_threshold: float 54 | If the average log probability over sampled tokens is below this value, treat as failed 55 | 56 | no_speech_threshold: float 57 | If the no_speech probability is higher than this value AND the average log probability 58 | over sampled tokens is below `logprob_threshold`, consider the segment as silent 59 | 60 | condition_on_previous_text: bool 61 | if True, the previous output of the model is provided as a prompt for the next window; 62 | disabling may make the text inconsistent across windows, but the model becomes less prone to 63 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 64 | 65 | decode_options: dict 66 | Keyword arguments to construct `DecodingOptions` instances 67 | 68 | Returns 69 | ------- 70 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and 71 | the spoken language ("language"), which is detected when `decode_options["language"]` is None. 72 | """ 73 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 74 | if model.device == torch.device("cpu"): 75 | if torch.cuda.is_available(): 76 | warnings.warn("Performing inference on CPU when CUDA is available") 77 | if dtype == torch.float16: 78 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") 79 | dtype = torch.float32 80 | 81 | if dtype == torch.float32: 82 | decode_options["fp16"] = False 83 | 84 | mel = log_mel_spectrogram(audio) 85 | 86 | if decode_options.get("language", None) is None: 87 | if not model.is_multilingual: 88 | decode_options["language"] = "en" 89 | else: 90 | if verbose: 91 | print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") 92 | segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 93 | _, probs = model.detect_language(segment) 94 | decode_options["language"] = max(probs, key=probs.get) 95 | if verbose is not None: 96 | print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") 97 | 98 | language = decode_options["language"] 99 | task = decode_options.get("task", "transcribe") 100 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) 101 | 102 | def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: 103 | temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature 104 | decode_result = None 105 | 106 | for t in temperatures: 107 | kwargs = {**decode_options} 108 | if t > 0: 109 | # disable beam_size and patience when t > 0 110 | kwargs.pop("beam_size", None) 111 | kwargs.pop("patience", None) 112 | else: 113 | # disable best_of when t == 0 114 | kwargs.pop("best_of", None) 115 | 116 | options = DecodingOptions(**kwargs, temperature=t) 117 | decode_result = model.decode(segment, options) 118 | 119 | needs_fallback = False 120 | if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: 121 | needs_fallback = True # too repetitive 122 | if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: 123 | needs_fallback = True # average log probability is too low 124 | 125 | if not needs_fallback: 126 | break 127 | 128 | return decode_result 129 | 130 | seek = 0 131 | input_stride = exact_div( 132 | N_FRAMES, model.dims.n_audio_ctx 133 | ) # mel frames per output token: 2 134 | time_precision = ( 135 | input_stride * HOP_LENGTH / SAMPLE_RATE 136 | ) # time per output token: 0.02 (seconds) 137 | all_tokens = [] 138 | all_segments = [] 139 | prompt_reset_since = 0 140 | 141 | initial_prompt = decode_options.pop("initial_prompt", None) or [] 142 | if initial_prompt: 143 | initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) 144 | all_tokens.extend(initial_prompt) 145 | 146 | def add_segment( 147 | *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult 148 | ): 149 | text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) 150 | if len(text.strip()) == 0: # skip empty text output 151 | return 152 | 153 | all_segments.append( 154 | { 155 | "id": len(all_segments), 156 | "seek": seek, 157 | "start": start, 158 | "end": end, 159 | "text": text, 160 | "tokens": result.tokens, 161 | "temperature": result.temperature, 162 | "avg_logprob": result.avg_logprob, 163 | "compression_ratio": result.compression_ratio, 164 | "no_speech_prob": result.no_speech_prob, 165 | } 166 | ) 167 | if verbose: 168 | print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") 169 | 170 | # show the progress bar when verbose is False (otherwise the transcribed text will be printed) 171 | num_frames = mel.shape[-1] 172 | previous_seek_value = seek 173 | 174 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: 175 | while seek < num_frames: 176 | timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 177 | segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) 178 | segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE 179 | 180 | decode_options["prompt"] = all_tokens[prompt_reset_since:] 181 | result: DecodingResult = decode_with_fallback(segment) 182 | tokens = torch.tensor(result.tokens) 183 | 184 | if no_speech_threshold is not None: 185 | # no voice activity check 186 | should_skip = result.no_speech_prob > no_speech_threshold 187 | if logprob_threshold is not None and result.avg_logprob > logprob_threshold: 188 | # don't skip if the logprob is high enough, despite the no_speech_prob 189 | should_skip = False 190 | 191 | if should_skip: 192 | seek += segment.shape[-1] # fast-forward to the next segment boundary 193 | continue 194 | 195 | timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 196 | consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) 197 | if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens 198 | last_slice = 0 199 | for current_slice in consecutive: 200 | sliced_tokens = tokens[last_slice:current_slice] 201 | start_timestamp_position = ( 202 | sliced_tokens[0].item() - tokenizer.timestamp_begin 203 | ) 204 | end_timestamp_position = ( 205 | sliced_tokens[-1].item() - tokenizer.timestamp_begin 206 | ) 207 | add_segment( 208 | start=timestamp_offset + start_timestamp_position * time_precision, 209 | end=timestamp_offset + end_timestamp_position * time_precision, 210 | text_tokens=sliced_tokens[1:-1], 211 | result=result, 212 | ) 213 | last_slice = current_slice 214 | last_timestamp_position = ( 215 | tokens[last_slice - 1].item() - tokenizer.timestamp_begin 216 | ) 217 | seek += last_timestamp_position * input_stride 218 | all_tokens.extend(tokens[: last_slice + 1].tolist()) 219 | else: 220 | duration = segment_duration 221 | timestamps = tokens[timestamp_tokens.nonzero().flatten()] 222 | if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin: 223 | # no consecutive timestamps but it has a timestamp; use the last one. 224 | # single timestamp at the end means no speech after the last timestamp. 225 | last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin 226 | duration = last_timestamp_position * time_precision 227 | 228 | add_segment( 229 | start=timestamp_offset, 230 | end=timestamp_offset + duration, 231 | text_tokens=tokens, 232 | result=result, 233 | ) 234 | 235 | seek += segment.shape[-1] 236 | all_tokens.extend(tokens.tolist()) 237 | 238 | if not condition_on_previous_text or result.temperature > 0.5: 239 | # do not feed the prompt tokens if a high temperature was used 240 | prompt_reset_since = len(all_tokens) 241 | 242 | # update progress bar 243 | pbar.update(min(num_frames, seek) - previous_seek_value) 244 | previous_seek_value = seek 245 | 246 | return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) 247 | 248 | 249 | def cli(): 250 | from . import available_models 251 | 252 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 253 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 254 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") 255 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 256 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 257 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 258 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 259 | 260 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 261 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 262 | 263 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 264 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 265 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 266 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 267 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") 268 | 269 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 270 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 271 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 272 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") 273 | 274 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 275 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 276 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 277 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 278 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 279 | 280 | args = parser.parse_args().__dict__ 281 | model_name: str = args.pop("model") 282 | model_dir: str = args.pop("model_dir") 283 | output_dir: str = args.pop("output_dir") 284 | device: str = args.pop("device") 285 | os.makedirs(output_dir, exist_ok=True) 286 | 287 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}: 288 | if args["language"] is not None: 289 | warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") 290 | args["language"] = "en" 291 | 292 | temperature = args.pop("temperature") 293 | temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") 294 | if temperature_increment_on_fallback is not None: 295 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) 296 | else: 297 | temperature = [temperature] 298 | 299 | threads = args.pop("threads") 300 | if threads > 0: 301 | torch.set_num_threads(threads) 302 | 303 | from . import load_model 304 | model = load_model(model_name, device=device, download_root=model_dir) 305 | 306 | for audio_path in args.pop("audio"): 307 | result = transcribe(model, audio_path, temperature=temperature, **args) 308 | 309 | audio_basename = os.path.basename(audio_path) 310 | 311 | # save TXT 312 | with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: 313 | write_txt(result["segments"], file=txt) 314 | 315 | # save VTT 316 | with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: 317 | write_vtt(result["segments"], file=vtt) 318 | 319 | # save SRT 320 | with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: 321 | write_srt(result["segments"], file=srt) 322 | 323 | 324 | if __name__ == '__main__': 325 | cli() 326 | -------------------------------------------------------------------------------- /whisper/normalizers/english.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from fractions import Fraction 5 | from typing import Iterator, List, Match, Optional, Union 6 | 7 | from more_itertools import windowed 8 | 9 | from .basic import remove_symbols_and_diacritics 10 | 11 | 12 | class EnglishNumberNormalizer: 13 | """ 14 | Convert any spelled-out numbers into arabic numbers, while handling: 15 | 16 | - remove any commas 17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. 18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` 19 | - spell out `one` and `ones` 20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101` 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.zeros = {"o", "oh", "zero"} 27 | self.ones = { 28 | name: i 29 | for i, name in enumerate( 30 | [ 31 | "one", 32 | "two", 33 | "three", 34 | "four", 35 | "five", 36 | "six", 37 | "seven", 38 | "eight", 39 | "nine", 40 | "ten", 41 | "eleven", 42 | "twelve", 43 | "thirteen", 44 | "fourteen", 45 | "fifteen", 46 | "sixteen", 47 | "seventeen", 48 | "eighteen", 49 | "nineteen", 50 | ], 51 | start=1, 52 | ) 53 | } 54 | self.ones_plural = { 55 | "sixes" if name == "six" else name + "s": (value, "s") 56 | for name, value in self.ones.items() 57 | } 58 | self.ones_ordinal = { 59 | "zeroth": (0, "th"), 60 | "first": (1, "st"), 61 | "second": (2, "nd"), 62 | "third": (3, "rd"), 63 | "fifth": (5, "th"), 64 | "twelfth": (12, "th"), 65 | **{ 66 | name + ("h" if name.endswith("t") else "th"): (value, "th") 67 | for name, value in self.ones.items() 68 | if value > 3 and value != 5 and value != 12 69 | }, 70 | } 71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} 72 | 73 | self.tens = { 74 | "twenty": 20, 75 | "thirty": 30, 76 | "forty": 40, 77 | "fifty": 50, 78 | "sixty": 60, 79 | "seventy": 70, 80 | "eighty": 80, 81 | "ninety": 90, 82 | } 83 | self.tens_plural = { 84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() 85 | } 86 | self.tens_ordinal = { 87 | name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items() 88 | } 89 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} 90 | 91 | self.multipliers = { 92 | "hundred": 100, 93 | "thousand": 1_000, 94 | "million": 1_000_000, 95 | "billion": 1_000_000_000, 96 | "trillion": 1_000_000_000_000, 97 | "quadrillion": 1_000_000_000_000_000, 98 | "quintillion": 1_000_000_000_000_000_000, 99 | "sextillion": 1_000_000_000_000_000_000_000, 100 | "septillion": 1_000_000_000_000_000_000_000_000, 101 | "octillion": 1_000_000_000_000_000_000_000_000_000, 102 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000, 103 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, 104 | } 105 | self.multipliers_plural = { 106 | name + "s": (value, "s") for name, value in self.multipliers.items() 107 | } 108 | self.multipliers_ordinal = { 109 | name + "th": (value, "th") for name, value in self.multipliers.items() 110 | } 111 | self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal} 112 | self.decimals = {*self.ones, *self.tens, *self.zeros} 113 | 114 | self.preceding_prefixers = { 115 | "minus": "-", 116 | "negative": "-", 117 | "plus": "+", 118 | "positive": "+", 119 | } 120 | self.following_prefixers = { 121 | "pound": "£", 122 | "pounds": "£", 123 | "euro": "€", 124 | "euros": "€", 125 | "dollar": "$", 126 | "dollars": "$", 127 | "cent": "¢", 128 | "cents": "¢", 129 | } 130 | self.prefixes = set( 131 | list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()) 132 | ) 133 | self.suffixers = { 134 | "per": {"cent": "%"}, 135 | "percent": "%", 136 | } 137 | self.specials = {"and", "double", "triple", "point"} 138 | 139 | self.words = set( 140 | [ 141 | key 142 | for mapping in [ 143 | self.zeros, 144 | self.ones, 145 | self.ones_suffixed, 146 | self.tens, 147 | self.tens_suffixed, 148 | self.multipliers, 149 | self.multipliers_suffixed, 150 | self.preceding_prefixers, 151 | self.following_prefixers, 152 | self.suffixers, 153 | self.specials, 154 | ] 155 | for key in mapping 156 | ] 157 | ) 158 | self.literal_words = {"one", "ones"} 159 | 160 | def process_words(self, words: List[str]) -> Iterator[str]: 161 | prefix: Optional[str] = None 162 | value: Optional[Union[str, int]] = None 163 | skip = False 164 | 165 | def to_fraction(s: str): 166 | try: 167 | return Fraction(s) 168 | except ValueError: 169 | return None 170 | 171 | def output(result: Union[str, int]): 172 | nonlocal prefix, value 173 | result = str(result) 174 | if prefix is not None: 175 | result = prefix + result 176 | value = None 177 | prefix = None 178 | return result 179 | 180 | if len(words) == 0: 181 | return 182 | 183 | for prev, current, next in windowed([None] + words + [None], 3): 184 | if skip: 185 | skip = False 186 | continue 187 | 188 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) 189 | has_prefix = current[0] in self.prefixes 190 | current_without_prefix = current[1:] if has_prefix else current 191 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix): 192 | # arabic numbers (potentially with signs and fractions) 193 | f = to_fraction(current_without_prefix) 194 | assert f is not None 195 | if value is not None: 196 | if isinstance(value, str) and value.endswith("."): 197 | # concatenate decimals / ip address components 198 | value = str(value) + str(current) 199 | continue 200 | else: 201 | yield output(value) 202 | 203 | prefix = current[0] if has_prefix else prefix 204 | if f.denominator == 1: 205 | value = f.numerator # store integers as int 206 | else: 207 | value = current_without_prefix 208 | elif current not in self.words: 209 | # non-numeric words 210 | if value is not None: 211 | yield output(value) 212 | yield output(current) 213 | elif current in self.zeros: 214 | value = str(value or "") + "0" 215 | elif current in self.ones: 216 | ones = self.ones[current] 217 | 218 | if value is None: 219 | value = ones 220 | elif isinstance(value, str) or prev in self.ones: 221 | if prev in self.tens and ones < 10: # replace the last zero with the digit 222 | assert value[-1] == "0" 223 | value = value[:-1] + str(ones) 224 | else: 225 | value = str(value) + str(ones) 226 | elif ones < 10: 227 | if value % 10 == 0: 228 | value += ones 229 | else: 230 | value = str(value) + str(ones) 231 | else: # eleven to nineteen 232 | if value % 100 == 0: 233 | value += ones 234 | else: 235 | value = str(value) + str(ones) 236 | elif current in self.ones_suffixed: 237 | # ordinal or cardinal; yield the number right away 238 | ones, suffix = self.ones_suffixed[current] 239 | if value is None: 240 | yield output(str(ones) + suffix) 241 | elif isinstance(value, str) or prev in self.ones: 242 | if prev in self.tens and ones < 10: 243 | assert value[-1] == "0" 244 | yield output(value[:-1] + str(ones) + suffix) 245 | else: 246 | yield output(str(value) + str(ones) + suffix) 247 | elif ones < 10: 248 | if value % 10 == 0: 249 | yield output(str(value + ones) + suffix) 250 | else: 251 | yield output(str(value) + str(ones) + suffix) 252 | else: # eleven to nineteen 253 | if value % 100 == 0: 254 | yield output(str(value + ones) + suffix) 255 | else: 256 | yield output(str(value) + str(ones) + suffix) 257 | value = None 258 | elif current in self.tens: 259 | tens = self.tens[current] 260 | if value is None: 261 | value = tens 262 | elif isinstance(value, str): 263 | value = str(value) + str(tens) 264 | else: 265 | if value % 100 == 0: 266 | value += tens 267 | else: 268 | value = str(value) + str(tens) 269 | elif current in self.tens_suffixed: 270 | # ordinal or cardinal; yield the number right away 271 | tens, suffix = self.tens_suffixed[current] 272 | if value is None: 273 | yield output(str(tens) + suffix) 274 | elif isinstance(value, str): 275 | yield output(str(value) + str(tens) + suffix) 276 | else: 277 | if value % 100 == 0: 278 | yield output(str(value + tens) + suffix) 279 | else: 280 | yield output(str(value) + str(tens) + suffix) 281 | elif current in self.multipliers: 282 | multiplier = self.multipliers[current] 283 | if value is None: 284 | value = multiplier 285 | elif isinstance(value, str) or value == 0: 286 | f = to_fraction(value) 287 | p = f * multiplier if f is not None else None 288 | if f is not None and p.denominator == 1: 289 | value = p.numerator 290 | else: 291 | yield output(value) 292 | value = multiplier 293 | else: 294 | before = value // 1000 * 1000 295 | residual = value % 1000 296 | value = before + residual * multiplier 297 | elif current in self.multipliers_suffixed: 298 | multiplier, suffix = self.multipliers_suffixed[current] 299 | if value is None: 300 | yield output(str(multiplier) + suffix) 301 | elif isinstance(value, str): 302 | f = to_fraction(value) 303 | p = f * multiplier if f is not None else None 304 | if f is not None and p.denominator == 1: 305 | yield output(str(p.numerator) + suffix) 306 | else: 307 | yield output(value) 308 | yield output(str(multiplier) + suffix) 309 | else: # int 310 | before = value // 1000 * 1000 311 | residual = value % 1000 312 | value = before + residual * multiplier 313 | yield output(str(value) + suffix) 314 | value = None 315 | elif current in self.preceding_prefixers: 316 | # apply prefix (positive, minus, etc.) if it precedes a number 317 | if value is not None: 318 | yield output(value) 319 | 320 | if next in self.words or next_is_numeric: 321 | prefix = self.preceding_prefixers[current] 322 | else: 323 | yield output(current) 324 | elif current in self.following_prefixers: 325 | # apply prefix (dollars, cents, etc.) only after a number 326 | if value is not None: 327 | prefix = self.following_prefixers[current] 328 | yield output(value) 329 | else: 330 | yield output(current) 331 | elif current in self.suffixers: 332 | # apply suffix symbols (percent -> '%') 333 | if value is not None: 334 | suffix = self.suffixers[current] 335 | if isinstance(suffix, dict): 336 | if next in suffix: 337 | yield output(str(value) + suffix[next]) 338 | skip = True 339 | else: 340 | yield output(value) 341 | yield output(current) 342 | else: 343 | yield output(str(value) + suffix) 344 | else: 345 | yield output(current) 346 | elif current in self.specials: 347 | if next not in self.words and not next_is_numeric: 348 | # apply special handling only if the next word can be numeric 349 | if value is not None: 350 | yield output(value) 351 | yield output(current) 352 | elif current == "and": 353 | # ignore "and" after hundreds, thousands, etc. 354 | if prev not in self.multipliers: 355 | if value is not None: 356 | yield output(value) 357 | yield output(current) 358 | elif current == "double" or current == "triple": 359 | if next in self.ones or next in self.zeros: 360 | repeats = 2 if current == "double" else 3 361 | ones = self.ones.get(next, 0) 362 | value = str(value or "") + str(ones) * repeats 363 | skip = True 364 | else: 365 | if value is not None: 366 | yield output(value) 367 | yield output(current) 368 | elif current == "point": 369 | if next in self.decimals or next_is_numeric: 370 | value = str(value or "") + "." 371 | else: 372 | # should all have been covered at this point 373 | raise ValueError(f"Unexpected token: {current}") 374 | else: 375 | # all should have been covered at this point 376 | raise ValueError(f"Unexpected token: {current}") 377 | 378 | if value is not None: 379 | yield output(value) 380 | 381 | def preprocess(self, s: str): 382 | # replace " and a half" with " point five" 383 | results = [] 384 | 385 | segments = re.split(r"\band\s+a\s+half\b", s) 386 | for i, segment in enumerate(segments): 387 | if len(segment.strip()) == 0: 388 | continue 389 | if i == len(segments) - 1: 390 | results.append(segment) 391 | else: 392 | results.append(segment) 393 | last_word = segment.rsplit(maxsplit=2)[-1] 394 | if last_word in self.decimals or last_word in self.multipliers: 395 | results.append("point five") 396 | else: 397 | results.append("and a half") 398 | 399 | s = " ".join(results) 400 | 401 | # put a space at number/letter boundary 402 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) 403 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) 404 | 405 | # but remove spaces which could be a suffix 406 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) 407 | 408 | return s 409 | 410 | def postprocess(self, s: str): 411 | def combine_cents(m: Match): 412 | try: 413 | currency = m.group(1) 414 | integer = m.group(2) 415 | cents = int(m.group(3)) 416 | return f"{currency}{integer}.{cents:02d}" 417 | except ValueError: 418 | return m.string 419 | 420 | def extract_cents(m: Match): 421 | try: 422 | return f"¢{int(m.group(1))}" 423 | except ValueError: 424 | return m.string 425 | 426 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07" 427 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) 428 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) 429 | 430 | # write "one(s)" instead of "1(s)", just for the readability 431 | s = re.sub(r"\b1(s?)\b", r"one\1", s) 432 | 433 | return s 434 | 435 | def __call__(self, s: str): 436 | s = self.preprocess(s) 437 | s = " ".join(word for word in self.process_words(s.split()) if word is not None) 438 | s = self.postprocess(s) 439 | 440 | return s 441 | 442 | 443 | class EnglishSpellingNormalizer: 444 | """ 445 | Applies British-American spelling mappings as listed in [1]. 446 | 447 | [1] https://www.tysto.com/uk-us-spelling-list.html 448 | """ 449 | 450 | def __init__(self): 451 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 452 | self.mapping = json.load(open(mapping_path)) 453 | 454 | def __call__(self, s: str): 455 | return " ".join(self.mapping.get(word, word) for word in s.split()) 456 | 457 | 458 | class EnglishTextNormalizer: 459 | def __init__(self): 460 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" 461 | self.replacers = { 462 | # common contractions 463 | r"\bwon't\b": "will not", 464 | r"\bcan't\b": "can not", 465 | r"\blet's\b": "let us", 466 | r"\bain't\b": "aint", 467 | r"\by'all\b": "you all", 468 | r"\bwanna\b": "want to", 469 | r"\bgotta\b": "got to", 470 | r"\bgonna\b": "going to", 471 | r"\bi'ma\b": "i am going to", 472 | r"\bimma\b": "i am going to", 473 | r"\bwoulda\b": "would have", 474 | r"\bcoulda\b": "could have", 475 | r"\bshoulda\b": "should have", 476 | r"\bma'am\b": "madam", 477 | # contractions in titles/prefixes 478 | r"\bmr\b": "mister ", 479 | r"\bmrs\b": "missus ", 480 | r"\bst\b": "saint ", 481 | r"\bdr\b": "doctor ", 482 | r"\bprof\b": "professor ", 483 | r"\bcapt\b": "captain ", 484 | r"\bgov\b": "governor ", 485 | r"\bald\b": "alderman ", 486 | r"\bgen\b": "general ", 487 | r"\bsen\b": "senator ", 488 | r"\brep\b": "representative ", 489 | r"\bpres\b": "president ", 490 | r"\brev\b": "reverend ", 491 | r"\bhon\b": "honorable ", 492 | r"\basst\b": "assistant ", 493 | r"\bassoc\b": "associate ", 494 | r"\blt\b": "lieutenant ", 495 | r"\bcol\b": "colonel ", 496 | r"\bjr\b": "junior ", 497 | r"\bsr\b": "senior ", 498 | r"\besq\b": "esquire ", 499 | # prefect tenses, ideally it should be any past participles, but it's harder.. 500 | r"'d been\b": " had been", 501 | r"'s been\b": " has been", 502 | r"'d gone\b": " had gone", 503 | r"'s gone\b": " has gone", 504 | r"'d done\b": " had done", # "'s done" is ambiguous 505 | r"'s got\b": " has got", 506 | # general contractions 507 | r"n't\b": " not", 508 | r"'re\b": " are", 509 | r"'s\b": " is", 510 | r"'d\b": " would", 511 | r"'ll\b": " will", 512 | r"'t\b": " not", 513 | r"'ve\b": " have", 514 | r"'m\b": " am", 515 | } 516 | self.standardize_numbers = EnglishNumberNormalizer() 517 | self.standardize_spellings = EnglishSpellingNormalizer() 518 | 519 | def __call__(self, s: str): 520 | s = s.lower() 521 | 522 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 523 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 524 | s = re.sub(self.ignore_patterns, "", s) 525 | s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe 526 | 527 | for pattern, replacement in self.replacers.items(): 528 | s = re.sub(pattern, replacement, s) 529 | 530 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits 531 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers 532 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics 533 | 534 | s = self.standardize_numbers(s) 535 | s = self.standardize_spellings(s) 536 | 537 | # now remove prefix/suffix symbols that are not preceded/followed by numbers 538 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) 539 | s = re.sub(r"([^0-9])%", r"\1 ", s) 540 | 541 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 542 | 543 | return s 544 | -------------------------------------------------------------------------------- /whisper/decoding.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | from torch.distributions import Categorical 9 | 10 | from .audio import CHUNK_LENGTH 11 | from .tokenizer import Tokenizer, get_tokenizer 12 | from .utils import compression_ratio 13 | 14 | if TYPE_CHECKING: 15 | from .model import Whisper 16 | 17 | 18 | @torch.no_grad() 19 | def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]: 20 | """ 21 | Detect the spoken language in the audio, and return them as list of strings, along with the ids 22 | of the most probable language tokens and the probability distribution over all language tokens. 23 | This is performed outside the main decode loop in order to not interfere with kv-caching. 24 | 25 | Returns 26 | ------- 27 | language_tokens : Tensor, shape = (n_audio,) 28 | ids of the most probable language tokens, which appears after the startoftranscript token. 29 | language_probs : List[Dict[str, float]], length = n_audio 30 | list of dictionaries containing the probability distribution over all languages. 31 | """ 32 | if tokenizer is None: 33 | tokenizer = get_tokenizer(model.is_multilingual) 34 | if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: 35 | raise ValueError(f"This model doesn't have language tokens so it can't perform lang id") 36 | 37 | single = mel.ndim == 2 38 | if single: 39 | mel = mel.unsqueeze(0) 40 | 41 | # skip encoder forward pass if already-encoded audio features were given 42 | # if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): 43 | if mel.shape[-1] != model.dims.n_audio_state: 44 | mel = model.encoder(mel) 45 | 46 | # forward pass using a single token, startoftranscript 47 | n_audio = mel.shape[0] 48 | x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] 49 | logits = model.logits(x, mel)[:, 0] 50 | 51 | # collect detected languages; suppress all non-language tokens 52 | mask = torch.ones(logits.shape[-1], dtype=torch.bool) 53 | mask[list(tokenizer.all_language_tokens)] = False 54 | logits[:, mask] = -np.inf 55 | language_tokens = logits.argmax(dim=-1) 56 | language_token_probs = logits.softmax(dim=-1).cpu() 57 | language_probs = [ 58 | { 59 | c: language_token_probs[i, j].item() 60 | for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) 61 | } 62 | for i in range(n_audio) 63 | ] 64 | 65 | if single: 66 | language_tokens = language_tokens[0] 67 | language_probs = language_probs[0] 68 | 69 | return language_tokens, language_probs 70 | 71 | 72 | @dataclass(frozen=True) 73 | class DecodingOptions: 74 | task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" 75 | language: Optional[str] = None # language that the audio is in; uses detected language if None 76 | 77 | # sampling-related options 78 | temperature: float = 0.0 79 | sample_len: Optional[int] = None # maximum number of tokens to sample 80 | best_of: Optional[int] = None # number of independent samples to collect, when t > 0 81 | beam_size: Optional[int] = None # number of beams in beam search, when t == 0 82 | patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) 83 | 84 | # options for ranking generations (either beams or best-of-N samples) 85 | length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm 86 | 87 | # prompt, prefix, and token suppression 88 | prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context 89 | prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context 90 | suppress_blank: bool = True # this will suppress blank outputs 91 | 92 | # list of tokens ids (or comma-separated token ids) to suppress 93 | # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` 94 | suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" 95 | 96 | # timestamp sampling options 97 | without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only 98 | max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this 99 | 100 | # implementation details 101 | fp16: bool = True # use fp16 for most of the calculation 102 | block_ngrams: list = field(default_factory=list) 103 | 104 | # customized 105 | concat_lang_token: Optional[str] = "0" # use two language tokens if 1 (should be in the code-switching case) 106 | 107 | logit_mask: Optional[Tensor] = None # if None, being added to the model output logit to contrain the output vocabulary 108 | 109 | 110 | @dataclass(frozen=True) 111 | class DecodingResult: 112 | audio_features: Tensor 113 | language: str 114 | language_probs: Optional[Dict[str, float]] = None 115 | tokens: List[int] = field(default_factory=list) 116 | text: str = "" 117 | avg_logprob: float = np.nan 118 | no_speech_prob: float = np.nan 119 | temperature: float = np.nan 120 | compression_ratio: float = np.nan 121 | 122 | 123 | class Inference: 124 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: 125 | """Perform a forward pass on the decoder and return per-token logits""" 126 | raise NotImplementedError 127 | 128 | def rearrange_kv_cache(self, source_indices) -> None: 129 | """Update the key-value cache according to the updated beams""" 130 | raise NotImplementedError 131 | 132 | def cleanup_caching(self) -> None: 133 | """Clean up any resources or hooks after decoding is finished""" 134 | pass 135 | 136 | 137 | class PyTorchInference(Inference): 138 | def __init__(self, model: "Whisper", initial_token_length: int, logit_mask=None): 139 | self.model: "Whisper" = model 140 | self.initial_token_length = initial_token_length 141 | self.kv_cache = {} 142 | self.hooks = [] 143 | self.logit_mask = None if logit_mask == None else logit_mask.to(model.device) 144 | 145 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: 146 | if not self.kv_cache: 147 | self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() 148 | 149 | if tokens.shape[-1] > self.initial_token_length: 150 | # only need to use the last token except in the first forward pass 151 | tokens = tokens[:, -1:] 152 | if self.logit_mask != None: 153 | orig_logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) 154 | logit_mask = self.logit_mask.to(orig_logit) 155 | return orig_logit + logit_mask 156 | return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) 157 | 158 | def cleanup_caching(self): 159 | for hook in self.hooks: 160 | hook.remove() 161 | 162 | self.kv_cache = {} 163 | self.hooks = [] 164 | 165 | def rearrange_kv_cache(self, source_indices): 166 | for module, tensor in self.kv_cache.items(): 167 | # update the key/value cache to contain the selected sequences 168 | self.kv_cache[module] = tensor[source_indices].detach() 169 | 170 | 171 | class SequenceRanker: 172 | def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: 173 | """ 174 | Given a list of groups of samples and their cumulative log probabilities, 175 | return the indices of the samples in each group to select as the final result 176 | """ 177 | raise NotImplementedError 178 | 179 | 180 | class MaximumLikelihoodRanker(SequenceRanker): 181 | """ 182 | Select the sample with the highest log probabilities, penalized using either 183 | a simple length normalization or Google NMT paper's length penalty 184 | """ 185 | 186 | def __init__(self, length_penalty: Optional[float]): 187 | self.length_penalty = length_penalty 188 | 189 | def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]): 190 | def scores(logprobs, lengths): 191 | result = [] 192 | for logprob, length in zip(logprobs, lengths): 193 | if self.length_penalty is None: 194 | penalty = length 195 | else: 196 | # from the Google NMT paper 197 | penalty = ((5 + length) / 6) ** self.length_penalty 198 | result.append(logprob / penalty) 199 | return result 200 | 201 | # get the sequence with the highest score 202 | lengths = [[len(t) for t in s] for s in tokens] 203 | return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] 204 | 205 | 206 | class TokenDecoder: 207 | def reset(self): 208 | """Initialize any stateful variables for decoding a new sequence""" 209 | 210 | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: 211 | """Specify how to select the next token, based on the current trace and logits 212 | 213 | Parameters 214 | ---------- 215 | tokens : Tensor, shape = (n_batch, current_sequence_length) 216 | all tokens in the context so far, including the prefix and sot_sequence tokens 217 | 218 | logits : Tensor, shape = (n_batch, vocab_size) 219 | per-token logits of the probability distribution at the current step 220 | 221 | sum_logprobs : Tensor, shape = (n_batch) 222 | cumulative log probabilities for each sequence 223 | 224 | Returns 225 | ------- 226 | tokens : Tensor, shape = (n_batch, current_sequence_length + 1) 227 | the tokens, appended with the selected next token 228 | 229 | completed : bool 230 | True if all sequences has reached the end of text 231 | 232 | """ 233 | raise NotImplementedError 234 | 235 | def finalize( 236 | self, tokens: Tensor, sum_logprobs: Tensor 237 | ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: 238 | """Finalize search and return the final candidate sequences 239 | 240 | Parameters 241 | ---------- 242 | tokens : Tensor, shape = (n_audio, n_group, current_sequence_length) 243 | all tokens in the context so far, including the prefix and sot_sequence 244 | 245 | sum_logprobs : Tensor, shape = (n_audio, n_group) 246 | cumulative log probabilities for each sequence 247 | 248 | Returns 249 | ------- 250 | tokens : Sequence[Sequence[Tensor]], length = n_audio 251 | sequence of Tensors containing candidate token sequences, for each audio input 252 | 253 | sum_logprobs : List[List[float]], length = n_audio 254 | sequence of cumulative log probabilities corresponding to the above 255 | 256 | """ 257 | raise NotImplementedError 258 | 259 | 260 | class GreedyDecoder(TokenDecoder): 261 | def __init__(self, temperature: float, eot: int): 262 | self.temperature = temperature 263 | self.eot = eot 264 | 265 | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: 266 | temperature = self.temperature 267 | if temperature == 0: 268 | next_tokens = logits.argmax(dim=-1) 269 | else: 270 | next_tokens = Categorical(logits=logits / temperature).sample() 271 | 272 | logprobs = F.log_softmax(logits.float(), dim=-1) 273 | current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] 274 | sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) 275 | 276 | next_tokens[tokens[:, -1] == self.eot] = self.eot 277 | tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) 278 | 279 | completed = (tokens[:, -1] == self.eot).all() 280 | return tokens, completed 281 | 282 | def finalize(self, tokens: Tensor, sum_logprobs: Tensor): 283 | # make sure each sequence has at least one EOT token at the end 284 | tokens = F.pad(tokens, (0, 1), value=self.eot) 285 | return tokens, sum_logprobs.tolist() 286 | 287 | 288 | class BeamSearchDecoder(TokenDecoder): 289 | def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None, block_ngrams: list[int]=[]): 290 | self.beam_size = beam_size 291 | self.eot = eot 292 | self.inference = inference 293 | self.patience = patience or 1.0 294 | self.max_candidates: int = round(beam_size * self.patience) 295 | self.finished_sequences = None 296 | self.block_ngrams = block_ngrams 297 | 298 | assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" 299 | 300 | def reset(self): 301 | self.finished_sequences = None 302 | 303 | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: 304 | 305 | if tokens.shape[0] % self.beam_size != 0: 306 | raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") 307 | 308 | n_audio = tokens.shape[0] // self.beam_size 309 | if self.finished_sequences is None: # for the first update 310 | self.finished_sequences = [{} for _ in range(n_audio)] 311 | 312 | logprobs = F.log_softmax(logits.float(), dim=-1) 313 | next_tokens, source_indices, finished_sequences = [], [], [] 314 | # tokenizer = get_tokenizer(multilingual=True, language='ja', task='transcribe') 315 | # print("n_audio should be 1, and it is: ", n_audio) 316 | for i in range(n_audio): 317 | scores, sources, finished = {}, {}, {} 318 | 319 | # STEP 1: calculate the cumulative log probabilities for possible candidates 320 | for j in range(self.beam_size): 321 | # print(f"beam j={j}:") 322 | idx = i * self.beam_size + j 323 | prefix = tokens[idx].tolist() 324 | # print("prefix: ", prefix) 325 | for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): 326 | new_logprob = (sum_logprobs[idx] + logprob).item() 327 | sequence = tuple(prefix + [token.item()]) 328 | scores[sequence] = new_logprob 329 | sources[sequence] = idx 330 | # print(sequence) 331 | # print(tokenizer.decode(sequence, skip_special_tokens=True)) 332 | # print(new_logprob) 333 | # STEP 2: rank the candidates and keep the top beam_size sequences for each audio 334 | saved = 0 335 | for sequence in sorted(scores, key=scores.get, reverse=True): 336 | rep = False 337 | for block_ngram in self.block_ngrams: 338 | if 2*block_ngram <= len(sequence) and tuple(sequence[-2*block_ngram:-block_ngram]) == tuple(sequence[-block_ngram:]): 339 | # print("block", sequence) 340 | finished[sequence[:-block_ngram]+(self.eot,)] = -1e20 341 | rep = True 342 | break 343 | if rep: 344 | continue 345 | if sequence[-1] == self.eot: 346 | finished[sequence] = scores[sequence] 347 | else: 348 | sum_logprobs[len(next_tokens)] = scores[sequence] 349 | next_tokens.append(sequence) 350 | source_indices.append(sources[sequence]) 351 | 352 | saved += 1 353 | if saved == self.beam_size: 354 | break 355 | 356 | finished_sequences.append(finished) 357 | 358 | tokens = torch.tensor(next_tokens, device=tokens.device) 359 | self.inference.rearrange_kv_cache(source_indices) 360 | 361 | # add newly finished sequences to self.finished_sequences 362 | assert len(self.finished_sequences) == len(finished_sequences) 363 | for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): 364 | for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): 365 | if len(previously_finished) >= self.max_candidates: 366 | break # the candidate list is full 367 | previously_finished[seq] = newly_finished[seq] 368 | 369 | # mark as completed if all audio has enough number of samples 370 | completed = all( 371 | len(sequences) >= self.max_candidates for sequences in self.finished_sequences 372 | ) 373 | return tokens, completed 374 | 375 | def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): 376 | # collect all finished sequences, including patience, and add unfinished ones if not enough 377 | sum_logprobs = sum_logprobs.cpu() 378 | for i, sequences in enumerate(self.finished_sequences): 379 | if len(sequences) < self.beam_size: # when not enough sequences are finished 380 | for j in list(np.argsort(sum_logprobs[i]))[::-1]: 381 | sequence = preceding_tokens[i, j].tolist() + [self.eot] 382 | sequences[tuple(sequence)] = sum_logprobs[i][j].item() 383 | if len(sequences) >= self.beam_size: 384 | break 385 | 386 | tokens: List[List[Tensor]] = [ 387 | [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences 388 | ] 389 | sum_logprobs: List[List[float]] = [ 390 | list(sequences.values()) for sequences in self.finished_sequences 391 | ] 392 | return tokens, sum_logprobs 393 | 394 | 395 | class LogitFilter: 396 | def apply(self, logits: Tensor, tokens: Tensor) -> None: 397 | """Apply any filtering or masking to logits in-place 398 | 399 | Parameters 400 | ---------- 401 | logits : Tensor, shape = (n_batch, vocab_size) 402 | per-token logits of the probability distribution at the current step 403 | 404 | tokens : Tensor, shape = (n_batch, current_sequence_length) 405 | all tokens in the context so far, including the prefix and sot_sequence tokens 406 | 407 | """ 408 | raise NotImplementedError 409 | 410 | 411 | class SuppressBlank(LogitFilter): 412 | def __init__(self, tokenizer: Tokenizer, sample_begin: int): 413 | self.tokenizer = tokenizer 414 | self.sample_begin = sample_begin 415 | 416 | def apply(self, logits: Tensor, tokens: Tensor): 417 | if tokens.shape[1] == self.sample_begin: 418 | logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf 419 | 420 | 421 | class SuppressTokens(LogitFilter): 422 | def __init__(self, suppress_tokens: Sequence[int]): 423 | self.suppress_tokens = list(suppress_tokens) 424 | 425 | def apply(self, logits: Tensor, tokens: Tensor): 426 | logits[:, self.suppress_tokens] = -np.inf 427 | 428 | 429 | class ApplyTimestampRules(LogitFilter): 430 | def __init__( 431 | self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] 432 | ): 433 | self.tokenizer = tokenizer 434 | self.sample_begin = sample_begin 435 | self.max_initial_timestamp_index = max_initial_timestamp_index 436 | 437 | def apply(self, logits: Tensor, tokens: Tensor): 438 | # suppress <|notimestamps|> which is handled by without_timestamps 439 | if self.tokenizer.no_timestamps is not None: 440 | logits[:, self.tokenizer.no_timestamps] = -np.inf 441 | 442 | # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly 443 | for k in range(tokens.shape[0]): 444 | seq = [t for t in tokens[k, self.sample_begin :].tolist()] 445 | last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin 446 | penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin 447 | 448 | if last_was_timestamp: 449 | if penultimate_was_timestamp: # has to be non-timestamp 450 | logits[k, self.tokenizer.timestamp_begin :] = -np.inf 451 | else: # cannot be normal text tokens 452 | logits[k, : self.tokenizer.eot] = -np.inf 453 | 454 | # apply the `max_initial_timestamp` option 455 | if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None: 456 | last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index 457 | logits[:, last_allowed + 1 :] = -np.inf 458 | 459 | # if sum of probability over timestamps is above any other token, sample timestamp 460 | logprobs = F.log_softmax(logits.float(), dim=-1) 461 | for k in range(tokens.shape[0]): 462 | timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) 463 | max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() 464 | if timestamp_logprob > max_text_token_logprob: 465 | logits[k, : self.tokenizer.timestamp_begin] = -np.inf 466 | 467 | 468 | class DecodingTask: 469 | inference: Inference 470 | sequence_ranker: SequenceRanker 471 | decoder: TokenDecoder 472 | logit_filters: List[LogitFilter] 473 | 474 | def __init__(self, model: "Whisper", options: DecodingOptions): 475 | self.model = model 476 | 477 | language = options.language or "en" 478 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task, concat_lang_token=options.concat_lang_token) 479 | self.tokenizer: Tokenizer = tokenizer 480 | self.options: DecodingOptions = self._verify_options(options) 481 | self.n_group: int = options.beam_size or options.best_of or 1 482 | self.n_ctx: int = model.dims.n_text_ctx 483 | self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 484 | 485 | self.sot_sequence: Tuple[int] = tokenizer.sot_sequence 486 | if self.options.without_timestamps: 487 | self.sot_sequence = tokenizer.sot_sequence_including_notimestamps 488 | 489 | self.initial_tokens: Tuple[int] = self._get_initial_tokens() 490 | self.sample_begin: int = len(self.initial_tokens) 491 | self.sot_index: int = self.initial_tokens.index(tokenizer.sot) 492 | 493 | # inference: implements the forward pass through the decoder, including kv caching 494 | self.inference = PyTorchInference(model, len(self.initial_tokens), logit_mask=options.logit_mask) 495 | 496 | # sequence ranker: implements how to rank a group of sampled sequences 497 | self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) 498 | 499 | # decoder: implements how to select the next tokens, given the autoregressive distribution 500 | if options.beam_size is not None: 501 | self.decoder = BeamSearchDecoder( 502 | options.beam_size, tokenizer.eot, self.inference, options.patience, options.block_ngrams 503 | ) 504 | else: 505 | self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) 506 | 507 | # logit filters: applies various rules to suppress or penalize certain tokens 508 | self.logit_filters = [] 509 | if self.options.suppress_blank: 510 | self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) 511 | if self.options.suppress_tokens: 512 | self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) 513 | if not options.without_timestamps: 514 | precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds 515 | max_initial_timestamp_index = None 516 | if options.max_initial_timestamp: 517 | max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) 518 | self.logit_filters.append( 519 | ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) 520 | ) 521 | 522 | def _verify_options(self, options: DecodingOptions) -> DecodingOptions: 523 | if options.beam_size is not None and options.best_of is not None: 524 | raise ValueError("beam_size and best_of can't be given together") 525 | if options.temperature == 0: 526 | if options.best_of is not None: 527 | raise ValueError("best_of with greedy sampling (T=0) is not compatible") 528 | if options.patience is not None and options.beam_size is None: 529 | raise ValueError("patience requires beam_size to be given") 530 | if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): 531 | raise ValueError("length_penalty (alpha) should be a value between 0 and 1") 532 | 533 | return options 534 | 535 | def _get_initial_tokens(self) -> Tuple[int]: 536 | tokens = list(self.sot_sequence) 537 | prefix = self.options.prefix 538 | prompt = self.options.prompt 539 | 540 | if prefix: 541 | prefix_tokens = ( 542 | self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix 543 | ) 544 | if self.sample_len is not None: 545 | max_prefix_len = self.n_ctx // 2 - self.sample_len 546 | prefix_tokens = prefix_tokens[-max_prefix_len:] 547 | tokens = tokens + prefix_tokens 548 | 549 | if prompt: 550 | prompt_tokens = ( 551 | self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt 552 | ) 553 | tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens 554 | 555 | return tuple(tokens) 556 | 557 | def _get_suppress_tokens(self) -> Tuple[int]: 558 | suppress_tokens = self.options.suppress_tokens 559 | 560 | if isinstance(suppress_tokens, str): 561 | suppress_tokens = [int(t) for t in suppress_tokens.split(",")] 562 | 563 | if -1 in suppress_tokens: 564 | suppress_tokens = [t for t in suppress_tokens if t >= 0] 565 | suppress_tokens.extend(self.tokenizer.non_speech_tokens) 566 | elif suppress_tokens is None or len(suppress_tokens) == 0: 567 | suppress_tokens = [] # interpret empty string as an empty list 568 | else: 569 | assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" 570 | 571 | suppress_tokens.extend( 572 | [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] 573 | ) 574 | if self.tokenizer.no_speech is not None: 575 | # no-speech probability is collected separately 576 | suppress_tokens.append(self.tokenizer.no_speech) 577 | 578 | return tuple(sorted(set(suppress_tokens))) 579 | 580 | def _get_audio_features(self, mel: Tensor): 581 | if self.options.fp16: 582 | mel = mel.half() 583 | 584 | # if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): 585 | if mel.shape[-1] == self.model.dims.n_audio_state: 586 | # encoded audio features are given; skip audio encoding 587 | audio_features = mel 588 | else: 589 | audio_features = self.model.encoder(mel) 590 | 591 | if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): 592 | return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") 593 | 594 | return audio_features 595 | 596 | def _detect_language(self, audio_features: Tensor, tokens: Tensor): 597 | languages = [self.options.language] * audio_features.shape[0] 598 | lang_probs = None 599 | 600 | if self.options.language is None or self.options.task == "lang_id": 601 | lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) 602 | languages = [max(probs, key=probs.get) for probs in lang_probs] 603 | if self.options.language is None: 604 | tokens[:, self.sot_index + 1] = lang_tokens # write language tokens 605 | 606 | return languages, lang_probs 607 | 608 | def _main_loop(self, audio_features: Tensor, tokens: Tensor): 609 | assert audio_features.shape[0] == tokens.shape[0] 610 | n_batch = tokens.shape[0] 611 | sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) 612 | no_speech_probs = [np.nan] * n_batch 613 | 614 | try: 615 | for i in range(self.sample_len): 616 | logits = self.inference.logits(tokens, audio_features) 617 | 618 | if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs 619 | probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) 620 | no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() 621 | 622 | # now we need to consider the logits at the last token only 623 | logits = logits[:, -1] 624 | 625 | # apply the logit filters, e.g. for suppressing or applying penalty to 626 | for logit_filter in self.logit_filters: 627 | logit_filter.apply(logits, tokens) 628 | 629 | # expand the tokens tensor with the selected next tokens 630 | # print("tokens before updating decoder:", tokens) 631 | tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) 632 | 633 | if completed or tokens.shape[-1] > self.n_ctx: 634 | break 635 | finally: 636 | self.inference.cleanup_caching() 637 | 638 | return tokens, sum_logprobs, no_speech_probs 639 | 640 | @torch.no_grad() 641 | def run(self, mel: Tensor) -> List[DecodingResult]: 642 | self.decoder.reset() 643 | tokenizer: Tokenizer = self.tokenizer 644 | n_audio: int = mel.shape[0] 645 | 646 | audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass 647 | tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) 648 | 649 | # detect language if requested, overwriting the language token 650 | languages, language_probs = self._detect_language(audio_features, tokens) 651 | if self.options.task == "lang_id": 652 | return [ 653 | DecodingResult(audio_features=features, language=language, language_probs=probs) 654 | for features, language, probs in zip(audio_features, languages, language_probs) 655 | ] 656 | 657 | # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling 658 | audio_features = audio_features.repeat_interleave(self.n_group, dim=0) 659 | tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) 660 | 661 | # call the main sampling loop 662 | tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) 663 | 664 | # reshape the tensors to have (n_audio, n_group) as the first two dimensions 665 | audio_features = audio_features[:: self.n_group] 666 | no_speech_probs = no_speech_probs[:: self.n_group] 667 | assert audio_features.shape[0] == len(no_speech_probs) == n_audio 668 | 669 | tokens = tokens.reshape(n_audio, self.n_group, -1) 670 | sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) 671 | 672 | # get the final candidates for each group, and slice between the first sampled token and EOT 673 | tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) 674 | tokens: List[List[Tensor]] = [ 675 | [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens 676 | ] 677 | 678 | # select the top-ranked sample in each group 679 | selected = self.sequence_ranker.rank(tokens, sum_logprobs) 680 | tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] 681 | texts: List[str] = [tokenizer.decode(t, skip_special_tokens=True).strip() for t in tokens] 682 | 683 | sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] 684 | avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] 685 | 686 | fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) 687 | if len(set(map(len, fields))) != 1: 688 | raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") 689 | 690 | return [ 691 | DecodingResult( 692 | audio_features=features, 693 | language=language, 694 | tokens=tokens, 695 | text=text, 696 | avg_logprob=avg_logprob, 697 | no_speech_prob=no_speech_prob, 698 | temperature=self.options.temperature, 699 | compression_ratio=compression_ratio(text), 700 | ) 701 | for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) 702 | ] 703 | 704 | 705 | @torch.no_grad() 706 | def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: 707 | """ 708 | Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). 709 | 710 | Parameters 711 | ---------- 712 | model: Whisper 713 | the Whisper model instance 714 | 715 | mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) 716 | A tensor containing the Mel spectrogram(s) 717 | 718 | options: DecodingOptions 719 | A dataclass that contains all necessary options for decoding 30-second segments 720 | 721 | Returns 722 | ------- 723 | result: Union[DecodingResult, List[DecodingResult]] 724 | The result(s) of decoding contained in `DecodingResult` dataclass instance(s) 725 | """ 726 | single = mel.ndim == 2 727 | if single: 728 | mel = mel.unsqueeze(0) 729 | 730 | result = DecodingTask(model, options).run(mel) 731 | 732 | if single: 733 | result = result[0] 734 | 735 | return result 736 | --------------------------------------------------------------------------------