├── tests ├── __init__.py ├── onnx_asr │ ├── __init__.py │ ├── test_load_model_errors.py │ └── test_recognize.py └── preprocessors │ ├── __init__.py │ ├── conftest.py │ ├── test_resample_preprocessor.py │ ├── test_kaldi_preprocessor.py │ ├── test_whisper_preprocessor.py │ ├── test_nemo_preprocessor.py │ └── test_gigaam_preprocessor.py ├── src └── onnx_asr │ ├── py.typed │ ├── __init__.py │ ├── preprocessors │ ├── __init__.py │ ├── resampler.py │ └── preprocessor.py │ ├── models │ ├── __init__.py │ ├── pyannote.py │ ├── tone.py │ ├── kaldi.py │ ├── silero.py │ ├── gigaam.py │ ├── whisper.py │ └── nemo.py │ ├── cli.py │ ├── vad.py │ ├── utils.py │ ├── adapters.py │ ├── asr.py │ └── loader.py ├── .vscode └── settings.json ├── preprocessors ├── __init__.py ├── build.py ├── whisper.py ├── gigaam.py ├── nemo.py ├── kaldi.py └── resample.py ├── .devcontainer └── devcontainer.json ├── .github └── workflows │ ├── python-publish.yml │ └── python-package.yml ├── LICENSE ├── examples └── performance benchmark.ipynb ├── pyproject.toml ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/onnx_asr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/onnx_asr/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file for PEP-561 (inline types) 2 | -------------------------------------------------------------------------------- /src/onnx_asr/__init__.py: -------------------------------------------------------------------------------- 1 | """Automatic Speech Recognition in Python using ONNX models.""" 2 | 3 | from .loader import load_model, load_vad 4 | 5 | __all__ = ["load_model", "load_vad"] 6 | -------------------------------------------------------------------------------- /src/onnx_asr/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | """ASR preprocessor implementations.""" 2 | 3 | from .preprocessor import Preprocessor, PreprocessorRuntimeConfig 4 | from .resampler import Resampler 5 | 6 | __all__ = ["Preprocessor", "PreprocessorRuntimeConfig", "Resampler"] 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.unittestEnabled": false, 3 | "python.testing.pytestEnabled": true, 4 | "python.analysis.autoImportCompletions": true, 5 | "python.analysis.typeCheckingMode": "standard", 6 | "[python]": { 7 | "editor.defaultFormatter": "charliermarsh.ruff" 8 | } 9 | } -------------------------------------------------------------------------------- /tests/preprocessors/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | def create_waveforms(base_sec: int): 6 | rng = np.random.default_rng(0) 7 | return [rng.random((base_sec * 16_000 + x), dtype=np.float32) * 2 - 1 for x in [0, 1, 79, 80, -1, -10000]] 8 | 9 | 10 | def pytest_generate_tests(metafunc: pytest.Metafunc): 11 | if "waveforms" in metafunc.fixturenames: 12 | batch = create_waveforms(30 if "whisper" in metafunc.module.__name__ else 5) 13 | metafunc.parametrize("waveforms", [waveform.reshape(1, -1) for waveform in batch] + [batch]) 14 | -------------------------------------------------------------------------------- /preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from .gigaam import GigaamPreprocessorV2, GigaamPreprocessorV3 2 | from .kaldi import KaldiPreprocessor 3 | from .nemo import NemoPreprocessor80, NemoPreprocessor128 4 | from .resample import ResamplePreprocessor8, ResamplePreprocessor16 5 | from .whisper import WhisperPreprocessor80, WhisperPreprocessor128 6 | 7 | __all__ = [ 8 | "GigaamPreprocessorV2", 9 | "GigaamPreprocessorV3", 10 | "KaldiPreprocessor", 11 | "NemoPreprocessor80", 12 | "NemoPreprocessor128", 13 | "ResamplePreprocessor8", 14 | "ResamplePreprocessor16", 15 | "WhisperPreprocessor80", 16 | "WhisperPreprocessor128", 17 | ] 18 | -------------------------------------------------------------------------------- /src/onnx_asr/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ASR model implementations.""" 2 | 3 | from .gigaam import GigaamV2Ctc, GigaamV2Rnnt, GigaamV3E2eCtc, GigaamV3E2eRnnt 4 | from .kaldi import KaldiTransducer 5 | from .nemo import NemoConformerAED, NemoConformerCtc, NemoConformerRnnt, NemoConformerTdt 6 | from .pyannote import PyAnnoteVad 7 | from .silero import SileroVad 8 | from .tone import TOneCtc 9 | from .whisper import WhisperHf, WhisperOrt 10 | 11 | __all__ = [ 12 | "GigaamV2Ctc", 13 | "GigaamV2Rnnt", 14 | "GigaamV3E2eCtc", 15 | "GigaamV3E2eRnnt", 16 | "KaldiTransducer", 17 | "NemoConformerAED", 18 | "NemoConformerCtc", 19 | "NemoConformerRnnt", 20 | "NemoConformerTdt", 21 | "PyAnnoteVad", 22 | "SileroVad", 23 | "TOneCtc", 24 | "WhisperHf", 25 | "WhisperOrt", 26 | ] 27 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/python 3 | { 4 | "name": "Python 3", 5 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile 6 | "image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye", 7 | "features": { 8 | "ghcr.io/devcontainers/features/git-lfs:1": {} 9 | }, 10 | "customizations": { 11 | "vscode": { 12 | "extensions": [ 13 | "ms-toolsai.jupyter", 14 | "charliermarsh.ruff", 15 | "tamasfe.even-better-toml", 16 | "vincent-templier.vscode-netron" 17 | ] 18 | } 19 | }, 20 | "postCreateCommand": "pipx install pdm && pip install --no-cache-dir --user ipykernel ipywidgets" 21 | } -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package to PyPI when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | pypi-publish: 15 | name: upload release to PyPI 16 | runs-on: ubuntu-latest 17 | permissions: 18 | contents: read 19 | id-token: write 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: pdm-project/setup-pdm@v4 24 | with: 25 | python-version: "3.12" 26 | 27 | - name: Publish package distributions to PyPI 28 | run: pdm publish 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Ilya Stupakov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/onnx_asr/models/pyannote.py: -------------------------------------------------------------------------------- 1 | """PyAnnote VAD implementation.""" 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | import onnxruntime as rt 8 | 9 | from onnx_asr.utils import OnnxSessionOptions, is_float32_array 10 | from onnx_asr.vad import Vad 11 | 12 | 13 | class PyAnnoteVad(Vad): 14 | """PyAnnote VAD implementation.""" 15 | 16 | def __init__(self, model_files: dict[str, Path], onnx_options: OnnxSessionOptions): 17 | """Create PyAnnote VAD. 18 | 19 | Args: 20 | model_files: Dict with paths to model files. 21 | onnx_options: Options for onnxruntime InferenceSession. 22 | 23 | """ 24 | self._model = rt.InferenceSession(model_files["model"], **onnx_options) 25 | 26 | @staticmethod 27 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 28 | suffix = "?" + quantization if quantization else "" 29 | return {"model": f"**/model{suffix}.onnx"} 30 | 31 | def _encode(self, waveforms: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: 32 | (logits,) = self._model.run(["logits"], {"input_values": waveforms[:, None]}) 33 | assert is_float32_array(logits) 34 | return logits 35 | -------------------------------------------------------------------------------- /tests/onnx_asr/test_load_model_errors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import onnx_asr 4 | 5 | 6 | def test_model_not_supported_error() -> None: 7 | with pytest.raises(onnx_asr.loader.ModelNotSupportedError): 8 | onnx_asr.load_model("xxx") 9 | 10 | 11 | def test_model_path_not_found_error() -> None: 12 | with pytest.raises(onnx_asr.loader.ModelPathNotDirectoryError): 13 | onnx_asr.load_model("whisper", "pyproject.toml") 14 | 15 | 16 | def test_model_file_not_found_error() -> None: 17 | with pytest.raises(onnx_asr.loader.ModelFileNotFoundError): 18 | onnx_asr.load_model("onnx-community/whisper-tiny", quantization="xxx") 19 | 20 | 21 | def test_more_than_one_model_file_found_error() -> None: 22 | with pytest.raises(onnx_asr.loader.MoreThanOneModelFileFoundError): 23 | onnx_asr.load_model("onnx-community/whisper-tiny", quantization="*int8") 24 | 25 | 26 | def test_no_model_name_or_path_specified_error() -> None: 27 | with pytest.raises(onnx_asr.loader.NoModelNameOrPathSpecifiedError): 28 | onnx_asr.load_model("whisper") 29 | 30 | 31 | def test_no_model_name_and_empty_path_specified_error() -> None: 32 | with pytest.raises(onnx_asr.loader.NoModelNameOrPathSpecifiedError): 33 | onnx_asr.load_model("whisper", "./xxx") 34 | 35 | 36 | def test_invalid_model_type_in_config_error() -> None: 37 | with pytest.raises(onnx_asr.loader.InvalidModelTypeInConfigError): 38 | onnx_asr.load_model("onnx-community/pyannote-segmentation-3.0") 39 | -------------------------------------------------------------------------------- /preprocessors/build.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import onnx 4 | import onnxscript 5 | 6 | import preprocessors 7 | 8 | 9 | def save_model(function: onnxscript.OnnxFunction, filename: Path): 10 | model = function.to_model_proto() 11 | 12 | model = onnxscript.optimizer.optimize(model, input_size_limit=100) 13 | model = onnxscript.optimizer.optimize(model, input_size_limit=100) 14 | 15 | model.producer_name = "OnnxScript" 16 | model.producer_version = onnxscript.__version__ 17 | model.metadata_props.add(key="model_author", value="Ilya Stupakov") 18 | model.metadata_props.add(key="model_license", value="MIT License") 19 | 20 | onnx.checker.check_model(model, full_check=True) 21 | onnx.save_model(model, filename) 22 | 23 | 24 | def build(): 25 | preprocessors_dir = Path("src/onnx_asr/preprocessors") 26 | save_model(preprocessors.KaldiPreprocessor, preprocessors_dir.joinpath("kaldi.onnx")) 27 | save_model(preprocessors.GigaamPreprocessorV2, preprocessors_dir.joinpath("gigaam_v2.onnx")) 28 | save_model(preprocessors.GigaamPreprocessorV3, preprocessors_dir.joinpath("gigaam_v3.onnx")) 29 | save_model(preprocessors.NemoPreprocessor80, preprocessors_dir.joinpath("nemo80.onnx")) 30 | save_model(preprocessors.NemoPreprocessor128, preprocessors_dir.joinpath("nemo128.onnx")) 31 | save_model(preprocessors.WhisperPreprocessor80, preprocessors_dir.joinpath("whisper80.onnx")) 32 | save_model(preprocessors.WhisperPreprocessor128, preprocessors_dir.joinpath("whisper128.onnx")) 33 | save_model(preprocessors.ResamplePreprocessor8, preprocessors_dir.joinpath("resample8.onnx")) 34 | save_model(preprocessors.ResamplePreprocessor16, preprocessors_dir.joinpath("resample16.onnx")) 35 | -------------------------------------------------------------------------------- /src/onnx_asr/preprocessors/resampler.py: -------------------------------------------------------------------------------- 1 | """Waveform resampler implementations.""" 2 | 3 | from importlib.resources import files 4 | from typing import Literal 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | import onnxruntime as rt 9 | 10 | from onnx_asr.utils import OnnxSessionOptions, SampleRates, is_float32_array, is_int64_array 11 | 12 | 13 | class Resampler: 14 | """Waveform resampler to 8/16 kHz implementation.""" 15 | 16 | def __init__(self, sample_rate: Literal[8_000, 16_000], onnx_options: OnnxSessionOptions): 17 | """Create waveform resampler. 18 | 19 | Args: 20 | sample_rate: Target sample rate. 21 | onnx_options: Options for onnxruntime InferenceSession. 22 | 23 | """ 24 | self._target_sample_rate = sample_rate 25 | self._preprocessor = rt.InferenceSession( 26 | files(__package__).joinpath(f"resample{sample_rate // 1000}.onnx").read_bytes(), **onnx_options 27 | ) 28 | 29 | def __call__( 30 | self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64], sample_rate: SampleRates 31 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 32 | """Resample waveform.""" 33 | if sample_rate == self._target_sample_rate: 34 | return waveforms, waveforms_lens 35 | 36 | resampled, resampled_lens = self._preprocessor.run( 37 | ["resampled", "resampled_lens"], 38 | {"waveforms": waveforms, "waveforms_lens": waveforms_lens, "sample_rate": np.array([sample_rate], dtype=np.int64)}, 39 | ) 40 | assert is_float32_array(resampled) 41 | assert is_int64_array(resampled_lens) 42 | return resampled, resampled_lens 43 | -------------------------------------------------------------------------------- /src/onnx_asr/cli.py: -------------------------------------------------------------------------------- 1 | """CLI for ASR models.""" 2 | 3 | import argparse 4 | import pathlib 5 | from importlib.metadata import version 6 | from typing import get_args 7 | 8 | import onnx_asr 9 | from onnx_asr.loader import ModelNames, ModelTypes, VadNames 10 | 11 | 12 | def run() -> None: 13 | """Run CLI for ASR models.""" 14 | parser = argparse.ArgumentParser(prog="onnx_asr", description="Automatic Speech Recognition in Python using ONNX models.") 15 | parser.add_argument( 16 | "model", 17 | help=f"Model name or type {(*get_args(ModelNames), *get_args(ModelTypes), 'onnx-community/whisper-...')}", 18 | ) 19 | parser.add_argument( 20 | "filename", 21 | help="Path to wav file (only PCM_U8, PCM_16, PCM_24 and PCM_32 formats are supported).", 22 | nargs="+", 23 | ) 24 | parser.add_argument("-p", "--model_path", type=pathlib.Path, help="Path to directory with model files") 25 | parser.add_argument("-q", "--quantization", help="Model quantization ('int8' for example)") 26 | parser.add_argument("--vad", help="Use VAD model", choices=get_args(VadNames)) 27 | parser.add_argument("--version", action="version", version=f"%(prog)s {version('onnx_asr')}") 28 | args = parser.parse_args() 29 | 30 | model = onnx_asr.load_model(args.model, args.model_path, quantization=args.quantization) 31 | if args.vad: 32 | vad = onnx_asr.load_vad(args.vad) 33 | for segment in model.with_vad(vad, batch_size=1).recognize(args.filename): 34 | for res in segment: 35 | print(f"[{res.start:5.1f}, {res.end:5.1f}]: {res.text}") # noqa: T201 36 | print() # noqa: T201 37 | else: 38 | for text in model.recognize(args.filename): 39 | print(text) # noqa: T201 40 | -------------------------------------------------------------------------------- /src/onnx_asr/vad.py: -------------------------------------------------------------------------------- 1 | """Base VAD classes.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from collections.abc import Iterator 5 | from dataclasses import dataclass 6 | from itertools import islice 7 | from typing import Literal 8 | 9 | import numpy as np 10 | import numpy.typing as npt 11 | 12 | from .asr import Asr, TimestampedResult 13 | from .utils import pad_list 14 | 15 | 16 | @dataclass 17 | class SegmentResult: 18 | """Segment recognition result.""" 19 | 20 | start: float 21 | end: float 22 | text: str 23 | 24 | 25 | @dataclass 26 | class TimestampedSegmentResult(TimestampedResult, SegmentResult): 27 | """Timestamped segment recognition result.""" 28 | 29 | 30 | class Vad(ABC): 31 | """Base VAD class.""" 32 | 33 | @abstractmethod 34 | def segment_batch( 35 | self, 36 | waveforms: npt.NDArray[np.float32], 37 | waveforms_len: npt.NDArray[np.int64], 38 | sample_rate: Literal[8_000, 16_000], 39 | **kwargs: float, 40 | ) -> Iterator[Iterator[tuple[int, int]]]: 41 | """Segment waveforms batch.""" 42 | ... 43 | 44 | def recognize_batch( 45 | self, 46 | asr: Asr, 47 | waveforms: npt.NDArray[np.float32], 48 | waveforms_len: npt.NDArray[np.int64], 49 | sample_rate: Literal[8_000, 16_000], 50 | language: str | None, 51 | batch_size: float = 8, 52 | **kwargs: float, 53 | ) -> Iterator[Iterator[TimestampedSegmentResult]]: 54 | """Segment and recognize waveforms batch.""" 55 | 56 | def recognize( 57 | waveform: npt.NDArray[np.float32], segment: Iterator[tuple[int, int]] 58 | ) -> Iterator[TimestampedSegmentResult]: 59 | while batch := tuple(islice(segment, int(batch_size))): 60 | yield from ( 61 | TimestampedSegmentResult(start / sample_rate, end / sample_rate, res.text, res.timestamps, res.tokens) 62 | for res, (start, end) in zip( 63 | asr.recognize_batch(*pad_list([waveform[start:end] for start, end in batch]), language), 64 | batch, 65 | strict=True, 66 | ) 67 | ) 68 | 69 | return map(recognize, waveforms, self.segment_batch(waveforms, waveforms_len, sample_rate, **kwargs)) 70 | -------------------------------------------------------------------------------- /examples/performance benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "461aa578", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%pip install onnx_asr[gpu,hub]" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "f73bf5a8", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import onnx_asr\n", 21 | "from onnx_asr.utils import read_wav_files\n", 22 | "\n", 23 | "model = onnx_asr.load_model(\"gigaam-v3-ctc\", providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "87b2c36a", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "for k in [1, 4, 16]:\n", 34 | " print(f\"Batch size: {k}\")\n", 35 | "\n", 36 | " data = [\"test.wav\"] * k\n", 37 | " waveform = read_wav_files(data, 16_000)\n", 38 | " resampled = model.resampler(*waveform)\n", 39 | " preprocessed = model.asr._preprocessor(*resampled)\n", 40 | " encoded = model.asr._encode(*preprocessed)\n", 41 | "\n", 42 | " t1 = %timeit -o -q model.resampler(*waveform)\n", 43 | " t2 = %timeit -o -q model.asr._preprocessor(*resampled)\n", 44 | " t3 = %timeit -o -q model.asr._encode(*preprocessed)\n", 45 | " t4 = %timeit -o -q list(model.asr._decoding(*encoded, None))\n", 46 | " t5 = %timeit -o -q model.recognize(data)\n", 47 | "\n", 48 | " for res in [t1, t2, t3, t4, t5]:\n", 49 | " res.timings = [t / k for t in res.timings]\n", 50 | "\n", 51 | " print(f\"resampler: {t1}\")\n", 52 | " print(f\"preprocessor: {t2}\")\n", 53 | " print(f\"encoder: {t3}\")\n", 54 | " print(f\"decoding: {t4}\")\n", 55 | " print(f\"full recognize: {t5}\")\n", 56 | " print()" 57 | ] 58 | } 59 | ], 60 | "metadata": { 61 | "kernelspec": { 62 | "display_name": "Python 3", 63 | "language": "python", 64 | "name": "python3" 65 | }, 66 | "language_info": { 67 | "codemirror_mode": { 68 | "name": "ipython", 69 | "version": 3 70 | }, 71 | "file_extension": ".py", 72 | "mimetype": "text/x-python", 73 | "name": "python", 74 | "nbconvert_exporter": "python", 75 | "pygments_lexer": "ipython3", 76 | "version": "3.12.11" 77 | } 78 | }, 79 | "nbformat": 4, 80 | "nbformat_minor": 5 81 | } 82 | -------------------------------------------------------------------------------- /tests/preprocessors/test_resample_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | import torchaudio 5 | 6 | from onnx_asr.preprocessors import Resampler 7 | from onnx_asr.utils import pad_list 8 | from preprocessors import resample 9 | 10 | 11 | def onnx_preprocessor8(waveforms, waveforms_lens, sample_rate): 12 | return resample.ResamplePreprocessor8(waveforms, waveforms_lens, sample_rate) 13 | 14 | 15 | def onnx_preprocessor16(waveforms, waveforms_lens, sample_rate): 16 | return resample.ResamplePreprocessor16(waveforms, waveforms_lens, sample_rate) 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def preprocessor(request): 21 | match request.param: 22 | case "onnx_func8": 23 | return onnx_preprocessor8 24 | case "onnx_func16": 25 | return onnx_preprocessor16 26 | case "onnx_model8": 27 | return Resampler(8_000, {}) 28 | case "onnx_model16": 29 | return Resampler(16_000, {}) 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "preprocessor", 34 | [ 35 | "onnx_func8", 36 | "onnx_model8", 37 | ], 38 | indirect=True, 39 | ) 40 | @pytest.mark.parametrize("sample_rate", [8_000, 11_025, 16_000, 22_050, 24_000, 32_000, 44_100, 48_000]) 41 | def test_resample8_preprocessor(preprocessor, sample_rate, waveforms): 42 | expected = [ 43 | torchaudio.functional.resample(torch.tensor(waveform).unsqueeze(0), sample_rate, 8_000)[0].numpy() 44 | for waveform in waveforms 45 | ] 46 | expected, expected_lens = pad_list(expected) 47 | actual, actual_lens = preprocessor(*pad_list(waveforms), sample_rate) 48 | 49 | np.testing.assert_equal(actual_lens, expected_lens) 50 | np.testing.assert_allclose(actual, expected, atol=1e-6) 51 | 52 | 53 | @pytest.mark.parametrize( 54 | "preprocessor", 55 | [ 56 | "onnx_func16", 57 | "onnx_model16", 58 | ], 59 | indirect=True, 60 | ) 61 | @pytest.mark.parametrize("sample_rate", [8_000, 11_025, 16_000, 22_050, 24_000, 32_000, 44_100, 48_000]) 62 | def test_resample16_preprocessor(preprocessor, sample_rate, waveforms): 63 | expected = [ 64 | torchaudio.functional.resample(torch.tensor(waveform).unsqueeze(0), sample_rate, 16_000)[0].numpy() 65 | for waveform in waveforms 66 | ] 67 | expected, expected_lens = pad_list(expected) 68 | actual, actual_lens = preprocessor(*pad_list(waveforms), sample_rate) 69 | 70 | np.testing.assert_equal(actual_lens, expected_lens) 71 | np.testing.assert_allclose(actual, expected, atol=1e-6) 72 | -------------------------------------------------------------------------------- /src/onnx_asr/preprocessors/preprocessor.py: -------------------------------------------------------------------------------- 1 | """ASR preprocessor implementations.""" 2 | 3 | from concurrent.futures import ThreadPoolExecutor 4 | from importlib.resources import files 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | import onnxruntime as rt 10 | 11 | from onnx_asr.utils import OnnxSessionOptions, is_float32_array, is_int64_array 12 | 13 | 14 | class PreprocessorRuntimeConfig(OnnxSessionOptions, total=False): 15 | """Preprocessor runtime config.""" 16 | 17 | max_concurrent_workers: int | None 18 | """Max parallel preprocessing threads (None - auto, 1 - without parallel processing).""" 19 | 20 | 21 | class Preprocessor: 22 | """ASR preprocessor implementation.""" 23 | 24 | def __init__(self, name: str, runtime_config: PreprocessorRuntimeConfig): 25 | """Create ASR preprocessor. 26 | 27 | Args: 28 | name: Preprocessor name. 29 | runtime_config: Runtime configuration. 30 | 31 | """ 32 | self._max_concurrent_workers = runtime_config.pop("max_concurrent_workers", 1) 33 | if name == "identity": 34 | self._preprocessor = None 35 | else: 36 | filename = str(Path(name).with_suffix(".onnx")) 37 | self._preprocessor = rt.InferenceSession(files(__package__).joinpath(filename).read_bytes(), **runtime_config) 38 | 39 | def _preprocess( 40 | self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64] 41 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 42 | if not self._preprocessor: 43 | return waveforms, waveforms_lens 44 | 45 | features, features_lens = self._preprocessor.run( 46 | ["features", "features_lens"], {"waveforms": waveforms, "waveforms_lens": waveforms_lens} 47 | ) 48 | assert is_float32_array(features) 49 | assert is_int64_array(features_lens) 50 | return features, features_lens 51 | 52 | def __call__( 53 | self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64] 54 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 55 | """Convert waveforms to model features.""" 56 | if self._preprocessor is None or waveforms.shape[0] == 1 or self._max_concurrent_workers == 1: 57 | return self._preprocess(waveforms, waveforms_lens) 58 | 59 | with ThreadPoolExecutor(max_workers=self._max_concurrent_workers) as executor: 60 | features, features_lens = zip( 61 | *executor.map(self._preprocess, waveforms[:, None], waveforms_lens[:, None]), strict=True 62 | ) 63 | return np.concatenate(features, axis=0), np.concatenate(features_lens, axis=0) 64 | -------------------------------------------------------------------------------- /tests/onnx_asr/test_recognize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import onnx_asr 5 | import onnx_asr.utils 6 | from onnx_asr.adapters import TextResultsAsrAdapter 7 | from onnx_asr.vad import Vad 8 | 9 | models = [ 10 | "gigaam-v2-ctc", 11 | "gigaam-v2-rnnt", 12 | "nemo-fastconformer-ru-ctc", 13 | "nemo-fastconformer-ru-rnnt", 14 | "istupakov/canary-180m-flash-onnx", 15 | "alphacep/vosk-model-ru", 16 | "alphacep/vosk-model-small-ru", 17 | "t-tech/t-one", 18 | "whisper-base", 19 | "onnx-community/whisper-tiny", 20 | ] 21 | 22 | 23 | @pytest.fixture(scope="module", params=models) 24 | def model(request: pytest.FixtureRequest) -> TextResultsAsrAdapter: 25 | if request.param == "t-tech/t-one": 26 | quantization = None 27 | elif request.param == "onnx-community/whisper-tiny": 28 | quantization = "uint8" 29 | else: 30 | quantization = "int8" 31 | 32 | return onnx_asr.load_model(request.param, quantization=quantization, providers=["CPUExecutionProvider"]) 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def vad() -> Vad: 37 | return onnx_asr.load_vad("silero") 38 | 39 | 40 | def test_supported_only_mono_audio_error(model: TextResultsAsrAdapter) -> None: 41 | rng = np.random.default_rng(0) 42 | waveform = rng.random((1 * 16_000, 2), dtype=np.float32) 43 | 44 | with pytest.raises(onnx_asr.utils.SupportedOnlyMonoAudioError): 45 | model.recognize(waveform) 46 | 47 | 48 | def test_wrong_sample_rate_error(model: TextResultsAsrAdapter) -> None: 49 | rng = np.random.default_rng(0) 50 | waveform = rng.random((1 * 16_000), dtype=np.float32) 51 | 52 | with pytest.raises(onnx_asr.utils.WrongSampleRateError): 53 | model.recognize(waveform, sample_rate=25_000) # type: ignore 54 | 55 | 56 | def test_recognize(model: TextResultsAsrAdapter) -> None: 57 | rng = np.random.default_rng(0) 58 | waveform = rng.random((1 * 16_000), dtype=np.float32) 59 | 60 | result = model.recognize(waveform) 61 | assert isinstance(result, str) 62 | 63 | 64 | def test_recognize_with_vad(model: TextResultsAsrAdapter, vad: Vad) -> None: 65 | rng = np.random.default_rng(0) 66 | waveform = rng.random((1 * 16_000), dtype=np.float32) 67 | 68 | result = list(model.with_vad(vad).recognize(waveform)) 69 | assert isinstance(result, list) 70 | 71 | 72 | def test_concurrent_preprocessor() -> None: 73 | model = onnx_asr.load_model( 74 | "alphacep/vosk-model-small-ru", quantization="int8", preprocessor_config={"max_concurrent_workers": 2} 75 | ) 76 | rng = np.random.default_rng(0) 77 | waveform = rng.random((1 * 16_000), dtype=np.float32) 78 | 79 | result = model.recognize([waveform] * 2) 80 | assert isinstance(result, list) 81 | -------------------------------------------------------------------------------- /tests/preprocessors/test_kaldi_preprocessor.py: -------------------------------------------------------------------------------- 1 | import kaldi_native_fbank as knf 2 | import numpy as np 3 | import pytest 4 | import torch 5 | import torchaudio 6 | 7 | from onnx_asr.preprocessors import Preprocessor 8 | from onnx_asr.utils import pad_list 9 | from preprocessors import kaldi 10 | 11 | 12 | def pad_features(arrays): 13 | lens = np.array([array.shape[0] for array in arrays]) 14 | max_len = lens.max() 15 | return np.stack([np.pad(array, ((0, max_len - array.shape[0]), (0, 0))) for array in arrays]), lens 16 | 17 | 18 | def preprocessor_origin(waveforms, lens): 19 | opts = knf.FbankOptions() 20 | opts.frame_opts.dither = kaldi.dither 21 | opts.frame_opts.snip_edges = kaldi.snip_edges 22 | opts.mel_opts.num_bins = kaldi.num_mel_bins 23 | opts.mel_opts.high_freq = kaldi.high_freq 24 | 25 | results = [] 26 | for waveform, len in zip(waveforms, lens, strict=True): 27 | fbank = knf.OnlineFbank(opts) 28 | fbank.accept_waveform(kaldi.sample_rate, waveform[:len]) 29 | fbank.input_finished() 30 | results.append(np.array([fbank.get_frame(i) for i in range(fbank.num_frames_ready)])) 31 | 32 | return pad_features(results) 33 | 34 | 35 | def preprocessor_torch(waveforms, lens): 36 | results = [] 37 | for waveform, len in zip(waveforms, lens, strict=True): 38 | results.append( 39 | torchaudio.compliance.kaldi.fbank( 40 | torch.from_numpy(waveform[:len]).unsqueeze(0).contiguous(), 41 | dither=kaldi.dither, 42 | snip_edges=kaldi.snip_edges, 43 | num_mel_bins=kaldi.num_mel_bins, 44 | high_freq=kaldi.high_freq, 45 | ).numpy() 46 | ) 47 | 48 | return pad_features(results) 49 | 50 | 51 | @pytest.fixture( 52 | scope="module", 53 | params=[ 54 | "torch", 55 | "onnx_func", 56 | "onnx_model", 57 | "onnx_model_mt", 58 | ], 59 | ) 60 | def preprocessor(request): 61 | match request.param: 62 | case "torch": 63 | return preprocessor_torch 64 | case "onnx_func": 65 | return kaldi.KaldiPreprocessor 66 | case "onnx_model": 67 | return Preprocessor("kaldi", {}) 68 | case "onnx_model_mt": 69 | return Preprocessor("kaldi", {"max_concurrent_workers": 2}) 70 | 71 | 72 | def test_kaldi_preprocessor(preprocessor, waveforms): 73 | waveforms, lens = pad_list(waveforms) 74 | expected, expected_lens = preprocessor_origin(waveforms, lens) 75 | actual, actual_lens = preprocessor(waveforms, lens) 76 | 77 | assert expected.shape[1] == max(expected_lens) 78 | np.testing.assert_equal(actual_lens, expected_lens) 79 | np.testing.assert_allclose(actual, expected, atol=5e-4, rtol=1e-4) 80 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: CI 5 | 6 | permissions: 7 | contents: read 8 | 9 | on: 10 | push: 11 | branches: [ "main" ] 12 | pull_request: 13 | branches: [ "main" ] 14 | 15 | env: 16 | HF_HUB_CACHE: hub_cache 17 | 18 | jobs: 19 | build: 20 | name: Build package 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: pdm-project/setup-pdm@v4 25 | with: 26 | cache: true 27 | python-version: "3.12" 28 | - name: Cache HF models 29 | uses: actions/cache@v4 30 | with: 31 | path: ${{ env.HF_HUB_CACHE }} 32 | key: cache-hf-models-${{ hashFiles('tests/onnx_asr/*.py') }} 33 | 34 | - name: Install dependencies 35 | run: pdm sync 36 | - name: Lint code with Ruff 37 | run: pdm run ruff check --output-format=github 38 | - name: Check code formatting with Ruff 39 | run: pdm run ruff format --diff 40 | - name: Check types with MyPy 41 | run: pdm run mypy . 42 | - name: Test with pytest 43 | run: pdm run pytest --cov=onnx_asr 44 | - name: Build package dist 45 | run: pdm build 46 | 47 | - uses: actions/upload-artifact@v4 48 | with: 49 | name: wheels 50 | path: ./dist/onnx_asr-*.whl 51 | retention-days: 3 52 | 53 | test: 54 | name: Test package 55 | needs: build 56 | runs-on: ${{ matrix.os }}-latest 57 | strategy: 58 | matrix: 59 | os: [Ubuntu, Windows, macOS] 60 | python-version: ["3.10", "3.11", "3.12", "3.13"] 61 | numpy-version: [numpy==1.*, numpy==2.*] 62 | exclude: 63 | - python-version: "3.13" 64 | numpy-version: numpy==1.* 65 | steps: 66 | - uses: actions/checkout@v4 67 | - name: Set up Python ${{ matrix.python-version }} 68 | uses: actions/setup-python@v5 69 | with: 70 | python-version: ${{ matrix.python-version }} 71 | cache: 'pip' 72 | - uses: actions/download-artifact@v4 73 | with: 74 | name: wheels 75 | path: ./dist 76 | - name: Cache HF models 77 | uses: actions/cache@v4 78 | with: 79 | path: ${{ env.HF_HUB_CACHE }} 80 | key: cache-hf-models-${{ hashFiles('tests/onnx_asr/*.py') }} 81 | 82 | - name: Install package 83 | shell: bash 84 | run: pip install pytest ${{ matrix.numpy-version }} $(find ./dist -iname onnx_asr-*.whl)[cpu,hub] 85 | - name: Test with pytest 86 | run: pytest ./tests/onnx_asr 87 | 88 | complete: 89 | name: Complete 90 | needs: test 91 | runs-on: ubuntu-latest 92 | steps: 93 | - uses: actions/checkout@v4 94 | -------------------------------------------------------------------------------- /preprocessors/whisper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchaudio 3 | from onnx import TensorProto, numpy_helper 4 | from onnxscript import FLOAT, INT64, script 5 | from onnxscript import opset17 as op 6 | 7 | chunk_length = 30 8 | sample_rate = 16_000 9 | n_fft = 400 10 | win_length = 400 11 | hop_length = 160 12 | 13 | clamp_min = 1e-10 14 | ln10 = 2.302585092994046 15 | 16 | melscale_fbanks80 = torchaudio.functional.melscale_fbanks( 17 | n_fft // 2 + 1, 0, sample_rate // 2, 80, sample_rate, "slaney", "slaney" 18 | ) 19 | melscale_fbanks128 = torchaudio.functional.melscale_fbanks( 20 | n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney" 21 | ) 22 | 23 | 24 | @script() 25 | def whisper_preprocessor( 26 | waveforms: FLOAT["batch_size", "N"], waveforms_lens: INT64["batch_size"], melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"] 27 | ): 28 | waveforms = op.Pad( 29 | waveforms, pads=(chunk_length * sample_rate - op.Shape(waveforms, start=1, end=2)) * op.Constant(value=[0, 0, 0, 1]) 30 | ) 31 | waveforms = op.Pad( 32 | waveforms, 33 | pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]), 34 | mode="reflect", 35 | ) 36 | 37 | hann_window = op.HannWindow(win_length, output_datatype=TensorProto.DOUBLE) 38 | image = op.STFT(op.CastLike(waveforms, hann_window), hop_length, hann_window)[:, :-1] 39 | spectrogram = op.ReduceSumSquare(image, axes=(-1,), keepdims=0) 40 | 41 | mel_spectrogram = op.MatMul(op.CastLike(spectrogram, melscale_fbanks), melscale_fbanks) 42 | log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min)) / ln10 43 | log_mel_spectrogram = (op.Max(log_mel_spectrogram, op.ReduceMax(log_mel_spectrogram) - 8) + 4) / 4.0 44 | 45 | return op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]), op.ConstantOfShape( 46 | op.Shape(waveforms_lens), value=numpy_helper.from_array(np.array([chunk_length * sample_rate // hop_length])) 47 | ) 48 | 49 | 50 | @script(doc_string="LogMelSpectrogram feature extractor for Whisper models") 51 | def WhisperPreprocessor80( 52 | waveforms: FLOAT["batch_size", "N"], 53 | waveforms_lens: INT64["batch_size"], 54 | ) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]: 55 | features, features_lens = whisper_preprocessor( 56 | waveforms, 57 | waveforms_lens, 58 | op.Constant(value=numpy_helper.from_array(melscale_fbanks80.numpy(), "melscale_fbanks")), 59 | ) 60 | return features, features_lens 61 | 62 | 63 | @script(doc_string="LogMelSpectrogram feature extractor for Whisper models") 64 | def WhisperPreprocessor128( 65 | waveforms: FLOAT["batch_size", "N"], 66 | waveforms_lens: INT64["batch_size"], 67 | ) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]: 68 | features, features_lens = whisper_preprocessor( 69 | waveforms, 70 | waveforms_lens, 71 | op.Constant(value=numpy_helper.from_array(melscale_fbanks128.numpy(), "melscale_fbanks")), 72 | ) 73 | return features, features_lens 74 | -------------------------------------------------------------------------------- /preprocessors/gigaam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from onnx import TensorProto, numpy_helper 4 | from onnxscript import FLOAT, INT64, script 5 | from onnxscript import opset17 as op 6 | 7 | sample_rate = 16_000 8 | n_fft_v2 = sample_rate // 40 9 | n_fft_v3 = sample_rate // 50 10 | win_length_v2 = sample_rate // 40 11 | win_length_v3 = sample_rate // 50 12 | hop_length = sample_rate // 100 13 | n_mels = 64 14 | 15 | f_min = 0 16 | f_max = 8_000 17 | 18 | clamp_min = 1e-9 19 | clamp_max = 1e9 20 | 21 | melscale_fbanks_v2 = torchaudio.functional.melscale_fbanks(n_fft_v2 // 2 + 1, f_min, f_max, n_mels, sample_rate) 22 | melscale_fbanks_v3 = ( 23 | torchaudio.functional.melscale_fbanks(n_fft_v3 // 2 + 1, f_min, f_max, n_mels, sample_rate).bfloat16().float() 24 | ) 25 | hann_window_v3 = torch.hann_window(win_length_v3).bfloat16().double() 26 | 27 | 28 | @script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models") 29 | def GigaamPreprocessorV2( 30 | waveforms: FLOAT["batch_size", "N"], 31 | waveforms_lens: INT64["batch_size"], 32 | ) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]: 33 | waveforms = op.Pad( 34 | waveforms, 35 | pads=op.Constant(value=[0, n_fft_v2 // 2, 0, n_fft_v2 // 2]), 36 | mode="reflect", 37 | ) 38 | 39 | hann_window = op.HannWindow(win_length_v2, output_datatype=TensorProto.DOUBLE) 40 | image = op.STFT(op.CastLike(waveforms, hann_window), hop_length, hann_window) 41 | spectrogram = op.ReduceSumSquare(image, axes=(-1,), keepdims=0) 42 | 43 | melscale_fbanks_tensor = op.Constant(value=numpy_helper.from_array(melscale_fbanks_v2.numpy(), "melscale_fbanks")) 44 | mel_spectrogram = op.MatMul(op.CastLike(spectrogram, melscale_fbanks_tensor), melscale_fbanks_tensor) 45 | log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max)) 46 | 47 | features_lens = waveforms_lens / hop_length + 1 48 | features = op.Transpose(log_mel_spectrogram, perm=(0, 2, 1)) 49 | return features, features_lens 50 | 51 | 52 | @script(doc_string="LogMelSpectrogram feature extractor for GigaAM v3 models") 53 | def GigaamPreprocessorV3( 54 | waveforms: FLOAT["batch_size", "N"], 55 | waveforms_lens: INT64["batch_size"], 56 | ) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]: 57 | hann_window = op.Constant(value=numpy_helper.from_array(hann_window_v3.numpy(), "hann_window")) 58 | image = op.STFT(op.CastLike(waveforms, hann_window), hop_length, hann_window) 59 | spectrogram = op.ReduceSumSquare(image, axes=(-1,), keepdims=0) 60 | 61 | melscale_fbanks_tensor = op.Constant(value=numpy_helper.from_array(melscale_fbanks_v3.numpy(), "melscale_fbanks")) 62 | mel_spectrogram = op.MatMul(op.CastLike(spectrogram, melscale_fbanks_tensor), melscale_fbanks_tensor) 63 | log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max)) 64 | 65 | features_lens = (waveforms_lens - win_length_v3) / hop_length + 1 66 | features = op.Transpose(log_mel_spectrogram, perm=(0, 2, 1)) 67 | return features, features_lens 68 | -------------------------------------------------------------------------------- /src/onnx_asr/models/tone.py: -------------------------------------------------------------------------------- 1 | """T-one model implementations.""" 2 | 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | import onnxruntime as rt 9 | 10 | from onnx_asr.asr import AsrRuntimeConfig, _AsrWithCtcDecoding 11 | from onnx_asr.utils import is_float16_array, is_float32_array 12 | 13 | 14 | class TOneCtc(_AsrWithCtcDecoding): 15 | """T-one CTC model implementation.""" 16 | 17 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 18 | """Create T-one CTC model. 19 | 20 | Args: 21 | model_files: Dict with paths to model files. 22 | runtime_config: Runtime configuration. 23 | 24 | """ 25 | super().__init__(model_files, runtime_config) 26 | self._model = rt.InferenceSession(model_files["model"], **runtime_config.onnx_options) 27 | 28 | shapes = {x.name: x.shape for x in self._model.get_inputs()} 29 | self._chunk_size = shapes["signal"][1] 30 | self._state_size = shapes["state"][1] 31 | 32 | self._vocab: dict[int, str] = dict(enumerate(self.config["decoder_params"]["vocabulary"])) # type: ignore[typeddict-item] 33 | self._vocab_size = len(self._vocab) + 1 34 | self._blank_idx = int(self.config["pad_token_id"]) # type: ignore[typeddict-item] 35 | 36 | @staticmethod 37 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 38 | suffix = "?" + quantization if quantization else "" 39 | return {"model": f"model{suffix}.onnx"} 40 | 41 | @staticmethod 42 | def _get_sample_rate() -> Literal[8_000, 16_000]: 43 | return 8_000 44 | 45 | @property 46 | def _preprocessor_name(self) -> str: 47 | return "identity" 48 | 49 | @property 50 | def _subsampling_factor(self) -> int: 51 | return int(self.config["encoder_params"]["reduction_kernel_size"]) # type: ignore[typeddict-item] 52 | 53 | def _encode_chunk( 54 | self, waveforms: npt.NDArray[np.float32], state: npt.NDArray[np.float16] 55 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float16]]: 56 | (logprobs, new_state) = self._model.run( 57 | ["logprobs", "state_next"], {"signal": (waveforms[..., None] * (2**15 - 1)).astype(np.int32), "state": state} 58 | ) 59 | assert is_float32_array(logprobs) 60 | assert is_float16_array(new_state) 61 | return logprobs, new_state 62 | 63 | def _encode( 64 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64] 65 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 66 | waveforms = np.pad(waveforms, ((0, 0), (self._chunk_size, self._chunk_size + (-waveforms.shape[1]) % self._chunk_size))) 67 | 68 | res = [] 69 | state = np.zeros((waveforms.shape[0], self._state_size), dtype=np.float16) 70 | for chunk in np.split(waveforms, waveforms.shape[1] // self._chunk_size, axis=1): 71 | logprobs, state = self._encode_chunk(chunk, state) 72 | res.append(logprobs) 73 | 74 | return np.hstack(res[1:]), res[0].shape[1] * ((waveforms_len + self._chunk_size - 1) // self._chunk_size + 1) 75 | -------------------------------------------------------------------------------- /src/onnx_asr/models/kaldi.py: -------------------------------------------------------------------------------- 1 | """Kaldi model implementations.""" 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | import onnxruntime as rt 8 | 9 | from onnx_asr.asr import AsrRuntimeConfig, _AsrWithTransducerDecoding 10 | from onnx_asr.utils import is_float32_array, is_int64_array 11 | 12 | _STATE_TYPE = dict[tuple[int, ...], npt.NDArray[np.float32]] 13 | 14 | 15 | class KaldiTransducer(_AsrWithTransducerDecoding[_STATE_TYPE]): 16 | """Kaldi Transducer model implementation.""" 17 | 18 | CONTEXT_SIZE = 2 19 | 20 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 21 | """Create Kaldi Transducer model. 22 | 23 | Args: 24 | model_files: Dict with paths to model files. 25 | runtime_config: Runtime configuration. 26 | 27 | """ 28 | super().__init__(model_files, runtime_config) 29 | self._encoder = rt.InferenceSession(model_files["encoder"], **runtime_config.onnx_options) 30 | self._decoder = rt.InferenceSession(model_files["decoder"], **runtime_config.onnx_options) 31 | self._joiner = rt.InferenceSession(model_files["joiner"], **runtime_config.onnx_options) 32 | 33 | @staticmethod 34 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 35 | suffix = "?" + quantization if quantization else "" 36 | return { 37 | "encoder": f"*/encoder{suffix}.onnx", 38 | "decoder": f"*/decoder{suffix}.onnx", 39 | "joiner": f"*/joiner{suffix}.onnx", 40 | "vocab": "*/tokens.txt", 41 | } 42 | 43 | @property 44 | def _preprocessor_name(self) -> str: 45 | assert self.config.get("features_size", 80) == 80 46 | return "kaldi" 47 | 48 | @property 49 | def _subsampling_factor(self) -> int: 50 | return self.config.get("subsampling_factor", 4) 51 | 52 | @property 53 | def _max_tokens_per_step(self) -> int: 54 | return self.config.get("max_tokens_per_step", 1) 55 | 56 | def _encode( 57 | self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64] 58 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 59 | encoder_out, encoder_out_lens = self._encoder.run( 60 | ["encoder_out", "encoder_out_lens"], {"x": features, "x_lens": features_lens} 61 | ) 62 | assert is_float32_array(encoder_out) 63 | assert is_int64_array(encoder_out_lens) 64 | return encoder_out, encoder_out_lens 65 | 66 | def _create_state(self) -> _STATE_TYPE: 67 | return {} 68 | 69 | def _decode( 70 | self, prev_tokens: list[int], prev_state: _STATE_TYPE, encoder_out: npt.NDArray[np.float32] 71 | ) -> tuple[npt.NDArray[np.float32], int, _STATE_TYPE]: 72 | context = (-1, self._blank_idx, *prev_tokens)[-self.CONTEXT_SIZE :] 73 | 74 | decoder_out = prev_state.get(context) 75 | if decoder_out is None: 76 | (_decoder_out,) = self._decoder.run(["decoder_out"], {"y": [context]}) 77 | assert is_float32_array(_decoder_out) 78 | prev_state[context] = (decoder_out := _decoder_out) 79 | 80 | (logit,) = self._joiner.run(["logit"], {"encoder_out": encoder_out[None, :], "decoder_out": decoder_out}) 81 | assert is_float32_array(logit) 82 | return np.squeeze(logit), -1, prev_state 83 | -------------------------------------------------------------------------------- /preprocessors/nemo.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from onnx import TensorProto, numpy_helper 3 | from onnxscript import FLOAT, INT64, script 4 | from onnxscript import opset17 as op 5 | 6 | sample_rate = 16_000 7 | n_fft = 512 8 | win_length = 400 9 | hop_length = 160 10 | preemph = 0.97 11 | 12 | log_zero_guard_value = float(2**-24) 13 | 14 | melscale_fbanks80 = torchaudio.functional.melscale_fbanks( 15 | n_fft // 2 + 1, 0, sample_rate // 2, 80, sample_rate, "slaney", "slaney" 16 | ) 17 | melscale_fbanks128 = torchaudio.functional.melscale_fbanks( 18 | n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney" 19 | ) 20 | 21 | 22 | @script() 23 | def normalize(x: FLOAT["batch_size", "M", "T"], lens: INT64["batch_size"]): 24 | lens_3d = op.Unsqueeze(lens, [1, 2]) 25 | mask = op.Range(0, op.Shape(x, start=2, end=3), 1) < lens_3d 26 | lens_3d = op.CastLike(lens_3d, x) 27 | mean = op.ReduceSum(op.Where(mask, x, 0), axes=[-1], keepdims=1) / lens_3d 28 | var = op.ReduceSumSquare(op.Where(mask, x - mean, 0), axes=(-1,), keepdims=1) / (lens_3d - 1) 29 | return op.Where(mask, (x - mean) / (op.Sqrt(var) + 1e-5), 0) 30 | 31 | 32 | @script() 33 | def nemo_preprocessor( 34 | waveforms: FLOAT["batch_size", "N"], waveforms_lens: INT64["batch_size"], melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"] 35 | ): 36 | if preemph != 0.0: 37 | timemask = op.Range(0, op.Shape(waveforms, start=1, end=2), 1) < op.Unsqueeze(waveforms_lens, [1]) 38 | waveforms = op.Concat(waveforms[:, :1], waveforms[:, 1:] - preemph * waveforms[:, :-1], axis=-1) 39 | waveforms = op.Where(timemask, waveforms, 0) 40 | 41 | waveforms = op.Pad( 42 | waveforms, 43 | pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]), 44 | mode="constant", 45 | ) 46 | hann_window = op.Pad( 47 | op.HannWindow(win_length, periodic=0, output_datatype=TensorProto.DOUBLE), 48 | pads=op.Constant(value=[n_fft // 2 - win_length // 2, n_fft // 2 - win_length // 2]), 49 | ) 50 | image = op.STFT(op.CastLike(waveforms, hann_window), hop_length, hann_window) 51 | spectrogram = op.ReduceSumSquare(image, axes=(-1,), keepdims=0) 52 | 53 | mel_spectrogram = op.MatMul(op.CastLike(spectrogram, melscale_fbanks), melscale_fbanks) 54 | log_mel_spectrogram = op.Log(mel_spectrogram + log_zero_guard_value) 55 | 56 | features_lens = waveforms_lens / hop_length 57 | return normalize(op.Transpose(log_mel_spectrogram, perm=(0, 2, 1)), features_lens), features_lens 58 | 59 | 60 | @script(doc_string="LogMelSpectrogram feature extractor for Nemo models") 61 | def NemoPreprocessor80( 62 | waveforms: FLOAT["batch_size", "N"], 63 | waveforms_lens: INT64["batch_size"], 64 | ) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]: 65 | features, features_lens = nemo_preprocessor( 66 | waveforms, 67 | waveforms_lens, 68 | op.Constant(value=numpy_helper.from_array(melscale_fbanks80.numpy(), "melscale_fbanks")), 69 | ) 70 | return features, features_lens 71 | 72 | 73 | @script(doc_string="LogMelSpectrogram feature extractor for Nemo models") 74 | def NemoPreprocessor128( 75 | waveforms: FLOAT["batch_size", "N"], 76 | waveforms_lens: INT64["batch_size"], 77 | ) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]: 78 | features, features_lens = nemo_preprocessor( 79 | waveforms, 80 | waveforms_lens, 81 | op.Constant(value=numpy_helper.from_array(melscale_fbanks128.numpy(), "melscale_fbanks")), 82 | ) 83 | return features, features_lens 84 | -------------------------------------------------------------------------------- /tests/preprocessors/test_whisper_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | import torchaudio 5 | from whisper.audio import N_FRAMES, N_SAMPLES, log_mel_spectrogram, mel_filters, pad_or_trim 6 | 7 | from onnx_asr.preprocessors import Preprocessor 8 | from onnx_asr.utils import pad_list 9 | from preprocessors import whisper 10 | 11 | 12 | def preprocessor_origin(waveforms, lens, n_mels): 13 | waveforms = pad_or_trim(waveforms, N_SAMPLES) 14 | return log_mel_spectrogram(waveforms, n_mels).numpy(), np.full_like(lens, N_FRAMES) 15 | 16 | 17 | def preprocessor_torch(waveforms, lens, n_mels): 18 | waveforms = torch.from_numpy(waveforms) 19 | waveforms = waveforms[:, : whisper.chunk_length * whisper.sample_rate] 20 | waveforms = torch.nn.functional.pad(waveforms, (0, whisper.chunk_length * whisper.sample_rate - waveforms.shape[-1])) 21 | spectrogram = torchaudio.functional.spectrogram( 22 | waveforms, 23 | pad=0, 24 | window=torch.hann_window(whisper.win_length), 25 | n_fft=whisper.n_fft, 26 | hop_length=whisper.hop_length, 27 | win_length=whisper.win_length, 28 | power=2, 29 | normalized=False, 30 | )[..., :-1] 31 | mel_spectrogram = torch.matmul( 32 | spectrogram.transpose(-1, -2), whisper.melscale_fbanks80 if n_mels == 80 else whisper.melscale_fbanks128 33 | ).transpose(-1, -2) 34 | log_mel_spectrogram = torch.clamp(mel_spectrogram, min=whisper.clamp_min).log10() 35 | features = (torch.maximum(log_mel_spectrogram, log_mel_spectrogram.max() - 8.0) + 4.0) / 4.0 36 | return features, np.full_like(lens, whisper.chunk_length * whisper.sample_rate // whisper.hop_length) 37 | 38 | 39 | def preprocessor_torch80(waveforms, lens): 40 | return preprocessor_torch(waveforms, lens, 80) 41 | 42 | 43 | def preprocessor_torch128(waveforms, lens): 44 | return preprocessor_torch(waveforms, lens, 128) 45 | 46 | 47 | @pytest.fixture(scope="module") 48 | def preprocessor(request): 49 | match request.param: 50 | case "torch 80": 51 | return preprocessor_torch80 52 | case "torch 128": 53 | return preprocessor_torch128 54 | case "onnx_func 80": 55 | return whisper.WhisperPreprocessor80 56 | case "onnx_func 128": 57 | return whisper.WhisperPreprocessor128 58 | case "onnx_model 80": 59 | return Preprocessor("whisper80", {}) 60 | case "onnx_model 128": 61 | return Preprocessor("whisper128", {}) 62 | 63 | 64 | @pytest.mark.parametrize( 65 | ("n_mels", "preprocessor"), 66 | [ 67 | (80, "torch 80"), 68 | (128, "torch 128"), 69 | (80, "onnx_func 80"), 70 | (128, "onnx_func 128"), 71 | (80, "onnx_model 80"), 72 | (128, "onnx_model 128"), 73 | ], 74 | indirect=["preprocessor"], 75 | ) 76 | def test_whisper_preprocessor(n_mels, preprocessor, waveforms): 77 | waveforms, lens = pad_list(waveforms) 78 | expected, expected_lens = preprocessor_origin(waveforms, lens, n_mels) 79 | actual, actual_lens = preprocessor(waveforms, lens) 80 | 81 | assert expected.shape[2] == max(expected_lens) 82 | np.testing.assert_equal(actual_lens, expected_lens) 83 | np.testing.assert_allclose(actual, expected, atol=5e-5) 84 | 85 | 86 | @pytest.mark.parametrize(("n_mels", "melscale_fbanks"), [(80, whisper.melscale_fbanks80), (128, whisper.melscale_fbanks128)]) 87 | def test_whisper_melscale_fbanks(n_mels, melscale_fbanks): 88 | expected = mel_filters("cpu", n_mels).T.numpy() 89 | actual = melscale_fbanks.numpy() 90 | 91 | np.testing.assert_allclose(actual, expected, atol=5e-7) 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "onnx-asr" 3 | dynamic = ["version"] 4 | description = "Automatic Speech Recognition in Python using ONNX models" 5 | authors = [{ name = "Ilya Stupakov", email = "istupakov@gmail.com" }] 6 | keywords = ["asr", "speech recognition", "onnx", "stt"] 7 | dependencies = ["numpy"] 8 | requires-python = ">=3.10" 9 | readme = "README.md" 10 | license = "MIT" 11 | license-files = ["LICENSE"] 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Education", 16 | "Intended Audience :: Science/Research", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3 :: Only", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Programming Language :: Python :: 3.13", 24 | "Topic :: Multimedia :: Sound/Audio :: Speech", 25 | "Topic :: Scientific/Engineering", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Typing :: Typed", 28 | ] 29 | 30 | [project.urls] 31 | Documentation = "https://github.com/istupakov/onnx-asr#readme" 32 | "Release notes" = "https://github.com/istupakov/onnx-asr/releases" 33 | Issues = "https://github.com/istupakov/onnx-asr/issues" 34 | Source = "https://github.com/istupakov/onnx-asr" 35 | 36 | [build-system] 37 | requires = ["pdm-backend"] 38 | build-backend = "pdm.backend" 39 | 40 | [project.optional-dependencies] 41 | cpu = ["onnxruntime>=1.18.1"] 42 | gpu = ["onnxruntime-gpu>=1.18.1"] 43 | hub = ["huggingface-hub"] 44 | 45 | [project.scripts] 46 | onnx-asr = "onnx_asr.cli:run" 47 | 48 | [dependency-groups] 49 | build = [ 50 | "onnx>=1.17.0", 51 | "onnxscript>=0.2.5", 52 | "torch>=2.6.0", 53 | "torchaudio>=2.6.0", 54 | ] 55 | asrs = [ 56 | "kaldi-native-fbank>=1.21.1", 57 | "nemo-toolkit[asr]>=2.2.1", 58 | "openai-whisper>=20240930", 59 | ] 60 | test = [ 61 | "pytest>=8.3.5", 62 | "pytest-cov>=6.1.1", 63 | "onnxruntime>=1.21.1", 64 | { include-group = "build" }, 65 | { include-group = "asrs" }, 66 | ] 67 | lint = ["ruff>=0.11.6", "mypy>=1.16.1"] 68 | 69 | [tool.pdm] 70 | distribution = true 71 | 72 | [tool.pdm.version] 73 | source = "scm" 74 | 75 | [tool.pdm.build] 76 | source-includes = ["preprocessors", "tests"] 77 | 78 | [tool.pdm.scripts] 79 | build_preprocessors = { call = "preprocessors.build:build" } 80 | post_install = { composite = ["build_preprocessors"] } 81 | pre_build = { composite = ["pdm sync --no-self --group build"] } 82 | lint = { composite = ["ruff format --diff", "ruff check", "mypy ."] } 83 | 84 | [[tool.pdm.source]] 85 | name = "torch-cpu" 86 | url = "https://download.pytorch.org/whl/cpu" 87 | include_packages = ["torch", "torchaudio"] 88 | exclude_packages = ["*"] 89 | 90 | [tool.pdm.resolution.overrides] 91 | numpy = "~=2.1" 92 | 93 | [tool.mypy] 94 | python_version = "3.10" 95 | strict = true 96 | pretty = true 97 | untyped_calls_exclude = "onnxruntime" 98 | exclude = ['^preprocessors.', '^tests.preprocessors.'] 99 | 100 | [[tool.mypy.overrides]] 101 | module = ["onnxruntime.*"] 102 | follow_untyped_imports = true 103 | implicit_reexport = true 104 | 105 | [tool.ruff] 106 | line-length = 130 107 | indent-width = 4 108 | target-version = "py310" 109 | 110 | [tool.ruff.lint] 111 | select = ["ALL"] 112 | ignore = [ 113 | "A001", 114 | "A002", 115 | "ANN204", 116 | "ARG002", 117 | "ARG004", 118 | "D203", 119 | "D213", 120 | "COM812", 121 | "PLR", 122 | "PLW2901", 123 | "S101", 124 | "S105", 125 | "SLF001", 126 | "TC", 127 | ] 128 | 129 | [tool.ruff.lint.per-file-ignores] 130 | "tests/*" = ["ANN", "D", "PGH003", "PLR0911", "PLR2004"] 131 | "preprocessors/*" = ["ANN", "D", "F821", "N802", "N806"] 132 | "*.ipynb" = ["ANN", "D", "ERA", "RUF001", "T"] 133 | 134 | [tool.pytest.ini_options] 135 | filterwarnings = ["ignore::DeprecationWarning:sys.*"] 136 | -------------------------------------------------------------------------------- /preprocessors/kaldi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchaudio 4 | from onnx import TensorProto, numpy_helper 5 | from onnxscript import FLOAT, INT64, graph, script 6 | from onnxscript import opset17 as op 7 | 8 | sample_rate = 16_000 9 | n_fft = 512 10 | win_length = 400 11 | hop_length = 160 12 | num_mel_bins = 80 13 | 14 | snip_edges = False 15 | dither = 0.0 16 | remove_dc_offset = True 17 | preemphasis_coefficient = 0.97 18 | 19 | low_freq = 20 20 | high_freq = -400 21 | 22 | float_eps = float(np.finfo(np.float32).eps) 23 | 24 | mel_banks, _ = torchaudio.compliance.kaldi.get_mel_banks(num_mel_bins, n_fft, sample_rate, low_freq, high_freq, 0, 0, 1) 25 | mel_banks = torch.nn.functional.pad(mel_banks, (0, 1)).T 26 | 27 | 28 | @script() 29 | def symmetric_pad(waveforms: FLOAT["batch_size", "N"], lens: INT64["batch_size"]): 30 | @graph() 31 | def pad(waveform: FLOAT["N"], len: INT64[1]): 32 | pad_left = op.Constant(value=win_length // 2 - hop_length // 2) 33 | pad_right = op.Constant(value=win_length // 2) 34 | 35 | return op.Concat( 36 | waveform[pad_left - 1 :: -1], waveform[:len], waveform[len - 1 : len - pad_right - 1 : -1], waveform[len:], axis=-1 37 | ) 38 | 39 | return op.Cast(op.Scan(waveforms, op.Unsqueeze(lens, axes=-1), body=pad, num_scan_inputs=2), to=TensorProto.FLOAT) 40 | 41 | 42 | @script() 43 | def sliding_window(waveform: FLOAT["batch_size", "N"]): 44 | samples = op.Shape(waveform, start=1, end=2)[0] 45 | X0 = waveform[:, : win_length - hop_length] 46 | X = op.Reshape( 47 | waveform[:, win_length - hop_length : samples - (samples + hop_length - win_length) % hop_length], 48 | shape=op.Constant(value=[0, -1, hop_length]), 49 | ) 50 | 51 | @graph() 52 | def sliding_buffer(prev: FLOAT["batch_size", win_length - hop_length], curr: FLOAT["batch_size", hop_length]): 53 | hop_len = op.Constant(value=hop_length // 1) 54 | frame = op.Concat(prev, curr, axis=-1) 55 | next = frame[:, hop_len:] 56 | return next, frame 57 | 58 | _, frames = op.Scan(X0, X, body=sliding_buffer, num_scan_inputs=1, scan_input_axes=(1,), scan_output_axes=(1,)) 59 | return op.Cast(frames, to=TensorProto.FLOAT) 60 | 61 | 62 | @script() 63 | def normalize(frames: FLOAT["batch_size", "T", win_length]): 64 | if dither != 0.0: 65 | frames = frames + op.RandomNormalLike(frames, scale=dither) 66 | 67 | if remove_dc_offset: 68 | mean = op.ReduceMean(frames, axes=(-1,)) 69 | frames = frames - mean 70 | 71 | if preemphasis_coefficient != 0.0: 72 | offset = op.Pad(frames, pads=[0, 0, 1, 0, 0, -1], mode="edge") 73 | frames = frames - preemphasis_coefficient * offset 74 | 75 | return frames 76 | 77 | 78 | @script(doc_string="LogMelSpectrogram feature extractor for Kaldi models") 79 | def KaldiPreprocessor( 80 | waveforms: FLOAT["batch_size", "N"], 81 | waveforms_lens: INT64["batch_size"], 82 | ) -> tuple[FLOAT["batch_size", "T", num_mel_bins], INT64["batch_size"]]: 83 | waveforms = symmetric_pad(waveforms, waveforms_lens) 84 | frames = sliding_window(waveforms) 85 | frames = normalize(frames) 86 | 87 | povey_window = op.Pow(op.HannWindow(win_length, periodic=0), 0.85) 88 | frames = povey_window * frames 89 | 90 | image = op.DFT(op.Unsqueeze(frames, axes=-1), n_fft, axis=-2, onesided=1) 91 | spectrogram = op.ReduceSumSquare(image, axes=(-1,), keepdims=0) 92 | 93 | mel_banks_tensor = op.Constant(value=numpy_helper.from_array(mel_banks.numpy(), "mel_banks")) 94 | mel_spectrogram = op.MatMul(spectrogram, mel_banks_tensor) 95 | log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, min=float_eps)) 96 | 97 | features_lens = (waveforms_lens + hop_length / 2) / hop_length 98 | mask = op.Unsqueeze(op.Range(0, op.Shape(log_mel_spectrogram, start=1, end=2), 1), [0, 2]) < op.Unsqueeze( 99 | features_lens, [1, 2] 100 | ) 101 | features = op.Where(mask, log_mel_spectrogram, 0) 102 | return features, features_lens 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | # ML models 177 | /models/ 178 | *.onnx 179 | 180 | /*.ipynb 181 | *.wav 182 | *.mp3 183 | -------------------------------------------------------------------------------- /tests/preprocessors/test_nemo_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | import torchaudio 5 | from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor 6 | 7 | from onnx_asr.preprocessors import Preprocessor 8 | from onnx_asr.utils import pad_list 9 | from preprocessors import nemo 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def preprocessor_origin(request): 14 | preprocessor = AudioToMelSpectrogramPreprocessor( 15 | window_size=nemo.win_length / nemo.sample_rate, 16 | window_stride=nemo.hop_length / nemo.sample_rate, 17 | features=request.param, 18 | n_fft=nemo.n_fft, 19 | pad_to=0, 20 | ) 21 | preprocessor.eval() 22 | return preprocessor 23 | 24 | 25 | def preprocessor_torch(waveforms, lens, n_mels): 26 | waveforms = torch.from_numpy(waveforms) 27 | if nemo.preemph != 0.0: 28 | timemask = torch.arange(waveforms.shape[-1]).unsqueeze(0) < torch.from_numpy(lens).unsqueeze(1) 29 | waveforms = torch.cat((waveforms[:, :1], waveforms[:, 1:] - nemo.preemph * waveforms[:, :-1]), dim=1) 30 | waveforms = waveforms.masked_fill(~timemask, 0.0) 31 | 32 | spectrogram = torchaudio.functional.spectrogram( 33 | waveforms, 34 | pad=0, 35 | window=torch.hann_window(nemo.win_length, periodic=False), 36 | n_fft=nemo.n_fft, 37 | hop_length=nemo.hop_length, 38 | win_length=nemo.win_length, 39 | power=2, 40 | normalized=False, 41 | pad_mode="constant", 42 | ) 43 | mel_spectrogram = torch.matmul( 44 | spectrogram.transpose(-1, -2), nemo.melscale_fbanks80 if n_mels == 80 else nemo.melscale_fbanks128 45 | ).transpose(-1, -2) 46 | log_mel_spectrogram = torch.log(mel_spectrogram + nemo.log_zero_guard_value) 47 | 48 | features_lens = torch.from_numpy(lens) // nemo.hop_length 49 | mask = torch.arange(log_mel_spectrogram.shape[-1]) < features_lens[:, None, None] 50 | mean = torch.where(mask, log_mel_spectrogram, 0).sum(dim=-1, keepdim=True) / features_lens[:, None, None] 51 | var = torch.where(mask, (log_mel_spectrogram - mean) ** 2, 0).sum(dim=-1, keepdim=True) / (features_lens[:, None, None] - 1) 52 | features = torch.where(mask, (log_mel_spectrogram - mean) / (var.sqrt() + 1e-5), 0).numpy() 53 | return features, features_lens.numpy() 54 | 55 | 56 | def preprocessor_torch80(waveforms, lens): 57 | return preprocessor_torch(waveforms, lens, 80) 58 | 59 | 60 | def preprocessor_torch128(waveforms, lens): 61 | return preprocessor_torch(waveforms, lens, 128) 62 | 63 | 64 | @pytest.fixture(scope="module") 65 | def preprocessor(request): 66 | match request.param: 67 | case "torch 80": 68 | return preprocessor_torch80 69 | case "torch 128": 70 | return preprocessor_torch128 71 | case "onnx_func 80": 72 | return nemo.NemoPreprocessor80 73 | case "onnx_func 128": 74 | return nemo.NemoPreprocessor128 75 | case "onnx_model 80": 76 | return Preprocessor("nemo80", {}) 77 | case "onnx_model 128": 78 | return Preprocessor("nemo128", {}) 79 | case "onnx_model_mt 80": 80 | return Preprocessor("nemo80", {"max_concurrent_workers": 2}) 81 | case "onnx_model_mt 128": 82 | return Preprocessor("nemo128", {"max_concurrent_workers": 2}) 83 | 84 | 85 | @pytest.mark.parametrize( 86 | ("preprocessor_origin", "preprocessor"), 87 | [ 88 | (80, "torch 80"), 89 | (128, "torch 128"), 90 | (80, "onnx_func 80"), 91 | (128, "onnx_func 128"), 92 | (80, "onnx_model 80"), 93 | (128, "onnx_model 128"), 94 | (80, "onnx_model_mt 80"), 95 | (128, "onnx_model_mt 128"), 96 | ], 97 | indirect=["preprocessor_origin", "preprocessor"], 98 | ) 99 | def test_nemo_preprocessor(preprocessor_origin, preprocessor, waveforms): 100 | waveforms, lens = pad_list(waveforms) 101 | expected, expected_lens = preprocessor_origin(input_signal=torch.from_numpy(waveforms), length=torch.from_numpy(lens)) 102 | actual, actual_lens = preprocessor(waveforms, lens) 103 | 104 | np.testing.assert_equal(actual_lens, expected_lens.numpy()) 105 | np.testing.assert_allclose(actual, expected.numpy(), atol=1e-4, rtol=1e-4) 106 | 107 | 108 | @pytest.mark.parametrize( 109 | ("preprocessor_origin", "melscale_fbanks"), 110 | [(80, nemo.melscale_fbanks80), (128, nemo.melscale_fbanks128)], 111 | indirect=["preprocessor_origin"], 112 | ) 113 | def test_nemo_melscale_fbanks(preprocessor_origin, melscale_fbanks): 114 | expected = preprocessor_origin.filter_banks[0].T.numpy() 115 | actual = melscale_fbanks.numpy() 116 | 117 | np.testing.assert_allclose(actual, expected, atol=5e-7) 118 | -------------------------------------------------------------------------------- /src/onnx_asr/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for ASR.""" 2 | 3 | import wave 4 | from collections.abc import Sequence 5 | from typing import Any, Literal, TypedDict, TypeGuard, get_args 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | import onnxruntime as rt 10 | 11 | SampleRates = Literal[8_000, 11_025, 16_000, 22_050, 24_000, 32_000, 44_100, 48_000] 12 | 13 | 14 | def is_supported_sample_rate(sample_rate: int) -> TypeGuard[SampleRates]: 15 | """Sample rate is supported.""" 16 | return sample_rate in get_args(SampleRates) 17 | 18 | 19 | def is_float16_array(x: object) -> TypeGuard[npt.NDArray[np.float16]]: 20 | """Numpy array is float32.""" 21 | return isinstance(x, np.ndarray) and x.dtype == np.float16 22 | 23 | 24 | def is_float32_array(x: object) -> TypeGuard[npt.NDArray[np.float32]]: 25 | """Numpy array is float32.""" 26 | return isinstance(x, np.ndarray) and x.dtype == np.float32 27 | 28 | 29 | def is_int32_array(x: object) -> TypeGuard[npt.NDArray[np.int32]]: 30 | """Numpy array is int32.""" 31 | return isinstance(x, np.ndarray) and x.dtype == np.int32 32 | 33 | 34 | def is_int64_array(x: object) -> TypeGuard[npt.NDArray[np.int64]]: 35 | """Numpy array is int64.""" 36 | return isinstance(x, np.ndarray) and x.dtype == np.int64 37 | 38 | 39 | class SupportedOnlyMonoAudioError(ValueError): 40 | """Supported only mono audio error.""" 41 | 42 | def __init__(self) -> None: 43 | """Create error.""" 44 | super().__init__("Supported only mono audio.") 45 | 46 | 47 | class WrongSampleRateError(ValueError): 48 | """Wrong sample rate error.""" 49 | 50 | def __init__(self) -> None: 51 | """Create error.""" 52 | super().__init__(f"Supported only {get_args(SampleRates)} sample rates.") 53 | 54 | 55 | class DifferentSampleRatesError(ValueError): 56 | """Different sample rates error.""" 57 | 58 | def __init__(self) -> None: 59 | """Create error.""" 60 | super().__init__("All sample rates in a batch must be the same.") 61 | 62 | 63 | class OnnxSessionOptions(TypedDict, total=False): 64 | """Options for onnxruntime InferenceSession.""" 65 | 66 | sess_options: rt.SessionOptions | None 67 | providers: Sequence[str | tuple[str, dict[Any, Any]]] | None 68 | provider_options: Sequence[dict[Any, Any]] | None 69 | 70 | 71 | def get_onnx_device(session: rt.InferenceSession) -> tuple[str, int]: 72 | """Get ONNX device type and id from Session.""" 73 | provider = session.get_providers()[0] 74 | match provider: 75 | case "CUDAExecutionProvider" | "ROCMExecutionProvider": 76 | device_type = "cuda" 77 | case "DmlExecutionProvider": 78 | device_type = "dml" 79 | case _: 80 | device_type = "cpu" 81 | 82 | return device_type, int(session.get_provider_options()[provider].get("device_id", 0)) 83 | 84 | 85 | def read_wav(filename: str) -> tuple[npt.NDArray[np.float32], int]: 86 | """Read PCM wav file to Numpy array.""" 87 | with wave.open(filename, mode="rb") as f: 88 | data = f.readframes(f.getnframes()) 89 | zero_value = 0 90 | if f.getsampwidth() == 1: 91 | buffer = np.frombuffer(data, dtype="u1") 92 | zero_value = 1 93 | elif f.getsampwidth() == 3: 94 | buffer = np.zeros((len(data) // 3, 4), dtype="V1") 95 | buffer[:, -3:] = np.frombuffer(data, dtype="V1").reshape(-1, f.getsampwidth()) 96 | buffer = buffer.view(dtype=" tuple[npt.NDArray[np.float32], npt.NDArray[np.int64], SampleRates]: 107 | """Convert list of waveform or filenames to Numpy array with common length.""" 108 | results = [] 109 | sample_rates = [] 110 | for x in waveforms: 111 | if isinstance(x, str): 112 | waveform, sample_rate = read_wav(x) 113 | if waveform.shape[1] != 1: 114 | raise SupportedOnlyMonoAudioError 115 | results.append(waveform[:, 0]) 116 | sample_rates.append(sample_rate) 117 | else: 118 | if x.ndim != 1: 119 | raise SupportedOnlyMonoAudioError 120 | results.append(x) 121 | sample_rates.append(numpy_sample_rate) 122 | 123 | if len(set(sample_rates)) > 1: 124 | raise DifferentSampleRatesError 125 | 126 | if is_supported_sample_rate(sample_rates[0]): 127 | return *pad_list(results), sample_rates[0] 128 | raise WrongSampleRateError 129 | 130 | 131 | def pad_list(arrays: list[npt.NDArray[np.float32]]) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 132 | """Pad list of Numpy arrays to common length.""" 133 | lens = np.array([array.shape[0] for array in arrays], dtype=np.int64) 134 | 135 | result = np.zeros((len(arrays), lens.max()), dtype=np.float32) 136 | for i, x in enumerate(arrays): 137 | result[i, : x.shape[0]] = x[: min(x.shape[0], result.shape[1])] 138 | 139 | return result, lens 140 | -------------------------------------------------------------------------------- /tests/preprocessors/test_gigaam_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | import torchaudio 5 | 6 | from onnx_asr.preprocessors import Preprocessor 7 | from onnx_asr.utils import pad_list 8 | from preprocessors import gigaam 9 | 10 | 11 | def preprocessor_origin_v2(waveforms, lens): 12 | transform = torchaudio.transforms.MelSpectrogram( 13 | sample_rate=gigaam.sample_rate, 14 | n_fft=gigaam.n_fft_v2, 15 | win_length=gigaam.win_length_v2, 16 | hop_length=gigaam.hop_length, 17 | n_mels=gigaam.n_mels, 18 | ) 19 | features_lens = torch.from_numpy(lens).div(gigaam.hop_length, rounding_mode="floor").add(1).long().numpy() 20 | return torch.log(transform(torch.from_numpy(waveforms)).clamp_(gigaam.clamp_min, gigaam.clamp_max)).numpy(), features_lens 21 | 22 | 23 | def preprocessor_origin_v3(waveforms, lens): 24 | transform = ( 25 | torchaudio.transforms.MelSpectrogram( 26 | sample_rate=gigaam.sample_rate, 27 | n_fft=gigaam.n_fft_v3, 28 | win_length=gigaam.win_length_v3, 29 | hop_length=gigaam.hop_length, 30 | n_mels=gigaam.n_mels, 31 | center=False, 32 | ) 33 | .bfloat16() 34 | .float() 35 | ) 36 | features_lens = ( 37 | torch.from_numpy(lens - gigaam.win_length_v3).div(gigaam.hop_length, rounding_mode="floor").add(1).long().numpy() 38 | ) 39 | return torch.log(transform(torch.from_numpy(waveforms)).clamp_(gigaam.clamp_min, gigaam.clamp_max)).numpy(), features_lens 40 | 41 | 42 | def preprocessor_torch_v2(waveforms, lens): 43 | waveforms = torch.from_numpy(waveforms) 44 | spectrogram = torchaudio.functional.spectrogram( 45 | waveforms, 46 | pad=0, 47 | window=torch.hann_window(gigaam.win_length_v2), 48 | n_fft=gigaam.n_fft_v2, 49 | hop_length=gigaam.hop_length, 50 | win_length=gigaam.win_length_v2, 51 | power=2, 52 | normalized=False, 53 | ) 54 | mel_spectrogram = torch.matmul(spectrogram.transpose(-1, -2), gigaam.melscale_fbanks_v2).transpose(-1, -2) 55 | return torch.log(mel_spectrogram.clamp_(gigaam.clamp_min, gigaam.clamp_max)).numpy(), lens // gigaam.hop_length + 1 56 | 57 | 58 | def preprocessor_torch_v3(waveforms, lens): 59 | waveforms = torch.from_numpy(waveforms) 60 | spectrogram = torchaudio.functional.spectrogram( 61 | waveforms, 62 | pad=0, 63 | window=torch.hann_window(gigaam.win_length_v3).bfloat16().float(), 64 | n_fft=gigaam.n_fft_v3, 65 | hop_length=gigaam.hop_length, 66 | win_length=gigaam.win_length_v3, 67 | power=2, 68 | normalized=False, 69 | center=False, 70 | ) 71 | mel_spectrogram = torch.matmul(spectrogram.transpose(-1, -2), gigaam.melscale_fbanks_v3).transpose(-1, -2) 72 | return torch.log(mel_spectrogram.clamp_(gigaam.clamp_min, gigaam.clamp_max)).numpy(), ( 73 | lens - gigaam.win_length_v3 74 | ) // gigaam.hop_length + 1 75 | 76 | 77 | @pytest.fixture(scope="module") 78 | def preprocessor(request): 79 | match request.param: 80 | case "torch_v2": 81 | return preprocessor_torch_v2 82 | case "onnx_func_v2": 83 | return gigaam.GigaamPreprocessorV2 84 | case "onnx_model_v2": 85 | return Preprocessor("gigaam_v2", {}) 86 | case "onnx_model_v2_mt": 87 | return Preprocessor("gigaam_v2", {"max_concurrent_workers": 2}) 88 | case "torch_v3": 89 | return preprocessor_torch_v3 90 | case "onnx_func_v3": 91 | return gigaam.GigaamPreprocessorV3 92 | case "onnx_model_v3": 93 | return Preprocessor("gigaam_v3", {}) 94 | case "onnx_model_v3_mt": 95 | return Preprocessor("gigaam_v3", {"max_concurrent_workers": 2}) 96 | 97 | 98 | @pytest.mark.parametrize( 99 | ("preprocessor", "equal"), 100 | [ 101 | ("torch_v2", True), 102 | ("onnx_func_v2", False), 103 | ("onnx_model_v2", False), 104 | ("onnx_model_v2_mt", False), 105 | ], 106 | indirect=["preprocessor"], 107 | ) 108 | def test_gigaam_preprocessor_v2(preprocessor, equal, waveforms): 109 | waveforms, lens = pad_list(waveforms) 110 | expected, expected_lens = preprocessor_origin_v2(waveforms, lens) 111 | actual, actual_lens = preprocessor(waveforms, lens) 112 | 113 | assert expected.shape[2] == max(expected_lens) 114 | np.testing.assert_equal(actual_lens, expected_lens) 115 | if equal: 116 | np.testing.assert_equal(actual, expected) 117 | else: 118 | np.testing.assert_allclose(actual, expected, atol=5e-5) 119 | 120 | 121 | @pytest.mark.parametrize( 122 | ("preprocessor", "equal"), 123 | [ 124 | ("torch_v3", True), 125 | ("onnx_func_v3", False), 126 | ("onnx_model_v3", False), 127 | ("onnx_model_v3_mt", False), 128 | ], 129 | indirect=["preprocessor"], 130 | ) 131 | def test_gigaam_preprocessor_v3(preprocessor, equal, waveforms): 132 | waveforms, lens = pad_list(waveforms) 133 | expected, expected_lens = preprocessor_origin_v3(waveforms, lens) 134 | actual, actual_lens = preprocessor(waveforms, lens) 135 | 136 | assert expected.shape[2] == max(expected_lens) 137 | np.testing.assert_equal(actual_lens, expected_lens) 138 | if equal: 139 | np.testing.assert_equal(actual, expected) 140 | else: 141 | np.testing.assert_allclose(actual, expected, atol=5e-5, rtol=5e-6) 142 | -------------------------------------------------------------------------------- /src/onnx_asr/models/silero.py: -------------------------------------------------------------------------------- 1 | """Silero VAD implementation.""" 2 | 3 | from collections.abc import Iterable, Iterator 4 | from itertools import chain 5 | from pathlib import Path 6 | from typing import Literal 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | import onnxruntime as rt 11 | 12 | from onnx_asr.utils import OnnxSessionOptions, is_float32_array 13 | from onnx_asr.vad import Vad 14 | 15 | 16 | class SileroVad(Vad): 17 | """Silero VAD implementation.""" 18 | 19 | INF = 10**15 20 | 21 | def __init__(self, model_files: dict[str, Path], onnx_options: OnnxSessionOptions): 22 | """Create Silero VAD. 23 | 24 | Args: 25 | model_files: Dict with paths to model files. 26 | onnx_options: Options for onnxruntime InferenceSession. 27 | 28 | """ 29 | self._model = rt.InferenceSession(model_files["model"], **onnx_options) 30 | 31 | @staticmethod 32 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 33 | suffix = "?" + quantization if quantization else "" 34 | return {"model": f"**/model{suffix}.onnx"} 35 | 36 | def _encode( 37 | self, waveforms: npt.NDArray[np.float32], sample_rate: int, hop_size: int, context_size: int 38 | ) -> Iterator[npt.NDArray[np.float32]]: 39 | frames = np.lib.stride_tricks.sliding_window_view(waveforms, context_size + hop_size, axis=-1)[ 40 | :, hop_size - context_size :: hop_size 41 | ] 42 | 43 | state: npt.NDArray[np.float32] = np.zeros((2, frames.shape[0], 128), dtype=np.float32) 44 | 45 | def process(frame: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: 46 | nonlocal state 47 | output, new_state = self._model.run(["output", "stateN"], {"input": frame, "state": state, "sr": [sample_rate]}) 48 | assert is_float32_array(output) 49 | assert is_float32_array(new_state) 50 | state = new_state 51 | return output[:, 0] 52 | 53 | yield process(np.pad(waveforms[:, :hop_size], ((0, 0), (context_size, 0)))) 54 | 55 | for i in range(frames.shape[1]): 56 | yield process(frames[:, i]) 57 | 58 | if last_frame := waveforms.shape[1] % hop_size: 59 | yield process(np.pad(waveforms[:, -last_frame - context_size :], ((0, 0), (0, hop_size - last_frame)))) 60 | 61 | def _find_segments( 62 | self, 63 | probs: Iterable[np.float32], 64 | hop_size: int, 65 | *, 66 | threshold: float = 0.5, 67 | neg_threshold: float | None = None, 68 | **kwargs: float, 69 | ) -> Iterator[tuple[int, int]]: 70 | if neg_threshold is None: 71 | neg_threshold = threshold - 0.15 72 | 73 | state = 0 74 | start = 0 75 | for i, p in enumerate(chain(probs, (np.float32(0),))): 76 | if state == 0 and p >= threshold: 77 | state = 1 78 | start = i * hop_size 79 | elif state == 1 and p < neg_threshold: 80 | state = 0 81 | yield start, i * hop_size 82 | 83 | def _merge_segments( 84 | self, 85 | segments: Iterator[tuple[int, int]], 86 | waveform_len: int, 87 | sample_rate: int, 88 | *, 89 | min_speech_duration_ms: float = 250, 90 | max_speech_duration_s: float = 20, 91 | min_silence_duration_ms: float = 100, 92 | speech_pad_ms: float = 30, 93 | **kwargs: float, 94 | ) -> Iterator[tuple[int, int]]: 95 | speech_pad = int(speech_pad_ms * sample_rate // 1000) 96 | min_speech_duration = int(min_speech_duration_ms * sample_rate // 1000) - 2 * speech_pad 97 | max_speech_duration = int(max_speech_duration_s * sample_rate) - 2 * speech_pad 98 | min_silence_duration = int(min_silence_duration_ms * sample_rate // 1000) + 2 * speech_pad 99 | 100 | cur_start, cur_end = -self.INF, -self.INF 101 | for start, end in chain(segments, ((waveform_len, waveform_len), (self.INF, self.INF))): 102 | if start - cur_end < min_silence_duration and end - cur_start < max_speech_duration: 103 | cur_end = end 104 | else: 105 | if cur_end - cur_start > min_speech_duration: 106 | yield max(cur_start - speech_pad, 0), min(cur_end + speech_pad, waveform_len) 107 | while end - start > max_speech_duration: 108 | yield max(start - speech_pad, 0), start + max_speech_duration - speech_pad 109 | start += max_speech_duration 110 | cur_start, cur_end = start, end 111 | 112 | def segment_batch( 113 | self, 114 | waveforms: npt.NDArray[np.float32], 115 | waveforms_len: npt.NDArray[np.int64], 116 | sample_rate: Literal[8_000, 16_000], 117 | **kwargs: float, 118 | ) -> Iterator[Iterator[tuple[int, int]]]: 119 | """Segment waveforms batch.""" 120 | hop_size = 512 if sample_rate == 16_000 else 256 121 | context_size = 64 if sample_rate == 16_000 else 32 122 | 123 | def segment(probs: Iterable[np.float32], waveform_len: np.int64, **kwargs: float) -> Iterator[tuple[int, int]]: 124 | return self._merge_segments(self._find_segments(probs, hop_size, **kwargs), int(waveform_len), sample_rate, **kwargs) 125 | 126 | encoding = self._encode(waveforms, sample_rate, hop_size, context_size) 127 | if len(waveforms) == 1: 128 | yield segment((probs[0] for probs in encoding), waveforms_len[0], **kwargs) 129 | else: 130 | yield from ( 131 | segment(probs, waveform_len, **kwargs) 132 | for probs, waveform_len in zip(zip(*encoding, strict=True), waveforms_len, strict=True) 133 | ) 134 | -------------------------------------------------------------------------------- /src/onnx_asr/adapters.py: -------------------------------------------------------------------------------- 1 | """ASR adapter classes.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | from collections.abc import Iterator 7 | from typing import Generic, TypeVar, overload 8 | 9 | import numpy as np 10 | import numpy.typing as npt 11 | 12 | from .asr import Asr, TimestampedResult 13 | from .preprocessors import Resampler 14 | from .utils import SampleRates, read_wav_files 15 | from .vad import SegmentResult, TimestampedSegmentResult, Vad 16 | 17 | R = TypeVar("R") 18 | 19 | 20 | class AsrAdapter(ABC, Generic[R]): 21 | """Base ASR adapter class.""" 22 | 23 | asr: Asr 24 | resampler: Resampler 25 | 26 | def __init__(self, asr: Asr, resampler: Resampler): 27 | """Create ASR adapter.""" 28 | self.asr = asr 29 | self.resampler = resampler 30 | 31 | def with_vad(self, vad: Vad, **kwargs: float) -> SegmentResultsAsrAdapter: 32 | """ASR with VAD adapter (text results).""" 33 | return SegmentResultsAsrAdapter(self.asr, vad, self.resampler, **kwargs) 34 | 35 | @abstractmethod 36 | def _recognize_batch( 37 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 38 | ) -> Iterator[R]: ... 39 | 40 | @overload 41 | def recognize( 42 | self, 43 | waveform: str | npt.NDArray[np.float32], 44 | *, 45 | sample_rate: SampleRates = 16_000, 46 | language: str | None = None, 47 | ) -> R: ... 48 | 49 | @overload 50 | def recognize( 51 | self, 52 | waveform: list[str | npt.NDArray[np.float32]], 53 | *, 54 | sample_rate: SampleRates = 16_000, 55 | language: str | None = None, 56 | ) -> list[R]: ... 57 | 58 | def recognize( 59 | self, 60 | waveform: str | npt.NDArray[np.float32] | list[str | npt.NDArray[np.float32]], 61 | *, 62 | sample_rate: SampleRates = 16_000, 63 | language: str | None = None, 64 | ) -> R | list[R]: 65 | """Recognize speech (single or batch). 66 | 67 | Args: 68 | waveform: Path to wav file (only PCM_U8, PCM_16, PCM_24 and PCM_32 formats are supported) 69 | or Numpy array with PCM waveform. 70 | A list of file paths or numpy arrays for batch recognition are also supported. 71 | sample_rate: Sample rate for Numpy arrays in waveform. 72 | language: Speech language (only for Whisper models). 73 | 74 | Returns: 75 | Speech recognition results (single or list for batch recognition). 76 | 77 | """ 78 | if isinstance(waveform, list): 79 | if not waveform: 80 | return [] 81 | return list(self._recognize_batch(*self.resampler(*read_wav_files(waveform, sample_rate)), language)) 82 | return next(self._recognize_batch(*self.resampler(*read_wav_files([waveform], sample_rate)), language)) 83 | 84 | 85 | class TimestampedResultsAsrAdapter(AsrAdapter[TimestampedResult]): 86 | """ASR adapter (timestamped results).""" 87 | 88 | def _recognize_batch( 89 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 90 | ) -> Iterator[TimestampedResult]: 91 | return self.asr.recognize_batch(waveforms, waveforms_len, language) 92 | 93 | 94 | class TextResultsAsrAdapter(AsrAdapter[str]): 95 | """ASR adapter (text results).""" 96 | 97 | def with_timestamps(self) -> TimestampedResultsAsrAdapter: 98 | """ASR adapter (timestamped results).""" 99 | return TimestampedResultsAsrAdapter(self.asr, self.resampler) 100 | 101 | def _recognize_batch( 102 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 103 | ) -> Iterator[str]: 104 | return (res.text for res in self.asr.recognize_batch(waveforms, waveforms_len, language)) 105 | 106 | 107 | class TimestampedSegmentResultsAsrAdapter(AsrAdapter[Iterator[TimestampedSegmentResult]]): 108 | """ASR with VAD adapter (timestamped results).""" 109 | 110 | vad: Vad 111 | 112 | def __init__(self, asr: Asr, vad: Vad, resampler: Resampler, **kwargs: float): 113 | """Create ASR adapter.""" 114 | super().__init__(asr, resampler) 115 | self.vad = vad 116 | self._vadargs = kwargs 117 | 118 | def _recognize_batch( 119 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 120 | ) -> Iterator[Iterator[TimestampedSegmentResult]]: 121 | return self.vad.recognize_batch( 122 | self.asr, waveforms, waveforms_len, self.asr._get_sample_rate(), language, **self._vadargs 123 | ) 124 | 125 | 126 | class SegmentResultsAsrAdapter(AsrAdapter[Iterator[SegmentResult]]): 127 | """ASR with VAD adapter (text results).""" 128 | 129 | vad: Vad 130 | 131 | def __init__(self, asr: Asr, vad: Vad, resampler: Resampler, **kwargs: float): 132 | """Create ASR adapter.""" 133 | super().__init__(asr, resampler) 134 | self.vad = vad 135 | self._vadargs = kwargs 136 | 137 | def with_timestamps(self) -> TimestampedSegmentResultsAsrAdapter: 138 | """ASR with VAD adapter (timestamped results).""" 139 | return TimestampedSegmentResultsAsrAdapter(self.asr, self.vad, self.resampler, **self._vadargs) 140 | 141 | def _recognize_batch( 142 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 143 | ) -> Iterator[Iterator[SegmentResult]]: 144 | return ( 145 | (SegmentResult(res.start, res.end, res.text) for res in results) 146 | for results in self.vad.recognize_batch( 147 | self.asr, waveforms, waveforms_len, self.asr._get_sample_rate(), language, **self._vadargs 148 | ) 149 | ) 150 | -------------------------------------------------------------------------------- /src/onnx_asr/models/gigaam.py: -------------------------------------------------------------------------------- 1 | """GigaAM v2+ model implementations.""" 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | import onnxruntime as rt 8 | 9 | from onnx_asr.asr import AsrRuntimeConfig, _AsrWithCtcDecoding, _AsrWithDecoding, _AsrWithTransducerDecoding 10 | from onnx_asr.utils import is_float32_array, is_int32_array 11 | 12 | 13 | class _GigaamV2(_AsrWithDecoding): 14 | @staticmethod 15 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 16 | return {"vocab": "v?_vocab.txt"} 17 | 18 | @property 19 | def _preprocessor_name(self) -> str: 20 | assert self.config.get("features_size", 64) == 64 21 | version = self.config.get("version", "v2") 22 | return f"gigaam_{version}" 23 | 24 | @property 25 | def _subsampling_factor(self) -> int: 26 | return self.config.get("subsampling_factor", 4) 27 | 28 | 29 | class GigaamV2Ctc(_AsrWithCtcDecoding, _GigaamV2): 30 | """GigaAM v2+ CTC model implementation.""" 31 | 32 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 33 | """Create GigaAM v2+ CTC model. 34 | 35 | Args: 36 | model_files: Dict with paths to model files. 37 | runtime_config: Runtime configuration. 38 | 39 | """ 40 | super().__init__(model_files, runtime_config) 41 | self._model = rt.InferenceSession(model_files["model"], **runtime_config.onnx_options) 42 | 43 | @staticmethod 44 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 45 | suffix = "?" + quantization if quantization else "" 46 | return {"model": f"v?_ctc{suffix}.onnx"} | _GigaamV2._get_model_files(quantization) 47 | 48 | def _encode( 49 | self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64] 50 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 51 | (log_probs,) = self._model.run(["log_probs"], {"features": features, "feature_lengths": features_lens}) 52 | assert is_float32_array(log_probs) 53 | return log_probs, (features_lens - 1) // self._subsampling_factor + 1 54 | 55 | 56 | _STATE_TYPE = list[npt.NDArray[np.float32]] 57 | 58 | 59 | class GigaamV2Rnnt(_AsrWithTransducerDecoding[_STATE_TYPE], _GigaamV2): 60 | """GigaAM v2+ RNN-T model implementation.""" 61 | 62 | PRED_HIDDEN = 320 63 | 64 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 65 | """Create GigaAM v2+ RNN-T model. 66 | 67 | Args: 68 | model_files: Dict with paths to model files. 69 | runtime_config: Runtime configuration. 70 | 71 | """ 72 | super().__init__(model_files, runtime_config) 73 | self._encoder = rt.InferenceSession(model_files["encoder"], **runtime_config.onnx_options) 74 | self._decoder = rt.InferenceSession(model_files["decoder"], **runtime_config.onnx_options) 75 | self._joiner = rt.InferenceSession(model_files["joint"], **runtime_config.onnx_options) 76 | 77 | @staticmethod 78 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 79 | suffix = "?" + quantization if quantization else "" 80 | return { 81 | "encoder": f"v?_rnnt_encoder{suffix}.onnx", 82 | "decoder": f"v?_rnnt_decoder{suffix}.onnx", 83 | "joint": f"v?_rnnt_joint{suffix}.onnx", 84 | } | _GigaamV2._get_model_files(quantization) 85 | 86 | @property 87 | def _max_tokens_per_step(self) -> int: 88 | return self.config.get("max_tokens_per_step", 3) 89 | 90 | def _encode( 91 | self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64] 92 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 93 | encoder_out, encoder_out_lens = self._encoder.run( 94 | ["encoded", "encoded_len"], {"audio_signal": features, "length": features_lens} 95 | ) 96 | assert is_float32_array(encoder_out) 97 | assert is_int32_array(encoder_out_lens) 98 | return encoder_out.transpose(0, 2, 1), encoder_out_lens.astype(np.int64) 99 | 100 | def _create_state(self) -> _STATE_TYPE: 101 | return [ 102 | np.zeros(shape=(1, 1, self.PRED_HIDDEN), dtype=np.float32), 103 | np.zeros(shape=(1, 1, self.PRED_HIDDEN), dtype=np.float32), 104 | ] 105 | 106 | def _decode( 107 | self, prev_tokens: list[int], prev_state: _STATE_TYPE, encoder_out: npt.NDArray[np.float32] 108 | ) -> tuple[npt.NDArray[np.float32], int, _STATE_TYPE]: 109 | if len(prev_state) == 2: 110 | decoder_out, state1, state2 = self._decoder.run( 111 | ["dec", "h", "c"], 112 | {"x": [[prev_tokens[-1] if prev_tokens else self._blank_idx]], "h.1": prev_state[0], "c.1": prev_state[1]}, 113 | ) 114 | assert is_float32_array(decoder_out) 115 | assert is_float32_array(state1) 116 | assert is_float32_array(state2) 117 | prev_state[:] = (decoder_out, state1, state2) 118 | else: 119 | decoder_out, state1, state2 = prev_state 120 | 121 | (joint,) = self._joiner.run(["joint"], {"enc": encoder_out[None, :, None], "dec": decoder_out.transpose(0, 2, 1)}) 122 | assert is_float32_array(joint) 123 | return np.squeeze(joint), -1, [state1, state2] 124 | 125 | 126 | class GigaamV3E2eCtc(GigaamV2Ctc): 127 | """GigaAM v3 E2E CTC model implementation.""" 128 | 129 | @staticmethod 130 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 131 | suffix = "?" + quantization if quantization else "" 132 | return {"model": f"v3_e2e_ctc{suffix}.onnx", "vocab": "v3_e2e_ctc_vocab.txt"} 133 | 134 | 135 | class GigaamV3E2eRnnt(GigaamV2Rnnt): 136 | """GigaAM v3 E2E RNN-T model implementation.""" 137 | 138 | @staticmethod 139 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 140 | suffix = "?" + quantization if quantization else "" 141 | return { 142 | "encoder": f"v3_e2e_rnnt_encoder{suffix}.onnx", 143 | "decoder": f"v3_e2e_rnnt_decoder{suffix}.onnx", 144 | "joint": f"v3_e2e_rnnt_joint{suffix}.onnx", 145 | "vocab": "v3_e2e_rnnt_vocab.txt", 146 | } 147 | -------------------------------------------------------------------------------- /src/onnx_asr/asr.py: -------------------------------------------------------------------------------- 1 | """Base ASR classes.""" 2 | 3 | import json 4 | import re 5 | from abc import ABC, abstractmethod 6 | from collections.abc import Iterable, Iterator 7 | from dataclasses import dataclass, field 8 | from pathlib import Path 9 | from typing import Generic, Literal, TypedDict, TypeVar 10 | 11 | import numpy as np 12 | import numpy.typing as npt 13 | 14 | from onnx_asr.preprocessors import Preprocessor, PreprocessorRuntimeConfig 15 | from onnx_asr.utils import OnnxSessionOptions 16 | 17 | S = TypeVar("S") 18 | 19 | 20 | @dataclass 21 | class TimestampedResult: 22 | """Timestamped recognition result.""" 23 | 24 | text: str 25 | timestamps: list[float] | None = None 26 | tokens: list[str] | None = None 27 | 28 | 29 | class AsrConfig(TypedDict, total=False): 30 | """Config for ASR model.""" 31 | 32 | model_type: str 33 | features_size: int 34 | subsampling_factor: int 35 | max_tokens_per_step: int 36 | max_sequence_length: int 37 | 38 | 39 | @dataclass() 40 | class AsrRuntimeConfig: 41 | """ASR runtime config.""" 42 | 43 | onnx_options: OnnxSessionOptions = field(default_factory=OnnxSessionOptions) 44 | preprocessor_config: PreprocessorRuntimeConfig = field(default_factory=PreprocessorRuntimeConfig) 45 | 46 | 47 | class Asr(ABC): 48 | """Base ASR class.""" 49 | 50 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 51 | """Init base ASR class.""" 52 | if "config" in model_files: 53 | with model_files["config"].open("rt", encoding="utf-8") as f: 54 | self.config: AsrConfig = json.load(f) 55 | else: 56 | self.config = {} 57 | 58 | self._preprocessor = Preprocessor(self._preprocessor_name, runtime_config.preprocessor_config) 59 | 60 | @staticmethod 61 | def _get_sample_rate() -> Literal[8_000, 16_000]: 62 | return 16_000 63 | 64 | @property 65 | @abstractmethod 66 | def _preprocessor_name(self) -> str: ... 67 | 68 | @abstractmethod 69 | def recognize_batch( 70 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 71 | ) -> Iterator[TimestampedResult]: 72 | """Recognize waveforms batch.""" 73 | ... 74 | 75 | 76 | class _AsrWithDecoding(Asr): 77 | DECODE_SPACE_PATTERN = re.compile(r"\A\s|\s\B|(\s)\b") 78 | window_step = 0.01 79 | 80 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 81 | super().__init__(model_files, runtime_config) 82 | 83 | if "vocab" in model_files: 84 | with Path(model_files["vocab"]).open("rt", encoding="utf-8") as f: 85 | self._vocab = {int(id): token.replace("\u2581", " ") for token, id in (line.strip("\n").split(" ") for line in f)} 86 | self._vocab_size = len(self._vocab) 87 | if (blank_idx := next((id for id, token in self._vocab.items() if token == ""), None)) is not None: 88 | self._blank_idx = blank_idx 89 | 90 | @property 91 | @abstractmethod 92 | def _subsampling_factor(self) -> int: ... 93 | 94 | @abstractmethod 95 | def _encode( 96 | self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64] 97 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: ... 98 | 99 | @abstractmethod 100 | def _decoding( 101 | self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64], language: str | None 102 | ) -> Iterator[tuple[Iterable[int], Iterable[int]]]: ... 103 | 104 | def _decode_tokens(self, ids: Iterable[int], indices: Iterable[int]) -> TimestampedResult: 105 | tokens = [self._vocab[i] for i in ids] 106 | timestamps = (self.window_step * self._subsampling_factor * np.asarray(indices)).tolist() 107 | text = re.sub(self.DECODE_SPACE_PATTERN, lambda x: " " if x.group(1) else "", "".join(tokens)) 108 | return TimestampedResult(text, timestamps, tokens) 109 | 110 | def recognize_batch( 111 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 112 | ) -> Iterator[TimestampedResult]: 113 | encoder_out, encoder_out_lens = self._encode(*self._preprocessor(waveforms, waveforms_len)) 114 | return ( 115 | self._decode_tokens(tokens, indices) for tokens, indices in self._decoding(encoder_out, encoder_out_lens, language) 116 | ) 117 | 118 | 119 | class _AsrWithCtcDecoding(_AsrWithDecoding): 120 | def _decoding( 121 | self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64], language: str | None 122 | ) -> Iterator[tuple[Iterable[int], Iterable[int]]]: 123 | assert encoder_out.shape[-1] <= self._vocab_size 124 | assert encoder_out.shape[1] >= max(encoder_out_lens) 125 | 126 | batch_tokens = encoder_out.argmax(axis=-1) 127 | batch_mask = np.diff(batch_tokens, axis=-1, append=self._blank_idx) != 0 128 | batch_mask &= batch_tokens != self._blank_idx 129 | batch_indices = (mask[:mask_len].nonzero() for mask, mask_len in zip(batch_mask, encoder_out_lens, strict=True)) 130 | return ((tokens[indices], indices[0]) for tokens, indices in zip(batch_tokens, batch_indices, strict=True)) 131 | 132 | 133 | class _AsrWithTransducerDecoding(_AsrWithDecoding, Generic[S]): 134 | @property 135 | @abstractmethod 136 | def _max_tokens_per_step(self) -> int: ... 137 | 138 | @abstractmethod 139 | def _create_state(self) -> S: ... 140 | 141 | @abstractmethod 142 | def _decode( 143 | self, prev_tokens: list[int], prev_state: S, encoder_out: npt.NDArray[np.float32] 144 | ) -> tuple[npt.NDArray[np.float32], int, S]: ... 145 | 146 | def _decoding( 147 | self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64], language: str | None 148 | ) -> Iterator[tuple[Iterable[int], Iterable[int]]]: 149 | for encodings, encodings_len in zip(encoder_out, encoder_out_lens, strict=True): 150 | prev_state = self._create_state() 151 | tokens: list[int] = [] 152 | timestamps: list[int] = [] 153 | 154 | t = 0 155 | emitted_tokens = 0 156 | while t < encodings_len: 157 | probs, step, state = self._decode(tokens, prev_state, encodings[t]) 158 | assert probs.shape[-1] <= self._vocab_size 159 | 160 | token = probs.argmax() 161 | 162 | if token != self._blank_idx: 163 | prev_state = state 164 | tokens.append(int(token)) 165 | timestamps.append(t) 166 | emitted_tokens += 1 167 | 168 | if step > 0: 169 | t += step 170 | emitted_tokens = 0 171 | elif token == self._blank_idx or emitted_tokens == self._max_tokens_per_step: 172 | t += 1 173 | emitted_tokens = 0 174 | 175 | yield tokens, timestamps 176 | -------------------------------------------------------------------------------- /preprocessors/resample.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Sequence # noqa: UP035 3 | 4 | from onnxscript import FLOAT, INT64, script 5 | from onnxscript import opset17 as op 6 | 7 | lowpass_filter_width: float = 6.0 8 | rolloff: float = 0.99 9 | 10 | 11 | @script() 12 | def sinc_resample_kernel(orig_freq: FLOAT, new_freq: FLOAT): 13 | base_freq = op.Min(orig_freq, new_freq) * rolloff 14 | width = op.Ceil(lowpass_filter_width * orig_freq / base_freq) 15 | 16 | idx = op.Range(-width, width + orig_freq, 1) / orig_freq 17 | t = op.Unsqueeze(op.Range(0, -new_freq, -1) / new_freq, -1) + idx 18 | t = op.Clip(t * base_freq, -lowpass_filter_width, lowpass_filter_width) 19 | t = t * op.Constant(value=math.pi) 20 | 21 | window = op.Cos(t / (lowpass_filter_width * 2.0)) ** 2 22 | kernels = op.Where(t == 0.0, 1.0, op.Sin(t) / (t + 1e-20)) 23 | kernels = kernels * window * base_freq / orig_freq 24 | 25 | return op.Unsqueeze(kernels, [1, 2]) 26 | 27 | 28 | @script() 29 | def resample( 30 | waveforms: FLOAT["batch_size", "N"], 31 | waveforms_lens: INT64["batch_size"], 32 | orig_freq: int, 33 | new_freq: int, 34 | pads: Sequence[int], 35 | strides: Sequence[int], 36 | ): 37 | kernel = sinc_resample_kernel(op.Cast(orig_freq, to=FLOAT.dtype), op.Cast(new_freq, to=FLOAT.dtype)) 38 | conv = op.Conv(op.Unsqueeze(waveforms, axes=[1, 2]), kernel, pads=pads, strides=strides) 39 | 40 | resampled = op.Flatten(op.Transpose(conv, perm=(0, 3, 2, 1))) 41 | resampled_lens = (new_freq * waveforms_lens + orig_freq - 1) / orig_freq 42 | 43 | new_len = (new_freq * op.Shape(waveforms, start=1, end=2)[0] + orig_freq - 1) / orig_freq 44 | mask = op.Unsqueeze(op.Range(0, new_len, 1), 0) < op.Unsqueeze(resampled_lens, 1) 45 | 46 | return op.Where(mask, resampled[:, :new_len], 0.0), resampled_lens 47 | 48 | 49 | def kernel_args(orig_freq, new_freq): 50 | gcd = math.gcd(orig_freq, new_freq) 51 | orig_freq //= gcd 52 | new_freq //= gcd 53 | base_freq = min(orig_freq, new_freq) * rolloff 54 | width = math.ceil(lowpass_filter_width * orig_freq / base_freq) 55 | return orig_freq, new_freq, (0, width, 0, width + orig_freq), (1, orig_freq) 56 | 57 | 58 | @script(doc_string="Resampling waveform to 8 kHz") 59 | def ResamplePreprocessor8( 60 | waveforms: FLOAT["batch_size", "N"], waveforms_lens: INT64["batch_size"], sample_rate: INT64 61 | ) -> tuple[FLOAT["batch_size", "M"], INT64["batch_size"]]: 62 | if sample_rate == 11_025: 63 | res, lens = resample( 64 | waveforms, 65 | waveforms_lens, 66 | kernel_args(11_025, 8_000)[0], 67 | kernel_args(11_025, 8_000)[1], 68 | kernel_args(11_025, 8_000)[2], 69 | kernel_args(11_025, 8_000)[3], 70 | ) 71 | elif sample_rate == 16_000: 72 | res, lens = resample( 73 | waveforms, 74 | waveforms_lens, 75 | kernel_args(16_000, 8_000)[0], 76 | kernel_args(16_000, 8_000)[1], 77 | kernel_args(16_000, 8_000)[2], 78 | kernel_args(16_000, 8_000)[3], 79 | ) 80 | elif sample_rate == 22_050: 81 | res, lens = resample( 82 | waveforms, 83 | waveforms_lens, 84 | kernel_args(22_050, 8_000)[0], 85 | kernel_args(22_050, 8_000)[1], 86 | kernel_args(22_050, 8_000)[2], 87 | kernel_args(22_050, 8_000)[3], 88 | ) 89 | elif sample_rate == 24_000: 90 | res, lens = resample( 91 | waveforms, 92 | waveforms_lens, 93 | kernel_args(24_000, 8_000)[0], 94 | kernel_args(24_000, 8_000)[1], 95 | kernel_args(24_000, 8_000)[2], 96 | kernel_args(24_000, 8_000)[3], 97 | ) 98 | elif sample_rate == 32_000: 99 | res, lens = resample( 100 | waveforms, 101 | waveforms_lens, 102 | kernel_args(32_000, 8_000)[0], 103 | kernel_args(32_000, 8_000)[1], 104 | kernel_args(32_000, 8_000)[2], 105 | kernel_args(32_000, 8_000)[3], 106 | ) 107 | elif sample_rate == 44_100: 108 | res, lens = resample( 109 | waveforms, 110 | waveforms_lens, 111 | kernel_args(44_100, 8_000)[0], 112 | kernel_args(44_100, 8_000)[1], 113 | kernel_args(44_100, 8_000)[2], 114 | kernel_args(44_100, 8_000)[3], 115 | ) 116 | elif sample_rate == 48_000: 117 | res, lens = resample( 118 | waveforms, 119 | waveforms_lens, 120 | kernel_args(48_000, 8_000)[0], 121 | kernel_args(48_000, 8_000)[1], 122 | kernel_args(48_000, 8_000)[2], 123 | kernel_args(48_000, 8_000)[3], 124 | ) 125 | else: 126 | res, lens = waveforms, waveforms_lens 127 | 128 | resampled, resampled_lens = op.Identity(res), op.Identity(lens) 129 | return resampled, resampled_lens 130 | 131 | 132 | @script(doc_string="Resampling waveform to 16 kHz") 133 | def ResamplePreprocessor16( 134 | waveforms: FLOAT["batch_size", "N"], waveforms_lens: INT64["batch_size"], sample_rate: INT64 135 | ) -> tuple[FLOAT["batch_size", "M"], INT64["batch_size"]]: 136 | if sample_rate == 8_000: 137 | res, lens = resample( 138 | waveforms, 139 | waveforms_lens, 140 | kernel_args(8_000, 16_000)[0], 141 | kernel_args(8_000, 16_000)[1], 142 | kernel_args(8_000, 16_000)[2], 143 | kernel_args(8_000, 16_000)[3], 144 | ) 145 | elif sample_rate == 11_025: 146 | res, lens = resample( 147 | waveforms, 148 | waveforms_lens, 149 | kernel_args(11_025, 16_000)[0], 150 | kernel_args(11_025, 16_000)[1], 151 | kernel_args(11_025, 16_000)[2], 152 | kernel_args(11_025, 16_000)[3], 153 | ) 154 | elif sample_rate == 22_050: 155 | res, lens = resample( 156 | waveforms, 157 | waveforms_lens, 158 | kernel_args(22_050, 16_000)[0], 159 | kernel_args(22_050, 16_000)[1], 160 | kernel_args(22_050, 16_000)[2], 161 | kernel_args(22_050, 16_000)[3], 162 | ) 163 | elif sample_rate == 24_000: 164 | res, lens = resample( 165 | waveforms, 166 | waveforms_lens, 167 | kernel_args(24_000, 16_000)[0], 168 | kernel_args(24_000, 16_000)[1], 169 | kernel_args(24_000, 16_000)[2], 170 | kernel_args(24_000, 16_000)[3], 171 | ) 172 | elif sample_rate == 32_000: 173 | res, lens = resample( 174 | waveforms, 175 | waveforms_lens, 176 | kernel_args(32_000, 16_000)[0], 177 | kernel_args(32_000, 16_000)[1], 178 | kernel_args(32_000, 16_000)[2], 179 | kernel_args(32_000, 16_000)[3], 180 | ) 181 | elif sample_rate == 44_100: 182 | res, lens = resample( 183 | waveforms, 184 | waveforms_lens, 185 | kernel_args(44_100, 16_000)[0], 186 | kernel_args(44_100, 16_000)[1], 187 | kernel_args(44_100, 16_000)[2], 188 | kernel_args(44_100, 16_000)[3], 189 | ) 190 | elif sample_rate == 48_000: 191 | res, lens = resample( 192 | waveforms, 193 | waveforms_lens, 194 | kernel_args(48_000, 16_000)[0], 195 | kernel_args(48_000, 16_000)[1], 196 | kernel_args(48_000, 16_000)[2], 197 | kernel_args(48_000, 16_000)[3], 198 | ) 199 | else: 200 | res, lens = waveforms, waveforms_lens 201 | 202 | resampled, resampled_lens = op.Identity(res), op.Identity(lens) 203 | return resampled, resampled_lens 204 | -------------------------------------------------------------------------------- /src/onnx_asr/models/whisper.py: -------------------------------------------------------------------------------- 1 | """Whisper model implementations.""" 2 | 3 | import json 4 | import typing 5 | from abc import abstractmethod 6 | from collections.abc import Iterator 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import numpy.typing as npt 11 | import onnxruntime as rt 12 | from onnxruntime import OrtValue 13 | 14 | from onnx_asr.asr import Asr, AsrRuntimeConfig, TimestampedResult 15 | from onnx_asr.utils import get_onnx_device, is_float32_array, is_int32_array 16 | 17 | 18 | @typing.no_type_check 19 | def bytes_to_unicode() -> dict[int, str]: 20 | """Magic func copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode.""" 21 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 22 | cs = bs[:] 23 | n = 0 24 | for b in range(2**8): 25 | if b not in bs: 26 | bs.append(b) 27 | cs.append(2**8 + n) 28 | n += 1 29 | cs = [chr(n) for n in cs] 30 | return dict(zip(bs, cs)) # noqa: B905 31 | 32 | 33 | class _Whisper(Asr): 34 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 35 | super().__init__(model_files, runtime_config) 36 | 37 | with model_files["vocab"].open("rt", encoding="utf-8") as f: 38 | self._tokens: dict[str, int] = json.load(f) 39 | 40 | with model_files["added_tokens"].open("rt", encoding="utf-8") as f: 41 | self._tokens |= json.load(f) 42 | 43 | self._vocab = {id: token for token, id in self._tokens.items()} 44 | self._bos_token_id = self._tokens["<|startoftranscript|>"] 45 | self._eos_token_id = self._tokens["<|endoftext|>"] 46 | self._byte_decoder = {v: k for k, v in bytes_to_unicode().items()} 47 | self._transcribe_input = np.array( 48 | [ 49 | [ 50 | self._bos_token_id, 51 | self._eos_token_id, 52 | self._tokens["<|transcribe|>"], 53 | self._tokens["<|notimestamps|>"], 54 | ] 55 | ], 56 | dtype=np.int64, 57 | ) 58 | self._detect_lang_input = np.array([[self._bos_token_id]], dtype=np.int64) 59 | 60 | @staticmethod 61 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 62 | return {"vocab": "vocab.json", "added_tokens": "added_tokens.json"} 63 | 64 | def _encode(self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64]) -> OrtValue: 65 | input_features, _ = self._preprocessor(waveforms, waveforms_len) 66 | return OrtValue.ortvalue_from_numpy(input_features) 67 | 68 | @abstractmethod 69 | def _decoding( 70 | self, input_features: OrtValue, tokens: npt.NDArray[np.int64], max_length: int = 448 71 | ) -> npt.NDArray[np.int64]: ... 72 | 73 | def _decode_tokens(self, tokens: npt.NDArray[np.int64]) -> TimestampedResult: 74 | text = "".join(token for id in tokens if (token := self._vocab[id]) and not token.startswith("<|")) 75 | return TimestampedResult( 76 | bytearray([self._byte_decoder[c] for c in text]).decode("utf-8", errors="replace").removeprefix(" ") 77 | ) 78 | 79 | def recognize_batch( 80 | self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], language: str | None 81 | ) -> Iterator[TimestampedResult]: 82 | input_encoding = self._encode(waveforms, waveforms_len) 83 | input_tokens = np.repeat(self._transcribe_input, len(waveforms), axis=0) 84 | 85 | if language: 86 | input_tokens[:, 1] = self._tokens[f"<|{language}|>"] 87 | else: 88 | input_tokens_detect_lang = np.repeat(self._detect_lang_input, len(waveforms), axis=0) 89 | input_tokens[:, 1] = self._decoding(input_encoding, input_tokens_detect_lang, 3)[:, 1] 90 | 91 | return map(self._decode_tokens, self._decoding(input_encoding, input_tokens)) 92 | 93 | 94 | class WhisperOrt(_Whisper): 95 | """Whisper (exported via onnxruntime) model implementation.""" 96 | 97 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 98 | """Create Whisper model. 99 | 100 | Args: 101 | model_files: Dict with paths to model files. 102 | runtime_config: Runtime configuration. 103 | 104 | """ 105 | super().__init__(model_files, runtime_config) 106 | self._model = rt.InferenceSession(model_files["model"], **runtime_config.onnx_options) 107 | 108 | @staticmethod 109 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 110 | suffix = "?" + quantization if quantization else "" 111 | return {"model": f"whisper-*_beamsearch{suffix}.onnx"} | _Whisper._get_model_files(quantization) 112 | 113 | @property 114 | def _preprocessor_name(self) -> str: 115 | return f"whisper{self.config.get('features_size', 80)}" 116 | 117 | def _decoding(self, input_features: OrtValue, tokens: npt.NDArray[np.int64], max_length: int = 448) -> npt.NDArray[np.int64]: 118 | (sequences,) = self._model.run( 119 | ["sequences"], 120 | { 121 | "input_features": input_features, 122 | "max_length": [max_length], 123 | "min_length": [0], 124 | "num_beams": [1], 125 | "num_return_sequences": [1], 126 | "length_penalty": [1.0], 127 | "repetition_penalty": [1.0], 128 | "decoder_input_ids": tokens.astype(np.int32), 129 | }, 130 | ) 131 | assert is_int32_array(sequences) 132 | return sequences[:, 0, :].astype(np.int64) 133 | 134 | 135 | class WhisperHf(_Whisper): 136 | """Whisper (exported via optimum) model implementation.""" 137 | 138 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 139 | """Create Whisper model. 140 | 141 | Args: 142 | model_files: Dict with paths to model files. 143 | runtime_config: Runtime configuration. 144 | 145 | """ 146 | super().__init__(model_files, runtime_config) 147 | self._encoder = rt.InferenceSession(model_files["encoder"], **runtime_config.onnx_options) 148 | self._decoder = rt.InferenceSession(model_files["decoder"], **runtime_config.onnx_options) 149 | self._device_type, self._device_id = get_onnx_device(self._encoder) 150 | 151 | @staticmethod 152 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 153 | suffix = "?" + quantization if quantization else "" 154 | return { 155 | "encoder": f"**/encoder_model{suffix}.onnx", 156 | "decoder": f"**/decoder_model_merged{suffix}.onnx", 157 | } | _Whisper._get_model_files(suffix) 158 | 159 | @property 160 | def _preprocessor_name(self) -> str: 161 | return f"whisper{self.config.get('num_mel_bins', 80)}" 162 | 163 | def _encode(self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64]) -> OrtValue: 164 | input_features = super()._encode(waveforms, waveforms_len) 165 | binding = self._encoder.io_binding() 166 | binding.bind_ortvalue_input("input_features", input_features) 167 | binding.bind_output("last_hidden_state", self._device_type, self._device_id) 168 | self._encoder.run_with_iobinding(binding) 169 | last_hidden_state: OrtValue = binding.get_outputs()[0] 170 | return last_hidden_state 171 | 172 | def _create_state(self) -> dict[str, OrtValue]: 173 | return { 174 | x.name: OrtValue.ortvalue_from_numpy(np.zeros((0, x.shape[1], 0, x.shape[3]), dtype=np.float32)) 175 | for x in self._decoder.get_inputs() 176 | if x.name.startswith("past_key_values.") 177 | } 178 | 179 | def _decode( 180 | self, 181 | tokens: npt.NDArray[np.int64], 182 | prev_state: dict[str, OrtValue], 183 | encoder_out: OrtValue, 184 | ) -> tuple[npt.NDArray[np.float32], dict[str, OrtValue]]: 185 | use_cache = any(x.shape()[0] for x in prev_state.values()) 186 | 187 | binding = self._decoder.io_binding() 188 | binding.bind_cpu_input("input_ids", tokens[:, -1:] if use_cache else tokens) 189 | binding.bind_ortvalue_input("encoder_hidden_states", encoder_out) 190 | binding.bind_output("logits") 191 | if prev_state: 192 | binding.bind_cpu_input("use_cache_branch", np.array([use_cache])) 193 | for key, value in prev_state.items(): 194 | binding.bind_ortvalue_input(key, value) 195 | binding.bind_output(key.replace("past_key_values.", "present."), self._device_type, self._device_id) 196 | 197 | self._decoder.run_with_iobinding(binding) 198 | outputs = binding.get_outputs() 199 | logits = outputs[0].numpy() 200 | assert is_float32_array(logits) 201 | return logits, { 202 | key: next_value if next_value.shape()[0] else prev_value 203 | for (key, prev_value), next_value in zip(prev_state.items(), outputs[1:], strict=True) 204 | } 205 | 206 | def _decoding(self, input_features: OrtValue, tokens: npt.NDArray[np.int64], max_length: int = 448) -> npt.NDArray[np.int64]: 207 | state = self._create_state() 208 | for _ in range(tokens.shape[-1], max_length): 209 | logits, state = self._decode(tokens, state, input_features) 210 | next_tokens = logits[:, -1].argmax(axis=-1) 211 | next_tokens[tokens[:, -1] == self._eos_token_id] = self._eos_token_id 212 | tokens = np.hstack((tokens, next_tokens[:, None])) 213 | if (tokens[:, -1] == self._eos_token_id).all(): 214 | break 215 | 216 | return tokens 217 | -------------------------------------------------------------------------------- /src/onnx_asr/models/nemo.py: -------------------------------------------------------------------------------- 1 | """NeMo model implementations.""" 2 | 3 | from collections.abc import Iterable, Iterator 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | import onnxruntime as rt 9 | 10 | from onnx_asr.asr import AsrRuntimeConfig, _AsrWithCtcDecoding, _AsrWithDecoding, _AsrWithTransducerDecoding 11 | from onnx_asr.utils import is_float32_array, is_int64_array 12 | 13 | 14 | class _NemoConformer(_AsrWithDecoding): 15 | @staticmethod 16 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 17 | return {"vocab": "vocab.txt"} 18 | 19 | @property 20 | def _preprocessor_name(self) -> str: 21 | return f"nemo{self.config.get('features_size', 80)}" 22 | 23 | @property 24 | def _subsampling_factor(self) -> int: 25 | return self.config.get("subsampling_factor", 8) 26 | 27 | 28 | class NemoConformerCtc(_AsrWithCtcDecoding, _NemoConformer): 29 | """NeMo Conformer CTC model implementations.""" 30 | 31 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 32 | """Create NeMo Conformer CTC model. 33 | 34 | Args: 35 | model_files: Dict with paths to model files. 36 | runtime_config: Runtime configuration. 37 | 38 | """ 39 | super().__init__(model_files, runtime_config) 40 | self._model = rt.InferenceSession(model_files["model"], **runtime_config.onnx_options) 41 | 42 | @staticmethod 43 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 44 | suffix = "?" + quantization if quantization else "" 45 | return {"model": f"model{suffix}.onnx"} | _NemoConformer._get_model_files(quantization) 46 | 47 | def _encode( 48 | self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64] 49 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 50 | (logprobs,) = self._model.run(["logprobs"], {"audio_signal": features, "length": features_lens}) 51 | assert is_float32_array(logprobs) 52 | return logprobs, (features_lens - 1) // self._subsampling_factor + 1 53 | 54 | 55 | _STATE_TYPE = tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]] 56 | 57 | 58 | class NemoConformerRnnt(_AsrWithTransducerDecoding[_STATE_TYPE], _NemoConformer): 59 | """NeMo Conformer RNN-T model implementations.""" 60 | 61 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 62 | """Create NeMo Conformer RNN-T model. 63 | 64 | Args: 65 | model_files: Dict with paths to model files. 66 | runtime_config: Runtime configuration. 67 | 68 | """ 69 | super().__init__(model_files, runtime_config) 70 | self._encoder = rt.InferenceSession(model_files["encoder"], **runtime_config.onnx_options) 71 | self._decoder_joint = rt.InferenceSession(model_files["decoder_joint"], **runtime_config.onnx_options) 72 | 73 | @staticmethod 74 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 75 | suffix = "?" + quantization if quantization else "" 76 | return { 77 | "encoder": f"encoder-model{suffix}.onnx", 78 | "decoder_joint": f"decoder_joint-model{suffix}.onnx", 79 | } | _NemoConformer._get_model_files(quantization) 80 | 81 | @property 82 | def _max_tokens_per_step(self) -> int: 83 | return self.config.get("max_tokens_per_step", 10) 84 | 85 | def _encode( 86 | self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64] 87 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 88 | encoder_out, encoder_out_lens = self._encoder.run( 89 | ["outputs", "encoded_lengths"], {"audio_signal": features, "length": features_lens} 90 | ) 91 | assert is_float32_array(encoder_out) 92 | assert is_int64_array(encoder_out_lens) 93 | return encoder_out.transpose(0, 2, 1), encoder_out_lens 94 | 95 | def _create_state(self) -> _STATE_TYPE: 96 | shapes = {x.name: x.shape for x in self._decoder_joint.get_inputs()} 97 | return ( 98 | np.zeros(shape=(shapes["input_states_1"][0], 1, shapes["input_states_1"][2]), dtype=np.float32), 99 | np.zeros(shape=(shapes["input_states_2"][0], 1, shapes["input_states_2"][2]), dtype=np.float32), 100 | ) 101 | 102 | def _decode( 103 | self, prev_tokens: list[int], prev_state: _STATE_TYPE, encoder_out: npt.NDArray[np.float32] 104 | ) -> tuple[npt.NDArray[np.float32], int, _STATE_TYPE]: 105 | outputs, state1, state2 = self._decoder_joint.run( 106 | ["outputs", "output_states_1", "output_states_2"], 107 | { 108 | "encoder_outputs": encoder_out[None, :, None], 109 | "targets": [[prev_tokens[-1] if prev_tokens else self._blank_idx]], 110 | "target_length": [1], 111 | "input_states_1": prev_state[0], 112 | "input_states_2": prev_state[1], 113 | }, 114 | ) 115 | assert is_float32_array(outputs) 116 | assert is_float32_array(state1) 117 | assert is_float32_array(state2) 118 | return np.squeeze(outputs), -1, (state1, state2) 119 | 120 | 121 | class NemoConformerTdt(NemoConformerRnnt): 122 | """NeMo Conformer TDT model implementations.""" 123 | 124 | def _decode( 125 | self, prev_tokens: list[int], prev_state: _STATE_TYPE, encoder_out: npt.NDArray[np.float32] 126 | ) -> tuple[npt.NDArray[np.float32], int, _STATE_TYPE]: 127 | output, _, state = super()._decode(prev_tokens, prev_state, encoder_out) 128 | return output[: self._vocab_size], int(output[self._vocab_size :].argmax()), state 129 | 130 | 131 | class NemoConformerAED(_NemoConformer): 132 | """NeMo Conformer AED model implementations.""" 133 | 134 | def __init__(self, model_files: dict[str, Path], runtime_config: AsrRuntimeConfig): 135 | """Create NeMo Conformer AED model. 136 | 137 | Args: 138 | model_files: Dict with paths to model files. 139 | runtime_config: Runtime configuration. 140 | 141 | """ 142 | super().__init__(model_files, runtime_config) 143 | self._encoder = rt.InferenceSession(model_files["encoder"], **runtime_config.onnx_options) 144 | self._decoder = rt.InferenceSession(model_files["decoder"], **runtime_config.onnx_options) 145 | 146 | self._tokens = {token: id for id, token in self._vocab.items()} 147 | self._eos_token_id = self._tokens["<|endoftext|>"] 148 | self._transcribe_input = np.array( 149 | [ 150 | [ 151 | self._tokens["<|startofcontext|>"], 152 | self._tokens["<|startoftranscript|>"], 153 | self._tokens["<|emo:undefined|>"], 154 | self._tokens["<|en|>"], 155 | self._tokens["<|en|>"], 156 | self._tokens["<|pnc|>"], 157 | self._tokens["<|noitn|>"], 158 | self._tokens["<|notimestamp|>"], 159 | self._tokens["<|nodiarize|>"], 160 | ] 161 | ], 162 | dtype=np.int64, 163 | ) 164 | 165 | @staticmethod 166 | def _get_model_files(quantization: str | None = None) -> dict[str, str]: 167 | suffix = "?" + quantization if quantization else "" 168 | return { 169 | "encoder": f"encoder-model{suffix}.onnx", 170 | "decoder": f"decoder-model{suffix}.onnx", 171 | } | _NemoConformer._get_model_files(quantization) 172 | 173 | @property 174 | def _max_sequence_length(self) -> int: 175 | return self.config.get("max_sequence_length", 1024) 176 | 177 | def _encode( 178 | self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64] 179 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: 180 | encoder_embeddings, encoder_mask = self._encoder.run( 181 | ["encoder_embeddings", "encoder_mask"], {"audio_signal": features, "length": features_lens} 182 | ) 183 | assert is_float32_array(encoder_embeddings) 184 | assert is_int64_array(encoder_mask) 185 | return encoder_embeddings, encoder_mask 186 | 187 | def _decode( 188 | self, 189 | input_ids: npt.NDArray[np.int64], 190 | encoder_embeddings: npt.NDArray[np.float32], 191 | encoder_mask: npt.NDArray[np.int64], 192 | decoder_mems: npt.NDArray[np.float32], 193 | ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: 194 | logits, decoder_hidden_states = self._decoder.run( 195 | ["logits", "decoder_hidden_states"], 196 | { 197 | "input_ids": input_ids if decoder_mems.shape[2] == 0 else input_ids[:, -1:], 198 | "encoder_embeddings": encoder_embeddings, 199 | "encoder_mask": encoder_mask, 200 | "decoder_mems": decoder_mems, 201 | }, 202 | ) 203 | assert is_float32_array(logits) 204 | assert is_float32_array(decoder_hidden_states) 205 | return logits, decoder_hidden_states 206 | 207 | def _decoding( 208 | self, encoder_embeddings: npt.NDArray[np.float32], encoder_mask: npt.NDArray[np.int64], language: str | None 209 | ) -> Iterator[tuple[Iterable[int], Iterable[int]]]: 210 | batch_size = encoder_embeddings.shape[0] 211 | tokens = np.repeat(self._transcribe_input, batch_size, axis=0) 212 | 213 | if language: 214 | tokens[:, 3] = self._tokens[f"<|{language}|>"] 215 | tokens[:, 4] = self._tokens[f"<|{language}|>"] 216 | 217 | shapes = {x.name: x.shape for x in self._decoder.get_inputs()} 218 | decoder_mems = np.empty((shapes["decoder_mems"][0], batch_size, 0, shapes["decoder_mems"][3]), dtype=np.float32) 219 | while tokens.shape[1] < self._max_sequence_length: 220 | logits, decoder_mems = self._decode(tokens, encoder_embeddings, encoder_mask, decoder_mems) 221 | 222 | next_tokens = np.argmax(logits[:, -1], axis=-1) 223 | if (next_tokens == self._eos_token_id).all(): 224 | break 225 | 226 | tokens = np.concatenate((tokens, next_tokens[:, None]), axis=-1) 227 | 228 | return (([id for id in tok if not self._vocab[id].startswith("<|")], []) for tok in tokens) 229 | -------------------------------------------------------------------------------- /src/onnx_asr/loader.py: -------------------------------------------------------------------------------- 1 | """Loader for ASR models.""" 2 | 3 | import json 4 | from collections.abc import Sequence 5 | from pathlib import Path 6 | from typing import Any, Literal, get_args 7 | 8 | import onnxruntime as rt 9 | 10 | from onnx_asr.adapters import TextResultsAsrAdapter 11 | from onnx_asr.asr import AsrRuntimeConfig 12 | from onnx_asr.models import ( 13 | GigaamV2Ctc, 14 | GigaamV2Rnnt, 15 | GigaamV3E2eCtc, 16 | GigaamV3E2eRnnt, 17 | KaldiTransducer, 18 | NemoConformerAED, 19 | NemoConformerCtc, 20 | NemoConformerRnnt, 21 | NemoConformerTdt, 22 | PyAnnoteVad, 23 | SileroVad, 24 | TOneCtc, 25 | WhisperHf, 26 | WhisperOrt, 27 | ) 28 | from onnx_asr.preprocessors import Resampler 29 | from onnx_asr.preprocessors.preprocessor import PreprocessorRuntimeConfig 30 | from onnx_asr.utils import OnnxSessionOptions 31 | from onnx_asr.vad import Vad 32 | 33 | ModelNames = Literal[ 34 | "gigaam-v2-ctc", 35 | "gigaam-v2-rnnt", 36 | "gigaam-v3-ctc", 37 | "gigaam-v3-rnnt", 38 | "gigaam-v3-e2e-ctc", 39 | "gigaam-v3-e2e-rnnt", 40 | "nemo-fastconformer-ru-ctc", 41 | "nemo-fastconformer-ru-rnnt", 42 | "nemo-parakeet-ctc-0.6b", 43 | "nemo-parakeet-rnnt-0.6b", 44 | "nemo-parakeet-tdt-0.6b-v2", 45 | "nemo-parakeet-tdt-0.6b-v3", 46 | "nemo-canary-1b-v2", 47 | "alphacep/vosk-model-ru", 48 | "alphacep/vosk-model-small-ru", 49 | "t-tech/t-one", 50 | "whisper-base", 51 | ] 52 | ModelTypes = Literal[ 53 | "gigaam-v2-ctc", 54 | "gigaam-v2-rnnt", 55 | "gigaam-v3-e2e-ctc", 56 | "gigaam-v3-e2e-rnnt", 57 | "kaldi-rnnt", 58 | "nemo-conformer-ctc", 59 | "nemo-conformer-rnnt", 60 | "nemo-conformer-tdt", 61 | "nemo-conformer-aed", 62 | "vosk", 63 | "whisper-ort", 64 | "whisper", 65 | ] 66 | VadNames = Literal["silero"] 67 | 68 | 69 | class ModelNotSupportedError(ValueError): 70 | """Model not supported error.""" 71 | 72 | def __init__(self, model: str): 73 | """Create error.""" 74 | super().__init__(f"Model '{model}' not supported!") 75 | 76 | 77 | class ModelPathNotDirectoryError(NotADirectoryError): 78 | """Model path not a directory error.""" 79 | 80 | def __init__(self, path: str | Path): 81 | """Create error.""" 82 | super().__init__(f"The path '{path}' is not a directory.") 83 | 84 | 85 | class ModelFileNotFoundError(FileNotFoundError): 86 | """Model file not found error.""" 87 | 88 | def __init__(self, filename: str | Path, path: str | Path): 89 | """Create error.""" 90 | super().__init__(f"File '{filename}' not found in path '{path}'.") 91 | 92 | 93 | class MoreThanOneModelFileFoundError(Exception): 94 | """More than one model file found error.""" 95 | 96 | def __init__(self, filename: str | Path, path: str | Path): 97 | """Create error.""" 98 | super().__init__(f"Found more than 1 file '{filename}' found in path '{path}'.") 99 | 100 | 101 | class NoModelNameOrPathSpecifiedError(Exception): 102 | """No model name or path specified error.""" 103 | 104 | def __init__(self) -> None: 105 | """Create error.""" 106 | super().__init__("If the path is not specified, you must specify a specific model name.") 107 | 108 | 109 | class InvalidModelTypeInConfigError(Exception): 110 | """Invalid model type in config error.""" 111 | 112 | def __init__(self, model_type: str) -> None: 113 | """Create error.""" 114 | super().__init__(f"Invalid model type '{model_type}' in config.json.") 115 | 116 | 117 | def _download_config(repo_id: str) -> str: 118 | from huggingface_hub import hf_hub_download # noqa: PLC0415 119 | 120 | return hf_hub_download(repo_id, "config.json") 121 | 122 | 123 | def _download_model(repo_id: str, files: list[str], *, local_dir: str | Path | None, local_files_only: bool) -> str: 124 | from huggingface_hub import snapshot_download # noqa: PLC0415 125 | 126 | files = [ 127 | "config.json", 128 | *files, 129 | *(str(path.with_suffix(".onnx?data")) for file in files if (path := Path(file)).suffix == ".onnx"), 130 | ] 131 | return snapshot_download(repo_id, local_dir=local_dir, local_files_only=local_files_only, allow_patterns=files) 132 | 133 | 134 | def _find_files(path: str | Path, files: dict[str, str]) -> dict[str, Path]: 135 | if not Path(path).is_dir(): 136 | raise ModelPathNotDirectoryError(path) 137 | 138 | if Path(path, "config.json").exists(): 139 | files |= {"config": "config.json"} 140 | 141 | def find(filename: str) -> Path: 142 | files = list(Path(path).glob(filename)) 143 | if len(files) == 0: 144 | raise ModelFileNotFoundError(filename, path) 145 | if len(files) > 1: 146 | raise MoreThanOneModelFileFoundError(filename, path) 147 | return files[0] 148 | 149 | return {key: find(filename) for key, filename in files.items()} 150 | 151 | 152 | def _download_files(repo_id: str | None, path: str | Path | None, files: dict[str, str]) -> dict[str, Path]: 153 | if path is not None and Path(path).exists(): 154 | return _find_files(path, files) 155 | 156 | if repo_id is None: 157 | raise NoModelNameOrPathSpecifiedError 158 | 159 | try: 160 | return _find_files(_download_model(repo_id, list(files.values()), local_dir=path, local_files_only=True), files) 161 | except (FileNotFoundError, ModelFileNotFoundError): 162 | return _find_files(_download_model(repo_id, list(files.values()), local_dir=path, local_files_only=False), files) 163 | 164 | 165 | def _find_model_type(repo_id: str, path: str | Path | None) -> str: 166 | if path is not None and Path(path).exists(): 167 | config_path = Path(path, "config.json") 168 | if not config_path.is_file(): 169 | raise ModelFileNotFoundError(config_path.name, path) 170 | else: 171 | config_path = Path(_download_config(repo_id)) 172 | 173 | with config_path.open("rt", encoding="utf-8") as f: 174 | config = json.load(f) 175 | config_model_type: str = config.get("model_type") 176 | if config_model_type not in get_args(ModelTypes): 177 | raise InvalidModelTypeInConfigError(config_model_type) 178 | 179 | return config_model_type 180 | 181 | 182 | def load_model( # noqa: C901 183 | model: str | ModelNames | ModelTypes, 184 | path: str | Path | None = None, 185 | *, 186 | quantization: str | None = None, 187 | sess_options: rt.SessionOptions | None = None, 188 | providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, 189 | provider_options: Sequence[dict[Any, Any]] | None = None, 190 | cpu_preprocessing: bool = True, 191 | preprocessor_config: PreprocessorRuntimeConfig | None = None, 192 | resampler_config: OnnxSessionOptions | None = None, 193 | ) -> TextResultsAsrAdapter: 194 | """Load ASR model. 195 | 196 | Args: 197 | model: Model name or type (download from Hugging Face supported if full model name is provided): 198 | GigaAM v2 (`gigaam-v2-ctc` | `gigaam-v2-rnnt`) 199 | GigaAM v3 (`gigaam-v3-ctc` | `gigaam-v3-rnnt` | `gigaam-v3-e2e-ctc` | `gigaam-v3-e2e-rnnt`) 200 | Kaldi Transducer (`kaldi-rnnt`) 201 | NeMo Conformer (`nemo-conformer-ctc` | `nemo-conformer-rnnt` | `nemo-conformer-tdt` | `nemo-conformer-aed`) 202 | NeMo FastConformer Hybrid Large Ru P&C (`nemo-fastconformer-ru-ctc` | `nemo-fastconformer-ru-rnnt`) 203 | NeMo Parakeet 0.6B En (`nemo-parakeet-ctc-0.6b` | `nemo-parakeet-rnnt-0.6b` | `nemo-parakeet-tdt-0.6b-v2`) 204 | NeMo Parakeet 0.6B Multilingual (`nemo-parakeet-tdt-0.6b-v3`) 205 | NeMo Canary (`nemo-canary-1b-v2`) 206 | T-One (`t-tech/t-one`) 207 | Vosk (`vosk` | `alphacep/vosk-model-ru` | `alphacep/vosk-model-small-ru`) 208 | Whisper Base exported with onnxruntime (`whisper-ort` | `whisper-base-ort`) 209 | Whisper from onnx-community (`whisper` | `onnx-community/whisper-large-v3-turbo` | `onnx-community/*whisper*`) 210 | path: Path to directory with model files. 211 | quantization: Model quantization (`None` | `int8` | ... ). 212 | sess_options: Optional SessionOptions for onnxruntime. 213 | providers: Optional providers for onnxruntime. 214 | provider_options: Optional provider_options for onnxruntime. 215 | cpu_preprocessing: Run preprocessors on CPU. 216 | preprocessor_config: Preprocessor ONNX and concurrency config. 217 | resampler_config: Resampler ONNX config. 218 | 219 | Returns: 220 | ASR model class. 221 | 222 | """ 223 | if "/" in model and not model.startswith(("alphacep/", "t-tech/")): 224 | repo_id = model 225 | model = _find_model_type(repo_id, path) 226 | else: 227 | repo_id = None 228 | 229 | model_type: type[ 230 | GigaamV2Ctc 231 | | GigaamV2Rnnt 232 | | KaldiTransducer 233 | | NemoConformerCtc 234 | | NemoConformerRnnt 235 | | NemoConformerAED 236 | | TOneCtc 237 | | WhisperOrt 238 | | WhisperHf 239 | ] 240 | default_repo_id = None 241 | match model: 242 | case "gigaam-v2-ctc": 243 | model_type = GigaamV2Ctc 244 | default_repo_id = "istupakov/gigaam-v2-onnx" 245 | case "gigaam-v2-rnnt": 246 | model_type = GigaamV2Rnnt 247 | default_repo_id = "istupakov/gigaam-v2-onnx" 248 | case "gigaam-v3-ctc": 249 | model_type = GigaamV2Ctc 250 | default_repo_id = "istupakov/gigaam-v3-onnx" 251 | case "gigaam-v3-rnnt": 252 | model_type = GigaamV2Rnnt 253 | default_repo_id = "istupakov/gigaam-v3-onnx" 254 | case "gigaam-v3-e2e-ctc": 255 | model_type = GigaamV3E2eCtc 256 | default_repo_id = "istupakov/gigaam-v3-onnx" 257 | case "gigaam-v3-e2e-rnnt": 258 | model_type = GigaamV3E2eRnnt 259 | default_repo_id = "istupakov/gigaam-v3-onnx" 260 | case "kaldi-rnnt" | "vosk": 261 | model_type = KaldiTransducer 262 | case "alphacep/vosk-model-ru" | "alphacep/vosk-model-small-ru": 263 | model_type = KaldiTransducer 264 | default_repo_id = model 265 | case "nemo-conformer-ctc": 266 | model_type = NemoConformerCtc 267 | case "nemo-fastconformer-ru-ctc": 268 | model_type = NemoConformerCtc 269 | default_repo_id = "istupakov/stt_ru_fastconformer_hybrid_large_pc_onnx" 270 | case "nemo-parakeet-ctc-0.6b": 271 | model_type = NemoConformerCtc 272 | default_repo_id = "istupakov/parakeet-ctc-0.6b-onnx" 273 | case "nemo-conformer-rnnt": 274 | model_type = NemoConformerRnnt 275 | case "nemo-fastconformer-ru-rnnt": 276 | model_type = NemoConformerRnnt 277 | default_repo_id = "istupakov/stt_ru_fastconformer_hybrid_large_pc_onnx" 278 | case "nemo-parakeet-rnnt-0.6b": 279 | model_type = NemoConformerRnnt 280 | default_repo_id = "istupakov/parakeet-rnnt-0.6b-onnx" 281 | case "nemo-conformer-tdt": 282 | model_type = NemoConformerTdt 283 | case "nemo-parakeet-tdt-0.6b-v2": 284 | model_type = NemoConformerTdt 285 | default_repo_id = "istupakov/parakeet-tdt-0.6b-v2-onnx" 286 | case "nemo-parakeet-tdt-0.6b-v3": 287 | model_type = NemoConformerTdt 288 | default_repo_id = "istupakov/parakeet-tdt-0.6b-v3-onnx" 289 | case "nemo-conformer-aed": 290 | model_type = NemoConformerAED 291 | case "nemo-canary-1b-v2": 292 | model_type = NemoConformerAED 293 | default_repo_id = "istupakov/canary-1b-v2-onnx" 294 | case "t-tech/t-one": 295 | model_type = TOneCtc 296 | default_repo_id = model 297 | case "whisper-ort": 298 | model_type = WhisperOrt 299 | case "whisper-base": 300 | model_type = WhisperOrt 301 | default_repo_id = "istupakov/whisper-base-onnx" 302 | case "whisper": 303 | model_type = WhisperHf 304 | case _: 305 | raise ModelNotSupportedError(model) 306 | 307 | onnx_options: PreprocessorRuntimeConfig = { 308 | "sess_options": sess_options, 309 | "providers": providers or rt.get_available_providers(), 310 | "provider_options": provider_options, 311 | } 312 | 313 | if resampler_config is None: 314 | resampler_config = {"sess_options": sess_options} if cpu_preprocessing else onnx_options 315 | 316 | if preprocessor_config is None: 317 | preprocessor_config = {"sess_options": sess_options} if cpu_preprocessing else onnx_options 318 | preprocessor_config |= {"max_concurrent_workers": 1} 319 | 320 | return TextResultsAsrAdapter( 321 | model_type( 322 | _download_files(repo_id or default_repo_id, path, model_type._get_model_files(quantization)), 323 | AsrRuntimeConfig(onnx_options, preprocessor_config), 324 | ), 325 | Resampler(model_type._get_sample_rate(), resampler_config), 326 | ) 327 | 328 | 329 | def load_vad( 330 | model: VadNames = "silero", 331 | path: str | Path | None = None, 332 | *, 333 | quantization: str | None = None, 334 | sess_options: rt.SessionOptions | None = None, 335 | providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, 336 | provider_options: Sequence[dict[Any, Any]] | None = None, 337 | ) -> Vad: 338 | """Load VAD model. 339 | 340 | Args: 341 | model: VAD model name (supports download from Hugging Face). 342 | path: Path to directory with model files. 343 | quantization: Model quantization (`None` | `int8` | ... ). 344 | sess_options: Optional SessionOptions for onnxruntime. 345 | providers: Optional providers for onnxruntime. 346 | provider_options: Optional provider_options for onnxruntime. 347 | 348 | Returns: 349 | VAD model class. 350 | 351 | """ 352 | model_type: type[SileroVad | PyAnnoteVad] 353 | match model: 354 | case "silero": 355 | model_type = SileroVad 356 | repo_id = "onnx-community/silero-vad" 357 | case "pyannote": 358 | model_type = PyAnnoteVad 359 | repo_id = "onnx-community/pyannote-segmentation-3.0" 360 | case _: 361 | raise ModelNotSupportedError(model) 362 | 363 | onnx_options: OnnxSessionOptions = { 364 | "sess_options": sess_options, 365 | "providers": providers or rt.get_available_providers(), 366 | "provider_options": provider_options, 367 | } 368 | 369 | return model_type(_download_files(repo_id, path, model_type._get_model_files(quantization)), onnx_options) 370 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX ASR 2 | 3 | [![PyPI - Version](https://img.shields.io/pypi/v/onnx-asr)](https://pypi.org/project/onnx-asr) 4 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/onnx-asr)](https://pypi.org/project/onnx-asr) 5 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-asr)](https://pypi.org/project/onnx-asr) 6 | [![PyPI - Types](https://img.shields.io/pypi/types/onnx-asr)](https://pypi.org/project/onnx-asr) 7 | [![GitHub - License](https://img.shields.io/github/license/istupakov/onnx-asr)](https://github.com/istupakov/onnx-asr/blob/main/LICENSE) 8 | [![GitHub - CI](https://github.com/istupakov/onnx-asr/actions/workflows/python-package.yml/badge.svg)](https://github.com/istupakov/onnx-asr/actions/workflows/python-package.yml) 9 | [![GitHub - Release Date](https://img.shields.io/github/release-date/istupakov/onnx-asr)](https://github.com/istupakov/onnx-asr/releases/latest) 10 | 11 | [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-xl-dark.svg)](https://istupakov-onnx-asr.hf.space/) 12 | 13 | **onnx-asr** is a Python package for Automatic Speech Recognition using ONNX models. It's written in pure Python with minimal dependencies (no PyTorch, Transformers, or FFmpeg required): 14 | 15 | [![numpy](https://img.shields.io/badge/numpy-required-blue?logo=numpy)](https://pypi.org/project/numpy/) 16 | [![onnxruntime](https://img.shields.io/badge/onnxruntime-required-blue?logo=onnx)](https://pypi.org/project/onnxruntime/) 17 | [![huggingface-hub](https://img.shields.io/badge/huggingface--hub-optional-blue?logo=huggingface)](https://pypi.org/project/huggingface-hub/) 18 | 19 | > [!TIP] 20 | > Supports **Parakeet v2 (En) / v3 (Multilingual)**, **Canary v2 (Multilingual)** and **GigaAM v2/v3 (Ru)** models! 21 | 22 | The **onnx-asr** package supports many modern ASR [models](#supported-models-architectures) and the following features: 23 | * Runs on Windows, Linux, and MacOS on a variety of devices, from IoT devices with Arm CPUs to servers with Nvidia GPUs ([benchmarks](#benchmarks)) 24 | * Loading models from hugging face or local folders (including quantized versions) 25 | * Accepts wav files or NumPy arrays (built-in support for file reading and resampling) 26 | * Batch processing 27 | * (experimental) Longform recognition with VAD (Voice Activity Detection) 28 | * (experimental) Returns token timestamps 29 | * Simple CLI 30 | * Online demo in [HF Spaces](https://istupakov-onnx-asr.hf.space/) 31 | 32 | ## Supported models architectures 33 | 34 | The package supports the following modern ASR model architectures ([comparison](#comparison-with-original-implementations) with original implementations): 35 | * Nvidia NeMo Conformer/FastConformer/Parakeet/Canary (with CTC, RNN-T, TDT and Transformer decoders) 36 | * Kaldi Icefall Zipformer (with stateless RNN-T decoder) including Alpha Cephei Vosk 0.52+ 37 | * Sber GigaAM v2/v3 (with CTC and RNN-T decoders, including E2E versions) 38 | * T-Tech T-one (with CTC decoder, no streaming support yet) 39 | * OpenAI Whisper 40 | 41 | When saving these models in onnx format, usually only the encoder and decoder are saved. To run them, the corresponding preprocessor and decoding must be implemented. Therefore, the package contains these implementations for all supported models: 42 | * Log-mel spectrogram preprocessors 43 | * Greedy search decoding 44 | 45 | ## Installation 46 | 47 | The package can be installed from [PyPI](https://pypi.org/project/onnx-asr/): 48 | 49 | 1. With CPU `onnxruntime` and `huggingface-hub` 50 | ```shell 51 | pip install onnx-asr[cpu,hub] 52 | ``` 53 | 2. With GPU `onnxruntime` and `huggingface-hub` 54 | 55 | > [!IMPORTANT] 56 | > First, you need to install the [required](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements) version of CUDA. 57 | 58 | ```shell 59 | pip install onnx-asr[gpu,hub] 60 | ``` 61 | 62 | 3. Without `onnxruntime` and `huggingface-hub` (if you already have some version of `onnxruntime` installed and prefer to download the models yourself) 63 | ```shell 64 | pip install onnx-asr 65 | ``` 66 | 4. To build onnx-asr from source, you need to install [pdm](https://pdm-project.org/en/latest/#installation). Then you can build onnx-asr with command: 67 | ```shell 68 | pdm build 69 | ``` 70 | 71 | ## Usage examples 72 | 73 | ### Load ONNX model from Hugging Face 74 | 75 | Load ONNX model from Hugging Face and recognize wav file: 76 | ```py 77 | import onnx_asr 78 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3") 79 | print(model.recognize("test.wav")) 80 | ``` 81 | 82 | > [!IMPORTANT] 83 | > Supported wav file formats: PCM_U8, PCM_16, PCM_24 and PCM_32 formats. For other formats, you either need to convert them first, or use a library that can read them into a numpy array. 84 | 85 | #### Supported model names: 86 | * `gigaam-v2-ctc` for Sber GigaAM v2 CTC ([origin](https://github.com/salute-developers/GigaAM), [onnx](https://huggingface.co/istupakov/gigaam-v2-onnx)) 87 | * `gigaam-v2-rnnt` for Sber GigaAM v2 RNN-T ([origin](https://github.com/salute-developers/GigaAM), [onnx](https://huggingface.co/istupakov/gigaam-v2-onnx)) 88 | * `gigaam-v3-ctc` for Sber GigaAM v3 CTC ([origin](https://github.com/salute-developers/GigaAM), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx)) 89 | * `gigaam-v3-rnnt` for Sber GigaAM v3 RNN-T ([origin](https://github.com/salute-developers/GigaAM), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx)) 90 | * `gigaam-v3-e2e-ctc` for Sber GigaAM v3 E2E CTC ([origin](https://github.com/salute-developers/GigaAM), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx)) 91 | * `gigaam-v3-e2e-rnnt` for Sber GigaAM v3 E2E RNN-T ([origin](https://github.com/salute-developers/GigaAM), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx)) 92 | * `nemo-fastconformer-ru-ctc` for Nvidia FastConformer-Hybrid Large (ru) with CTC decoder ([origin](https://huggingface.co/nvidia/stt_ru_fastconformer_hybrid_large_pc), [onnx](https://huggingface.co/istupakov/stt_ru_fastconformer_hybrid_large_pc_onnx)) 93 | * `nemo-fastconformer-ru-rnnt` for Nvidia FastConformer-Hybrid Large (ru) with RNN-T decoder ([origin](https://huggingface.co/nvidia/stt_ru_fastconformer_hybrid_large_pc), [onnx](https://huggingface.co/istupakov/stt_ru_fastconformer_hybrid_large_pc_onnx)) 94 | * `nemo-parakeet-ctc-0.6b` for Nvidia Parakeet CTC 0.6B (en) ([origin](https://huggingface.co/nvidia/parakeet-ctc-0.6b), [onnx](https://huggingface.co/istupakov/parakeet-ctc-0.6b-onnx)) 95 | * `nemo-parakeet-rnnt-0.6b` for Nvidia Parakeet RNNT 0.6B (en) ([origin](https://huggingface.co/nvidia/parakeet-rnnt-0.6b), [onnx](https://huggingface.co/istupakov/parakeet-rnnt-0.6b-onnx)) 96 | * `nemo-parakeet-tdt-0.6b-v2` for Nvidia Parakeet TDT 0.6B V2 (en) ([origin](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2), [onnx](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v2-onnx)) 97 | * `nemo-parakeet-tdt-0.6b-v3` for Nvidia Parakeet TDT 0.6B V3 (multilingual) ([origin](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3), [onnx](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx)) 98 | * `nemo-canary-1b-v2` for Nvidia Canary 1B V2 (multilingual) ([origin](https://huggingface.co/nvidia/canary-1b-v2), [onnx](https://huggingface.co/istupakov/canary-1b-v2-onnx)) 99 | * `whisper-base` for OpenAI Whisper Base exported with onnxruntime ([origin](https://huggingface.co/openai/whisper-base), [onnx](https://huggingface.co/istupakov/whisper-base-onnx)) 100 | * `alphacep/vosk-model-ru` for Alpha Cephei Vosk 0.54-ru ([origin](https://huggingface.co/alphacep/vosk-model-ru)) 101 | * `alphacep/vosk-model-small-ru` for Alpha Cephei Vosk 0.52-small-ru ([origin](https://huggingface.co/alphacep/vosk-model-small-ru)) 102 | * `t-tech/t-one` for T-Tech T-one ([origin](https://huggingface.co/t-tech/T-one)) 103 | * `onnx-community/whisper-tiny`, `onnx-community/whisper-base`, `onnx-community/whisper-small`, `onnx-community/whisper-large-v3-turbo`, etc. for OpenAI Whisper exported with Hugging Face optimum ([onnx-community](https://huggingface.co/onnx-community?search_models=whisper)) 104 | 105 | > [!IMPORTANT] 106 | > Some long-ago converted `onnx-community` models have a broken `fp16` precision version. 107 | 108 | > [!IMPORTANT] 109 | > Canary models do not work with the CoreML provider. 110 | 111 | Example with `soundfile`: 112 | ```py 113 | import onnx_asr 114 | import soundfile as sf 115 | 116 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3") 117 | 118 | waveform, sample_rate = sf.read("test.wav", dtype="float32") 119 | model.recognize(waveform, sample_rate=sample_rate) 120 | ``` 121 | 122 | Batch processing is also supported: 123 | ```py 124 | import onnx_asr 125 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3") 126 | print(model.recognize(["test1.wav", "test2.wav", "test3.wav", "test4.wav"])) 127 | ``` 128 | 129 | Most models have a quantized versions: 130 | ```py 131 | import onnx_asr 132 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3", quantization="int8") 133 | print(model.recognize("test.wav")) 134 | ``` 135 | 136 | Return tokens and timestamps: 137 | ```py 138 | import onnx_asr 139 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3").with_timestamps() 140 | print(model.recognize("test1.wav")) 141 | ``` 142 | 143 | ### VAD 144 | 145 | Load VAD ONNX model from Hugging Face and recognize wav file: 146 | ```py 147 | import onnx_asr 148 | vad = onnx_asr.load_vad("silero") 149 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3").with_vad(vad) 150 | for res in model.recognize("test.wav"): 151 | print(res) 152 | ``` 153 | 154 | > [!NOTE] 155 | > You will most likely need to adjust VAD parameters to get the correct results. 156 | 157 | #### Supported VAD names: 158 | * `silero` for Silero VAD ([origin](https://github.com/snakers4/silero-vad), [onnx](https://huggingface.co/onnx-community/silero-vad)) 159 | 160 | ### CLI 161 | 162 | Package has simple CLI interface 163 | ```shell 164 | onnx-asr nemo-parakeet-tdt-0.6b-v3 test.wav 165 | ``` 166 | 167 | For full usage parameters, see help: 168 | ```shell 169 | onnx-asr -h 170 | ``` 171 | 172 | ### Gradio 173 | 174 | Create simple web interface with Gradio: 175 | ```py 176 | import onnx_asr 177 | import gradio as gr 178 | 179 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3") 180 | 181 | def recognize(audio): 182 | if not audio: 183 | return None 184 | 185 | sample_rate, waveform = audio 186 | waveform = waveform / 2**15 187 | if waveform.ndim == 2: 188 | waveform = waveform.mean(axis=1) 189 | return model.recognize(waveform, sample_rate=sample_rate) 190 | 191 | demo = gr.Interface(fn=recognize, inputs="audio", outputs="text") 192 | demo.launch() 193 | ``` 194 | 195 | ### Load ONNX model from local directory 196 | 197 | Load ONNX model from local directory and recognize wav file: 198 | ```py 199 | import onnx_asr 200 | model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3", "models/parakeet-v3") 201 | print(model.recognize("test.wav")) 202 | ``` 203 | 204 | > [!NOTE] 205 | > If the directory does not exist, it will be created and the model will be loaded into it. 206 | 207 | #### Supported model types: 208 | * All models from [supported model names](#supported-model-names) 209 | * `nemo-conformer-ctc` for NeMo Conformer/FastConformer/Parakeet with CTC decoder 210 | * `nemo-conformer-rnnt` for NeMo Conformer/FastConformer/Parakeet with RNN-T decoder 211 | * `nemo-conformer-tdt` for NeMo Conformer/FastConformer/Parakeet with TDT decoder 212 | * `nemo-conformer-aed` for NeMo Canary with Transformer decoder 213 | * `kaldi-rnnt` or `vosk` for Kaldi Icefall Zipformer with stateless RNN-T decoder 214 | * `whisper-ort` for Whisper (exported with [onnxruntime](#openai-whisper-with-onnxruntime-export)) 215 | * `whisper` for Whisper (exported with [optimum](#openai-whisper-with-optimum-export)) 216 | 217 | ## Comparison with original implementations 218 | 219 | Packages with original implementations: 220 | * `gigaam` for GigaAM models ([github](https://github.com/salute-developers/GigaAM)) 221 | * `nemo-toolkit` for NeMo models ([github](https://github.com/nvidia/nemo)) 222 | * `openai-whisper` for Whisper models ([github](https://github.com/openai/whisper)) 223 | * `sherpa-onnx` for Vosk models ([github](https://github.com/k2-fsa/sherpa-onnx), [docs](https://k2-fsa.github.io/sherpa/onnx/index.html)) 224 | * `T-one` for T-Tech T-one model ([github](https://github.com/voicekit-team/T-one)) 225 | 226 | Hardware: 227 | 1. CPU tests were run on a laptop with an Intel i7-7700HQ processor. 228 | 2. GPU tests were run in Google Colab on Nvidia T4 229 | 230 | Tests of Russian ASR models were performed on a *test* subset of the [Russian LibriSpeech](https://huggingface.co/datasets/istupakov/russian_librispeech) dataset. 231 | 232 | | Model | Package / decoding | CER | WER | RTFx (CPU) | RTFx (GPU) | 233 | |---------------------------|----------------------|--------|--------|------------|--------------| 234 | | GigaAM v2 CTC | default | 1.06% | 5.23% | 7.2 | 44.2 | 235 | | GigaAM v2 CTC | onnx-asr | 1.06% | 5.23% | 11.6 | 64.3 | 236 | | GigaAM v2 RNN-T | default | 1.10% | 5.22% | 5.5 | 23.3 | 237 | | GigaAM v2 RNN-T | onnx-asr | 1.10% | 5.22% | 10.7 | 38.7 | 238 | | GigaAM v3 CTC | default | 0.98% | 4.72% | 12.2 | 73.3 | 239 | | GigaAM v3 CTC | onnx-asr | 0.98% | 4.72% | 14.5 | 68.3 | 240 | | GigaAM v3 RNN-T | default | 0.93% | 4.39% | 8.2 | 41.6 | 241 | | GigaAM v3 RNN-T | onnx-asr | 0.93% | 4.39% | 13.3 | 39.9 | 242 | | GigaAM v3 E2E CTC | default | 1.50% | 7.10% | N/A | 178.0 | 243 | | GigaAM v3 E2E CTC | onnx-asr | 1.56% | 7.80% | N/A | 65.6 | 244 | | GigaAM v3 E2E RNN-T | default | 1.61% | 6.94% | N/A | 47.6 | 245 | | GigaAM v3 E2E RNN-T | onnx-asr | 1.67% | 7.60% | N/A | 42.8 | 246 | | Nemo FastConformer CTC | default | 3.11% | 13.12% | 29.1 | 143.0 | 247 | | Nemo FastConformer CTC | onnx-asr | 3.11% | 13.12% | 45.8 | 103.3 | 248 | | Nemo FastConformer RNN-T | default | 2.63% | 11.62% | 17.4 | 111.6 | 249 | | Nemo FastConformer RNN-T | onnx-asr | 2.63% | 11.62% | 27.2 | 53.4 | 250 | | Nemo Parakeet TDT 0.6B V3 | default | 2.34% | 10.95% | 5.6 | 75.4 | 251 | | Nemo Parakeet TDT 0.6B V3 | onnx-asr | 2.38% | 10.95% | 9.7 | 59.7 | 252 | | Nemo Canary 1B V2 | default | 4.89% | 20.00% | N/A | 14.0 | 253 | | Nemo Canary 1B V2 | onnx-asr | 5.00% | 20.03% | N/A | 17.4 | 254 | | T-Tech T-one | default | 1.28% | 6.56% | 11.9 | N/A | 255 | | T-Tech T-one | onnx-asr | 1.28% | 6.57% | 11.7 | 16.5 | 256 | | Vosk 0.52 small | greedy_search | 3.64% | 14.53% | 48.2 | 71.4 | 257 | | Vosk 0.52 small | modified_beam_search | 3.50% | 14.25% | 29.0 | 24.7 | 258 | | Vosk 0.52 small | onnx-asr | 3.64% | 14.53% | 45.5 | 75.2 | 259 | | Vosk 0.54 | greedy_search | 2.21% | 9.89% | 34.8 | 64.2 | 260 | | Vosk 0.54 | modified_beam_search | 2.21% | 9.85% | 23.9 | 24 | 261 | | Vosk 0.54 | onnx-asr | 2.21% | 9.89% | 33.6 | 69.6 | 262 | | Whisper base | default | 10.61% | 38.89% | 5.4 | 17.3 | 263 | | Whisper base | onnx-asr* | 10.64% | 38.33% | 6.6 | 20.1 | 264 | | Whisper large-v3-turbo | default | 2.96% | 10.27% | N/A | 13.6 | 265 | | Whisper large-v3-turbo | onnx-asr** | 2.63% | 10.13% | N/A | 12.4 | 266 | 267 | Tests of English ASR models were performed on a *test* subset of the [Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) dataset. 268 | 269 | | Model | Package / decoding | CER | WER | RTFx (CPU) | RTFx (GPU) | 270 | |---------------------------|----------------------|--------|--------|------------|--------------| 271 | | Nemo Parakeet CTC 0.6B | default | 4.09% | 7.20% | 8.3 | 107.7 | 272 | | Nemo Parakeet CTC 0.6B | onnx-asr | 4.09% | 7.20% | 11.5 | 89.0 | 273 | | Nemo Parakeet RNN-T 0.6B | default | 3.64% | 6.32% | 6.7 | 85.0 | 274 | | Nemo Parakeet RNN-T 0.6B | onnx-asr | 3.64% | 6.32% | 8.7 | 48.0 | 275 | | Nemo Parakeet TDT 0.6B V2 | default | 3.88% | 6.52% | 6.5 | 87.6 | 276 | | Nemo Parakeet TDT 0.6B V2 | onnx-asr | 3.88% | 6.52% | 10.5 | 70.1 | 277 | | Nemo Parakeet TDT 0.6B V3 | default | 3.97% | 6.76% | 6.1 | 90.0 | 278 | | Nemo Parakeet TDT 0.6B V3 | onnx-asr | 3.97% | 6.75% | 9.5 | 68.2 | 279 | | Nemo Canary 1B V2 | default | 4.62% | 7.42% | N/A | 17.5 | 280 | | Nemo Canary 1B V2 | onnx-asr | 4.67% | 7.47% | N/A | 20.8 | 281 | | Whisper base | default | 7.81% | 13.24% | 8.4 | 27.7 | 282 | | Whisper base | onnx-asr* | 7.52% | 12.76% | 9.2 | 28.9 | 283 | | Whisper large-v3-turbo | default | 6.85% | 11.16% | N/A | 20.4 | 284 | | Whisper large-v3-turbo | onnx-asr** | 10.31% | 14.65% | N/A | 17.9 | 285 | 286 | > [!NOTE] 287 | > 1. \* `whisper-ort` model ([model types](#supported-model-types)). 288 | > 2. ** `whisper` model ([model types](#supported-model-types)) with `fp16` precision. 289 | > 3. All other models were run with the default precision - `fp32` on CPU and `fp32` or `fp16` (some of the original models) on GPU. 290 | 291 | ## Benchmarks 292 | 293 | Hardware: 294 | 1. Arm tests were run on an Orange Pi Zero 3 with a Cortex-A53 processor. 295 | 2. x64 tests were run on a laptop with an Intel i7-7700HQ processor. 296 | 3. T4 tests were run in Google Colab on Nvidia T4 297 | 298 | ### Russian ASR models 299 | Notebook with benchmark code - [benchmark-ru](examples/benchmark-ru.ipynb) 300 | 301 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/istupakov/onnx-asr/blob/main/examples/benchmark-ru.ipynb) 302 | 303 | | Model | RTFx (Arm) | RTFx (x64) | RTFx (T4) | 304 | |---------------------------|------------|------------|-----------| 305 | | GigaAM v2 CTC | 0.8 | 11.6 | 64.3 | 306 | | GigaAM v2 RNN-T | 0.8 | 10.7 | 38.7 | 307 | | GigaAM v3 CTC | N/A | 14.5 | 68.3 | 308 | | GigaAM v3 RNN-T | N/A | 13.3 | 39.9 | 309 | | Nemo FastConformer CTC | 4.0 | 45.8 | 103.3 | 310 | | Nemo FastConformer RNN-T | 3.2 | 27.2 | 53.4 | 311 | | Nemo Parakeet TDT 0.6B V3 | N/A | 9.7 | 59.7 | 312 | | Nemo Canary 1B V2 | N/A | N/A | 17.4 | 313 | | T-Tech T-one | N/A | 11.7 | 16.5 | 314 | | Vosk 0.52 small | 5.1 | 45.5 | 75.2 | 315 | | Vosk 0.54 | 3.8 | 33.6 | 69.6 | 316 | | Whisper base | 0.8 | 6.6 | 20.1 | 317 | | Whisper large-v3-turbo | N/A | N/A | 12.4 | 318 | 319 | ### English ASR models 320 | 321 | Notebook with benchmark code - [benchmark-en](examples/benchmark-en.ipynb) 322 | 323 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/istupakov/onnx-asr/blob/main/examples/benchmark-en.ipynb) 324 | 325 | | Model | RTFx (Arm) | RTFx (x64) | RTFx (T4) | 326 | |---------------------------|------------|------------|-----------| 327 | | Nemo Parakeet CTC 0.6B | 1.1 | 11.5 | 89.0 | 328 | | Nemo Parakeet RNN-T 0.6B | 1.0 | 8.7 | 48.0 | 329 | | Nemo Parakeet TDT 0.6B V2 | 1.1 | 10.5 | 70.1 | 330 | | Nemo Parakeet TDT 0.6B V3 | N/A | 9.5 | 68.2 | 331 | | Nemo Canary 1B V2 | N/A | N/A | 20.8 | 332 | | Whisper base | 1.2 | 9.2 | 28.9 | 333 | | Whisper large-v3-turbo | N/A | N/A | 17.9 | 334 | 335 | ## Convert model to ONNX 336 | 337 | Save the model according to the instructions below and add config.json: 338 | 339 | ```json 340 | { 341 | "model_type": "nemo-conformer-rnnt", // See "Supported model types" 342 | "features_size": 80, // Size of preprocessor features for Whisper or Nemo models, supported 80 and 128 343 | "subsampling_factor": 8, // Subsampling factor - 4 for conformer models and 8 for fastconformer and parakeet models 344 | "max_tokens_per_step": 10 // Max tokens per step for RNN-T decoder 345 | } 346 | ``` 347 | Then you can upload the model into Hugging Face and use `load_model` to download it. 348 | 349 | ### Nvidia NeMo Conformer/FastConformer/Parakeet 350 | Install **NeMo Toolkit** 351 | ```shell 352 | pip install nemo_toolkit['asr'] 353 | ``` 354 | 355 | Download model and export to ONNX format 356 | ```py 357 | import nemo.collections.asr as nemo_asr 358 | from pathlib import Path 359 | 360 | model = nemo_asr.models.ASRModel.from_pretrained("nvidia/stt_ru_fastconformer_hybrid_large_pc") 361 | 362 | # For export Hybrid models with CTC decoder 363 | # model.set_export_config({"decoder_type": "ctc"}) 364 | 365 | onnx_dir = Path("nemo-onnx") 366 | onnx_dir.mkdir(exist_ok=True) 367 | model.export(str(Path(onnx_dir, "model.onnx"))) 368 | 369 | with Path(onnx_dir, "vocab.txt").open("wt") as f: 370 | for i, token in enumerate([*model.tokenizer.vocab, ""]): 371 | f.write(f"{token} {i}\n") 372 | ``` 373 | 374 | ### Sber GigaAM v2/v3 375 | Install **GigaAM** 376 | ```shell 377 | git clone https://github.com/salute-developers/GigaAM.git 378 | pip install ./GigaAM --extra-index-url https://download.pytorch.org/whl/cpu 379 | ``` 380 | 381 | Download model and export to ONNX format 382 | ```py 383 | import gigaam 384 | from pathlib import Path 385 | 386 | onnx_dir = "gigaam-onnx" 387 | model_type = "rnnt" # or "ctc" 388 | 389 | model = gigaam.load_model( 390 | model_type, 391 | fp16_encoder=False, # only fp32 tensors 392 | use_flash=False, # disable flash attention 393 | ) 394 | model.to_onnx(dir_path=onnx_dir) 395 | 396 | with Path(onnx_dir, "v2_vocab.txt").open("wt") as f: 397 | for i, token in enumerate(["\u2581", *(chr(ord("а") + i) for i in range(32)), ""]): 398 | f.write(f"{token} {i}\n") 399 | ``` 400 | 401 | ### OpenAI Whisper (with `onnxruntime` export) 402 | 403 | Read onnxruntime [instruction](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/whisper/README.md) for convert Whisper to ONNX. 404 | 405 | Download model and export with *Beam Search* and *Forced Decoder Input Ids*: 406 | ```shell 407 | python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-base --output ./whisper-onnx --use_forced_decoder_ids --optimize_onnx --precision fp32 408 | ``` 409 | 410 | Save tokenizer config 411 | ```py 412 | from transformers import WhisperTokenizer 413 | 414 | processor = WhisperTokenizer.from_pretrained("openai/whisper-base") 415 | processor.save_pretrained("whisper-onnx") 416 | ``` 417 | 418 | ### OpenAI Whisper (with `optimum` export) 419 | 420 | Export model to ONNX with Hugging Face `optimum-cli` 421 | ```shell 422 | optimum-cli export onnx --model openai/whisper-base ./whisper-onnx 423 | ``` 424 | --------------------------------------------------------------------------------