├── whisper_s2t ├── backends │ ├── openai │ │ ├── __init__.py │ │ └── model.py │ ├── tensorrt │ │ ├── __init__.py │ │ ├── engine_builder │ │ │ ├── download_utils.py │ │ │ ├── __init__.py │ │ │ └── builder.py │ │ ├── hf_utils.py │ │ ├── tokenizer.py │ │ ├── trt_model.py │ │ └── model.py │ ├── ctranslate2 │ │ ├── __init__.py │ │ ├── hf_utils.py │ │ ├── tokenizer.py │ │ └── model.py │ ├── huggingface │ │ ├── __init__.py │ │ └── model.py │ └── __init__.py ├── assets │ ├── silent.mp3 │ ├── mel_filters.npz │ ├── vad_pp_cpu.ts │ ├── vad_pp_gpu.ts │ ├── seg_vad_model_cpu.ts │ ├── seg_vad_model_gpu.ts │ ├── frame_vad_model_cpu.ts │ ├── frame_vad_model_gpu.ts │ └── lang_codes.txt ├── configs.py ├── __init__.py ├── speech_segmenter │ ├── seg_vad.py │ ├── frame_vad.py │ └── __init__.py ├── audio.py ├── utils.py └── data.py ├── files └── benchmarks.png ├── benchmark_requirements.txt ├── requirements.txt ├── prepare_benchmark_env.sh ├── setup.py ├── Dockerfile ├── LICENSE ├── install_tensorrt.sh ├── run_benchmark.sh ├── scripts ├── benchmark_openai.py ├── benchmark_whisperx.py ├── benchmark_whisper_s2t_trt.py ├── benchmark_whisper_s2t.py ├── benchmark_whisper_s2t_distil.py ├── benchmark_huggingface.py └── benchmark_huggingface_distil.py ├── .gitignore ├── tools ├── metrics.py └── text_normalizer.py ├── docs.md └── README.md /whisper_s2t/backends/openai/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /whisper_s2t/backends/ctranslate2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /whisper_s2t/backends/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /files/benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/files/benchmarks.png -------------------------------------------------------------------------------- /whisper_s2t/assets/silent.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/silent.mp3 -------------------------------------------------------------------------------- /whisper_s2t/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisper_s2t/assets/vad_pp_cpu.ts: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/vad_pp_cpu.ts -------------------------------------------------------------------------------- /whisper_s2t/assets/vad_pp_gpu.ts: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/vad_pp_gpu.ts -------------------------------------------------------------------------------- /whisper_s2t/assets/seg_vad_model_cpu.ts: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/seg_vad_model_cpu.ts -------------------------------------------------------------------------------- /whisper_s2t/assets/seg_vad_model_gpu.ts: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/seg_vad_model_gpu.ts -------------------------------------------------------------------------------- /whisper_s2t/assets/frame_vad_model_cpu.ts: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/frame_vad_model_cpu.ts -------------------------------------------------------------------------------- /whisper_s2t/assets/frame_vad_model_gpu.ts: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashikg/WhisperS2T/HEAD/whisper_s2t/assets/frame_vad_model_gpu.ts -------------------------------------------------------------------------------- /benchmark_requirements.txt: -------------------------------------------------------------------------------- 1 | nemo_text_processing 2 | packaging 3 | ninja 4 | clean-text 5 | contractions 6 | jiwer 7 | nltk 8 | editdistance 9 | unidecode 10 | diff_match_patch -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.66.2 2 | rich==13.7.0 3 | torch==2.1.2+cu121 4 | numpy==1.26.4 5 | platformdirs==4.2.0 6 | ctranslate2==4.0.0 7 | tokenizers==0.15.2 8 | huggingface-hub==0.20.3 9 | accelerate==0.25.0 10 | optimum==1.17.1 11 | transformers==4.36.1 12 | openai-whisper==20231117 13 | nvidia-ml-py==12.535.133 -------------------------------------------------------------------------------- /prepare_benchmark_env.sh: -------------------------------------------------------------------------------- 1 | # Install packages 2 | apt-get update 3 | apt-get install -y libsndfile1 ffmpeg 4 | pip install -U -r requirements.txt 5 | pip install -U git+https://github.com/m-bain/whisperx.git 6 | pip install -U -r benchmark_requirements.txt 7 | pip install flash-attn==2.3.6 --no-build-isolation 8 | 9 | # Download dataset 10 | rm -rf data.zip 11 | wget https://github.com/shashikg/WhisperS2T/releases/download/v1.0.0/data.zip 12 | unzip data.zip 13 | rm -rf data.zip -------------------------------------------------------------------------------- /whisper_s2t/configs.py: -------------------------------------------------------------------------------- 1 | def exact_div(x, y): 2 | assert x % y == 0 3 | return x // y 4 | 5 | # hard-coded audio hyperparameters 6 | N_FFT = 400 7 | INPUT_STRIDE=2 8 | N_MELS=80 9 | HOP_LENGTH = 160 10 | CHUNK_LENGTH = 30 11 | SAMPLE_RATE = 16000 12 | MAX_TEXT_TOKEN_LENGTH = 448 13 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 14 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 15 | 16 | N_SAMPLES_PER_TOKEN = HOP_LENGTH*INPUT_STRIDE # the initial convolutions has stride 2 17 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 18 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 19 | TIME_PRECISION = 1/TOKENS_PER_SECOND -------------------------------------------------------------------------------- /whisper_s2t/assets/lang_codes.txt: -------------------------------------------------------------------------------- 1 | af 2 | am 3 | ar 4 | as 5 | az 6 | ba 7 | be 8 | bg 9 | bn 10 | bo 11 | br 12 | bs 13 | ca 14 | cs 15 | cy 16 | da 17 | de 18 | el 19 | en 20 | es 21 | et 22 | eu 23 | fa 24 | fi 25 | fo 26 | fr 27 | gl 28 | gu 29 | ha 30 | haw 31 | he 32 | hi 33 | hr 34 | ht 35 | hu 36 | hy 37 | id 38 | is 39 | it 40 | ja 41 | jw 42 | ka 43 | kk 44 | km 45 | kn 46 | ko 47 | la 48 | lb 49 | ln 50 | lo 51 | lt 52 | lv 53 | mg 54 | mi 55 | mk 56 | ml 57 | mn 58 | mr 59 | ms 60 | mt 61 | my 62 | ne 63 | nl 64 | nn 65 | no 66 | oc 67 | pa 68 | pl 69 | ps 70 | pt 71 | ro 72 | ru 73 | sa 74 | sd 75 | si 76 | sk 77 | sl 78 | sn 79 | so 80 | sq 81 | sr 82 | su 83 | sv 84 | sw 85 | ta 86 | te 87 | tg 88 | th 89 | tk 90 | tl 91 | tr 92 | tt 93 | uk 94 | ur 95 | uz 96 | vi 97 | yi 98 | yo 99 | zh -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt", "r", encoding="utf-8") as f: 7 | requirements = f.read().splitlines() 8 | 9 | setup( 10 | name="whisper_s2t", 11 | version="1.3.1", 12 | description="An Optimized Speech-to-Text Pipeline for the Whisper Model.", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | python_requires=">=3.8", 16 | author="Shashi Kant Gupta", 17 | url="https://github.com/shashikg/WhisperS2T", 18 | license="MIT", 19 | packages=find_packages(exclude=["tests*"]), 20 | install_requires=requirements, 21 | package_data={ 22 | '': ['assets/*'], 23 | }, 24 | include_package_data=True, 25 | ) 26 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=nvidia/cuda 2 | ARG BASE_TAG=12.1.0-devel-ubuntu22.04 3 | 4 | FROM ${BASE_IMAGE}:${BASE_TAG} 5 | ARG WHISPER_S2T_VER=main 6 | ARG SKIP_TENSORRT_LLM 7 | 8 | WORKDIR /workspace 9 | ENTRYPOINT [] 10 | SHELL ["/bin/bash", "-c"] 11 | 12 | COPY ./install_tensorrt.sh install_tensorrt.sh 13 | 14 | RUN apt update && apt-get install -y python3.10 python3-pip libsndfile1 ffmpeg git && \ 15 | pip3 install --no-cache-dir notebook jupyterlab ipywidgets && \ 16 | pip3 install --no-cache-dir git+https://github.com/shashikg/WhisperS2T.git@${WHISPER_S2T_VER} && \ 17 | CUDNN_PATH=$(python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))') && \ 18 | echo 'export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:'${CUDNN_PATH} >> ~/.bashrc 19 | 20 | RUN if [[ -z "$SKIP_TENSORRT_LLM" ]]; then /bin/bash install_tensorrt.sh; fi 21 | 22 | RUN apt-get autoremove -y && \ 23 | apt-get clean && \ 24 | rm -rf /var/lib/apt/lists/* && \ 25 | rm -r install_tensorrt.sh 26 | 27 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shashi Kant Gupta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /install_tensorrt.sh: -------------------------------------------------------------------------------- 1 | echo "" 2 | echo "###########################[ Installing Build Tools ]##########################" 3 | apt-get update && apt-get install -y build-essential ca-certificates ccache cmake gnupg2 wget curl gdb || sudo apt-get update && sudo apt-get install -y build-essential ca-certificates ccache cmake gnupg2 wget curl gdb 4 | 5 | echo "" 6 | echo "###########################[ Installing OpenMPI ]###########################" 7 | apt-get update && apt-get -y install openmpi-bin libopenmpi-dev || sudo apt-get update && sudo apt-get -y install openmpi-bin libopenmpi-dev 8 | 9 | echo "" 10 | echo "###########################[ Installing MPI4PY ]###########################" 11 | MPI4PY_VERSION="3.1.5" 12 | RELEASE_URL="https://github.com/mpi4py/mpi4py/archive/refs/tags/${MPI4PY_VERSION}.tar.gz" 13 | curl -L ${RELEASE_URL} | tar -zx -C /tmp 14 | # Bypassing compatibility issues with higher versions (>= 69) of setuptools. 15 | sed -i 's/>= 40\.9\.0/>= 40.9.0, < 69/g' /tmp/mpi4py-${MPI4PY_VERSION}/pyproject.toml 16 | pip3 install /tmp/mpi4py-${MPI4PY_VERSION} 17 | rm -rf /tmp/mpi4py* 18 | 19 | echo "" 20 | echo "###########################[ Installing TensorRT-LLM ]###########################" 21 | pip3 install --no-cache-dir -U torch==2.1.2 22 | pip3 install --no-cache-dir tensorrt_llm==0.8.0.dev2024012301 --extra-index-url https://pypi.nvidia.com -------------------------------------------------------------------------------- /run_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Variable 4 | REPO_DIR=$(pwd) 5 | 6 | # Commons For All GPUs 7 | 8 | echo "WhisperS2T - CTranslate2" 9 | python3 scripts/benchmark_whisper_s2t.py --repo_path="$REPO_DIR" --backend="CTranslate2" --batch_size=56 --eval_mp3="yes" 10 | 11 | echo "WhisperS2T - OpenAI" 12 | python3 scripts/benchmark_whisper_s2t.py --repo_path="$REPO_DIR" --backend="OpenAI" --batch_size=16 --eval_mp3="no" 13 | 14 | echo "WhisperS2T - HuggingFace" 15 | python3 scripts/benchmark_whisper_s2t.py --repo_path="$REPO_DIR" --backend="HuggingFace" --batch_size=48 --eval_mp3="no" 16 | 17 | echo "WhisperS2T - HuggingFace - BT" 18 | python3 scripts/benchmark_whisper_s2t.py --repo_path="$REPO_DIR" --backend="HuggingFace" --batch_size=48 --better_transformer="yes" --eval_mp3="no" 19 | 20 | echo "WhisperX" 21 | python3 scripts/benchmark_whisperx.py --repo_path="$REPO_DIR" --batch_size=56 22 | 23 | echo "HuggingFace" 24 | python3 scripts/benchmark_huggingface.py --repo_path="$REPO_DIR" --batch_size=48 --eval_mp3="no" 25 | 26 | echo "HuggingFace - BT" 27 | python3 scripts/benchmark_huggingface.py --repo_path="$REPO_DIR" --batch_size=48 --better_transformer="yes" --eval_mp3="no" 28 | 29 | # Flash Attention 2 Supported Arch 30 | 31 | echo "WhisperS2T - HuggingFace - FA" 32 | python3 scripts/benchmark_whisper_s2t.py --repo_path="$REPO_DIR" --backend="HuggingFace" --batch_size=48 --flash_attention="yes" --eval_mp3="no" 33 | 34 | echo "HuggingFace - FA" 35 | python3 scripts/benchmark_huggingface.py --repo_path="$REPO_DIR" --batch_size=48 --flash_attention="yes" --eval_mp3="yes" 36 | 37 | echo "OpenAI" 38 | python3 scripts/benchmark_openai.py --repo_path="$REPO_DIR" 39 | -------------------------------------------------------------------------------- /whisper_s2t/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from platformdirs import user_cache_dir 3 | 4 | from .utils import write_outputs 5 | 6 | BASE_PATH = os.path.dirname(__file__) 7 | 8 | CACHE_DIR = user_cache_dir("whisper_s2t") 9 | os.makedirs(CACHE_DIR, exist_ok=True) 10 | 11 | 12 | def load_model(model_identifier="large-v2", 13 | backend='CTranslate2', 14 | **model_kwargs): 15 | 16 | if model_identifier in ['large-v3']: 17 | model_kwargs['n_mels'] = 128 18 | elif (model_identifier in ['distil-large-v2']) and (backend.lower() not in ["huggingface", "hf"]): 19 | print(f"Switching backend to HuggingFace. Distill whisper is only supported with HuggingFace backend.") 20 | backend = "huggingface" 21 | 22 | model_kwargs['max_speech_len'] = 15.0 23 | model_kwargs['max_text_token_len'] = 128 24 | 25 | if backend.lower() in ["ctranslate2", "ct2"]: 26 | from .backends.ctranslate2.model import WhisperModelCT2 as WhisperModel 27 | 28 | elif backend.lower() in ["huggingface", "hf"]: 29 | from .backends.huggingface.model import WhisperModelHF as WhisperModel 30 | 31 | if 'distil' in model_identifier: 32 | model_identifier = f"distil-whisper/{model_identifier}" 33 | else: 34 | model_identifier = f"openai/whisper-{model_identifier}" 35 | 36 | elif backend.lower() in ["openai", "oai"]: 37 | from .backends.openai.model import WhisperModelOAI as WhisperModel 38 | 39 | elif backend.lower() in ["tensorrt", "trt", "trt-llm", "tensorrt-llm", "trt_llm", "tensorrt_llm"]: 40 | from .backends.tensorrt.model import WhisperModelTRT as WhisperModel 41 | else: 42 | raise ValueError(f"Backend name '{backend}' is invalid. Only following options are available: ['CTranslate2', 'TensorRT-LLM', 'HuggingFace', 'OpenAI']") 43 | 44 | return WhisperModel(model_identifier, **model_kwargs) 45 | -------------------------------------------------------------------------------- /scripts/benchmark_openai.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_arguments(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--repo_path', default="", type=str) 6 | args = parser.parse_args() 7 | return args 8 | 9 | 10 | def run(repo_path): 11 | import time, os 12 | from tqdm import tqdm 13 | import pandas as pd 14 | import whisper 15 | 16 | results_dir = f"{repo_path}/results/OpenAI" 17 | os.makedirs(results_dir, exist_ok=True) 18 | 19 | model = whisper.load_model('large-v2') 20 | model = model.cuda().eval() 21 | 22 | # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 23 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") 24 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 25 | 26 | for fn in tqdm(files[:5], desc="warming"): 27 | result = model.transcribe(fn, language='en') 28 | 29 | st = time.time() 30 | pred_text = [] 31 | for fn in tqdm(files, desc="KINCAID"): 32 | result = model.transcribe(fn, language='en') 33 | pred_text.append(result['text'].strip()) 34 | 35 | time_kincaid46_wav = time.time()-st 36 | 37 | data['pred_text'] = pred_text 38 | data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) 39 | 40 | # MultiLingualLongform >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 41 | data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") 42 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 43 | lang_codes = data['lang_code'].to_list() 44 | 45 | st = time.time() 46 | pred_text = [] 47 | for idx in tqdm(range(len(files)), desc="MultiLingualLongform"): 48 | result = model.transcribe(files[idx], language=lang_codes[idx]) 49 | pred_text.append(result['text'].strip()) 50 | 51 | time_multilingual = time.time()-st 52 | 53 | data['pred_text'] = pred_text 54 | data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) 55 | 56 | infer_time = [ 57 | ["Dataset", "Time"], 58 | ["KINCAID46 WAV", time_kincaid46_wav], 59 | ["KINCAID46 MP3", 0.0], 60 | ["MultiLingualLongform", time_multilingual] 61 | ] 62 | 63 | infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) 64 | infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) 65 | 66 | 67 | if __name__ == '__main__': 68 | args = parse_arguments() 69 | run(args.repo_path) -------------------------------------------------------------------------------- /whisper_s2t/backends/openai/model.py: -------------------------------------------------------------------------------- 1 | import whisper 2 | from whisper.decoding import DecodingOptions 3 | 4 | from .. import WhisperModel 5 | from ...configs import * 6 | 7 | 8 | ASR_OPTIONS = { 9 | "beam_size": 1, 10 | "without_timestamps": True, 11 | "return_scores": True, 12 | "return_no_speech_prob": True, 13 | "patience": 1, 14 | "length_penalty": 1, 15 | } 16 | 17 | 18 | class WhisperModelOAI(WhisperModel): 19 | def __init__(self, 20 | model_name: str, 21 | device="cuda", 22 | compute_type="float16", 23 | max_text_token_len=MAX_TEXT_TOKEN_LENGTH, 24 | asr_options={}, 25 | **model_kwargs): 26 | 27 | self.model_name = model_name 28 | self.asr_options = ASR_OPTIONS 29 | self.asr_options.update(asr_options) 30 | 31 | self.model = whisper.load_model(model_name) 32 | self.model.to(device).eval() 33 | 34 | self.decode_options = { 35 | "sample_len": max_text_token_len, 36 | 'fp16': True if compute_type == "float16" else False 37 | } 38 | 39 | for k, v in self.asr_options.items(): 40 | if hasattr(DecodingOptions, k): 41 | self.decode_options[k] = v 42 | 43 | super().__init__( 44 | device=device, 45 | compute_type=compute_type, 46 | max_text_token_len=max_text_token_len, 47 | **model_kwargs 48 | ) 49 | 50 | def update_decode_options(self, params={}): 51 | self.decode_options.update(params) 52 | 53 | if 'sample_len' in params: 54 | self.update_params(params={'max_text_token_len': params['sample_len']}) 55 | 56 | def generate_segment_batched(self, features, prompts, seq_lens, seg_metadata): 57 | 58 | if self.compute_type == "float16": 59 | features = features.to(self.device).half() 60 | 61 | lang_and_task_pairs = {} 62 | for _i, _p in enumerate(prompts): 63 | try: 64 | lang_and_task_pairs[(_p[-3], _p[-2])].append(_i) 65 | except: 66 | lang_and_task_pairs[(_p[-3], _p[-2])] = [_i] 67 | 68 | 69 | response = [{} for _ in prompts] 70 | for (task, lang), idx_list in lang_and_task_pairs.items(): 71 | 72 | results = self.model.decode(features[idx_list].to(self.device), DecodingOptions(task=task, language=lang, **self.decode_options)) 73 | 74 | for idx, result in zip(idx_list, results): 75 | response[idx]['text'] = result.text.strip() 76 | 77 | if self.asr_options['return_scores']: 78 | response[idx]['avg_logprob'] = result.avg_logprob 79 | 80 | if self.asr_options['return_no_speech_prob']: 81 | response[idx]['no_speech_prob'] = result.no_speech_prob 82 | 83 | return response -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/engine_builder/download_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import hashlib 4 | import warnings 5 | from typing import List, Optional, Union 6 | 7 | from tqdm import tqdm 8 | 9 | from .... import CACHE_DIR 10 | 11 | SAVE_DIR = f"{CACHE_DIR}/models/trt" 12 | os.makedirs(SAVE_DIR, exist_ok=True) 13 | 14 | _MODELS = { 15 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 16 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 17 | } 18 | 19 | _TOKENIZER = { 20 | "large-v2": "https://huggingface.co/Systran/faster-whisper-large-v2/raw/main/tokenizer.json", 21 | "large-v3": "https://huggingface.co/Systran/faster-whisper-large-v3/raw/main/tokenizer.json", 22 | } 23 | 24 | def download_model(name): 25 | 26 | url = _MODELS[name] 27 | expected_sha256 = url.split("/")[-2] 28 | 29 | download_path = os.path.join(SAVE_DIR, name) 30 | os.makedirs(download_path, exist_ok=True) 31 | 32 | model_ckpt_path = os.path.join(download_path, "pt_ckpt.pt") 33 | tokenizer_path = os.path.join(download_path, "tokenizer.json") 34 | 35 | if not os.path.exists(tokenizer_path): 36 | with urllib.request.urlopen(_TOKENIZER[name]) as source, open(tokenizer_path, "wb") as output: 37 | with tqdm( 38 | total=int(source.info().get("Content-Length")), 39 | ncols=80, 40 | unit="iB", 41 | unit_scale=True, 42 | unit_divisor=1024, 43 | ) as pbar: 44 | while True: 45 | buffer = source.read(8192) 46 | if not buffer: 47 | break 48 | 49 | output.write(buffer) 50 | pbar.update(len(buffer)) 51 | 52 | if os.path.exists(model_ckpt_path) and not os.path.isfile(model_ckpt_path): 53 | raise RuntimeError(f"{model_ckpt_path} exists and is not a regular file") 54 | 55 | if os.path.isfile(model_ckpt_path): 56 | with open(model_ckpt_path, "rb") as f: 57 | model_bytes = f.read() 58 | 59 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 60 | return model_ckpt_path, tokenizer_path 61 | else: 62 | warnings.warn( 63 | f"{model_ckpt_path} exists, but the SHA256 checksum does not match; re-downloading the file" 64 | ) 65 | 66 | with urllib.request.urlopen(url) as source, open(model_ckpt_path, "wb") as output: 67 | with tqdm( 68 | total=int(source.info().get("Content-Length")), 69 | ncols=80, 70 | unit="iB", 71 | unit_scale=True, 72 | unit_divisor=1024, 73 | ) as pbar: 74 | while True: 75 | buffer = source.read(8192) 76 | if not buffer: 77 | break 78 | 79 | output.write(buffer) 80 | pbar.update(len(buffer)) 81 | 82 | with open(model_ckpt_path, "rb") as f: 83 | model_bytes = f.read() 84 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 85 | raise RuntimeError( 86 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." 87 | ) 88 | 89 | return model_ckpt_path, tokenizer_path -------------------------------------------------------------------------------- /whisper_s2t/backends/huggingface/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import WhisperProcessor, WhisperForConditionalGeneration 4 | 5 | from .. import WhisperModel 6 | from ...configs import * 7 | 8 | 9 | ASR_OPTIONS = { 10 | "beam_size": 1, 11 | "without_timestamps": True, 12 | "return_scores": False, 13 | "return_no_speech_prob": False, 14 | "use_flash_attention": True, 15 | "use_better_transformer": False, 16 | } 17 | 18 | 19 | COMPUTE_TYPE_TO_TORCH_DTYPE = { 20 | "float16": torch.float16 21 | } 22 | 23 | 24 | class WhisperModelHF(WhisperModel): 25 | def __init__(self, 26 | model_name: str, 27 | device="cuda", 28 | compute_type="float16", 29 | max_text_token_len=MAX_TEXT_TOKEN_LENGTH, 30 | asr_options={}, 31 | **model_kwargs): 32 | 33 | self.model_name = model_name 34 | self.asr_options = ASR_OPTIONS 35 | self.asr_options.update(asr_options) 36 | 37 | self.processor = WhisperProcessor.from_pretrained(self.model_name) 38 | self.model = WhisperForConditionalGeneration.from_pretrained(self.model_name, 39 | torch_dtype=COMPUTE_TYPE_TO_TORCH_DTYPE.get(compute_type, torch.float32), 40 | low_cpu_mem_usage=True, 41 | use_safetensors=True, 42 | use_flash_attention_2=self.asr_options["use_flash_attention"]) 43 | self.model.config.forced_decoder_ids = None 44 | self.model.to(device).eval() 45 | 46 | if self.asr_options["use_better_transformer"]: 47 | self.model = self.model.to_bettertransformer() 48 | 49 | self.generate_kwargs = { 50 | "max_new_tokens": max_text_token_len, 51 | "num_beams": self.asr_options['beam_size'], 52 | "return_timestamps": not self.asr_options['without_timestamps'], 53 | } 54 | 55 | super().__init__( 56 | device=device, 57 | compute_type=compute_type, 58 | max_text_token_len=max_text_token_len, 59 | **model_kwargs 60 | ) 61 | 62 | def update_generation_kwargs(self, params={}): 63 | self.generate_kwargs.update(params) 64 | 65 | if 'max_new_tokens' in params: 66 | self.update_params(params={'max_text_token_len': params['max_new_tokens']}) 67 | 68 | def generate_segment_batched(self, features, prompts, seq_lens, seg_metadata): 69 | if self.compute_type == "float16": 70 | features = features.to(self.device).half() 71 | 72 | lang_and_task_pairs = {} 73 | for _i, _p in enumerate(prompts): 74 | try: 75 | lang_and_task_pairs[(_p[-3], _p[-2])].append(_i) 76 | except: 77 | lang_and_task_pairs[(_p[-3], _p[-2])] = [_i] 78 | 79 | 80 | response = [{} for _ in prompts] 81 | for (task, lang), idx_list in lang_and_task_pairs.items(): 82 | predicted_ids = self.model.generate(features[idx_list], 83 | task=task, 84 | language=lang, 85 | **self.generate_kwargs) 86 | 87 | results = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) 88 | 89 | for idx, text in zip(idx_list, results): 90 | response[idx]['text'] = text.strip() 91 | 92 | return response -------------------------------------------------------------------------------- /scripts/benchmark_whisperx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_arguments(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--repo_path', default="", type=str) 6 | parser.add_argument('--batch_size', default=16, type=int) 7 | args = parser.parse_args() 8 | return args 9 | 10 | 11 | def run(repo_path, batch_size=16): 12 | import time, os 13 | import whisperx 14 | from tqdm import tqdm 15 | import pandas as pd 16 | 17 | results_dir = f"{repo_path}/results/WhisperX-bs_{batch_size}" 18 | os.makedirs(results_dir, exist_ok=True) 19 | 20 | model = whisperx.load_model("large-v2", "cuda", compute_type="float16", language='en', asr_options={'beam_size': 1}) 21 | 22 | # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 23 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") 24 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 25 | 26 | for fn in tqdm(files, desc="Warming"): 27 | audio = whisperx.load_audio(fn) 28 | result = model.transcribe(audio, batch_size=batch_size) 29 | 30 | st = time.time() 31 | pred_text = [] 32 | for fn in tqdm(files, desc="KINCAID WAV"): 33 | audio = whisperx.load_audio(fn) 34 | result = model.transcribe(audio, batch_size=batch_size) 35 | pred_text.append(" ".join([_['text'].strip() for _ in result['segments']])) 36 | 37 | time_kincaid46_wav = time.time()-st 38 | 39 | data['pred_text'] = pred_text 40 | data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) 41 | 42 | 43 | # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 44 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") 45 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 46 | 47 | st = time.time() 48 | pred_text = [] 49 | for fn in tqdm(files, desc="KINCAID MP3"): 50 | audio = whisperx.load_audio(fn) 51 | result = model.transcribe(audio, batch_size=batch_size) 52 | pred_text.append(" ".join([_['text'].strip() for _ in result['segments']])) 53 | 54 | time_kincaid46_mp3 = time.time()-st 55 | 56 | data['pred_text'] = pred_text 57 | data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) 58 | 59 | 60 | # MultiLingualLongform >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 61 | data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") 62 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 63 | lang_codes = data['lang_code'].to_list() 64 | 65 | st = time.time() 66 | pred_text = [] 67 | for idx in tqdm(range(len(files)), desc="MultiLingualLongform"): 68 | audio = whisperx.load_audio(files[idx]) 69 | result = model.transcribe(audio, batch_size=batch_size, language=lang_codes[idx], task='transcribe') 70 | pred_text.append(" ".join([_['text'].strip() for _ in result['segments']])) 71 | 72 | time_multilingual = time.time()-st 73 | 74 | data['pred_text'] = pred_text 75 | data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) 76 | 77 | infer_time = [ 78 | ["Dataset", "Time"], 79 | ["KINCAID46 WAV", time_kincaid46_wav], 80 | ["KINCAID46 MP3", time_kincaid46_mp3], 81 | ["MultiLingualLongform", time_multilingual] 82 | ] 83 | infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) 84 | infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) 85 | 86 | 87 | if __name__ == '__main__': 88 | args = parse_arguments() 89 | run(args.repo_path, batch_size=args.batch_size) -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/hf_utils.py: -------------------------------------------------------------------------------- 1 | # https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/utils.py 2 | 3 | import os 4 | import re 5 | import requests 6 | 7 | import huggingface_hub 8 | from typing import List, Optional 9 | 10 | from ... import CACHE_DIR 11 | 12 | 13 | os.makedirs(f"{CACHE_DIR}/models", exist_ok=True) 14 | 15 | 16 | _MODELS = { 17 | "tiny.en": "Systran/faster-whisper-tiny.en", 18 | "tiny": "Systran/faster-whisper-tiny", 19 | "base.en": "Systran/faster-whisper-base.en", 20 | "base": "Systran/faster-whisper-base", 21 | "small.en": "Systran/faster-whisper-small.en", 22 | "small": "Systran/faster-whisper-small", 23 | "medium.en": "Systran/faster-whisper-medium.en", 24 | "medium": "Systran/faster-whisper-medium", 25 | "large-v1": "Systran/faster-whisper-large-v1", 26 | "large-v2": "Systran/faster-whisper-large-v2", 27 | "large-v3": "Systran/faster-whisper-large-v3", 28 | "large": "Systran/faster-whisper-large-v3", 29 | } 30 | 31 | 32 | def available_models() -> List[str]: 33 | """Returns the names of available models.""" 34 | return list(_MODELS.keys()) 35 | 36 | 37 | def download_model( 38 | size_or_id: str, 39 | output_dir: Optional[str] = None, 40 | local_files_only: bool = False, 41 | cache_dir: Optional[str] = None, 42 | ): 43 | """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. 44 | 45 | Args: 46 | size_or_id: Size of the model to download from https://huggingface.co/guillaumekln 47 | (tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2, 48 | large), or a CTranslate2-converted model ID from the Hugging Face Hub 49 | (e.g. guillaumekln/faster-whisper-large-v2). 50 | output_dir: Directory where the model should be saved. If not set, the model is saved in 51 | the cache directory. 52 | local_files_only: If True, avoid downloading the file and return the path to the local 53 | cached file if it exists. 54 | cache_dir: Path to the folder where cached files are stored. 55 | 56 | Returns: 57 | The path to the downloaded model. 58 | 59 | Raises: 60 | ValueError: if the model size is invalid. 61 | """ 62 | if re.match(r".*/.*", size_or_id): 63 | repo_id = size_or_id 64 | else: 65 | repo_id = _MODELS.get(size_or_id) 66 | if repo_id is None: 67 | raise ValueError( 68 | "Invalid model size '%s', expected one of: %s" 69 | % (size_or_id, ", ".join(_MODELS.keys())) 70 | ) 71 | 72 | allow_patterns = [ 73 | "config.json", 74 | "model.bin", 75 | "tokenizer.json", 76 | "vocabulary.*", 77 | ] 78 | 79 | kwargs = { 80 | "local_files_only": local_files_only, 81 | "allow_patterns": allow_patterns, 82 | } 83 | 84 | if output_dir is not None: 85 | kwargs["local_dir"] = output_dir 86 | kwargs["local_dir_use_symlinks"] = False 87 | 88 | if cache_dir is not None: 89 | kwargs["cache_dir"] = cache_dir 90 | else: 91 | kwargs["cache_dir"] = f"{CACHE_DIR}/models" 92 | 93 | try: 94 | return huggingface_hub.snapshot_download(repo_id, **kwargs) 95 | except ( 96 | huggingface_hub.utils.HfHubHTTPError, 97 | requests.exceptions.ConnectionError, 98 | ) as exception: 99 | print(exception) 100 | logger = get_logger() 101 | logger.warning( 102 | "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s", 103 | repo_id, 104 | exception, 105 | ) 106 | logger.warning( 107 | "Trying to load the model directly from the local cache, if it exists." 108 | ) 109 | 110 | kwargs["local_files_only"] = True 111 | return huggingface_hub.snapshot_download(repo_id, **kwargs) 112 | -------------------------------------------------------------------------------- /whisper_s2t/backends/ctranslate2/hf_utils.py: -------------------------------------------------------------------------------- 1 | # https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/utils.py 2 | 3 | import os 4 | import re 5 | import requests 6 | 7 | import huggingface_hub 8 | from typing import List, Optional 9 | 10 | from ... import CACHE_DIR 11 | 12 | 13 | os.makedirs(f"{CACHE_DIR}/models", exist_ok=True) 14 | 15 | 16 | _MODELS = { 17 | "tiny.en": "Systran/faster-whisper-tiny.en", 18 | "tiny": "Systran/faster-whisper-tiny", 19 | "base.en": "Systran/faster-whisper-base.en", 20 | "base": "Systran/faster-whisper-base", 21 | "small.en": "Systran/faster-whisper-small.en", 22 | "small": "Systran/faster-whisper-small", 23 | "medium.en": "Systran/faster-whisper-medium.en", 24 | "medium": "Systran/faster-whisper-medium", 25 | "large-v1": "Systran/faster-whisper-large-v1", 26 | "large-v2": "Systran/faster-whisper-large-v2", 27 | "large-v3": "Systran/faster-whisper-large-v3", 28 | "large": "Systran/faster-whisper-large-v3", 29 | } 30 | 31 | 32 | def available_models() -> List[str]: 33 | """Returns the names of available models.""" 34 | return list(_MODELS.keys()) 35 | 36 | 37 | def download_model( 38 | size_or_id: str, 39 | output_dir: Optional[str] = None, 40 | local_files_only: bool = False, 41 | cache_dir: Optional[str] = None, 42 | ): 43 | """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. 44 | 45 | Args: 46 | size_or_id: Size of the model to download from https://huggingface.co/guillaumekln 47 | (tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2, 48 | large), or a CTranslate2-converted model ID from the Hugging Face Hub 49 | (e.g. guillaumekln/faster-whisper-large-v2). 50 | output_dir: Directory where the model should be saved. If not set, the model is saved in 51 | the cache directory. 52 | local_files_only: If True, avoid downloading the file and return the path to the local 53 | cached file if it exists. 54 | cache_dir: Path to the folder where cached files are stored. 55 | 56 | Returns: 57 | The path to the downloaded model. 58 | 59 | Raises: 60 | ValueError: if the model size is invalid. 61 | """ 62 | if re.match(r".*/.*", size_or_id): 63 | repo_id = size_or_id 64 | else: 65 | repo_id = _MODELS.get(size_or_id) 66 | if repo_id is None: 67 | raise ValueError( 68 | "Invalid model size '%s', expected one of: %s" 69 | % (size_or_id, ", ".join(_MODELS.keys())) 70 | ) 71 | 72 | allow_patterns = [ 73 | "config.json", 74 | "model.bin", 75 | "tokenizer.json", 76 | "vocabulary.*", 77 | ] 78 | 79 | kwargs = { 80 | "local_files_only": local_files_only, 81 | "allow_patterns": allow_patterns, 82 | } 83 | 84 | if output_dir is not None: 85 | kwargs["local_dir"] = output_dir 86 | kwargs["local_dir_use_symlinks"] = False 87 | 88 | if cache_dir is not None: 89 | kwargs["cache_dir"] = cache_dir 90 | else: 91 | kwargs["cache_dir"] = f"{CACHE_DIR}/models" 92 | 93 | try: 94 | return huggingface_hub.snapshot_download(repo_id, **kwargs) 95 | except ( 96 | huggingface_hub.utils.HfHubHTTPError, 97 | requests.exceptions.ConnectionError, 98 | ) as exception: 99 | print(exception) 100 | logger = get_logger() 101 | logger.warning( 102 | "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s", 103 | repo_id, 104 | exception, 105 | ) 106 | logger.warning( 107 | "Trying to load the model directly from the local cache, if it exists." 108 | ) 109 | 110 | kwargs["local_files_only"] = True 111 | return huggingface_hub.snapshot_download(repo_id, **kwargs) 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Additional 2 | data 3 | temp 4 | results 5 | temp.ipynb 6 | .vscode 7 | push_to_pypy.sh 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /whisper_s2t/speech_segmenter/seg_vad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from . import VADBaseClass 6 | from .. import BASE_PATH 7 | 8 | 9 | class SegmentVAD(VADBaseClass): 10 | def __init__(self, 11 | device=None, 12 | win_len=0.32, 13 | win_step=0.08, 14 | batch_size=512, 15 | sampling_rate=16000): 16 | 17 | super().__init__(sampling_rate=sampling_rate) 18 | 19 | if device == None: 20 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | 22 | self.device = device 23 | 24 | if self.device == 'cpu': 25 | # This is a JIT Scripted model of Nvidia's NeMo Marblenet Model: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/vad_multilingual_marblenet 26 | self.vad_pp = torch.jit.load(os.path.join(BASE_PATH, "assets/vad_pp_cpu.ts")).to(self.device) 27 | self.vad_model = torch.jit.load(os.path.join(BASE_PATH, "assets/seg_vad_model_cpu.ts")).to(self.device) 28 | else: 29 | self.vad_pp = torch.jit.load(os.path.join(BASE_PATH, "assets/vad_pp_gpu.ts")).to(self.device) 30 | self.vad_model = torch.jit.load(os.path.join(BASE_PATH, "assets/seg_vad_model_gpu.ts")).to(self.device) 31 | 32 | self.vad_pp = torch.jit.load(os.path.join(BASE_PATH, "assets/vad_pp.ts")) 33 | self.vad_model = torch.jit.load(os.path.join(BASE_PATH, "assets/segment_vad_model.ts")) 34 | 35 | self.vad_model.eval() 36 | self.vad_model.to(self.device) 37 | 38 | self.vad_pp.eval() 39 | self.vad_pp.to(self.device) 40 | 41 | self.batch_size = batch_size 42 | self.win_len = win_len 43 | self.win_step = win_step 44 | 45 | self._init_params() 46 | 47 | def _init_params(self): 48 | self.signal_win_len = int(self.win_len*self.sampling_rate) 49 | self.signal_win_step = int(self.win_step*self.sampling_rate) 50 | 51 | def update_params(self, params={}): 52 | for key, value in params.items(): 53 | setattr(self, key, value) 54 | 55 | self._init_params() 56 | 57 | def prepare_input_batch(self, audio_signal): 58 | 59 | num_chunks = (self.signal_win_len//2+len(audio_signal))//self.signal_win_step 60 | if num_chunks < (self.signal_win_len//2+len(audio_signal))/self.signal_win_step: 61 | num_chunks += 1 62 | 63 | input_signal = np.zeros((num_chunks, self.signal_win_len), dtype=np.float32) 64 | input_signal_length = np.zeros(num_chunks, dtype=np.int64) 65 | 66 | chunk_idx = 0 67 | for idx in range(-1*self.signal_win_len//2, len(audio_signal), self.signal_win_step): 68 | s_idx = max(idx, 0) 69 | e_idx = min(idx + self.signal_win_len, len(audio_signal)) 70 | input_signal[chunk_idx][:e_idx-s_idx] = audio_signal[s_idx:e_idx] 71 | input_signal_length[chunk_idx] = e_idx-s_idx 72 | chunk_idx += 1 73 | 74 | return input_signal, input_signal_length 75 | 76 | @torch.cuda.amp.autocast() 77 | @torch.no_grad() 78 | def forward(self, input_signal, input_signal_length): 79 | x, x_len = self.vad_pp(torch.Tensor(input_signal).to(self.device), 80 | torch.Tensor(input_signal_length).to(self.device)) 81 | logits = self.vad_model(x, x_len) 82 | logits = torch.softmax(logits, dim=-1) 83 | return logits[:, 1].detach().cpu().numpy() 84 | 85 | def __call__(self, audio_signal): 86 | 87 | audio_duration = len(audio_signal)/self.sampling_rate 88 | 89 | input_signal, input_signal_length = self.prepare_input_batch(audio_signal) 90 | 91 | speech_probs = np.zeros(len(input_signal)) 92 | for s_idx in range(0, len(input_signal), self.batch_size): 93 | speech_probs[s_idx:s_idx+self.batch_size] = self.forward(input_signal=input_signal[s_idx:s_idx+self.batch_size], 94 | input_signal_length=input_signal_length[s_idx:s_idx+self.batch_size]) 95 | 96 | vad_times = [] 97 | for idx, prob in enumerate(speech_probs): 98 | s_time = max(0, (idx-0.5)*self.win_step) 99 | e_time = min(audio_duration, (idx+0.5)*self.win_step) 100 | vad_times.append([prob, s_time, e_time]) 101 | 102 | return np.array(vad_times) 103 | -------------------------------------------------------------------------------- /whisper_s2t/speech_segmenter/frame_vad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from . import VADBaseClass 6 | from .. import BASE_PATH 7 | 8 | 9 | class FrameVAD(VADBaseClass): 10 | def __init__(self, 11 | device=None, 12 | chunk_size=15.0, 13 | margin_size=1.0, 14 | frame_size=0.02, 15 | batch_size=4, 16 | sampling_rate=16000): 17 | 18 | super().__init__(sampling_rate=sampling_rate) 19 | 20 | if device == None: 21 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 22 | 23 | self.device = device 24 | 25 | if self.device == 'cpu': 26 | # This is a JIT Scripted model of Nvidia's NeMo Framewise Marblenet Model: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/vad_multilingual_frame_marblenet 27 | self.vad_pp = torch.jit.load(os.path.join(BASE_PATH, "assets/vad_pp_cpu.ts")).to(self.device) 28 | self.vad_model = torch.jit.load(os.path.join(BASE_PATH, "assets/frame_vad_model_cpu.ts")).to(self.device) 29 | else: 30 | self.vad_pp = torch.jit.load(os.path.join(BASE_PATH, "assets/vad_pp_gpu.ts")).to(self.device) 31 | self.vad_model = torch.jit.load(os.path.join(BASE_PATH, "assets/frame_vad_model_gpu.ts")).to(self.device) 32 | 33 | self.vad_pp.eval() 34 | self.vad_model.eval() 35 | 36 | self.batch_size = batch_size 37 | self.frame_size = frame_size 38 | self.chunk_size = chunk_size 39 | self.margin_size = margin_size 40 | 41 | self._init_params() 42 | 43 | def _init_params(self): 44 | self.signal_chunk_len = int(self.chunk_size*self.sampling_rate) 45 | self.signal_stride = int(self.signal_chunk_len-2*int(self.margin_size*self.sampling_rate)) 46 | 47 | self.margin_logit_len = int(self.margin_size/self.frame_size) 48 | self.signal_to_logit_len = int(self.frame_size*self.sampling_rate) 49 | 50 | self.vad_pp.to(self.device) 51 | self.vad_model.to(self.device) 52 | 53 | def update_params(self, params={}): 54 | for key, value in params.items(): 55 | setattr(self, key, value) 56 | 57 | self._init_params() 58 | 59 | def prepare_input_batch(self, audio_signal): 60 | input_signal = [] 61 | input_signal_length = [] 62 | for s_idx in range(0, len(audio_signal), self.signal_stride): 63 | _signal = audio_signal[s_idx:s_idx+self.signal_chunk_len] 64 | _signal_len = len(_signal) 65 | input_signal.append(_signal) 66 | input_signal_length.append(_signal_len) 67 | 68 | if _signal_len < self.signal_chunk_len: 69 | input_signal[-1] = np.pad(input_signal[-1], (0, self.signal_chunk_len-_signal_len)) 70 | break 71 | 72 | return input_signal, input_signal_length 73 | 74 | @torch.cuda.amp.autocast() 75 | @torch.no_grad() 76 | def forward(self, input_signal, input_signal_length): 77 | 78 | all_logits = [] 79 | for s_idx in range(0, len(input_signal), self.batch_size): 80 | input_signal_pt = torch.stack([torch.tensor(_, device=self.device) for _ in input_signal[s_idx:s_idx+self.batch_size]]) 81 | input_signal_length_pt = torch.tensor(input_signal_length[s_idx:s_idx+self.batch_size], device=self.device) 82 | 83 | x, x_len = self.vad_pp(input_signal_pt, input_signal_length_pt) 84 | logits = self.vad_model(x, x_len) 85 | 86 | for _logits, _len in zip(logits, input_signal_length_pt): 87 | all_logits.append(_logits[:int(_len/self.signal_to_logit_len)]) 88 | 89 | if len(all_logits) > 1 and self.margin_logit_len > 0: 90 | all_logits[0] = all_logits[0][:-self.margin_logit_len] 91 | all_logits[-1] = all_logits[-1][self.margin_logit_len:] 92 | 93 | for i in range(1, len(all_logits)-1): 94 | all_logits[i] = all_logits[i][self.margin_logit_len:-self.margin_logit_len] 95 | 96 | all_logits = torch.concatenate(all_logits) 97 | all_logits = torch.softmax(all_logits, dim=-1) 98 | 99 | return all_logits[:, 1].detach().cpu().numpy() 100 | 101 | def __call__(self, audio_signal): 102 | audio_duration = len(audio_signal)/self.sampling_rate 103 | 104 | input_signal, input_signal_length = self.prepare_input_batch(audio_signal) 105 | speech_probs = self.forward(input_signal, input_signal_length) 106 | 107 | vad_times = [] 108 | for idx, prob in enumerate(speech_probs): 109 | s_time = idx*self.frame_size 110 | e_time = min(audio_duration, (idx+1)*self.frame_size) 111 | 112 | if s_time >= e_time: break 113 | 114 | vad_times.append([prob, s_time, e_time]) 115 | 116 | return np.array(vad_times) 117 | -------------------------------------------------------------------------------- /scripts/benchmark_whisper_s2t_trt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_arguments(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--repo_path', default="", type=str) 6 | parser.add_argument('--batch_size', default=16, type=int) 7 | parser.add_argument('--eval_mp3', default="yes", type=str) 8 | parser.add_argument('--eval_multilingual', default="yes", type=str) 9 | args = parser.parse_args() 10 | return args 11 | 12 | def run(repo_path, batch_size=16, eval_mp3=True, eval_multilingual=True): 13 | import sys, time, os 14 | 15 | if len(repo_path): 16 | sys.path.append(repo_path) 17 | 18 | import whisper_s2t 19 | from whisper_s2t.backends.tensorrt.engine_builder import TRTBuilderConfig 20 | 21 | import pandas as pd 22 | 23 | results_dir = f"{repo_path}/results/WhisperS2T-TensorRT-LLM-bs_{batch_size}" 24 | os.makedirs(results_dir, exist_ok=True) 25 | 26 | trt_build_args = TRTBuilderConfig( 27 | max_batch_size=batch_size, 28 | max_output_len=448 29 | ) 30 | 31 | model_kwargs = { 32 | 'trt_build_args': trt_build_args, 33 | } 34 | 35 | model = whisper_s2t.load_model(model_identifier="large-v2", backend='TensorRT-LLM', **model_kwargs) 36 | 37 | # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 38 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") 39 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 40 | lang_codes = len(files)*['en'] 41 | tasks = len(files)*['transcribe'] 42 | initial_prompts = len(files)*[None] 43 | 44 | _ = model.transcribe_with_vad(files, 45 | lang_codes=lang_codes, 46 | tasks=tasks, 47 | initial_prompts=initial_prompts, 48 | batch_size=batch_size) 49 | 50 | st = time.time() 51 | out = model.transcribe_with_vad(files, 52 | lang_codes=lang_codes, 53 | tasks=tasks, 54 | initial_prompts=initial_prompts, 55 | batch_size=batch_size) 56 | time_kincaid46_wav = time.time()-st 57 | 58 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 59 | data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) 60 | 61 | 62 | # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 63 | if eval_mp3: 64 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") 65 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 66 | lang_codes = len(files)*['en'] 67 | tasks = len(files)*['transcribe'] 68 | initial_prompts = len(files)*[None] 69 | 70 | st = time.time() 71 | out = model.transcribe_with_vad(files, 72 | lang_codes=lang_codes, 73 | tasks=tasks, 74 | initial_prompts=initial_prompts, 75 | batch_size=batch_size) 76 | time_kincaid46_mp3 = time.time()-st 77 | 78 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 79 | data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) 80 | else: 81 | time_kincaid46_mp3 = 0.0 82 | 83 | 84 | # MultiLingualLongform 85 | if eval_multilingual: 86 | data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") 87 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 88 | lang_codes = data['lang_code'].to_list() 89 | tasks = len(files)*['transcribe'] 90 | initial_prompts = len(files)*[None] 91 | 92 | st = time.time() 93 | out = model.transcribe_with_vad(files, 94 | lang_codes=lang_codes, 95 | tasks=tasks, 96 | initial_prompts=initial_prompts, 97 | batch_size=batch_size) 98 | time_multilingual = time.time()-st 99 | 100 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 101 | data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) 102 | else: 103 | time_multilingual = 0.0 104 | 105 | infer_time = [ 106 | ["Dataset", "Time"], 107 | ["KINCAID46 WAV", time_kincaid46_wav], 108 | ["KINCAID46 MP3", time_kincaid46_mp3], 109 | ["MultiLingualLongform", time_multilingual] 110 | ] 111 | infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) 112 | infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) 113 | 114 | 115 | if __name__ == '__main__': 116 | args = parse_arguments() 117 | eval_mp3 = True if args.eval_mp3 == "yes" else False 118 | eval_multilingual = True if args.eval_multilingual == "yes" else False 119 | run(args.repo_path, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual) -------------------------------------------------------------------------------- /whisper_s2t/speech_segmenter/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | 4 | from ..audio import load_audio 5 | 6 | 7 | class VADBaseClass(ABC): 8 | def __init__(self, sampling_rate=16000): 9 | self.sampling_rate = sampling_rate 10 | 11 | @abstractmethod 12 | def update_params(self, params={}): 13 | pass 14 | 15 | @abstractmethod 16 | def __call__(self, audio_signal, batch_size=4): 17 | pass 18 | 19 | 20 | class SpeechSegmenter: 21 | def __init__(self, vad_model=None, 22 | device=None, 23 | frame_size=0.02, 24 | min_seg_len=0.08, 25 | max_seg_len=29.0, 26 | max_silent_region=0.6, 27 | padding=0.2, 28 | eos_thresh=0.3, 29 | bos_thresh=0.3, 30 | cut_factor=2, 31 | sampling_rate=16000): 32 | 33 | if vad_model is None: 34 | from .frame_vad import FrameVAD 35 | vad_model = FrameVAD(device=device) 36 | 37 | self.vad_model = vad_model 38 | 39 | self.sampling_rate = sampling_rate 40 | self.padding = padding 41 | self.frame_size = frame_size 42 | self.min_seg_len = min_seg_len 43 | self.max_seg_len = max_seg_len 44 | self.max_silent_region = max_silent_region 45 | 46 | self.eos_thresh = eos_thresh 47 | self.bos_thresh = bos_thresh 48 | 49 | self.cut_factor = cut_factor 50 | self.cut_idx = int(self.max_seg_len/(self.cut_factor*self.frame_size)) 51 | self.max_idx_in_seg = self.cut_factor*self.cut_idx 52 | 53 | def update_params(self, params={}): 54 | for key, value in params.items(): 55 | setattr(self, key, value) 56 | 57 | self.cut_idx = int(self.max_seg_len/(self.cut_factor*self.frame_size)) 58 | self.max_idx_in_seg = self.cut_factor*self.cut_idx 59 | 60 | def update_vad_model_params(self, params={}): 61 | self.vad_model.update_params(params=params) 62 | 63 | def okay_to_merge(self, speech_probs, last_seg, curr_seg): 64 | conditions = [ 65 | (speech_probs[curr_seg['start']][1]-speech_probs[last_seg['end']][2]) < self.max_silent_region, 66 | (speech_probs[curr_seg['end']][2]-speech_probs[last_seg['start']][1]) <= self.max_seg_len, 67 | ] 68 | 69 | return all(conditions) 70 | 71 | def get_speech_segments(self, speech_probs): 72 | 73 | speech_flag, start_idx = False, 0 74 | speech_segments = [] 75 | for idx, (speech_prob, st, et) in enumerate(speech_probs): 76 | if speech_flag: 77 | if speech_prob < self.eos_thresh: 78 | speech_flag = False 79 | curr_seg = {'start': start_idx, 'end': idx-1} 80 | 81 | if len(speech_segments) and self.okay_to_merge(speech_probs, speech_segments[-1], curr_seg): 82 | speech_segments[-1]['end'] = curr_seg['end'] 83 | else: 84 | speech_segments.append(curr_seg) 85 | 86 | elif speech_prob >= self.bos_thresh: 87 | speech_flag = True 88 | start_idx = idx 89 | 90 | if speech_flag: 91 | curr_seg = {'start': start_idx, 'end': len(speech_probs)-1} 92 | 93 | if len(speech_segments) and self.okay_to_merge(speech_probs, speech_segments[-1], curr_seg): 94 | speech_segments[-1]['end'] = curr_seg['end'] 95 | else: 96 | speech_segments.append(curr_seg) 97 | 98 | speech_segments = [_ for _ in speech_segments if (speech_probs[_['end']][2]-speech_probs[_['start']][1]) > self.min_seg_len] 99 | 100 | start_ends = [] 101 | for _ in speech_segments: 102 | first_idx = len(start_ends) 103 | start_idx, end_idx = _['start'], _['end'] 104 | while (end_idx-start_idx) > self.max_idx_in_seg: 105 | _start_idx = int(start_idx + self.cut_idx) 106 | _end_idx = int(min(end_idx, start_idx + self.max_idx_in_seg)) 107 | 108 | new_end_idx = _start_idx+np.argmin(speech_probs[_start_idx:_end_idx, 0]) 109 | start_ends.append([speech_probs[start_idx][1], speech_probs[new_end_idx][2]]) 110 | start_idx = new_end_idx+1 111 | 112 | start_ends.append([speech_probs[start_idx][1], speech_probs[end_idx][2]+self.padding]) 113 | start_ends[first_idx][0] = start_ends[first_idx][0]-self.padding 114 | 115 | return start_ends 116 | 117 | 118 | def __call__(self, input_file=None, audio_signal=None): 119 | if audio_signal is None: 120 | audio_signal, audio_duration = load_audio(input_file, sr=self.sampling_rate, return_duration=True) 121 | else: 122 | audio_duration = len(audio_signal)/self.sampling_rate 123 | 124 | speech_probs = self.vad_model(audio_signal) 125 | start_ends = self.get_speech_segments(speech_probs) 126 | 127 | if len(start_ends) == 0: 128 | start_ends = [[0.0, self.max_seg_len]] # Quick fix for silent audio. 129 | 130 | start_ends[0][0] = max(0.0, start_ends[0][0]) # fix edges 131 | start_ends[-1][1] = min(audio_duration, start_ends[-1][1]) # fix edges 132 | 133 | return start_ends, audio_signal -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/engine_builder/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import hashlib 3 | import os 4 | from pynvml import * 5 | 6 | from rich.console import Console 7 | console = Console() 8 | 9 | from .download_utils import SAVE_DIR, download_model 10 | from ....utils import RunningStatus 11 | 12 | 13 | class TRTBuilderConfig: 14 | def __init__(self, 15 | max_batch_size=24, 16 | max_beam_width=1, 17 | max_input_len=4, 18 | max_output_len=448, 19 | world_size=1, 20 | dtype='float16', 21 | quantize_dir='quantize/1-gpu', 22 | use_gpt_attention_plugin='float16', 23 | use_bert_attention_plugin=None, 24 | use_context_fmha_enc=False, 25 | use_context_fmha_dec=False, 26 | use_gemm_plugin='float16', 27 | use_layernorm_plugin=False, 28 | remove_input_padding=False, 29 | use_weight_only_enc=False, 30 | use_weight_only_dec=False, 31 | weight_only_precision='int8', 32 | int8_kv_cache=False, 33 | debug_mode=False, 34 | **kwargs, 35 | ): 36 | 37 | self.max_batch_size = max_batch_size 38 | self.max_beam_width = max_beam_width 39 | self.max_input_len = max_input_len 40 | self.max_output_len = max_output_len 41 | self.world_size = world_size 42 | self.dtype = dtype 43 | self.quantize_dir = quantize_dir 44 | self.use_gpt_attention_plugin = use_gpt_attention_plugin 45 | self.use_bert_attention_plugin = use_bert_attention_plugin 46 | self.use_context_fmha_enc = use_context_fmha_enc 47 | self.use_context_fmha_dec = use_context_fmha_dec 48 | self.use_gemm_plugin = use_gemm_plugin 49 | self.use_layernorm_plugin = use_layernorm_plugin 50 | self.remove_input_padding = remove_input_padding 51 | self.use_weight_only_enc = use_weight_only_enc 52 | self.use_weight_only_dec = use_weight_only_dec 53 | self.weight_only_precision = weight_only_precision 54 | self.int8_kv_cache = int8_kv_cache 55 | self.debug_mode = debug_mode 56 | 57 | nvmlInit() 58 | self.cuda_compute_capability = list(nvmlDeviceGetCudaComputeCapability(nvmlDeviceGetHandleByIndex(0))) 59 | nvmlShutdown() 60 | 61 | 62 | def identifier(self): 63 | params = vars(self) 64 | return hashlib.md5(json.dumps(params).encode()).hexdigest() 65 | 66 | 67 | def save_trt_build_configs(trt_build_args): 68 | with open(f'{trt_build_args.output_dir}/trt_build_args.json', 'w') as f: 69 | f.write(json.dumps(vars(trt_build_args))) 70 | 71 | 72 | def load_trt_build_config(output_dir): 73 | """ 74 | [TODO]: Add cuda_compute_capability verification check 75 | """ 76 | 77 | with open(f'{output_dir}/trt_build_args.json', 'r') as f: 78 | trt_build_configs = json.load(f) 79 | 80 | trt_build_args = TRTBuilderConfig(**trt_build_configs) 81 | trt_build_args.output_dir = trt_build_configs['output_dir'] 82 | trt_build_args.model_path = trt_build_configs['model_path'] 83 | 84 | return trt_build_args 85 | 86 | 87 | def build_trt_engine(model_name='large-v2', args=None, force=False, log_level='error'): 88 | 89 | if args is None: 90 | console.print(f"args is None, using default configs.") 91 | args = TRTBuilderConfig() 92 | 93 | args.output_dir = os.path.join(SAVE_DIR, model_name, args.identifier()) 94 | args.model_path, tokenizer_path = download_model(model_name) 95 | 96 | if force: 97 | console.print(f"'force' flag is 'True'. Removing previous build.") 98 | with RunningStatus("Cleaning", console=console): 99 | os.system(f"rm -rf '{args.output_dir}'") 100 | 101 | if not os.path.exists(args.output_dir): 102 | os.makedirs(args.output_dir) 103 | else: 104 | _files = os.listdir(args.output_dir) 105 | 106 | _failed_export = False 107 | for _req_files in ['tokenizer.json', 108 | 'trt_build_args.json', 109 | 'encoder_config.json', 110 | 'decoder_config.json', 111 | 'encoder.engine', 112 | 'decoder.engine']: 113 | 114 | if _req_files not in _files: 115 | _failed_export = True 116 | break 117 | 118 | if _failed_export: 119 | console.print(f"Export directory exists but seems like a failed export, regenerating the engine files.") 120 | os.system(f"rm -rf '{args.output_dir}'") 121 | os.makedirs(args.output_dir) 122 | else: 123 | return args.output_dir 124 | 125 | os.system(f"cp '{tokenizer_path}' '{args.output_dir}/tokenizer.json'") 126 | save_trt_build_configs(args) 127 | 128 | with RunningStatus("Exporting Model To TensorRT Engine (3-6 mins)", console=console): 129 | out_logs = os.popen(f"python3 -m whisper_s2t.backends.tensorrt.engine_builder.builder --output_dir='{args.output_dir}' --log_level='{log_level}'").read().split("\n") 130 | print_flag = False 131 | for line in out_logs: 132 | if print_flag: 133 | console.print(line) 134 | elif 'TRTBuilderConfig' in line: 135 | print_flag = True 136 | console.print("[TRTBuilderConfig]:") 137 | 138 | return args.output_dir -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import cached_property 3 | 4 | from ... import BASE_PATH 5 | 6 | 7 | _TASKS = ( 8 | "transcribe", 9 | "translate", 10 | ) 11 | 12 | 13 | with open(os.path.join(BASE_PATH, "assets/lang_codes.txt"), 'r') as f: 14 | _LANGUAGE_CODES = [_ for _ in f.read().split("\n") if _] 15 | 16 | 17 | class Tokenizer: 18 | def __init__(self, tokenizer, multilingual): 19 | 20 | self.tokenizer = tokenizer 21 | self.multilingual = multilingual 22 | 23 | if self.multilingual: 24 | self.task_to_token_id = {task: self.tokenizer.token_to_id(f"<|{task}|>") for task in _TASKS} 25 | self.lang_code_to_token_id = {lang: self.tokenizer.token_to_id(f"<|{lang}|>") for lang in _LANGUAGE_CODES} 26 | else: 27 | self.task_to_token_id = None 28 | self.lang_code_to_token_id = None 29 | 30 | @cached_property 31 | def transcribe(self) -> int: 32 | return self.tokenizer.token_to_id("<|transcribe|>") 33 | 34 | @cached_property 35 | def translate(self) -> int: 36 | return self.tokenizer.token_to_id("<|translate|>") 37 | 38 | @cached_property 39 | def silent_token(self) -> int: 40 | return self.encode(" ")[0] 41 | 42 | @cached_property 43 | def sot(self) -> int: 44 | return self.tokenizer.token_to_id("<|startoftranscript|>") 45 | 46 | @cached_property 47 | def sot_lm(self) -> int: 48 | return self.tokenizer.token_to_id("<|startoflm|>") 49 | 50 | @cached_property 51 | def sot_prev(self) -> int: 52 | return self.tokenizer.token_to_id("<|startofprev|>") 53 | 54 | @cached_property 55 | def eot(self) -> int: 56 | return self.tokenizer.token_to_id("<|endoftext|>") 57 | 58 | @cached_property 59 | def no_timestamps(self) -> int: 60 | return self.tokenizer.token_to_id("<|notimestamps|>") 61 | 62 | @property 63 | def timestamp_begin(self) -> int: 64 | return self.no_timestamps + 1 65 | 66 | def sot_sequence(self, task=None, lang=None): 67 | sequence = [self.sot] 68 | 69 | if self.multilingual: 70 | sequence.append(self.lang_code_to_token_id[lang]) 71 | sequence.append(self.task_to_token_id[task]) 72 | 73 | return sequence 74 | 75 | def encode(self, text): 76 | return self.tokenizer.encode(text, add_special_tokens=False).ids 77 | 78 | def decode(self, tokens): 79 | text_tokens = [token for token in tokens if token < self.eot] 80 | return self.tokenizer.decode(text_tokens) 81 | 82 | def decode_batch(self, tokens): 83 | res = [] 84 | for tk in tokens: 85 | res.append([token for token in tk if token < self.eot]) 86 | 87 | return self.tokenizer.decode_batch(res) 88 | 89 | def split_tokens_on_unicode(self, text, tokens): 90 | replacement_char = "\ufffd" 91 | 92 | subwords, subword_tokens_list, current_tokens = [], [], [] 93 | unicode_offset, word_finished = 0, False 94 | 95 | for token in tokens: 96 | current_tokens.append(token) 97 | decoded = self.decode(current_tokens) 98 | 99 | try: 100 | replacement_char_index = decoded.index(replacement_char) + unicode_offset 101 | if (replacement_char_index < len(text)) and (text[replacement_char_index] == replacement_char): 102 | word_finished = True 103 | except ValueError: 104 | word_finished = True 105 | 106 | if word_finished: 107 | subwords.append(decoded) 108 | subword_tokens_list.append(current_tokens) 109 | 110 | current_tokens = [] 111 | word_finished = False 112 | unicode_offset += len(decoded) 113 | 114 | return subwords, subword_tokens_list 115 | 116 | def split_tokens_on_spaces(self, text, tokens): 117 | subwords, subword_tokens_list = self.split_tokens_on_unicode(text, tokens) 118 | words = [] 119 | word_tokens = [] 120 | 121 | for subword, subword_tokens in zip(subwords, subword_tokens_list): 122 | conditions = [ 123 | subword_tokens[0] >= self.eot, # special 124 | subword.startswith(" "), # with_space 125 | # subword.strip() in string.punctuation, # punctuation 126 | len(words) == 0 127 | ] 128 | 129 | if any(conditions): 130 | words.append(subword.strip()) 131 | word_tokens.append(subword_tokens) 132 | else: 133 | words[-1] = words[-1] + subword 134 | word_tokens[-1].extend(subword_tokens) 135 | 136 | return words, word_tokens 137 | 138 | def split_to_word_tokens(self, text, tokens, lang_code): 139 | if lang_code in {"zh", "ja", "th", "lo", "my", "yue"}: 140 | # These languages don't typically use spaces, so it is difficult to split words 141 | # without morpheme analysis. Here, we instead split words at any 142 | # position where the tokens are decoded as valid unicode points 143 | return self.split_tokens_on_unicode(text, tokens) 144 | 145 | return self.split_tokens_on_spaces(text, tokens) 146 | 147 | def split_to_word_tokens_batch(self, text_batch, tokens_batch, lang_code_batch): 148 | res = [] 149 | for text, tokens, lang_code in zip(text_batch, tokens_batch, lang_code_batch): 150 | res.append(self.split_to_word_tokens(text, tokens, lang_code)) 151 | 152 | return res -------------------------------------------------------------------------------- /whisper_s2t/backends/ctranslate2/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import cached_property 3 | 4 | from ... import BASE_PATH 5 | 6 | 7 | _TASKS = ( 8 | "transcribe", 9 | "translate", 10 | ) 11 | 12 | 13 | with open(os.path.join(BASE_PATH, "assets/lang_codes.txt"), 'r') as f: 14 | _LANGUAGE_CODES = [_ for _ in f.read().split("\n") if _] 15 | 16 | 17 | class Tokenizer: 18 | def __init__(self, tokenizer, multilingual): 19 | 20 | self.tokenizer = tokenizer 21 | self.multilingual = multilingual 22 | 23 | if self.multilingual: 24 | self.task_to_token_id = {task: self.tokenizer.token_to_id(f"<|{task}|>") for task in _TASKS} 25 | self.lang_code_to_token_id = {lang: self.tokenizer.token_to_id(f"<|{lang}|>") for lang in _LANGUAGE_CODES} 26 | else: 27 | self.task_to_token_id = None 28 | self.lang_code_to_token_id = None 29 | 30 | @cached_property 31 | def transcribe(self) -> int: 32 | return self.tokenizer.token_to_id("<|transcribe|>") 33 | 34 | @cached_property 35 | def translate(self) -> int: 36 | return self.tokenizer.token_to_id("<|translate|>") 37 | 38 | @cached_property 39 | def silent_token(self) -> int: 40 | return self.encode(" ")[0] 41 | 42 | @cached_property 43 | def sot(self) -> int: 44 | return self.tokenizer.token_to_id("<|startoftranscript|>") 45 | 46 | @cached_property 47 | def sot_lm(self) -> int: 48 | return self.tokenizer.token_to_id("<|startoflm|>") 49 | 50 | @cached_property 51 | def sot_prev(self) -> int: 52 | return self.tokenizer.token_to_id("<|startofprev|>") 53 | 54 | @cached_property 55 | def eot(self) -> int: 56 | return self.tokenizer.token_to_id("<|endoftext|>") 57 | 58 | @cached_property 59 | def no_timestamps(self) -> int: 60 | return self.tokenizer.token_to_id("<|notimestamps|>") 61 | 62 | @property 63 | def timestamp_begin(self) -> int: 64 | return self.no_timestamps + 1 65 | 66 | def sot_sequence(self, task=None, lang=None): 67 | sequence = [self.sot] 68 | 69 | if self.multilingual: 70 | sequence.append(self.lang_code_to_token_id[lang]) 71 | sequence.append(self.task_to_token_id[task]) 72 | 73 | return sequence 74 | 75 | def encode(self, text): 76 | return self.tokenizer.encode(text, add_special_tokens=False).ids 77 | 78 | def decode(self, tokens): 79 | text_tokens = [token for token in tokens if token < self.eot] 80 | return self.tokenizer.decode(text_tokens) 81 | 82 | def decode_batch(self, tokens): 83 | res = [] 84 | for tk in tokens: 85 | res.append([token for token in tk if token < self.eot]) 86 | 87 | return self.tokenizer.decode_batch(res) 88 | 89 | def split_tokens_on_unicode(self, text, tokens): 90 | replacement_char = "\ufffd" 91 | 92 | subwords, subword_tokens_list, current_tokens = [], [], [] 93 | unicode_offset, word_finished = 0, False 94 | 95 | for token in tokens: 96 | current_tokens.append(token) 97 | decoded = self.decode(current_tokens) 98 | 99 | try: 100 | replacement_char_index = decoded.index(replacement_char) + unicode_offset 101 | if (replacement_char_index < len(text)) and (text[replacement_char_index] == replacement_char): 102 | word_finished = True 103 | except ValueError: 104 | word_finished = True 105 | 106 | if word_finished: 107 | subwords.append(decoded) 108 | subword_tokens_list.append(current_tokens) 109 | 110 | current_tokens = [] 111 | word_finished = False 112 | unicode_offset += len(decoded) 113 | 114 | return subwords, subword_tokens_list 115 | 116 | def split_tokens_on_spaces(self, text, tokens): 117 | subwords, subword_tokens_list = self.split_tokens_on_unicode(text, tokens) 118 | words = [] 119 | word_tokens = [] 120 | 121 | for subword, subword_tokens in zip(subwords, subword_tokens_list): 122 | conditions = [ 123 | subword_tokens[0] >= self.eot, # special 124 | subword.startswith(" "), # with_space 125 | # subword.strip() in string.punctuation, # punctuation 126 | len(words) == 0 127 | ] 128 | 129 | if any(conditions): 130 | words.append(subword.strip()) 131 | word_tokens.append(subword_tokens) 132 | else: 133 | words[-1] = words[-1] + subword 134 | word_tokens[-1].extend(subword_tokens) 135 | 136 | return words, word_tokens 137 | 138 | def split_to_word_tokens(self, text, tokens, lang_code): 139 | if lang_code in {"zh", "ja", "th", "lo", "my", "yue"}: 140 | # These languages don't typically use spaces, so it is difficult to split words 141 | # without morpheme analysis. Here, we instead split words at any 142 | # position where the tokens are decoded as valid unicode points 143 | return self.split_tokens_on_unicode(text, tokens) 144 | 145 | return self.split_tokens_on_spaces(text, tokens) 146 | 147 | def split_to_word_tokens_batch(self, text_batch, tokens_batch, lang_code_batch): 148 | res = [] 149 | for text, tokens, lang_code in zip(text_batch, tokens_batch, lang_code_batch): 150 | res.append(self.split_to_word_tokens(text, tokens, lang_code)) 151 | 152 | return res -------------------------------------------------------------------------------- /tools/metrics.py: -------------------------------------------------------------------------------- 1 | import jiwer 2 | import editdistance 3 | import diff_match_patch 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | from nltk import ngrams 8 | from jiwer.transformations import wer_default 9 | 10 | 11 | def WER(references, hypotheses): 12 | measures = jiwer.compute_measures(references, 13 | hypotheses, 14 | truth_transform=wer_default, 15 | hypothesis_transform=wer_default) 16 | 17 | out = { 18 | 'WER': round(100*measures['wer'], 2), 19 | 'IER': round(100*measures['insertions']/(measures['hits']+measures['substitutions']+measures['deletions']), 2), 20 | 'DER': round(100*measures['deletions']/(measures['hits']+measures['substitutions']+measures['deletions']), 2), 21 | 'SER': round(100*measures['substitutions']/(measures['hits']+measures['substitutions']+measures['deletions']), 2) 22 | } 23 | 24 | return out 25 | 26 | def CER(references, hypotheses): 27 | errors = 0 28 | words = 0 29 | 30 | for h, r in zip(hypotheses, references): 31 | h_list = list(h) 32 | r_list = list(r) 33 | words += len(r_list) 34 | errors += editdistance.eval(h_list, r_list) 35 | 36 | return round(100*errors/words, 2) 37 | 38 | def NGramDuplicates(text, ngram_size=5): 39 | all_ngrams = list(ngrams(text.split(), ngram_size)) 40 | return len(all_ngrams)-len(set(all_ngrams)) 41 | 42 | def NGramInsertions(references, hypotheses, ngram_size=5): 43 | 44 | repeated_ngrams = 0 45 | for r, h in zip(references, hypotheses): 46 | all_ngrams = list(ngrams(r.split(), ngram_size)) 47 | ref_counts = {} 48 | for ngram in all_ngrams: 49 | try: 50 | ref_counts[ngram] += 1 51 | except: 52 | ref_counts[ngram] = 1 53 | 54 | all_ngrams = list(ngrams(h.split(), ngram_size)) 55 | hyp_counts = {} 56 | for ngram in all_ngrams: 57 | try: 58 | hyp_counts[ngram] += 1 59 | except: 60 | hyp_counts[ngram] = 1 61 | 62 | for k, v in hyp_counts.items(): 63 | if (v > 1) and (ref_counts.get(k, 1) < v): 64 | repeated_ngrams += (v-ref_counts.get(k, 1)) 65 | 66 | return repeated_ngrams 67 | 68 | def evaluate(references, hypotheses, cer=False, ngram_size=5): 69 | scores = WER(references, hypotheses) 70 | if cer: 71 | scores.update({'CER': CER(references, hypotheses), f'{ngram_size}-GramInsertions': NGramInsertions(references, hypotheses, ngram_size=ngram_size)}) 72 | else: 73 | scores.update({f'{ngram_size}-GramInsertions': NGramInsertions(references, hypotheses, ngram_size=ngram_size)}) 74 | 75 | return scores 76 | 77 | def word_alignment_accuracy_single(references, hypotheses, collar=0.2): 78 | # Find diffs between ref and hyp 79 | r_list = [_['word'].replace(" ", "_") for _ in references] 80 | h_list = [_['word'].replace(" ", "_") for _ in hypotheses] 81 | 82 | orig_words = '\n'.join(r_list) + '\n' 83 | pred_words = '\n'.join(h_list) + '\n' 84 | 85 | diff = diff_match_patch.diff_match_patch() 86 | diff.Diff_Timeout = 0 87 | orig_enc, pred_enc, enc = diff.diff_linesToChars(orig_words, pred_words) 88 | diffs = diff.diff_main(orig_enc, pred_enc, False) 89 | diff.diff_charsToLines(diffs, enc) 90 | 91 | diffs_post = [(d[0], d[1].replace('\n', ' ').strip().split()) for d in diffs] 92 | 93 | # Find words which got HIT and their matching 94 | r_idx, h_idx = 0, 0 95 | word_idx_match = {} 96 | for case, words in diffs_post: 97 | if case == -1: 98 | r_idx += len(words) 99 | elif case == 1: 100 | h_idx += len(words) 101 | else: 102 | for _ in words: 103 | word_idx_match[r_idx] = h_idx 104 | r_idx += 1 105 | h_idx += 1 106 | 107 | 108 | # Find words whose alignments overlap with each other 109 | overlapped_words = 0 110 | within_collar_words = 0 111 | for r_idx, h_idx in word_idx_match.items(): 112 | if (hypotheses[h_idx]['start']references[r_idx]['start']): 113 | overlapped_words += 1 114 | 115 | if (hypotheses[h_idx]['start']>=references[r_idx]['start']-collar) and (hypotheses[h_idx]['end']<=references[r_idx]['end']+collar): 116 | within_collar_words += 1 117 | 118 | 119 | results = { 120 | 'acc_overlapped': round(100*overlapped_words/len(word_idx_match), 2), 121 | 'acc_within_collar': round(100*within_collar_words/len(word_idx_match), 2), 122 | 'overlapped_words': overlapped_words, 123 | 'within_collar_words': within_collar_words, 124 | 'total_hit_words': len(word_idx_match), 125 | } 126 | 127 | return results 128 | 129 | def word_alignment_accuracy(references, hypotheses, collar=0.2): 130 | overlapped_words = 0 131 | within_collar_words = 0 132 | total_hit_words = 0 133 | 134 | for r, h in tqdm(zip(references, hypotheses), total=len(references)): 135 | res = word_alignment_accuracy_single(r, h, collar=collar) 136 | overlapped_words += res['overlapped_words'] 137 | within_collar_words += res['within_collar_words'] 138 | total_hit_words += res['total_hit_words'] 139 | 140 | results = { 141 | 'acc_overlapped': round(100*overlapped_words/total_hit_words, 2), 142 | 'acc_within_collar': round(100*within_collar_words/total_hit_words, 2), 143 | 'overlapped_words': overlapped_words, 144 | 'within_collar_words': within_collar_words, 145 | 'total_hit_words': total_hit_words, 146 | } 147 | 148 | return results -------------------------------------------------------------------------------- /whisper_s2t/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wave 3 | import tempfile 4 | import subprocess 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from multiprocessing.dummy import Pool 12 | 13 | from . import BASE_PATH 14 | from .configs import * 15 | 16 | silent_file = f"{BASE_PATH}/assets/silent.mp3" 17 | 18 | RESAMPLING_ENGINE = 'soxr' 19 | with tempfile.TemporaryDirectory() as tmpdir: 20 | ffmpeg_install_link = "https://github.com/shashikg/WhisperS2T?tab=readme-ov-file#for-ubuntu" 21 | 22 | try: 23 | subprocess.check_output(['ffmpeg', '-version']) 24 | except: 25 | raise RuntimeError(f"Seems 'ffmpeg' is not installed. Please install ffmpeg before using this package!\nCheck: {ffmpeg_install_link}") 26 | 27 | ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{silent_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar 1600 "{tmpdir}/tmp.wav" -y') 28 | 29 | if ret_code != 0: 30 | print(f"'ffmpeg' failed with soxr resampler, trying 'swr' resampler.") 31 | RESAMPLING_ENGINE = 'swr' 32 | 33 | ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{silent_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar 1600 "{tmpdir}/tmp.wav" -y') 34 | 35 | if ret_code != 0: 36 | raise RuntimeError(f"Seems 'ffmpeg' is not installed properly. Please uninstall and install it again.\nCheck: {ffmpeg_install_link}") 37 | else: 38 | print(f"Using 'swr' resampler. This may degrade performance.") 39 | 40 | 41 | def load_audio(input_file, sr=16000, return_duration=False): 42 | 43 | try: 44 | with wave.open(input_file, 'rb') as wf: 45 | if (wf.getframerate() != sr) or (wf.getnchannels() != 1): 46 | raise Exception("Not a 16kHz wav mono channel file!") 47 | 48 | frames = wf.getnframes() 49 | x = wf.readframes(int(frames)) 50 | except: 51 | with tempfile.TemporaryDirectory() as tmpdir: 52 | wav_file = f"{tmpdir}/tmp.wav" 53 | ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar {sr} "{wav_file}" -y') 54 | if ret_code != 0: raise RuntimeError("ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!") 55 | 56 | with wave.open(wav_file, 'rb') as wf: 57 | frames = wf.getnframes() 58 | x = wf.readframes(int(frames)) 59 | 60 | audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32)/32768.0 61 | audio_duration = len(audio_signal)/sr 62 | 63 | if return_duration: 64 | return audio_signal, audio_duration 65 | else: 66 | return audio_signal 67 | 68 | 69 | THREAD_POOL_AUDIO_LOADER = Pool(2) 70 | def audio_batch_generator(audio_files): 71 | return THREAD_POOL_AUDIO_LOADER.imap(load_audio, audio_files) 72 | 73 | 74 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 75 | """ 76 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 77 | """ 78 | 79 | if torch.is_tensor(array): 80 | if array.shape[axis] > length: 81 | array = array.index_select( 82 | dim=axis, index=torch.arange(length, device=array.device) 83 | ) 84 | 85 | if array.shape[axis] < length: 86 | pad_widths = [(0, 0)] * array.ndim 87 | pad_widths[axis] = (0, length - array.shape[axis]) 88 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 89 | else: 90 | if array.shape[axis] > length: 91 | array = array.take(indices=range(length), axis=axis) 92 | 93 | if array.shape[axis] < length: 94 | pad_widths = [(0, 0)] * array.ndim 95 | pad_widths[axis] = (0, length - array.shape[axis]) 96 | array = np.pad(array, pad_widths) 97 | 98 | return array 99 | 100 | 101 | class TorchSTFT(nn.Module): 102 | def __init__(self, n_fft, hop_length): 103 | super().__init__() 104 | 105 | self.n_fft = n_fft 106 | self.hop_length = hop_length 107 | 108 | window = torch.hann_window(n_fft) 109 | self.register_buffer("window", window) 110 | 111 | def forward(self, x): 112 | return torch.stft(x, self.n_fft, self.hop_length, window=self.window, return_complex=True) 113 | 114 | 115 | class LogMelSpectogram(nn.Module): 116 | def __init__(self, 117 | n_mels=N_MELS, 118 | n_fft=N_FFT, 119 | hop_length=HOP_LENGTH, 120 | padding=0): 121 | 122 | super().__init__() 123 | 124 | self.n_fft = n_fft 125 | self.n_mels = n_mels 126 | self.hop_length = hop_length 127 | self.padding = padding 128 | 129 | mel_filters = np.load(os.path.join(BASE_PATH, "assets/mel_filters.npz")) 130 | mel_filters = torch.from_numpy(mel_filters[f"mel_{n_mels}"]) 131 | self.register_buffer("mel_filters", mel_filters) 132 | 133 | self.stft = TorchSTFT(n_fft, hop_length) 134 | 135 | def get_seq_len(self, seq_len): 136 | seq_len = torch.floor(seq_len/self.hop_length) 137 | return seq_len.to(dtype=torch.long) 138 | 139 | @torch.no_grad() 140 | def forward(self, x, seq_len): 141 | 142 | seq_len = self.get_seq_len(seq_len.float()) 143 | 144 | if self.padding > 0: 145 | x = F.pad(x, (0, self.padding)) 146 | 147 | x = self.stft(x) 148 | 149 | x = x[..., :-1].abs()**2 150 | x = self.mel_filters@x # mels 151 | 152 | x = torch.clamp(x, min=1e-10).log10() # log_mels 153 | x = torch.maximum(x, torch.amax(x, dim=(1, 2), keepdims=True) - 8.0) # clip 154 | x = (x + 4.0) / 4.0 # scale 155 | 156 | return x, seq_len -------------------------------------------------------------------------------- /scripts/benchmark_whisper_s2t.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_arguments(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--repo_path', default="", type=str) 6 | parser.add_argument('--backend', default="CTranslate2", type=str) 7 | parser.add_argument('--batch_size', default=16, type=int) 8 | parser.add_argument('--flash_attention', default="no", type=str) 9 | parser.add_argument('--better_transformer', default="no", type=str) 10 | parser.add_argument('--eval_mp3', default="no", type=str) 11 | parser.add_argument('--eval_multilingual', default="yes", type=str) 12 | args = parser.parse_args() 13 | return args 14 | 15 | def run(repo_path, backend, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True): 16 | import sys, time, os 17 | 18 | if len(repo_path): 19 | sys.path.append(repo_path) 20 | 21 | import whisper_s2t 22 | import pandas as pd 23 | 24 | if backend.lower() in ["huggingface", "hf"]: 25 | asr_options = { 26 | "use_flash_attention": flash_attention, 27 | "use_better_transformer": better_transformer 28 | } 29 | 30 | if flash_attention: 31 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}-fa" 32 | elif better_transformer: 33 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}-bt" 34 | else: 35 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}" 36 | else: 37 | asr_options = {} 38 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}" 39 | 40 | os.makedirs(results_dir, exist_ok=True) 41 | 42 | model = whisper_s2t.load_model("large-v2", backend=backend, asr_options=asr_options) 43 | 44 | # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 45 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") 46 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 47 | lang_codes = len(files)*['en'] 48 | tasks = len(files)*['transcribe'] 49 | initial_prompts = len(files)*[None] 50 | 51 | _ = model.transcribe_with_vad(files, 52 | lang_codes=lang_codes, 53 | tasks=tasks, 54 | initial_prompts=initial_prompts, 55 | batch_size=batch_size) 56 | 57 | st = time.time() 58 | out = model.transcribe_with_vad(files, 59 | lang_codes=lang_codes, 60 | tasks=tasks, 61 | initial_prompts=initial_prompts, 62 | batch_size=batch_size) 63 | time_kincaid46_wav = time.time()-st 64 | 65 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 66 | data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) 67 | 68 | 69 | # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 70 | if eval_mp3: 71 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") 72 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 73 | lang_codes = len(files)*['en'] 74 | tasks = len(files)*['transcribe'] 75 | initial_prompts = len(files)*[None] 76 | 77 | st = time.time() 78 | out = model.transcribe_with_vad(files, 79 | lang_codes=lang_codes, 80 | tasks=tasks, 81 | initial_prompts=initial_prompts, 82 | batch_size=batch_size) 83 | time_kincaid46_mp3 = time.time()-st 84 | 85 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 86 | data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) 87 | else: 88 | time_kincaid46_mp3 = 0.0 89 | 90 | 91 | # MultiLingualLongform 92 | if eval_multilingual: 93 | data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") 94 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 95 | lang_codes = data['lang_code'].to_list() 96 | tasks = len(files)*['transcribe'] 97 | initial_prompts = len(files)*[None] 98 | 99 | st = time.time() 100 | out = model.transcribe_with_vad(files, 101 | lang_codes=lang_codes, 102 | tasks=tasks, 103 | initial_prompts=initial_prompts, 104 | batch_size=batch_size) 105 | time_multilingual = time.time()-st 106 | 107 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 108 | data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) 109 | else: 110 | time_multilingual = 0.0 111 | 112 | infer_time = [ 113 | ["Dataset", "Time"], 114 | ["KINCAID46 WAV", time_kincaid46_wav], 115 | ["KINCAID46 MP3", time_kincaid46_mp3], 116 | ["MultiLingualLongform", time_multilingual] 117 | ] 118 | infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) 119 | infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) 120 | 121 | 122 | if __name__ == '__main__': 123 | args = parse_arguments() 124 | eval_mp3 = True if args.eval_mp3 == "yes" else False 125 | eval_multilingual = True if args.eval_multilingual == "yes" else False 126 | flash_attention = True if args.flash_attention == "yes" else False 127 | better_transformer = True if args.better_transformer == "yes" else False 128 | 129 | run(args.repo_path, args.backend, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual) -------------------------------------------------------------------------------- /scripts/benchmark_whisper_s2t_distil.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_arguments(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--repo_path', default="", type=str) 6 | parser.add_argument('--backend', default="HuggingFace", type=str) 7 | parser.add_argument('--batch_size', default=16, type=int) 8 | parser.add_argument('--flash_attention', default="yes", type=str) 9 | parser.add_argument('--better_transformer', default="no", type=str) 10 | parser.add_argument('--eval_mp3', default="no", type=str) 11 | parser.add_argument('--eval_multilingual', default="no", type=str) 12 | args = parser.parse_args() 13 | return args 14 | 15 | def run(repo_path, backend, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True): 16 | import sys, time, os 17 | 18 | if len(repo_path): 19 | sys.path.append(repo_path) 20 | 21 | import whisper_s2t 22 | import pandas as pd 23 | 24 | if backend.lower() in ["huggingface", "hf"]: 25 | asr_options = { 26 | "use_flash_attention": flash_attention, 27 | "use_better_transformer": better_transformer 28 | } 29 | 30 | if flash_attention: 31 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-fa" 32 | elif better_transformer: 33 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-bt" 34 | else: 35 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}" 36 | else: 37 | asr_options = {} 38 | results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}" 39 | 40 | os.makedirs(results_dir, exist_ok=True) 41 | 42 | model = whisper_s2t.load_model("distil-large-v2", backend=backend, asr_options=asr_options) 43 | 44 | # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 45 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") 46 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 47 | lang_codes = len(files)*['en'] 48 | tasks = len(files)*['transcribe'] 49 | initial_prompts = len(files)*[None] 50 | 51 | _ = model.transcribe_with_vad(files, 52 | lang_codes=lang_codes, 53 | tasks=tasks, 54 | initial_prompts=initial_prompts, 55 | batch_size=batch_size) 56 | 57 | st = time.time() 58 | out = model.transcribe_with_vad(files, 59 | lang_codes=lang_codes, 60 | tasks=tasks, 61 | initial_prompts=initial_prompts, 62 | batch_size=batch_size) 63 | time_kincaid46_wav = time.time()-st 64 | 65 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 66 | data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) 67 | 68 | 69 | # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 70 | if eval_mp3: 71 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") 72 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 73 | lang_codes = len(files)*['en'] 74 | tasks = len(files)*['transcribe'] 75 | initial_prompts = len(files)*[None] 76 | 77 | st = time.time() 78 | out = model.transcribe_with_vad(files, 79 | lang_codes=lang_codes, 80 | tasks=tasks, 81 | initial_prompts=initial_prompts, 82 | batch_size=batch_size) 83 | time_kincaid46_mp3 = time.time()-st 84 | 85 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 86 | data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) 87 | else: 88 | time_kincaid46_mp3 = 0.0 89 | 90 | 91 | # MultiLingualLongform 92 | if eval_multilingual: 93 | data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") 94 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 95 | lang_codes = data['lang_code'].to_list() 96 | tasks = len(files)*['transcribe'] 97 | initial_prompts = len(files)*[None] 98 | 99 | st = time.time() 100 | out = model.transcribe_with_vad(files, 101 | lang_codes=lang_codes, 102 | tasks=tasks, 103 | initial_prompts=initial_prompts, 104 | batch_size=batch_size) 105 | time_multilingual = time.time()-st 106 | 107 | data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] 108 | data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) 109 | else: 110 | time_multilingual = 0.0 111 | 112 | infer_time = [ 113 | ["Dataset", "Time"], 114 | ["KINCAID46 WAV", time_kincaid46_wav], 115 | ["KINCAID46 MP3", time_kincaid46_mp3], 116 | ["MultiLingualLongform", time_multilingual] 117 | ] 118 | infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) 119 | infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) 120 | 121 | 122 | if __name__ == '__main__': 123 | args = parse_arguments() 124 | eval_mp3 = True if args.eval_mp3 == "yes" else False 125 | eval_multilingual = True if args.eval_multilingual == "yes" else False 126 | flash_attention = True if args.flash_attention == "yes" else False 127 | better_transformer = True if args.better_transformer == "yes" else False 128 | 129 | run(args.repo_path, args.backend, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual) -------------------------------------------------------------------------------- /whisper_s2t/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from rich.console import Console 4 | from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn 5 | 6 | 7 | class RunningStatus: 8 | def __init__(self, status_text, console=None): 9 | self.status_text = status_text 10 | if console: 11 | self.console = console 12 | else: 13 | self.console = Console() 14 | 15 | def __enter__(self): 16 | self.progress = Progress( 17 | SpinnerColumn(), 18 | *Progress.get_default_columns(), 19 | TimeElapsedColumn(), 20 | console=self.console, 21 | transient=False 22 | ) 23 | self.task = self.progress.add_task(f"{self.status_text}", total=None) 24 | self.progress.start() 25 | return self 26 | 27 | def __exit__(self, exc_type, exc_val, exc_tb): 28 | self.progress.update(self.task, advance=1.0) # Complete the progress bar 29 | self.progress.stop() # Stop the progress display 30 | 31 | 32 | def format_timestamp(seconds, always_include_hours=False, decimal_marker="."): 33 | 34 | assert seconds >= 0, "non-negative timestamp expected" 35 | 36 | milliseconds = round(seconds * 1000.0) 37 | 38 | hours = milliseconds // 3_600_000 39 | milliseconds -= hours * 3_600_000 40 | 41 | minutes = milliseconds // 60_000 42 | milliseconds -= minutes * 60_000 43 | 44 | seconds = milliseconds // 1_000 45 | milliseconds -= seconds * 1_000 46 | 47 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 48 | return ( 49 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 50 | ) 51 | 52 | 53 | def get_single_sentence_in_one_utterance(transcript, end_punct_marks=["?", "."]): 54 | if 'word_timestamps' not in transcript[0]: 55 | print(f"Word Timestamp not available, one utterance can have multiple sentences.") 56 | return transcript 57 | 58 | new_transcript = [] 59 | 60 | all_words = [] 61 | for utt in transcript: 62 | all_words += utt['word_timestamps'] 63 | 64 | curr_utt = [] 65 | for word in all_words: 66 | curr_utt.append(word) 67 | if len(word['word']) and word['word'][-1] in end_punct_marks: 68 | if len(curr_utt): 69 | new_transcript.append({ 70 | 'text': " ".join([_['word'] for _ in curr_utt]), 71 | 'start_time': curr_utt[0]['start'], 72 | 'end_time': curr_utt[-1]['end'] 73 | }) 74 | 75 | curr_utt = [] 76 | 77 | if len(curr_utt): 78 | new_transcript.append({ 79 | 'text': " ".join([_['word'] for _ in curr_utt]), 80 | 'start_time': curr_utt[0]['start'], 81 | 'end_time': curr_utt[-1]['end'] 82 | }) 83 | 84 | return new_transcript 85 | 86 | 87 | def ExportVTT(transcript, file, single_sentence_in_one_utterance=False, end_punct_marks=["?", "."]): 88 | 89 | if single_sentence_in_one_utterance: 90 | transcript = get_single_sentence_in_one_utterance(transcript, end_punct_marks=end_punct_marks) 91 | 92 | with open(file, 'w', encoding="utf-8") as f: 93 | f.write("WEBVTT\n\n") 94 | for _utt in transcript: 95 | f.write(f"{format_timestamp(_utt['start_time'])} --> {format_timestamp(_utt['end_time'])}\n{_utt['text']}\n\n") 96 | 97 | 98 | def ExportSRT(transcript, file, single_sentence_in_one_utterance=False, end_punct_marks=["?", "."]): 99 | 100 | if single_sentence_in_one_utterance: 101 | transcript = get_single_sentence_in_one_utterance(transcript, end_punct_marks=end_punct_marks) 102 | 103 | with open(file, 'w', encoding="utf-8") as f: 104 | for i, _utt in enumerate(transcript): 105 | 106 | f.write(f"{i}\n") 107 | f.write(f"{format_timestamp(_utt['start_time'], always_include_hours=True, decimal_marker=',')} --> ") 108 | f.write(f"{format_timestamp(_utt['end_time'], always_include_hours=True, decimal_marker=',')}\n") 109 | f.write(f"{_utt['text']}\n\n") 110 | 111 | 112 | def ExportJSON(transcript, file): 113 | 114 | with open(file, 'w', encoding="utf-8") as f: 115 | f.write(json.dumps(transcript)) 116 | 117 | 118 | def ExportTSV(transcript, file, single_sentence_in_one_utterance=False, end_punct_marks=["?", "."]): 119 | 120 | if single_sentence_in_one_utterance: 121 | transcript = get_single_sentence_in_one_utterance(transcript, end_punct_marks=end_punct_marks) 122 | 123 | keys = ['start_time', 'end_time', 'text'] 124 | if len(transcript): 125 | for k in transcript[0].keys(): 126 | if k not in keys: keys.append(k) 127 | 128 | with open(file, 'w', encoding="utf-8") as f: 129 | f.write("\t".join(keys)) 130 | for _utt in transcript: 131 | f.write("\n" + "\t".join([_utt[k].strip().replace("\t", " ") if k == 'text' else str(_utt[k]) for k in keys])) 132 | 133 | 134 | def ExportTXT(transcript, file, single_sentence_in_one_utterance=False, end_punct_marks=["?", "."]): 135 | 136 | if single_sentence_in_one_utterance: 137 | transcript = get_single_sentence_in_one_utterance(transcript, end_punct_marks=end_punct_marks) 138 | 139 | with open(file, 'w', encoding="utf-8") as f: 140 | for _utt in transcript: 141 | f.write(f"[{format_timestamp(_utt['start_time'])} --> {format_timestamp(_utt['end_time'])}]: {_utt['text']}\n\n") 142 | 143 | 144 | TranscriptExporter = { 145 | 'txt': ExportTXT, 146 | 'vtt': ExportVTT, 147 | 'srt': ExportSRT, 148 | 'tsv': ExportTSV, 149 | 'json': ExportJSON 150 | } 151 | 152 | 153 | def write_outputs(transcripts, format='json', ip_files=None, op_files=None, save_dir="./", **kwargs): 154 | if (op_files is None) or (len(op_files) != len(transcripts)): 155 | os.makedirs(save_dir, exist_ok=True) 156 | 157 | op_files = [] 158 | 159 | if (ip_files is None) or (len(ip_files) != len(transcripts)): 160 | for i in range(len(transcripts)): 161 | op_files.append(os.path.join(save_dir, f"{i}.{format}")) 162 | else: 163 | for i, _ip_fn in enumerate(ip_files): 164 | base_name = ".".join(os.path.basename(_ip_fn).split(".")[:-1]) 165 | op_files.append(os.path.join(save_dir, f"{i}_{base_name}.{format}")) 166 | 167 | 168 | for transcript, file_name in zip(transcripts, op_files): 169 | TranscriptExporter[format](transcript, file_name, **kwargs) -------------------------------------------------------------------------------- /tools/text_normalizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import contractions 5 | from cleantext import clean 6 | from nemo_text_processing.text_normalization.normalize import Normalizer 7 | 8 | 9 | def BlankFunction(text): 10 | return text 11 | 12 | class EnglishSpellingNormalizer: 13 | """ 14 | [Note]: Taken from OpenAI Whisper repo: https://github.com/openai/whisper/blob/main/whisper/normalizers/english.py#L450 15 | 16 | Applies British-American spelling mappings as listed in [1]. 17 | 18 | [1] https://www.tysto.com/uk-us-spelling-list.html 19 | """ 20 | 21 | def __init__(self): 22 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 23 | self.mapping = json.load(open(mapping_path)) 24 | 25 | def __call__(self, s: str): 26 | return " ".join(self.mapping.get(word, word) for word in s.split()) 27 | 28 | 29 | filler_words = ["um", "uh", "uhuh", "mhm", "ah", "mm", "mmm", "mn", "hmm", "hm", "huh", "heh", "eh", "eah", "ha", "aw", "ye", "ey", "aha", "ugh"] 30 | filler_words = {k: False for k in filler_words} 31 | def clean_text(text, replacers={}, remove_filler_words=True): 32 | text = text.lower() 33 | 34 | for pattern, replacement in replacers.items(): 35 | text = re.sub(pattern, replacement, text) 36 | 37 | text = contractions.fix(text) 38 | 39 | text = re.sub('(\.)(com|org|net|ai|su)', r' dot \2', text) # fix websites 40 | text = re.sub('(www\.)([a-z])', r'www dot \2', text) # fix websites 41 | text = re.sub('-|~|\[.+?\]|\(.+?\)|\{.+?\}|\$|\^|\+|\=|\>|\<|\|', ' ', text) # remove anything insider brackets [..] (..) 42 | 43 | text = re.sub('([a-z]/)([a-z])', r'\1 or \2', text) 44 | text = text.replace("½", "half") 45 | text = text.replace("¼", "quarter") 46 | text = text.replace(" ok ", " okay ") 47 | 48 | text = re.sub("\u2019|'", "APSTROPH", text) 49 | 50 | text = clean(text, 51 | fix_unicode=False, 52 | to_ascii=False, 53 | lower=False, 54 | no_line_breaks=True, 55 | no_punct=True, 56 | no_urls=False, 57 | no_emails=False, 58 | no_phone_numbers=False, 59 | no_numbers=False, 60 | no_digits=False, 61 | no_currency_symbols=False) 62 | 63 | if remove_filler_words: 64 | words = [] 65 | for w in text.split(" "): 66 | if filler_words.get(w, True) and len(w.strip()): 67 | words.append(w) 68 | 69 | text = " ".join(words) 70 | 71 | text = text.replace("APSTROPH", "'") 72 | text = text.lower() 73 | return text 74 | 75 | class TextNormalizer: 76 | def __init__(self, lang='en', remove_filler_words=True): 77 | """ 78 | [Note]: Taken from OpenAI Whisper repo: https://github.com/openai/whisper/blob/main/whisper/normalizers/english.py#L465 79 | """ 80 | 81 | self.standardize_spellings = BlankFunction 82 | self.normalizer = Normalizer(lang=lang, input_case='cased') 83 | self.remove_filler_words = remove_filler_words 84 | 85 | self.replacers = {} 86 | if lang=='en': 87 | self.standardize_spellings = EnglishSpellingNormalizer() 88 | self.replacers = { 89 | # common contractions 90 | r"\bwon't\b": "will not", 91 | r"\bcan't\b": "can not", 92 | r"\blet's\b": "let us", 93 | r"\bain't\b": "aint", 94 | r"\by'all\b": "you all", 95 | r"\bwanna\b": "want to", 96 | r"\bgotta\b": "got to", 97 | r"\bgonna\b": "going to", 98 | r"\bi'ma\b": "i am going to", 99 | r"\bimma\b": "i am going to", 100 | r"\bwoulda\b": "would have", 101 | r"\bcoulda\b": "could have", 102 | r"\bshoulda\b": "should have", 103 | r"\bma'am\b": "madam", 104 | # contractions in titles/prefixes 105 | r"\bmr\b": "mister ", 106 | r"\bmrs\b": "missus ", 107 | r"\bst\b": "saint ", 108 | r"\bdr\b": "doctor ", 109 | r"\bprof\b": "professor ", 110 | r"\bcapt\b": "captain ", 111 | r"\bgov\b": "governor ", 112 | r"\bald\b": "alderman ", 113 | r"\bgen\b": "general ", 114 | r"\bsen\b": "senator ", 115 | r"\brep\b": "representative ", 116 | r"\bpres\b": "president ", 117 | r"\brev\b": "reverend ", 118 | r"\bhon\b": "honorable ", 119 | r"\basst\b": "assistant ", 120 | r"\bassoc\b": "associate ", 121 | r"\blt\b": "lieutenant ", 122 | r"\bcol\b": "colonel ", 123 | r"\bjr\b": "junior ", 124 | r"\bsr\b": "senior ", 125 | r"\besq\b": "esquire ", 126 | # prefect tenses, ideally it should be any past participles, but it's harder.. 127 | r"'d been\b": " had been", 128 | r"'s been\b": " has been", 129 | r"'d gone\b": " had gone", 130 | r"'s gone\b": " has gone", 131 | r"'d done\b": " had done", # "'s done" is ambiguous 132 | r"'s got\b": " has got", 133 | # general contractions 134 | r"n't\b": " not", 135 | r"'re\b": " are", 136 | r"'s\b": " is", 137 | r"'d\b": " would", 138 | r"'ll\b": " will", 139 | r"'t\b": " not", 140 | r"'ve\b": " have", 141 | r"'m\b": " am", 142 | } 143 | 144 | def __call__(self, txt): 145 | norm_txt = [] 146 | for sent in txt.split(". "): 147 | norm_sent = [] 148 | for sub_sent in sent.split(", "): 149 | if len(re.sub('[0123456789]', '', sub_sent)) != len(sub_sent): 150 | sub_sent = self.normalizer.normalize(sub_sent, verbose=False).strip() 151 | 152 | if len(sub_sent): 153 | norm_sent.append(sub_sent.strip()) 154 | 155 | norm_txt.append(", ".join(norm_sent)) 156 | 157 | norm_txt = [self.standardize_spellings(clean_text(_, replacers=self.replacers, remove_filler_words=self.remove_filler_words)) for _ in norm_txt] 158 | 159 | return " ".join(norm_txt) -------------------------------------------------------------------------------- /scripts/benchmark_huggingface.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from rich.console import Console 3 | console = Console() 4 | 5 | def parse_arguments(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--repo_path', default="", type=str) 8 | parser.add_argument('--batch_size', default=16, type=int) 9 | parser.add_argument('--flash_attention', default="no", type=str) 10 | parser.add_argument('--better_transformer', default="no", type=str) 11 | parser.add_argument('--eval_mp3', default="no", type=str) 12 | parser.add_argument('--eval_multilingual', default="yes", type=str) 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def run(repo_path, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True): 18 | import torch 19 | import time, os 20 | import pandas as pd 21 | from transformers import pipeline 22 | 23 | # Load Model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 24 | model_kwargs = { 25 | "use_safetensors": True, 26 | "low_cpu_mem_usage": True 27 | } 28 | 29 | results_dir = f"{repo_path}/results/HuggingFace-bs_{batch_size}" 30 | 31 | if flash_attention: 32 | results_dir = f"{results_dir}-fa" 33 | model_kwargs["use_flash_attention_2"] = True 34 | 35 | ASR = pipeline("automatic-speech-recognition", 36 | "openai/whisper-large-v2", 37 | num_workers=1, 38 | torch_dtype=torch.float16, 39 | device="cuda", 40 | model_kwargs=model_kwargs) 41 | 42 | if (not flash_attention) and better_transformer: 43 | ASR.model = ASR.model.to_bettertransformer() 44 | results_dir = f"{results_dir}-bt" 45 | 46 | os.makedirs(results_dir, exist_ok=True) 47 | 48 | # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 49 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") 50 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 51 | 52 | with console.status("Warming"): 53 | st = time.time() 54 | _ = ASR(files, 55 | batch_size=batch_size, 56 | chunk_length_s=30, 57 | generate_kwargs={'num_beams': 1, 'language': 'en'}, 58 | return_timestamps=False) 59 | 60 | print(f"[Warming Time]: {time.time()-st}") 61 | 62 | with console.status("KINCAID WAV"): 63 | st = time.time() 64 | outputs = ASR(files, 65 | batch_size=batch_size, 66 | chunk_length_s=30, 67 | generate_kwargs={'num_beams': 1, 'language': 'en'}, 68 | return_timestamps=False) 69 | 70 | time_kincaid46_wav = time.time()-st 71 | print(f"[KINCAID WAV Time]: {time_kincaid46_wav}") 72 | 73 | data['pred_text'] = [_['text'].strip() for _ in outputs] 74 | data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) 75 | 76 | 77 | # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 78 | if eval_mp3: 79 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") 80 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 81 | 82 | with console.status("KINCAID MP3"): 83 | st = time.time() 84 | outputs = ASR(files, 85 | batch_size=batch_size, 86 | chunk_length_s=30, 87 | generate_kwargs={'num_beams': 1, 'language': 'en'}, 88 | return_timestamps=False) 89 | 90 | time_kincaid46_mp3 = time.time()-st 91 | 92 | print(f"[KINCAID MP3 Time]: {time_kincaid46_mp3}") 93 | 94 | data['pred_text'] = [_['text'].strip() for _ in outputs] 95 | data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) 96 | else: 97 | time_kincaid46_mp3 = 0.0 98 | 99 | # MultiLingualLongform >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 100 | if eval_multilingual: 101 | data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") 102 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 103 | lang_codes = data['lang_code'].to_list() 104 | 105 | with console.status("MultiLingualLongform"): 106 | st = time.time() 107 | 108 | curr_files = [files[0]] 109 | curr_lang = lang_codes[0] 110 | outputs = [] 111 | for fn, lang in zip(files[1:], lang_codes[1:]): 112 | if lang != curr_lang: 113 | _outputs = ASR(curr_files, 114 | batch_size=batch_size, 115 | chunk_length_s=30, 116 | generate_kwargs={'num_beams': 1, 'language': curr_lang}, 117 | return_timestamps=False) 118 | outputs.extend(_outputs) 119 | 120 | curr_files = [fn] 121 | curr_lang = lang 122 | else: 123 | curr_files.append(fn) 124 | 125 | _outputs = ASR(curr_files, 126 | batch_size=batch_size, 127 | chunk_length_s=30, 128 | generate_kwargs={'num_beams': 1, 'language': curr_lang}, 129 | return_timestamps=False) 130 | 131 | outputs.extend(_outputs) 132 | 133 | time_multilingual = time.time()-st 134 | print(f"[MultiLingualLongform Time]: {time_multilingual}") 135 | 136 | data['pred_text'] = [_['text'].strip() for _ in outputs] 137 | data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) 138 | else: 139 | time_multilingual = 0.0 140 | 141 | infer_time = [ 142 | ["Dataset", "Time"], 143 | ["KINCAID46 WAV", time_kincaid46_wav], 144 | ["KINCAID46 MP3", time_kincaid46_mp3], 145 | ["MultiLingualLongform", time_multilingual] 146 | ] 147 | 148 | infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) 149 | infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) 150 | 151 | 152 | if __name__ == '__main__': 153 | args = parse_arguments() 154 | eval_mp3 = True if args.eval_mp3 == "yes" else False 155 | eval_multilingual = True if args.eval_multilingual == "yes" else False 156 | flash_attention = True if args.flash_attention == "yes" else False 157 | better_transformer = True if args.better_transformer == "yes" else False 158 | 159 | run(args.repo_path, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual) -------------------------------------------------------------------------------- /scripts/benchmark_huggingface_distil.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from rich.console import Console 3 | console = Console() 4 | 5 | def parse_arguments(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--repo_path', default="", type=str) 8 | parser.add_argument('--batch_size', default=16, type=int) 9 | parser.add_argument('--flash_attention', default="yes", type=str) 10 | parser.add_argument('--better_transformer', default="no", type=str) 11 | parser.add_argument('--eval_mp3', default="no", type=str) 12 | parser.add_argument('--eval_multilingual', default="no", type=str) 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def run(repo_path, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True): 18 | import torch 19 | import time, os 20 | import pandas as pd 21 | from transformers import pipeline 22 | 23 | # Load Model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 24 | model_kwargs = { 25 | "use_safetensors": True, 26 | "low_cpu_mem_usage": True 27 | } 28 | 29 | results_dir = f"{repo_path}/results/HuggingFaceDistilWhisper-bs_{batch_size}" 30 | 31 | if flash_attention: 32 | results_dir = f"{results_dir}-fa" 33 | model_kwargs["use_flash_attention_2"] = True 34 | 35 | ASR = pipeline("automatic-speech-recognition", 36 | f"distil-whisper/distil-large-v2", 37 | num_workers=1, 38 | torch_dtype=torch.float16, 39 | device="cuda", 40 | model_kwargs=model_kwargs) 41 | 42 | if (not flash_attention) and better_transformer: 43 | ASR.model = ASR.model.to_bettertransformer() 44 | results_dir = f"{results_dir}-bt" 45 | 46 | os.makedirs(results_dir, exist_ok=True) 47 | 48 | # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 49 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") 50 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 51 | 52 | with console.status("Warming"): 53 | st = time.time() 54 | _ = ASR(files, 55 | batch_size=batch_size, 56 | chunk_length_s=15, 57 | generate_kwargs={'num_beams': 1, 'language': 'en'}, 58 | return_timestamps=False) 59 | 60 | print(f"[Warming Time]: {time.time()-st}") 61 | 62 | with console.status("KINCAID WAV"): 63 | st = time.time() 64 | outputs = ASR(files, 65 | batch_size=batch_size, 66 | chunk_length_s=15, 67 | generate_kwargs={'num_beams': 1, 'language': 'en'}, 68 | return_timestamps=False) 69 | 70 | time_kincaid46_wav = time.time()-st 71 | print(f"[KINCAID WAV Time]: {time_kincaid46_wav}") 72 | 73 | data['pred_text'] = [_['text'].strip() for _ in outputs] 74 | data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) 75 | 76 | 77 | # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 78 | if eval_mp3: 79 | data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") 80 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 81 | 82 | with console.status("KINCAID MP3"): 83 | st = time.time() 84 | outputs = ASR(files, 85 | batch_size=batch_size, 86 | chunk_length_s=30, 87 | generate_kwargs={'num_beams': 1, 'language': 'en'}, 88 | return_timestamps=False) 89 | 90 | time_kincaid46_mp3 = time.time()-st 91 | 92 | print(f"[KINCAID MP3 Time]: {time_kincaid46_mp3}") 93 | 94 | data['pred_text'] = [_['text'].strip() for _ in outputs] 95 | data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) 96 | else: 97 | time_kincaid46_mp3 = 0.0 98 | 99 | # MultiLingualLongform >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 100 | if eval_multilingual: 101 | data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") 102 | files = [f"{repo_path}/{fn}" for fn in data['audio_path']] 103 | lang_codes = data['lang_code'].to_list() 104 | 105 | with console.status("MultiLingualLongform"): 106 | st = time.time() 107 | 108 | curr_files = [files[0]] 109 | curr_lang = lang_codes[0] 110 | outputs = [] 111 | for fn, lang in zip(files[1:], lang_codes[1:]): 112 | if lang != curr_lang: 113 | _outputs = ASR(curr_files, 114 | batch_size=batch_size, 115 | chunk_length_s=30, 116 | generate_kwargs={'num_beams': 1, 'language': curr_lang}, 117 | return_timestamps=False) 118 | outputs.extend(_outputs) 119 | 120 | curr_files = [fn] 121 | curr_lang = lang 122 | else: 123 | curr_files.append(fn) 124 | 125 | _outputs = ASR(curr_files, 126 | batch_size=batch_size, 127 | chunk_length_s=30, 128 | generate_kwargs={'num_beams': 1, 'language': curr_lang}, 129 | return_timestamps=False) 130 | 131 | outputs.extend(_outputs) 132 | 133 | time_multilingual = time.time()-st 134 | print(f"[MultiLingualLongform Time]: {time_multilingual}") 135 | 136 | data['pred_text'] = [_['text'].strip() for _ in outputs] 137 | data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) 138 | else: 139 | time_multilingual = 0.0 140 | 141 | infer_time = [ 142 | ["Dataset", "Time"], 143 | ["KINCAID46 WAV", time_kincaid46_wav], 144 | ["KINCAID46 MP3", time_kincaid46_mp3], 145 | ["MultiLingualLongform", time_multilingual] 146 | ] 147 | 148 | infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) 149 | infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) 150 | 151 | 152 | if __name__ == '__main__': 153 | args = parse_arguments() 154 | eval_mp3 = True if args.eval_mp3 == "yes" else False 155 | eval_multilingual = True if args.eval_multilingual == "yes" else False 156 | flash_attention = True if args.flash_attention == "yes" else False 157 | better_transformer = True if args.better_transformer == "yes" else False 158 | 159 | run(args.repo_path, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual) -------------------------------------------------------------------------------- /docs.md: -------------------------------------------------------------------------------- 1 | # Detailed Usage and Documentation 2 | 3 | 1. [Basic Usage](#basic-usage) 4 | 1. [Using Custom VAD Model](#using-custom-vad-model) 5 | 1. [Run Without VAD Model](#run-without-vad-model) 6 | 1. [Passing Custom Model Configuration](#passing-custom-model-configuration) 7 | 1. [Return Word-Alignments](#return-word-alignments) 8 | 1. [Write Transcripts To a File](#write-transcripts-to-a-file) 9 | 10 | ## Basic Usage 11 | 12 | Load WhisperS2T with CTranslate2 backend with default parameters: 13 | 14 | ```py 15 | import whisper_s2t 16 | 17 | model = whisper_s2t.load_model(model_identifier="large-v2", backend='CTranslate2') 18 | 19 | files = ['sample_1.wav'] 20 | lang_codes = ['en'] 21 | tasks = ['transcribe'] 22 | initial_prompts = [None] 23 | 24 | out = model.transcribe_with_vad(files, 25 | lang_codes=lang_codes, # pass lang_codes for each file 26 | tasks=tasks, # pass transcribe/translate 27 | initial_prompts=initial_prompts, # to do prompting (currently only supported for CTranslate2 backend) 28 | batch_size=16) 29 | 30 | print(out[0][0]) # Print first utterance for first file 31 | """ 32 | [Console Output] 33 | 34 | {'text': "Let's bring in Phil Mackie who is there at the palace. We're looking at Teresa and Philip May. Philip, can you see how he's being transferred from the helicopters? It looks like, as you said, the beast. It's got its headlights on because the sun is beginning to set now, certainly sinking behind some clouds. It's about a quarter of a mile away down the Grand Drive", 35 | 'avg_logprob': -0.25426941679184695, 36 | 'no_speech_prob': 8.147954940795898e-05, 37 | 'start_time': 0.0, 38 | 'end_time': 24.8} 39 | """ 40 | 41 | ``` 42 | 43 | Switch to HuggingFace backend (by default it will use FlashAttention2). Note: FlashAttention2 only works with Ampere/Hopper Nvidia GPUs. 44 | 45 | ```py 46 | model = whisper_s2t.load_model(model_identifier="large-v2", backend='HuggingFace') # Supported backends ['CTranslate2', 'HuggingFace', 'OpenAI'] 47 | ``` 48 | 49 | ## Using Custom VAD Model 50 | 51 | Wrap your VAD model (say `CustomVAD`) using the base class as `whisper_s2t.speech_segmenter.VADBaseClass`. See [whisper_s2t/speech_segmenter/frame_vad.py](whisper_s2t/speech_segmenter/frame_vad.py) for example. The `def __call__` must take audio_signal as input and returns a numpy array of size **T x 3** where T is the frame length. Each row should have following data `[speech_prob, frame_start_time, frame_end_time]`. Next pass your vad_model while initialising the whisper model. 52 | 53 | ```py 54 | vad_model = CustomVAD() 55 | model = whisper_s2t.load_model(model_identifier="large-v2", backend='CTranslate2', vad_model=vad_model) 56 | ``` 57 | 58 | ## Run Without VAD Model 59 | 60 | For some languages VAD model can give poor performance. For those cases, it's better to disable VAD. 61 | 62 | ```py 63 | out = model.transcribe(files, 64 | lang_codes=lang_codes, # pass lang_codes for each file 65 | tasks=tasks, # pass transcribe/translate 66 | initial_prompts=initial_prompts, # to do prompting (currently only supported for CTranslate2 backend) 67 | batch_size=24) 68 | 69 | print(out[0][0]) 70 | """ 71 | {'text': "Let's bring in Phil Mackie who is there at the palace. We're looking at Theresa and Philip May. Philip, can you see how he's being transferred from the helicopters? It looks like, as you said, the beast. It's got its headlights on because the sun is beginning to set now, certainly sinking behind some clouds. It's about a quarter of a mile away down the Grand Drive leading up into the courtyard. So you've seen the pictures there of the Prime Minister", 72 | 'avg_logprob': -0.25300603330135346, 73 | 'no_speech_prob': 1.9311904907226562e-05, 74 | 'start_time': 0, 75 | 'end_time': 29.0} 76 | """ 77 | ``` 78 | 79 | VAD parameters can also be tweaked using: 80 | 81 | ```py 82 | speech_segmenter_options = { 83 | 'eos_thresh': 0.1, 84 | 'bos_thresh': 0.1, 85 | } 86 | 87 | model = whisper_s2t.load_model(speech_segmenter_options=speech_segmenter_options) 88 | ``` 89 | 90 | ## Passing Custom Model Configuration 91 | 92 | Custom model configs can be passed as keyword arguments when loading the model: 93 | 94 | ```py 95 | import whisper_s2t 96 | from whisper_s2t.backends.ctranslate2.model import BEST_ASR_CONFIG 97 | 98 | model_kwargs = { 99 | 'compute_type': 'int8', # Note int8 is only supported for CTranslate2 backend, for others only float16 is supported for lower precision. 100 | 'asr_options': BEST_ASR_CONFIG 101 | } 102 | 103 | model = whisper_s2t.load_model(model_identifier="large-v2", backend='CTranslate2', **model_kwargs) 104 | ``` 105 | 106 | OR to update the configs after loading the model: 107 | 108 | ```py 109 | model.update_params(model_kwargs) 110 | ``` 111 | 112 | ## Return Word-Alignments 113 | 114 | Only for CTranslate2 and TensorRT backend. 115 | 116 | ```py 117 | import whisper_s2t 118 | 119 | model = whisper_s2t.load_model(model_identifier="large-v2", asr_options={'word_timestamps': True}) 120 | 121 | files = ['sample_1.wav'] 122 | lang_codes = ['en'] 123 | tasks = ['transcribe'] 124 | initial_prompts = [None] 125 | 126 | out = model.transcribe_with_vad(files, 127 | lang_codes=lang_codes, # pass lang_codes for each file 128 | tasks=tasks, # pass transcribe/translate 129 | initial_prompts=initial_prompts, # to do prompting (currently only supported for CTranslate2 backend) 130 | batch_size=24) 131 | 132 | print(out[0][0]) # Print first utterance for first file 133 | """ 134 | [Console Output] 135 | 136 | {'text': "Let's bring in Phil Mackie who is there at the palace. We're looking at Teresa and Philip May. Philip, can you see how he's being transferred from the helicopters? It looks like, as you said, the beast. It's got its headlights on because the sun is beginning to set now, certainly sinking behind some clouds. It's about a quarter of a mile away down the Grand Drive", 137 | 'avg_logprob': -0.2544597674565143, 138 | 'no_speech_prob': 8.213520050048828e-05, 139 | 'word_timestamps': [{'word': "Let's", 140 | 'start': 0.0, 141 | 'end': 0.24, 142 | 'prob': 0.63}, 143 | {'word': 'bring', 'start': 0.24, 'end': 0.4, 'prob': 0.96}, 144 | {'word': 'in', 'start': 0.4, 'end': 0.52, 'prob': 0.71}, 145 | {'word': 'Phil', 'start': 0.52, 'end': 0.66, 'prob': 0.46}, 146 | {'word': 'Mackie', 'start': 0.66, 'end': 1.02, 'prob': 0.27}, 147 | {'word': 'who', 'start': 1.02, 'end': 1.2, 'prob': 0.01}, 148 | . 149 | . 150 | . 151 | . 152 | } 153 | """ 154 | ``` 155 | 156 | ## Write Transcripts To a File 157 | 158 | Predicted transcripts can be easily exported to following output formats: `vtt, srt, json, tsv`. 159 | 160 | ```py 161 | files = ['file.wav'] 162 | lang_codes = ['en'] 163 | tasks = ['transcribe'] 164 | initial_prompts = [None] 165 | 166 | out = model.transcribe_with_vad(files, 167 | lang_codes=lang_codes, 168 | tasks=tasks, 169 | initial_prompts=initial_prompts, 170 | batch_size=24) 171 | 172 | whisper_s2t.write_outputs(out, format='vtt', ip_files=files, save_dir="./save_dir") # Save outputs 173 | 174 | whisper_s2t.write_outputs(out, format='vtt', op_files=op_files) # custom output file names 175 | ``` -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/trt_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import tensorrt_llm 4 | 5 | from pathlib import Path 6 | from collections import OrderedDict 7 | 8 | from tensorrt_llm._utils import str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch 9 | from tensorrt_llm.runtime import ModelConfig, SamplingConfig 10 | from tensorrt_llm.runtime.session import Session, TensorInfo 11 | 12 | 13 | class WhisperEncoding: 14 | 15 | def __init__(self, engine_dir): 16 | self.session = self.get_session(engine_dir) 17 | 18 | def get_session(self, engine_dir): 19 | config_path = engine_dir / 'encoder_config.json' 20 | with open(config_path, 'r') as f: 21 | config = json.load(f) 22 | 23 | dtype = config['builder_config']['precision'] 24 | n_mels = config['builder_config']['n_mels'] 25 | num_languages = config['builder_config']['num_languages'] 26 | 27 | self.dtype = dtype 28 | self.n_mels = n_mels 29 | self.num_languages = num_languages 30 | 31 | serialize_path = engine_dir / f'encoder.engine' 32 | 33 | with open(serialize_path, 'rb') as f: 34 | session = Session.from_serialized_engine(f.read()) 35 | 36 | return session 37 | 38 | def get_audio_features(self, mel): 39 | 40 | input_lengths = torch.tensor( 41 | [mel.shape[2] // 2 for _ in range(mel.shape[0])], 42 | dtype=torch.int32, 43 | device=mel.device) 44 | 45 | inputs = OrderedDict() 46 | inputs['x'] = mel 47 | inputs['input_lengths'] = input_lengths 48 | 49 | output_list = [ 50 | TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape), 51 | TensorInfo('input_lengths', str_dtype_to_trt('int32'), 52 | input_lengths.shape) 53 | ] 54 | 55 | output_info = (self.session).infer_shapes(output_list) 56 | 57 | outputs = { 58 | t.name: torch.empty(tuple(t.shape), 59 | dtype=trt_dtype_to_torch(t.dtype), 60 | device='cuda') 61 | for t in output_info 62 | } 63 | stream = torch.cuda.current_stream() 64 | ok = self.session.run(inputs=inputs, 65 | outputs=outputs, 66 | stream=stream.cuda_stream) 67 | assert ok, 'Engine execution failed' 68 | stream.synchronize() 69 | audio_features = outputs['output'] 70 | return audio_features 71 | 72 | 73 | class WhisperDecoding: 74 | 75 | def __init__(self, engine_dir, runtime_mapping, debug_mode=False): 76 | 77 | self.decoder_config = self.get_config(engine_dir) 78 | self.decoder_generation_session = self.get_session( 79 | engine_dir, runtime_mapping, debug_mode) 80 | 81 | def get_config(self, engine_dir): 82 | config_path = engine_dir / 'decoder_config.json' 83 | with open(config_path, 'r') as f: 84 | config = json.load(f) 85 | decoder_config = OrderedDict() 86 | decoder_config.update(config['plugin_config']) 87 | decoder_config.update(config['builder_config']) 88 | return decoder_config 89 | 90 | def get_session(self, engine_dir, runtime_mapping, debug_mode=False): 91 | dtype = self.decoder_config['precision'] 92 | serialize_path = engine_dir / f'decoder.engine' 93 | with open(serialize_path, "rb") as f: 94 | decoder_engine_buffer = f.read() 95 | 96 | decoder_model_config = ModelConfig( 97 | num_heads=self.decoder_config['num_heads'], 98 | num_kv_heads=self.decoder_config['num_heads'], 99 | hidden_size=self.decoder_config['hidden_size'], 100 | vocab_size=self.decoder_config['vocab_size'], 101 | num_layers=self.decoder_config['num_layers'], 102 | gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'], 103 | remove_input_padding=self.decoder_config['remove_input_padding'], 104 | cross_attention=self.decoder_config['cross_attention'], 105 | has_position_embedding=self. 106 | decoder_config['has_position_embedding'], 107 | has_token_type_embedding=self. 108 | decoder_config['has_token_type_embedding'], 109 | ) 110 | decoder_generation_session = tensorrt_llm.runtime.GenerationSession( 111 | decoder_model_config, 112 | decoder_engine_buffer, 113 | runtime_mapping, 114 | debug_mode=debug_mode) 115 | 116 | return decoder_generation_session 117 | 118 | def generate(self, 119 | decoder_input_ids, 120 | encoder_outputs, 121 | sampling_config): 122 | 123 | encoder_input_lengths = torch.tensor( 124 | [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], 125 | dtype=torch.int32, 126 | device='cuda') 127 | 128 | decoder_input_lengths = torch.tensor([ 129 | decoder_input_ids.shape[-1] 130 | for _ in range(decoder_input_ids.shape[0]) 131 | ], 132 | dtype=torch.int32, 133 | device='cuda') 134 | decoder_max_input_length = torch.max(decoder_input_lengths).item() 135 | 136 | self.decoder_generation_session.setup( 137 | decoder_input_lengths.size(0), 138 | decoder_max_input_length, 139 | sampling_config.max_new_tokens, 140 | beam_width=sampling_config.num_beams, 141 | encoder_max_input_length=encoder_outputs.shape[1]) 142 | 143 | torch.cuda.synchronize() 144 | 145 | decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() 146 | output_ids = self.decoder_generation_session.decode( 147 | decoder_input_ids, 148 | decoder_input_lengths, 149 | sampling_config, 150 | encoder_output=encoder_outputs, 151 | encoder_input_lengths=encoder_input_lengths, 152 | ) 153 | torch.cuda.synchronize() 154 | 155 | # get the list of int from output_ids tensor 156 | output_ids = output_ids.cpu().numpy().tolist() 157 | return output_ids 158 | 159 | 160 | class WhisperTRT: 161 | def __init__(self, engine_dir, compute_type='float16'): 162 | world_size = 1 163 | runtime_rank = tensorrt_llm.mpi_rank() 164 | runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank) 165 | torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) 166 | engine_dir = Path(engine_dir) 167 | 168 | self.encoder = WhisperEncoding(engine_dir) 169 | self.decoder = WhisperDecoding(engine_dir, runtime_mapping) 170 | self.n_mels = self.encoder.n_mels 171 | self.is_multilingual = True 172 | self.compute_type = compute_type 173 | 174 | def encode(self, mel): 175 | return self.encoder.get_audio_features(mel.type(str_dtype_to_torch(self.compute_type))) 176 | 177 | def generate(self, features, prompts, **generate_kwargs): 178 | if features.shape[1] == self.n_mels: 179 | features = self.encode(features) 180 | 181 | decoder_input_ids = torch.tensor(prompts) 182 | 183 | sampling_config = SamplingConfig(**generate_kwargs) 184 | 185 | output_ids = self.decoder.generate(decoder_input_ids, 186 | features, 187 | sampling_config) 188 | 189 | return output_ids -------------------------------------------------------------------------------- /whisper_s2t/backends/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from abc import ABC, abstractmethod 4 | 5 | from ..configs import * 6 | from ..data import WhisperDataLoader 7 | from ..audio import LogMelSpectogram 8 | from ..speech_segmenter import SpeechSegmenter 9 | 10 | 11 | class NoneTokenizer: 12 | def __init__(self): 13 | self.sot_prev = 0 14 | self.silent_token = 0 15 | self.no_timestamps = 0 16 | self.timestamp_begin = 0 17 | 18 | def sot_sequence(self, task=None, lang=None): 19 | return [task, lang] 20 | 21 | def encode(self, text): 22 | return [0] 23 | 24 | 25 | def fix_batch_param(param, default_value, N): 26 | if param is None: 27 | param = N*[default_value] 28 | elif type(param) == type(default_value): 29 | param = N*[param] 30 | elif len(param) != N: 31 | param = N*[param[0]] 32 | 33 | return param 34 | 35 | 36 | class WhisperModel(ABC): 37 | def __init__(self, 38 | tokenizer=None, 39 | vad_model=None, 40 | n_mels=80, 41 | device="cuda", 42 | device_index=0, 43 | compute_type="float16", 44 | merge_chunks=True, 45 | dta_padding=3.0, 46 | use_dynamic_time_axis = False, 47 | max_speech_len=29.0, 48 | max_text_token_len=MAX_TEXT_TOKEN_LENGTH, 49 | without_timestamps=True, 50 | speech_segmenter_options={}): 51 | 52 | # Configure Params 53 | self.device = device 54 | self.device_index = device_index 55 | self.compute_type = compute_type 56 | 57 | self.n_mels = n_mels 58 | self.merge_chunks = merge_chunks 59 | self.max_speech_len = max_speech_len 60 | 61 | self.dta_padding = dta_padding 62 | self.use_dynamic_time_axis = use_dynamic_time_axis 63 | 64 | self.without_timestamps = without_timestamps 65 | self.max_text_token_len = max_text_token_len 66 | 67 | self.vad_model = vad_model 68 | self.speech_segmenter_options = speech_segmenter_options 69 | self.speech_segmenter_options['max_seg_len'] = self.max_speech_len 70 | 71 | # Tokenizer 72 | if tokenizer is None: 73 | tokenizer = NoneTokenizer() 74 | 75 | self.tokenizer = tokenizer 76 | 77 | self._init_dependables() 78 | 79 | 80 | def _init_dependables(self): 81 | # Rescaled Params 82 | self.dta_padding = int(self.dta_padding*SAMPLE_RATE) 83 | self.max_initial_prompt_len = self.max_text_token_len//2 -1 84 | 85 | # Load Pre Processor 86 | self.preprocessor = LogMelSpectogram(n_mels=self.n_mels).to(self.device) 87 | 88 | # Load Speech Segmenter 89 | self.speech_segmenter = SpeechSegmenter(self.vad_model, device=self.device, **self.speech_segmenter_options) 90 | 91 | # Load Data Loader 92 | self.data_loader = WhisperDataLoader( 93 | self.device, self.tokenizer, self.speech_segmenter, 94 | dta_padding=self.dta_padding, 95 | without_timestamps=self.without_timestamps, 96 | max_speech_len=self.max_speech_len, 97 | max_initial_prompt_len=self.max_initial_prompt_len, 98 | use_dynamic_time_axis=self.use_dynamic_time_axis, 99 | merge_chunks=self.merge_chunks 100 | ) 101 | 102 | def update_params(self, params={}): 103 | for key, value in params.items(): 104 | setattr(self, key, value) 105 | 106 | self._init_dependables() 107 | 108 | 109 | @abstractmethod 110 | def generate_segment_batched(self, features, prompts): 111 | pass 112 | 113 | @torch.no_grad() 114 | def transcribe(self, audio_files, lang_codes=None, tasks=None, initial_prompts=None, batch_size=8): 115 | 116 | # if lang_codes == None: 117 | # lang_codes = len(audio_files)*['en'] 118 | 119 | # if tasks == None: 120 | # tasks = len(audio_files)*['transcribe'] 121 | 122 | # if initial_prompts == None: 123 | # initial_prompts = len(audio_files)*[None] 124 | 125 | # responses = [] 126 | # for signals, prompts, seq_len in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size, use_vad=False): 127 | # mels, seq_len = self.preprocessor(signals, seq_len) 128 | # res = self.generate_segment_batched(mels.to(self.device), prompts) 129 | # responses.extend(res) 130 | 131 | # return responses 132 | 133 | lang_codes = fix_batch_param(lang_codes, 'en', len(audio_files)) 134 | tasks = fix_batch_param(tasks, 'transcribe', len(audio_files)) 135 | initial_prompts = fix_batch_param(initial_prompts, None, len(audio_files)) 136 | 137 | responses = [[] for _ in audio_files] 138 | 139 | pbar_pos = 0 140 | with tqdm(total=len(audio_files)*100, desc=f"Transcribing") as pbar: 141 | for signals, prompts, seq_len, seg_metadata, pbar_update in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size, use_vad=False): 142 | mels, seq_len = self.preprocessor(signals, seq_len) 143 | res = self.generate_segment_batched(mels.to(self.device), prompts, seq_len, seg_metadata) 144 | 145 | for res_idx, _seg_metadata in enumerate(seg_metadata): 146 | responses[_seg_metadata['file_id']].append({**res[res_idx], 147 | 'start_time': round(_seg_metadata['start_time'], 3), 148 | 'end_time': round(_seg_metadata['end_time'], 3)}) 149 | 150 | if (pbar_pos) <= pbar.total: 151 | pbar_pos += pbar_update 152 | pbar.update(pbar_update) 153 | 154 | pbar.update(pbar.total-pbar_pos) 155 | 156 | return responses 157 | 158 | @torch.no_grad() 159 | def transcribe_with_vad(self, audio_files, lang_codes=None, tasks=None, initial_prompts=None, batch_size=8): 160 | 161 | lang_codes = fix_batch_param(lang_codes, 'en', len(audio_files)) 162 | tasks = fix_batch_param(tasks, 'transcribe', len(audio_files)) 163 | initial_prompts = fix_batch_param(initial_prompts, None, len(audio_files)) 164 | 165 | responses = [[] for _ in audio_files] 166 | 167 | pbar_pos = 0 168 | with tqdm(total=len(audio_files)*100, desc=f"Transcribing") as pbar: 169 | for signals, prompts, seq_len, seg_metadata, pbar_update in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size): 170 | mels, seq_len = self.preprocessor(signals, seq_len) 171 | res = self.generate_segment_batched(mels.to(self.device), prompts, seq_len, seg_metadata) 172 | 173 | for res_idx, _seg_metadata in enumerate(seg_metadata): 174 | responses[_seg_metadata['file_id']].append({**res[res_idx], 175 | 'start_time': round(_seg_metadata['start_time'], 3), 176 | 'end_time': round(_seg_metadata['end_time'], 3)}) 177 | 178 | if (pbar_pos) <= pbar.total: 179 | pbar_pos += pbar_update 180 | pbar.update(pbar_update) 181 | 182 | pbar.update(pbar.total-pbar_pos) 183 | 184 | return responses -------------------------------------------------------------------------------- /whisper_s2t/backends/ctranslate2/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tokenizers 3 | import ctranslate2 4 | import numpy as np 5 | 6 | from .tokenizer import Tokenizer 7 | from .hf_utils import download_model 8 | 9 | 10 | from .. import WhisperModel 11 | from ...configs import * 12 | 13 | 14 | FAST_ASR_OPTIONS = { 15 | "beam_size": 1, 16 | "best_of": 1, # Placeholder 17 | "patience": 1, 18 | "length_penalty": 1, 19 | "repetition_penalty": 1.01, 20 | "no_repeat_ngram_size": 0, 21 | "compression_ratio_threshold": 2.4, # Placeholder 22 | "log_prob_threshold": -1.0, # Placeholder 23 | "no_speech_threshold": 0.5, # Placeholder 24 | "prefix": None, # Placeholder 25 | "suppress_blank": True, 26 | "suppress_tokens": [-1], 27 | "without_timestamps": True, 28 | "max_initial_timestamp": 1.0, 29 | "word_timestamps": False, # Placeholder 30 | "sampling_temperature": 1.0, 31 | "return_scores": True, 32 | "return_no_speech_prob": True, 33 | "word_aligner_model": 'tiny', 34 | } 35 | 36 | 37 | BEST_ASR_CONFIG = { 38 | "beam_size": 5, 39 | "best_of": 1, # Placeholder 40 | "patience": 2, 41 | "length_penalty": 1, 42 | "repetition_penalty": 1.01, 43 | "no_repeat_ngram_size": 0, 44 | "compression_ratio_threshold": 2.4, # Placeholder 45 | "log_prob_threshold": -1.0, # Placeholder 46 | "no_speech_threshold": 0.5, # Placeholder 47 | "prefix": None, # Placeholder 48 | "suppress_blank": True, 49 | "suppress_tokens": [-1], 50 | "without_timestamps": True, 51 | "max_initial_timestamp": 1.0, 52 | "word_timestamps": False, # Placeholder 53 | "sampling_temperature": 1.0, 54 | "return_scores": True, 55 | "return_no_speech_prob": True, 56 | "word_aligner_model": 'tiny', 57 | } 58 | 59 | 60 | class WhisperModelCT2(WhisperModel): 61 | def __init__(self, 62 | model_name_or_path: str, 63 | cpu_threads=4, 64 | num_workers=1, 65 | device="cuda", 66 | device_index=0, 67 | compute_type="float16", 68 | max_text_token_len=MAX_TEXT_TOKEN_LENGTH, 69 | asr_options={}, 70 | **model_kwargs): 71 | 72 | 73 | # Get local model path or download from huggingface 74 | if os.path.isdir(model_name_or_path): 75 | self.model_path = model_name_or_path 76 | else: 77 | self.model_path = download_model(model_name_or_path) 78 | 79 | # Load model 80 | self.model = ctranslate2.models.Whisper(self.model_path, 81 | device=device, 82 | device_index=device_index, 83 | compute_type=compute_type, 84 | intra_threads=cpu_threads, 85 | inter_threads=num_workers) 86 | 87 | # Load tokenizer 88 | tokenizer_file = os.path.join(self.model_path, "tokenizer.json") 89 | tokenizer = Tokenizer(tokenizers.Tokenizer.from_file(tokenizer_file), self.model.is_multilingual) 90 | 91 | # ASR Options 92 | self.asr_options = FAST_ASR_OPTIONS 93 | self.asr_options.update(asr_options) 94 | 95 | if self.asr_options['word_timestamps']: 96 | self.aligner_model_path = download_model(self.asr_options['word_aligner_model']) 97 | self.aligner_model = ctranslate2.models.Whisper(self.aligner_model_path, 98 | device=device, 99 | device_index=device_index, 100 | compute_type=compute_type, 101 | intra_threads=cpu_threads, 102 | inter_threads=num_workers) 103 | 104 | self.generate_kwargs = { 105 | "max_length": max_text_token_len, 106 | "return_scores": self.asr_options['return_scores'], 107 | "return_no_speech_prob": self.asr_options['return_no_speech_prob'], 108 | "length_penalty": self.asr_options['length_penalty'], 109 | "repetition_penalty": self.asr_options['repetition_penalty'], 110 | "no_repeat_ngram_size": self.asr_options['no_repeat_ngram_size'], 111 | "beam_size": self.asr_options['beam_size'], 112 | "patience": self.asr_options['patience'], 113 | "suppress_blank": self.asr_options['suppress_blank'], 114 | "suppress_tokens": self.asr_options['suppress_tokens'], 115 | "max_initial_timestamp_index": int(round(self.asr_options['max_initial_timestamp']/TIME_PRECISION)), 116 | "sampling_temperature": self.asr_options['sampling_temperature'], 117 | } 118 | 119 | super().__init__( 120 | tokenizer=tokenizer, 121 | device=device, 122 | device_index=device_index, 123 | compute_type=compute_type, 124 | max_text_token_len=max_text_token_len, 125 | **model_kwargs 126 | ) 127 | 128 | def update_generation_kwargs(self, params={}): 129 | self.generate_kwargs.update(params) 130 | 131 | if 'max_text_token_len' in params: 132 | self.update_params(params={'max_text_token_len': params['max_text_token_len']}) 133 | 134 | def encode(self, features): 135 | """ 136 | [Not Used] 137 | """ 138 | 139 | features = ctranslate2.StorageView.from_array(features.contiguous()) 140 | return self.model.encode(features) 141 | 142 | def assign_word_timings(self, alignments, text_token_probs, words, word_tokens): 143 | text_indices = np.array([pair[0] for pair in alignments]) 144 | time_indices = np.array([pair[1] for pair in alignments]) 145 | 146 | if len(word_tokens) <= 1: 147 | return [] 148 | 149 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) 150 | if len(word_boundaries) <= 1: 151 | return [] 152 | 153 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) 154 | jump_times = time_indices[jumps]*TIME_PRECISION 155 | start_times = jump_times[word_boundaries[:-1]] 156 | end_times = jump_times[word_boundaries[1:]] 157 | word_probs = [ 158 | np.mean(text_token_probs[i:j]) 159 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) 160 | ] 161 | 162 | return [ 163 | dict( 164 | word=word, start=round(start, 2), end=round(end, 2), prob=round(prob, 2) 165 | ) 166 | for word, start, end, prob in zip( 167 | words, start_times, end_times, word_probs 168 | ) 169 | ] 170 | 171 | def align_words(self, features, texts, text_tokens, sot_seqs, seq_lens, seg_metadata): 172 | lang_codes = [_['lang_code'] for _ in seg_metadata] 173 | word_tokens = self.tokenizer.split_to_word_tokens_batch(texts, text_tokens, lang_codes) 174 | 175 | start_seq_wise_req = {} 176 | for _idx, _sot_seq in enumerate(sot_seqs): 177 | try: 178 | # print(_sot_seq) 179 | start_seq_wise_req[_sot_seq].append(_idx) 180 | except: 181 | start_seq_wise_req[_sot_seq] = [_idx] 182 | 183 | token_alignments = [[] for _ in seg_metadata] 184 | for start_seq, req_idx in start_seq_wise_req.items(): 185 | res = self.aligner_model.align(ctranslate2.StorageView.from_array(features[req_idx]), 186 | start_sequence=list(start_seq), 187 | text_tokens=[text_tokens[_] for _ in req_idx], 188 | num_frames=list(seq_lens[req_idx].detach().cpu().numpy()), 189 | median_filter_width=7) 190 | 191 | for _res, _req_idx in zip(res, req_idx): 192 | token_alignments[_req_idx] = _res 193 | 194 | word_timings = [] 195 | for _idx, _seg_metadata in enumerate(seg_metadata): 196 | _word_timings = self.assign_word_timings(token_alignments[_idx].alignments, 197 | token_alignments[_idx].text_token_probs, 198 | word_tokens[_idx][0], 199 | word_tokens[_idx][1]) 200 | 201 | stitched_seg = _seg_metadata['stitched_seg'] 202 | 203 | current_seg_idx = 0 204 | current_offset = _seg_metadata['start_time'] 205 | 206 | for w in _word_timings: 207 | while (w['start'] + current_offset) >= stitched_seg[current_seg_idx][1]: 208 | current_seg_idx += 1 209 | current_offset += (stitched_seg[current_seg_idx][0]-stitched_seg[current_seg_idx-1][1]) 210 | 211 | w['start'] += current_offset 212 | w['end'] += current_offset 213 | 214 | word_timings.append(_word_timings) 215 | 216 | return word_timings 217 | 218 | def generate_segment_batched(self, features, prompts, seq_lens, seg_metadata): 219 | 220 | if self.device == 'cpu': 221 | features = np.ascontiguousarray(features.detach().numpy()) 222 | else: 223 | features = features.contiguous() 224 | 225 | result = self.model.generate(ctranslate2.StorageView.from_array(features), 226 | prompts, 227 | **self.generate_kwargs) 228 | 229 | texts = self.tokenizer.decode_batch([x.sequences_ids[0] for x in result]) 230 | 231 | response = [] 232 | for idx, r in enumerate(result): 233 | response.append({'text': texts[idx].strip()}) 234 | 235 | if self.generate_kwargs['return_scores']: 236 | seq_len = len(r.sequences_ids[0]) 237 | cum_logprob = r.scores[0]*(seq_len**self.generate_kwargs['length_penalty']) 238 | response[-1]['avg_logprob'] = cum_logprob/(seq_len + 1) 239 | 240 | if self.generate_kwargs['return_no_speech_prob']: 241 | response[-1]['no_speech_prob'] = r.no_speech_prob 242 | 243 | if self.asr_options['word_timestamps']: 244 | text_tokens = [x.sequences_ids[0]+[self.tokenizer.eot] for x in result] 245 | sot_seqs = [tuple(_[-4:]) for _ in prompts] 246 | word_timings = self.align_words(features, texts, text_tokens, sot_seqs, seq_lens, seg_metadata) 247 | 248 | for _response, _word_timings in zip(response, word_timings): 249 | _response['word_timestamps'] = _word_timings 250 | 251 | return response -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tokenizers 3 | import ctranslate2 4 | import numpy as np 5 | 6 | from .trt_model import WhisperTRT 7 | from .tokenizer import Tokenizer 8 | from .hf_utils import download_model 9 | from .engine_builder import build_trt_engine, TRTBuilderConfig, load_trt_build_config 10 | 11 | 12 | from .. import WhisperModel 13 | from ...configs import * 14 | 15 | 16 | FAST_ASR_OPTIONS = { 17 | "beam_size": 1, 18 | "best_of": 1, # Placeholder 19 | "patience": 1, 20 | "length_penalty": 1, 21 | "repetition_penalty": 1.01, 22 | "no_repeat_ngram_size": 0, 23 | "compression_ratio_threshold": 2.4, # Placeholder 24 | "log_prob_threshold": -1.0, # Placeholder 25 | "no_speech_threshold": 0.5, # Placeholder 26 | "prefix": None, # Placeholder 27 | "suppress_blank": True, 28 | "suppress_tokens": [-1], 29 | "without_timestamps": True, 30 | "max_initial_timestamp": 1.0, 31 | "word_timestamps": False, # Placeholder 32 | "sampling_temperature": 1.0, 33 | "return_scores": True, 34 | "return_no_speech_prob": True, 35 | "word_aligner_model": 'tiny', 36 | } 37 | 38 | 39 | BEST_ASR_CONFIG = { 40 | "beam_size": 5, 41 | "best_of": 1, # Placeholder 42 | "patience": 2, 43 | "length_penalty": 1, 44 | "repetition_penalty": 1.01, 45 | "no_repeat_ngram_size": 0, 46 | "compression_ratio_threshold": 2.4, # Placeholder 47 | "log_prob_threshold": -1.0, # Placeholder 48 | "no_speech_threshold": 0.5, # Placeholder 49 | "prefix": None, # Placeholder 50 | "suppress_blank": True, 51 | "suppress_tokens": [-1], 52 | "without_timestamps": True, 53 | "max_initial_timestamp": 1.0, 54 | "word_timestamps": False, # Placeholder 55 | "sampling_temperature": 1.0, 56 | "return_scores": True, 57 | "return_no_speech_prob": True, 58 | "word_aligner_model": 'tiny', 59 | } 60 | 61 | 62 | class WhisperModelTRT(WhisperModel): 63 | def __init__(self, 64 | model_name_or_path: str, 65 | cpu_threads=4, 66 | num_workers=1, 67 | device="cuda", 68 | device_index=0, 69 | compute_type="float16", 70 | max_text_token_len=MAX_TEXT_TOKEN_LENGTH, 71 | asr_options={}, 72 | **model_kwargs): 73 | 74 | # ASR Options 75 | self.asr_options = FAST_ASR_OPTIONS 76 | self.asr_options.update(asr_options) 77 | 78 | # Get local model path or build a new engine 79 | if os.path.isdir(model_name_or_path): 80 | self.model_path = model_name_or_path 81 | trt_build_args = load_trt_build_config(self.model_path) 82 | else: 83 | trt_build_args = model_kwargs.get('trt_build_args', None) 84 | if trt_build_args is None: 85 | print(f"'trt_build_args' not provided in model_kwargs, using default configs.") 86 | trt_build_args = TRTBuilderConfig( 87 | max_output_len=max_text_token_len, 88 | max_beam_width=self.asr_options["beam_size"] 89 | ) 90 | 91 | self.model_path = build_trt_engine(model_name=model_name_or_path, args=trt_build_args) 92 | 93 | if 'trt_build_args' in model_kwargs: 94 | del model_kwargs['trt_build_args'] 95 | 96 | self.trt_build_args = trt_build_args 97 | 98 | # Update params according to TRT Build Args 99 | if max_text_token_len > self.trt_build_args.max_output_len: 100 | print(f"'max_text_token_len' cannot be larger than 'self.trt_build_args.max_output_len'. Setting 'max_text_token_len' to {self.trt_build_args.max_output_len}.") 101 | max_text_token_len = self.trt_build_args.max_output_len 102 | 103 | if self.asr_options["beam_size"] > self.trt_build_args.max_beam_width: 104 | print(f"'beam_size' cannot be larger than 'self.trt_build_args.max_beam_width'. Setting 'beam_size' to {self.trt_build_args.max_beam_width}.") 105 | self.asr_options["beam_size"] = self.trt_build_args.max_beam_width 106 | 107 | # Load model 108 | self.model = WhisperTRT(self.model_path) 109 | 110 | # Load tokenizer 111 | tokenizer_file = os.path.join(self.model_path, "tokenizer.json") 112 | tokenizer = Tokenizer(tokenizers.Tokenizer.from_file(tokenizer_file), self.model.is_multilingual) 113 | 114 | if self.asr_options['word_timestamps']: 115 | self.aligner_model_path = download_model(self.asr_options['word_aligner_model']) 116 | self.aligner_model = ctranslate2.models.Whisper(self.aligner_model_path, 117 | device=device, 118 | device_index=device_index, 119 | compute_type=compute_type, 120 | intra_threads=cpu_threads, 121 | inter_threads=num_workers) 122 | 123 | self.generate_kwargs = { 124 | "end_id": tokenizer.eot, 125 | "pad_id": tokenizer.eot, 126 | "max_new_tokens": max_text_token_len, 127 | "length_penalty": self.asr_options['length_penalty'], 128 | "repetition_penalty": self.asr_options['repetition_penalty'], 129 | "num_beams": self.asr_options['beam_size'], 130 | "stop_words_list": self.asr_options['suppress_blank'], 131 | "bad_words_list": self.asr_options['suppress_tokens'], 132 | "temperature": self.asr_options['sampling_temperature'], 133 | } 134 | 135 | super().__init__( 136 | tokenizer=tokenizer, 137 | device=device, 138 | device_index=device_index, 139 | compute_type=compute_type, 140 | max_text_token_len=max_text_token_len, 141 | **model_kwargs 142 | ) 143 | 144 | def update_generation_kwargs(self, params={}): 145 | self.generate_kwargs.update(params) 146 | 147 | if 'max_text_token_len' in params: 148 | self.update_params(params={'max_text_token_len': params['max_text_token_len']}) 149 | 150 | def encode(self, features): 151 | """ 152 | [Not Used] 153 | """ 154 | 155 | return self.model.encode(features) 156 | 157 | def assign_word_timings(self, alignments, text_token_probs, words, word_tokens): 158 | text_indices = np.array([pair[0] for pair in alignments]) 159 | time_indices = np.array([pair[1] for pair in alignments]) 160 | 161 | if len(word_tokens) <= 1: 162 | return [] 163 | 164 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) 165 | if len(word_boundaries) <= 1: 166 | return [] 167 | 168 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) 169 | jump_times = time_indices[jumps]*TIME_PRECISION 170 | start_times = jump_times[word_boundaries[:-1]] 171 | end_times = jump_times[word_boundaries[1:]] 172 | word_probs = [ 173 | np.mean(text_token_probs[i:j]) 174 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) 175 | ] 176 | 177 | return [ 178 | dict( 179 | word=word, start=round(start, 2), end=round(end, 2), prob=round(prob, 2) 180 | ) 181 | for word, start, end, prob in zip( 182 | words, start_times, end_times, word_probs 183 | ) 184 | ] 185 | 186 | def align_words(self, features, texts, text_tokens, sot_seqs, seq_lens, seg_metadata): 187 | lang_codes = [_['lang_code'] for _ in seg_metadata] 188 | word_tokens = self.tokenizer.split_to_word_tokens_batch(texts, text_tokens, lang_codes) 189 | 190 | start_seq_wise_req = {} 191 | for _idx, _sot_seq in enumerate(sot_seqs): 192 | try: 193 | # print(_sot_seq) 194 | start_seq_wise_req[_sot_seq].append(_idx) 195 | except: 196 | start_seq_wise_req[_sot_seq] = [_idx] 197 | 198 | token_alignments = [[] for _ in seg_metadata] 199 | for start_seq, req_idx in start_seq_wise_req.items(): 200 | res = self.aligner_model.align(ctranslate2.StorageView.from_array(features[req_idx]), 201 | start_sequence=list(start_seq), 202 | text_tokens=[text_tokens[_] for _ in req_idx], 203 | num_frames=list(seq_lens[req_idx].detach().cpu().numpy()), 204 | median_filter_width=7) 205 | 206 | for _res, _req_idx in zip(res, req_idx): 207 | token_alignments[_req_idx] = _res 208 | 209 | word_timings = [] 210 | for _idx, _seg_metadata in enumerate(seg_metadata): 211 | _word_timings = self.assign_word_timings(token_alignments[_idx].alignments, 212 | token_alignments[_idx].text_token_probs, 213 | word_tokens[_idx][0], 214 | word_tokens[_idx][1]) 215 | 216 | stitched_seg = _seg_metadata['stitched_seg'] 217 | 218 | current_seg_idx = 0 219 | current_offset = _seg_metadata['start_time'] 220 | 221 | for w in _word_timings: 222 | while (w['start'] + current_offset) >= stitched_seg[current_seg_idx][1]: 223 | current_seg_idx += 1 224 | current_offset += (stitched_seg[current_seg_idx][0]-stitched_seg[current_seg_idx-1][1]) 225 | 226 | w['start'] += current_offset 227 | w['end'] += current_offset 228 | 229 | word_timings.append(_word_timings) 230 | 231 | return word_timings 232 | 233 | def generate_segment_batched(self, features, prompts, seq_lens, seg_metadata): 234 | 235 | result = self.model.generate(features, 236 | prompts, 237 | **self.generate_kwargs) 238 | 239 | texts = self.tokenizer.decode_batch([x[0] for x in result]) 240 | 241 | response = [] 242 | for idx, r in enumerate(result): 243 | response.append({'text': texts[idx].strip()}) 244 | 245 | if self.asr_options['word_timestamps']: 246 | text_tokens = [[_t for _t in x[0] if _t < self.tokenizer.eot]+[self.tokenizer.eot] for x in result] 247 | sot_seqs = [tuple(_[-4:]) for _ in prompts] 248 | word_timings = self.align_words(features, texts, text_tokens, sot_seqs, seq_lens, seg_metadata) 249 | 250 | for _response, _word_timings in zip(response, word_timings): 251 | _response['word_timestamps'] = _word_timings 252 | 253 | return response -------------------------------------------------------------------------------- /whisper_s2t/backends/tensorrt/engine_builder/builder.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper 2 | 3 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import os 19 | import time 20 | import argparse 21 | 22 | import torch 23 | import tensorrt_llm 24 | from tensorrt_llm import str_dtype_to_torch, str_dtype_to_trt 25 | 26 | from tensorrt_llm.logger import logger 27 | from tensorrt_llm.builder import Builder 28 | from tensorrt_llm.network import net_guard 29 | from tensorrt_llm.models import quantize_model 30 | from tensorrt_llm.quantization import QuantMode 31 | from tensorrt_llm.plugin.plugin import ContextFMHAType 32 | from tensorrt_llm.functional import LayerNormPositionType, LayerNormType 33 | 34 | from . import load_trt_build_config 35 | from .model_utils import load_encoder_weight, load_decoder_weight 36 | 37 | 38 | def get_export_size(output_path): 39 | return os.popen(f'du -sh {output_path}').read().split("\t")[0] 40 | 41 | 42 | def serialize_engine(engine, path): 43 | with open(path, 'wb') as f: 44 | f.write(engine) 45 | 46 | 47 | def build_encoder(model, args): 48 | 49 | model_metadata = model['dims'] 50 | model_params = model['model_state_dict'] 51 | 52 | # cast params according dtype 53 | for k, v in model_params.items(): 54 | model_params[k] = v.to(str_dtype_to_torch(args.dtype)) 55 | 56 | builder = Builder() 57 | 58 | max_batch_size = args.max_batch_size 59 | hidden_states = model_metadata['n_audio_state'] 60 | num_heads = model_metadata['n_audio_head'] 61 | num_layers = model_metadata['n_audio_layer'] 62 | 63 | model_is_multilingual = (model_metadata['n_vocab'] >= 51865) 64 | 65 | builder_config = builder.create_builder_config( 66 | name="encoder", 67 | precision=args.dtype, 68 | tensor_parallel=1, 69 | num_layers=num_layers, 70 | num_heads=num_heads, 71 | hidden_size=hidden_states, 72 | max_batch_size=max_batch_size, 73 | int8=args.quant_mode_enc.has_act_or_weight_quant(), 74 | n_mels=model_metadata['n_mels'], 75 | num_languages=model_metadata['n_vocab'] - 51765 - 76 | int(model_is_multilingual), 77 | ) 78 | 79 | tensorrt_llm_whisper_encoder = tensorrt_llm.models.WhisperEncoder( 80 | model_metadata['n_mels'], model_metadata['n_audio_ctx'], 81 | model_metadata['n_audio_state'], model_metadata['n_audio_head'], 82 | model_metadata['n_audio_layer'], str_dtype_to_trt(args.dtype)) 83 | 84 | 85 | if args.use_weight_only_enc: 86 | tensorrt_llm_whisper_encoder = quantize_model( 87 | tensorrt_llm_whisper_encoder, args.quant_mode_enc) 88 | 89 | load_encoder_weight(tensorrt_llm_whisper_encoder, model_metadata, 90 | model_params, model_metadata['n_audio_layer']) 91 | 92 | network = builder.create_network() 93 | 94 | if args.use_gemm_plugin: 95 | network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) 96 | 97 | if args.use_layernorm_plugin: 98 | network.plugin_config.set_layernorm_plugin(dtype=args.use_layernorm_plugin) 99 | 100 | if args.use_bert_attention_plugin: 101 | network.plugin_config.set_bert_attention_plugin(dtype=args.use_bert_attention_plugin) 102 | 103 | if args.use_context_fmha_enc: 104 | network.plugin_config.set_context_fmha(ContextFMHAType.enabled) 105 | 106 | if args.remove_input_padding: 107 | network.plugin_config.enable_remove_input_padding() 108 | 109 | if args.use_weight_only_enc: 110 | network.plugin_config.set_weight_only_quant_matmul_plugin( 111 | dtype=args.dtype) 112 | 113 | with net_guard(network): 114 | inputs = tensorrt_llm_whisper_encoder.prepare_inputs( 115 | args.max_batch_size) 116 | 117 | tensorrt_llm_whisper_encoder(*inputs) 118 | 119 | if args.debug_mode: 120 | for k, v in tensorrt_llm_whisper_encoder.named_network_outputs(): 121 | network._mark_output(v, k, str_dtype_to_trt(args.dtype)) 122 | 123 | engine = None 124 | engine = builder.build_engine(network, builder_config) 125 | 126 | config_path = os.path.join(args.output_dir, 'encoder_config.json') 127 | builder.save_config(builder_config, config_path) 128 | 129 | serialize_engine(engine, os.path.join(args.output_dir, "encoder.engine")) 130 | 131 | 132 | def build_decoder(model, args): 133 | 134 | model_metadata = model['dims'] 135 | model_params = model['model_state_dict'] 136 | 137 | # cast params according dtype 138 | for k, v in model_params.items(): 139 | model_params[k] = v.to(str_dtype_to_torch(args.dtype)) 140 | 141 | builder = Builder() 142 | 143 | timing_cache_file = os.path.join(args.output_dir, 'decoder_model.cache') 144 | builder_config = builder.create_builder_config( 145 | name="decoder", 146 | precision=args.dtype, 147 | timing_cache=timing_cache_file, 148 | tensor_parallel=args.world_size, 149 | num_layers=model_metadata['n_text_layer'], 150 | num_heads=model_metadata['n_text_head'], 151 | hidden_size=model_metadata['n_text_state'], 152 | vocab_size=model_metadata['n_vocab'], 153 | hidden_act="gelu", 154 | max_position_embeddings=model_metadata['n_text_ctx'], 155 | apply_query_key_layer_scaling=False, 156 | max_batch_size=args.max_batch_size, 157 | max_input_len=args.max_input_len, 158 | max_output_len=args.max_output_len, 159 | opt_level=None, 160 | cross_attention=True, 161 | has_position_embedding=True, 162 | has_token_type_embedding=False, 163 | int8=args.quant_mode_dec.has_act_or_weight_quant() 164 | ) 165 | 166 | tensorrt_llm_whisper_decoder = tensorrt_llm.models.DecoderModel( 167 | num_layers=model_metadata['n_text_layer'], 168 | num_heads=model_metadata['n_text_head'], 169 | hidden_size=model_metadata['n_text_state'], 170 | ffn_hidden_size=4 * model_metadata['n_text_state'], 171 | encoder_hidden_size=model_metadata['n_text_state'], 172 | encoder_num_heads=model_metadata['n_text_head'], 173 | vocab_size=model_metadata['n_vocab'], 174 | head_size=model_metadata['n_text_state'] // 175 | model_metadata['n_text_head'], 176 | max_position_embeddings=model_metadata['n_text_ctx'], 177 | has_position_embedding=True, 178 | relative_attention=False, 179 | max_distance=0, 180 | num_buckets=0, 181 | has_embedding_layernorm=False, 182 | has_embedding_scale=False, 183 | q_scaling=1.0, 184 | has_attention_qkvo_bias=True, 185 | has_mlp_bias=True, 186 | has_model_final_layernorm=True, 187 | layernorm_eps=1e-5, 188 | layernorm_position=LayerNormPositionType.pre_layernorm, 189 | layernorm_type=LayerNormType.LayerNorm, 190 | hidden_act="gelu", 191 | rescale_before_lm_head=False, 192 | dtype=str_dtype_to_trt(args.dtype), 193 | logits_dtype=str_dtype_to_trt(args.dtype)) 194 | 195 | if args.use_weight_only_dec: 196 | tensorrt_llm_whisper_decoder = quantize_model( 197 | tensorrt_llm_whisper_decoder, args.quant_mode_dec) 198 | 199 | load_decoder_weight( 200 | tensorrt_llm_whisper_decoder, 201 | model_params, 202 | ) 203 | 204 | network = builder.create_network() 205 | 206 | if args.use_gemm_plugin: 207 | network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) 208 | 209 | if args.use_layernorm_plugin: 210 | network.plugin_config.set_layernorm_plugin(dtype=args.use_layernorm_plugin) 211 | 212 | if args.use_gpt_attention_plugin: 213 | network.plugin_config.set_gpt_attention_plugin(dtype=args.use_gpt_attention_plugin) 214 | 215 | if args.use_context_fmha_dec: 216 | network.plugin_config.set_context_fmha(ContextFMHAType.enabled) 217 | 218 | if args.remove_input_padding: 219 | network.plugin_config.enable_remove_input_padding() 220 | 221 | with net_guard(network): 222 | inputs = tensorrt_llm_whisper_decoder.prepare_inputs( 223 | args.max_batch_size, 224 | args.max_beam_width, 225 | args.max_input_len, 226 | args.max_output_len, 227 | model_metadata['n_audio_ctx'], 228 | ) 229 | 230 | tensorrt_llm_whisper_decoder(*inputs) 231 | 232 | if args.debug_mode: 233 | for k, v in tensorrt_llm_whisper_decoder.named_network_outputs(): 234 | network._mark_output(v, k, str_dtype_to_trt(args.dtype)) 235 | 236 | engine = None 237 | engine = builder.build_engine(network, builder_config) 238 | 239 | config_path = os.path.join(args.output_dir, 'decoder_config.json') 240 | builder.save_config(builder_config, config_path) 241 | 242 | serialize_engine(engine, os.path.join(args.output_dir, "decoder.engine")) 243 | 244 | 245 | def run(args=None, log_level='error'): 246 | 247 | logger.set_level(log_level) 248 | 249 | if args.use_weight_only_enc: 250 | args.quant_mode_enc = QuantMode.from_description( 251 | quantize_weights=True, 252 | quantize_activations=False, 253 | use_int4_weights="int4" in args.weight_only_precision) 254 | else: 255 | args.quant_mode_enc = QuantMode(0) 256 | 257 | if args.use_weight_only_dec: 258 | args.quant_mode_dec = QuantMode.from_description( 259 | quantize_weights=True, 260 | quantize_activations=False, 261 | use_int4_weights="int4" in args.weight_only_precision) 262 | else: 263 | args.quant_mode_dec = QuantMode(0) 264 | 265 | if args.int8_kv_cache: 266 | args.quant_mode_dec = args.quant_mode.set_int8_kv_cache() 267 | 268 | model = torch.load(args.model_path) 269 | 270 | _t = time.time() 271 | build_encoder(model, args) 272 | _te = time.time()-_t 273 | 274 | _t = time.time() 275 | build_decoder(model, args) 276 | _td = time.time()-_t 277 | 278 | print(f"Time taken for building Encoder: {_te:.2f} seconds.") 279 | print(f"Time taken for building Decoder: {_td:.2f} seconds.") 280 | print(f"Exported model size: {get_export_size(args.output_dir)}") 281 | 282 | 283 | if __name__ == '__main__': 284 | parser = argparse.ArgumentParser() 285 | parser.add_argument('--output_dir', type=str) 286 | parser.add_argument('--log_level', type=str) 287 | args = parser.parse_args() 288 | 289 | trt_build_args = load_trt_build_config(args.output_dir) 290 | 291 | print(f"[TRTBuilderConfig]:") 292 | print(vars(trt_build_args)) 293 | 294 | run(args=trt_build_args, log_level=args.log_level) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

WhisperS2T ⚡

2 |

An Optimized Speech-to-Text Pipeline for the Whisper Model Supporting Multiple Inference Engine!

3 |

4 | 5 | Downloads 6 | 7 | 8 | GitHub Contributors 9 | 10 | 11 | PyPi Release Version 12 | 13 | 14 | Issues 15 | 16 |

17 |

18 | 19 | WhisperS2T is an optimized lightning-fast open-sourced **Speech-to-Text** (ASR) pipeline. It is tailored for the whisper model to provide faster whisper transcription. It's designed to be exceptionally fast than other implementation, boasting a **2.3X speed improvement over [WhisperX](https://github.com/m-bain/whisperX/tree/main) and a 3X speed boost compared to [HuggingFace Pipeline](https://huggingface.co/openai/whisper-large-v2) with FlashAttention 2 ([Insanely Fast Whisper](https://github.com/Vaibhavs10/insanely-fast-whisper))**. Moreover, it includes several heuristics to enhance transcription accuracy. 20 | 21 | [**Whisper**](https://github.com/openai/whisper) is a general-purpose speech recognition model developed by OpenAI and not me. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. 22 | 23 | 24 | ## Release Notes 25 | 26 | * [Feb 25, 2024]: Added prebuilt docker images and transcript exporter to `txt, json, tsv, srt, vtt`. (Check complete [release note](https://github.com/shashikg/WhisperS2T/releases/tag/v1.3.1)) 27 | * [Jan 28, 2024]: Added support for TensorRT-LLM backend. 28 | * [Dec 23, 2023]: Added support for word alignment for CTranslate2 backend (check [benchmark](https://github.com/shashikg/WhisperS2T/releases/tag/v1.2.0)). 29 | * [Dec 19, 2023]: Added support for Whisper-Large-V3 and Distil-Whisper-Large-V2 (check [benchmark](https://github.com/shashikg/WhisperS2T/releases/tag/v1.1.0)). 30 | * [Dec 17, 2023]: Released WhisperS2T! 31 | 32 | ## Quickstart 33 | 34 | Checkout the Google Colab notebooks provided here: [notebooks](notebooks) 35 | 36 | ## Future Roadmaps 37 | 38 | - [x] Ready to use docker container. 39 | - [ ] WhisperS2T-Server: Optimized end-to-end deployment ready server codebase. 40 | - [ ] In depth documentation, use github pages to host it. 41 | - [ ] Explore possibility of integrating Meta's SeamlessM4T model. 42 | - [ ] Add more datasets for WER benchmarking. 43 | 44 | ## Benchmark and Technical Report 45 | 46 | Stay tuned for a technical report comparing WhisperS2T against other whisper pipelines. Meanwhile, check some quick benchmarks on A30 GPU. See `scripts/` directory for the benchmarking scripts that I used. 47 | 48 | ![A30 Benchmark](https://github.com/shashikg/WhisperS2T/assets/22556187/caecbb38-b69e-4daa-bcdc-16beb9456de5) 49 | 50 | **NOTE:** I conducted all the benchmarks using the `without_timestamps` parameter set as `True`. Adjusting this parameter to `False` may enhance the Word Error Rate (WER) of the HuggingFace pipeline but at the expense of increased inference time. Notably, the improvements in inference speed were achieved solely through a **superior pipeline design**, without any specific optimization made to the backend inference engines (such as CTranslate2, FlashAttention2, etc.). For instance, WhisperS2T (utilizing FlashAttention2) demonstrates significantly superior inference speed compared to the HuggingFace pipeline (also using FlashAttention2), despite both leveraging the same inference engine—HuggingFace whisper model with FlashAttention2. Additionally, there is a noticeable difference in the WER as well. 51 | 52 | 53 | ## Features 54 | 55 | - 🔄 **Multi-Backend Support:** Support for various Whisper model backends including Original OpenAI Model, HuggingFace Model with FlashAttention2, and CTranslate2 Model. 56 | - 🎙️ **Easy Integration of Custom VAD Models:** Seamlessly add custom Voice Activity Detection (VAD) models to enhance control and accuracy in speech recognition. 57 | - 🎧 **Effortless Handling of Small or Large Audio Files:** Intelligently batch smaller speech segments from various files, ensuring optimal performance. 58 | - ⏳ **Streamlined Processing for Large Audio Files:** Asynchronously loads large audio files in the background while transcribing segmented batches, notably reducing loading times. 59 | - 🌐 **Batching Support with Multiple Language/Task Decoding:** Decode multiple languages or perform both transcription and translation in a single batch for improved versatility and transcription time. (Best support with CTranslate2 backend) 60 | - 🧠 **Reduction in Hallucination:** Optimized parameters and heuristics to decrease repeated text output or hallucinations. (Some heuristics works only with CTranslate2 backend) 61 | - ⏱️ **Dynamic Time Length Support (Experimental):** Process variable-length inputs in a given input batch instead of fixed 30 seconds, providing flexibility and saving computation time during transcription. (Only with CTranslate2 backend) 62 | 63 | 64 | ## Getting Started 65 | 66 | ### From Docker Container 67 | 68 | #### Prebuilt containers 69 | 70 | ```sh 71 | docker pull shashikg/whisper_s2t:dev-trtllm 72 | ``` 73 | 74 | Dockerhub repo: [https://hub.docker.com/r/shashikg/whisper_s2t/tags](https://hub.docker.com/r/shashikg/whisper_s2t/tags) 75 | 76 | #### Building your own container 77 | 78 | Build from `main` branch. 79 | 80 | ```sh 81 | docker build --build-arg WHISPER_S2T_VER=main --build-arg SKIP_TENSORRT_LLM=1 -t whisper_s2t:main . 82 | ``` 83 | 84 | Build from specific release `v1.3.0`. 85 | 86 | ```sh 87 | git checkout v1.3.0 88 | docker build --build-arg WHISPER_S2T_VER=v1.3.0 --build-arg SKIP_TENSORRT_LLM=1 -t whisper_s2t:1.3.0 . 89 | ``` 90 | 91 | To build the container with TensorRT-LLM support: 92 | 93 | ```sh 94 | docker build --build-arg WHISPER_S2T_VER=main -t whisper_s2t:main-trtllm . 95 | ``` 96 | 97 | ### Local Installation 98 | 99 | Install audio packages required for resampling and loading audio files. 100 | 101 | #### For Ubuntu 102 | ```sh 103 | apt-get install -y libsndfile1 ffmpeg 104 | ``` 105 | 106 | #### For MAC 107 | ```sh 108 | brew install ffmpeg 109 | ``` 110 | 111 | #### For Ubuntu/MAC/Windows/AnyOther With Conda for Python 112 | ```sh 113 | conda install conda-forge::ffmpeg 114 | ``` 115 | 116 | To install or update to the latest released version of WhisperS2T use the following command: 117 | 118 | ```sh 119 | pip install -U whisper-s2t 120 | ``` 121 | 122 | Or to install from latest commit in this repo: 123 | 124 | ```sh 125 | pip install -U git+https://github.com/shashikg/WhisperS2T.git 126 | ``` 127 | 128 | **NOTE:** If your CUDNN and CUBLAS installation is done using pip wheel, you can run the following to add CUDNN path to `LD_LIBRARY_PATH`: 129 | 130 | ```sh 131 | export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` 132 | ``` 133 | 134 | **To use TensorRT-LLM Backend** 135 | 136 | For TensortRT-LLM backend, you will need to install TensorRT and TensorRT-LLM. 137 | 138 | ```sh 139 | bash /install_tensorrt.sh 140 | ``` 141 | 142 | For most of the debian system the given bash script should work, if it doesn't/other system please follow the official TensorRT-LLM instructions [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main). 143 | 144 | ### Usage 145 | 146 | #### CTranslate2 Backend 147 | 148 | ```py 149 | import whisper_s2t 150 | 151 | model = whisper_s2t.load_model(model_identifier="large-v2", backend='CTranslate2') 152 | 153 | files = ['data/KINCAID46/audio/1.wav'] 154 | lang_codes = ['en'] 155 | tasks = ['transcribe'] 156 | initial_prompts = [None] 157 | 158 | out = model.transcribe_with_vad(files, 159 | lang_codes=lang_codes, 160 | tasks=tasks, 161 | initial_prompts=initial_prompts, 162 | batch_size=32) 163 | 164 | print(out[0][0]) # Print first utterance for first file 165 | """ 166 | [Console Output] 167 | 168 | {'text': "Let's bring in Phil Mackie who is there at the palace. We're looking at Teresa and Philip May. Philip, can you see how he's being transferred from the helicopters? It looks like, as you said, the beast. It's got its headlights on because the sun is beginning to set now, certainly sinking behind some clouds. It's about a quarter of a mile away down the Grand Drive", 169 | 'avg_logprob': -0.25426941679184695, 170 | 'no_speech_prob': 8.147954940795898e-05, 171 | 'start_time': 0.0, 172 | 'end_time': 24.8} 173 | """ 174 | ``` 175 | 176 | To use word alignment load the model using this: 177 | 178 | ```py 179 | model = whisper_s2t.load_model("large-v2", asr_options={'word_timestamps': True}) 180 | ``` 181 | 182 | #### TensorRT-LLM Backend 183 | 184 | ```py 185 | import whisper_s2t 186 | 187 | model = whisper_s2t.load_model(model_identifier="large-v2", backend='TensorRT-LLM') 188 | 189 | files = ['data/KINCAID46/audio/1.wav'] 190 | lang_codes = ['en'] 191 | tasks = ['transcribe'] 192 | initial_prompts = [None] 193 | 194 | out = model.transcribe_with_vad(files, 195 | lang_codes=lang_codes, 196 | tasks=tasks, 197 | initial_prompts=initial_prompts, 198 | batch_size=24) 199 | 200 | print(out[0][0]) # Print first utterance for first file 201 | """ 202 | [Console Output] 203 | 204 | {'text': "Let's bring in Phil Mackie who is there at the palace. We're looking at Teresa and Philip May. Philip, can you see how he's being transferred from the helicopters? It looks like, as you said, the beast. It's got its headlights on because the sun is beginning to set now, certainly sinking behind some clouds. It's about a quarter of a mile away down the Grand Drive", 205 | 'start_time': 0.0, 206 | 'end_time': 24.8} 207 | """ 208 | ``` 209 | 210 | Check this [Documentation](docs.md) for more details. 211 | 212 | **NOTE:** For first run the model may give slightly slower inference speed. After 1-2 runs it will give better inference speed. This is due to the JIT tracing of the VAD model. 213 | 214 | 215 | ## Acknowledgements 216 | - [**OpenAI Whisper Team**](https://github.com/openai/whisper): Thanks to the OpenAI Whisper Team for open-sourcing the whisper model. 217 | - [**HuggingFace Team**](https://huggingface.co/docs/transformers/model_doc/whisper): Thanks to the HuggingFace Team for their integration of FlashAttention2 and the Whisper model in the transformers library. 218 | - [**CTranslate2 Team**](https://github.com/OpenNMT/CTranslate2/): Thanks to the CTranslate2 Team for providing a faster inference engine for Transformers architecture. 219 | - [**NVIDIA NeMo Team**](https://github.com/NVIDIA/NeMo): Thanks to the NVIDIA NeMo Team for their contribution of the open-source VAD model used in this pipeline. 220 | - [**NVIDIA TensorRT-LLM Team**](https://github.com/NVIDIA/TensorRT-LLM/): Thanks to the NVIDIA TensorRT-LLM Team for their awesome LLM inference optimizations. 221 | 222 | 223 | ## License 224 | 225 | This project is licensed under MIT License - see the [LICENSE](LICENSE) file for details. 226 | 227 | -------------------------------------------------------------------------------- /whisper_s2t/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from .configs import * 7 | from .audio import pad_or_trim, audio_batch_generator, load_audio 8 | 9 | 10 | def stitch_speech_segments(start_ends, max_len=27.0, max_silent_region=None): 11 | 12 | speech_duration = [end - start for start, end in start_ends] 13 | 14 | stitched_speech_segments = [] 15 | 16 | curr_seg = [0] 17 | curr_dur = speech_duration[0] 18 | idx = 1 19 | while idx < len(start_ends): 20 | if curr_dur + speech_duration[idx] > max_len: 21 | stitched_speech_segments.append([start_ends[_] for _ in curr_seg]) 22 | curr_seg = [idx] 23 | curr_dur = speech_duration[idx] 24 | else: 25 | curr_dur += speech_duration[idx] 26 | curr_seg.append(idx) 27 | 28 | idx += 1 29 | 30 | stitched_speech_segments.append([start_ends[_] for _ in curr_seg]) 31 | 32 | if max_silent_region is None: 33 | return stitched_speech_segments 34 | 35 | stitched_speech_segments_joined = [] 36 | for segs in stitched_speech_segments: 37 | _segs = [] 38 | curr_seg_start_time, curr_seg_end_time = segs[0] 39 | for i in range(1, len(segs)): 40 | if (segs[i][0] - curr_seg_end_time) >= max_silent_region: 41 | _segs.append((curr_seg_start_time, curr_seg_end_time)) 42 | curr_seg_start_time = segs[i][0] 43 | 44 | curr_seg_end_time = segs[i][1] 45 | 46 | _segs.append((curr_seg_start_time, curr_seg_end_time)) 47 | 48 | stitched_speech_segments_joined.append(_segs) 49 | 50 | 51 | return stitched_speech_segments_joined 52 | 53 | 54 | class BasicSegmenter: 55 | def __init__(self, max_seg_len=29.0, sampling_rate=16000): 56 | self.max_seg_len = max_seg_len 57 | self.sampling_rate = sampling_rate 58 | 59 | def __call__(self, input_file=None, audio_signal=None): 60 | if audio_signal is None: 61 | audio_signal, audio_duration = load_audio(input_file, sr=self.sampling_rate, return_duration=True) 62 | else: 63 | audio_duration = len(audio_signal)/self.sampling_rate 64 | 65 | start_ends = [] 66 | for i in range(0, int(audio_duration), int(self.max_seg_len)): 67 | start_ends.append([i, i + self.max_seg_len]) 68 | 69 | start_ends[-1][1] = min(audio_duration, start_ends[-1][1]) # fix edge 70 | 71 | return start_ends, audio_signal 72 | 73 | 74 | class WhisperDataset(torch.utils.data.Dataset): 75 | def __init__(self, audio_files, lang_codes, tasks, initial_prompts, tokenizer, max_initial_prompt_len, 76 | device="cuda", 77 | dta_padding=48000, 78 | without_timestamps=True, 79 | use_dynamic_time_axis=False): 80 | 81 | self.audio_files = audio_files 82 | self.lang_codes = lang_codes 83 | self.tasks = tasks 84 | self.initial_prompts = initial_prompts 85 | self.tokenizer = tokenizer 86 | self.device = device 87 | self.dta_padding = dta_padding 88 | self.without_timestamps = without_timestamps 89 | self.use_dynamic_time_axis = use_dynamic_time_axis 90 | self.max_initial_prompt_len = max_initial_prompt_len 91 | 92 | if type(audio_files[0]) == str: 93 | self.get_audio_signal = self._get_audio_signal_from_file 94 | else: 95 | self.get_audio_signal = self._get_audio_signal_from_array 96 | 97 | def _get_audio_signal_from_array(self, item): 98 | return self.audio_files[item] 99 | 100 | def _get_audio_signal_from_file(self, item): 101 | return load_audio(self.audio_files[item]) 102 | 103 | def __len__(self): 104 | return len(self.audio_files) 105 | 106 | def __getitem__(self, item): 107 | audio = self.get_audio_signal(item) 108 | seq_len = audio.shape[-1] 109 | 110 | if self.initial_prompts[item]: 111 | initial_prompt = " " + self.initial_prompts[item].strip() 112 | initial_prompt_tokens = self.tokenizer.encode(initial_prompt)[-self.max_initial_prompt_len:] 113 | else: 114 | initial_prompt_tokens = [] 115 | 116 | prompt = self.tokenizer.sot_sequence(task=self.tasks[item], lang=self.lang_codes[item]) 117 | 118 | if self.without_timestamps: 119 | prompt = prompt + [self.tokenizer.no_timestamps] 120 | 121 | return audio, prompt, initial_prompt_tokens, seq_len 122 | 123 | 124 | class WhisperDataLoader: 125 | def __init__(self, device, tokenizer, speech_segmenter, 126 | dta_padding=3.0, 127 | without_timestamps=True, 128 | max_speech_len=29.0, 129 | max_initial_prompt_len=223, 130 | merge_chunks=True, 131 | use_dynamic_time_axis=False): 132 | 133 | self.device = device 134 | self.tokenizer = tokenizer 135 | self.speech_segmenter = speech_segmenter 136 | self.basic_segmenter = BasicSegmenter(max_seg_len=max_speech_len) 137 | self.dta_padding = int(dta_padding*SAMPLE_RATE) 138 | self.without_timestamps = without_timestamps 139 | self.max_speech_len = max_speech_len 140 | self.max_initial_prompt_len = max_initial_prompt_len 141 | self.use_dynamic_time_axis = use_dynamic_time_axis 142 | self.merge_chunks = merge_chunks 143 | 144 | def data_collate_fn(self, batch): 145 | if self.use_dynamic_time_axis: 146 | max_len = min(max([_[3] for _ in batch]) + self.dta_padding, N_SAMPLES) 147 | else: 148 | max_len = N_SAMPLES 149 | 150 | signal_batch = torch.stack([torch.from_numpy(pad_or_trim(_[0], length=max_len)).to(self.device) for _ in batch]) 151 | seq_len = torch.tensor([_[3] for _ in batch]).to(self.device) 152 | 153 | prompt_batch = [] 154 | initial_prompt_max_len = max([len(_[2]) for _ in batch]) 155 | if initial_prompt_max_len: 156 | for _ in batch: prompt_batch.append([self.tokenizer.sot_prev] + (initial_prompt_max_len-len(_[2]))*[self.tokenizer.silent_token] + _[2] + _[1]) 157 | else: 158 | for _ in batch: prompt_batch.append(_[1]) 159 | 160 | if len(batch[0]) == 5: 161 | seg_metadata = [_[4] for _ in batch] 162 | return signal_batch, prompt_batch, seq_len, seg_metadata 163 | else: 164 | return signal_batch, prompt_batch, seq_len 165 | 166 | def get_segmented_audio_signal(self, start_ends, audio_signal, file_id, lang, task, initial_prompt, sr=16000): 167 | 168 | if initial_prompt: 169 | initial_prompt = " " + initial_prompt.strip() 170 | initial_prompt_tokens = self.tokenizer.encode(initial_prompt)[-self.max_initial_prompt_len:] 171 | else: 172 | initial_prompt_tokens = [] 173 | 174 | prompt = self.tokenizer.sot_sequence(task=task, lang=lang) 175 | 176 | if self.without_timestamps: 177 | prompt.append(self.tokenizer.no_timestamps) 178 | else: 179 | prompt.append(self.tokenizer.timestamp_begin) 180 | 181 | segmented_audio_signal = [] 182 | 183 | if self.merge_chunks: 184 | stitched_speech_segments = stitch_speech_segments(start_ends, max_len=self.max_speech_len) 185 | for stitched_seg in stitched_speech_segments: 186 | audio = [] 187 | for st, et in stitched_seg: 188 | audio.append(audio_signal[int(st*sr):int(et*sr)]) 189 | 190 | audio = np.concatenate(audio) 191 | seq_len = audio.shape[-1] 192 | seg_metadata = { 193 | 'file_id': file_id, 194 | 'start_time': stitched_seg[0][0], 195 | 'end_time': stitched_seg[-1][1], 196 | 'stitched_seg': stitched_seg, 197 | 'lang_code': lang 198 | } 199 | segmented_audio_signal.append((audio, prompt, initial_prompt_tokens, seq_len, seg_metadata)) 200 | else: 201 | for st, et in start_ends: 202 | audio = audio_signal[int(st*sr):int(et*sr)] 203 | seq_len = audio.shape[-1] 204 | segmented_audio_signal.append((audio, prompt, initial_prompt_tokens, seq_len, {'file_id': file_id, 'start_time': st, 'end_time': et})) 205 | 206 | return segmented_audio_signal 207 | 208 | def get_data_loader_with_vad(self, audio_files, lang_codes, tasks, initial_prompts, batch_size=16): 209 | 210 | segmented_audio_signal = [] 211 | pbar_update_len = {} 212 | for file_id, (audio_signal, lang, task, initial_prompt) in enumerate(zip(audio_batch_generator(audio_files), lang_codes, tasks, initial_prompts)): 213 | start_ends, audio_signal = self.speech_segmenter(audio_signal=audio_signal) 214 | new_segmented_audio_signal = self.get_segmented_audio_signal(start_ends, audio_signal, file_id, lang, task, initial_prompt) 215 | pbar_update_len[file_id] = 1/len(new_segmented_audio_signal) 216 | 217 | segmented_audio_signal = segmented_audio_signal + new_segmented_audio_signal 218 | 219 | while len(segmented_audio_signal) > batch_size: 220 | batch = segmented_audio_signal[:batch_size] 221 | segmented_audio_signal = segmented_audio_signal[batch_size:] 222 | 223 | signal_batch, prompt_batch, seq_len, seg_metadata = self.data_collate_fn(batch) 224 | pbar_update = int(sum([pbar_update_len[_['file_id']] for _ in seg_metadata])*100) 225 | 226 | yield signal_batch, prompt_batch, seq_len, seg_metadata, pbar_update 227 | 228 | signal_batch, prompt_batch, seq_len, seg_metadata = self.data_collate_fn(segmented_audio_signal) 229 | pbar_update = int(sum([pbar_update_len[_['file_id']] for _ in seg_metadata])*100) 230 | 231 | yield signal_batch, prompt_batch, seq_len, seg_metadata, pbar_update 232 | 233 | def get_data_loader(self, audio_files, lang_codes, tasks, initial_prompts, batch_size=16): 234 | 235 | # dataset = WhisperDataset(audio_files, lang_codes, tasks, initial_prompts, self.tokenizer, 236 | # without_timestamps=self.without_timestamps, 237 | # max_initial_prompt_len=self.max_initial_prompt_len, 238 | # use_dynamic_time_axis=self.use_dynamic_time_axis) 239 | 240 | # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=self.data_collate_fn) 241 | 242 | # return tqdm(data_loader, desc=f"Transcribing") 243 | 244 | segmented_audio_signal = [] 245 | pbar_update_len = {} 246 | for file_id, (audio_signal, lang, task, initial_prompt) in enumerate(zip(audio_batch_generator(audio_files), lang_codes, tasks, initial_prompts)): 247 | start_ends, audio_signal = self.basic_segmenter(audio_signal=audio_signal) 248 | new_segmented_audio_signal = self.get_segmented_audio_signal(start_ends, audio_signal, file_id, lang, task, initial_prompt) 249 | pbar_update_len[file_id] = 1/len(new_segmented_audio_signal) 250 | 251 | segmented_audio_signal = segmented_audio_signal + new_segmented_audio_signal 252 | 253 | while len(segmented_audio_signal) > batch_size: 254 | batch = segmented_audio_signal[:batch_size] 255 | segmented_audio_signal = segmented_audio_signal[batch_size:] 256 | 257 | signal_batch, prompt_batch, seq_len, seg_metadata = self.data_collate_fn(batch) 258 | pbar_update = int(sum([pbar_update_len[_['file_id']] for _ in seg_metadata])*100) 259 | 260 | yield signal_batch, prompt_batch, seq_len, seg_metadata, pbar_update 261 | 262 | signal_batch, prompt_batch, seq_len, seg_metadata = self.data_collate_fn(segmented_audio_signal) 263 | pbar_update = int(sum([pbar_update_len[_['file_id']] for _ in seg_metadata])*100) 264 | 265 | yield signal_batch, prompt_batch, seq_len, seg_metadata, pbar_update 266 | 267 | def __call__(self, audio_files, lang_codes, tasks, initial_prompts, batch_size=16, use_vad=True): 268 | if use_vad: 269 | return self.get_data_loader_with_vad(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size) 270 | else: 271 | return self.get_data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size) --------------------------------------------------------------------------------