├── tests ├── __init__.py └── test_wrapper.py ├── MANIFEST.in ├── simpler_whisper ├── __init__.py └── whisper.py ├── LICENSE ├── CMakeLists.txt ├── .github └── workflows │ ├── release.yaml │ └── build.yaml ├── .gitignore ├── pyproject.toml ├── README.md ├── setup.py ├── test_simpler_whisper.py ├── cmake └── BuildWhispercpp.cmake └── src └── whisper_wrapper.cpp /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include simpler_whisper/*.so 2 | include simpler_whisper/*.pyd 3 | include simpler_whisper/*.dll 4 | include simpler_whisper/*.py 5 | include simpler_whisper/*.metal 6 | -------------------------------------------------------------------------------- /simpler_whisper/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._whisper_cpp import * 3 | except ImportError as e: 4 | import sys 5 | print(f"Error importing modules: {e}", file=sys.stderr) 6 | try: 7 | from .whisper import ( 8 | WhisperModel, 9 | AsyncWhisperModel, 10 | ThreadedWhisperModel, 11 | WhisperSegment, 12 | WhisperToken, 13 | set_log_callback, 14 | LogLevel, 15 | ) 16 | 17 | __all__ = [ 18 | "WhisperModel", 19 | "AsyncWhisperModel", 20 | "ThreadedWhisperModel", 21 | "WhisperSegment", 22 | "WhisperToken", 23 | "set_log_callback", 24 | "LogLevel", 25 | ] 26 | except ImportError as e: 27 | import sys 28 | 29 | print(f"Error importing modules: {e}", file=sys.stderr) 30 | __all__ = [] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Locaal, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | project(whisper_cpp_wrapper) 3 | 4 | set(CMAKE_CXX_STANDARD 11) 5 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 | 7 | if(APPLE) 8 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -arch x86_64 -arch arm64") 9 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -arch x86_64 -arch arm64") 10 | endif() 11 | 12 | # Fetch pybind11 13 | include(FetchContent) 14 | FetchContent_Declare( 15 | pybind11 16 | DOWNLOAD_EXTRACT_TIMESTAMP TRUE 17 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 18 | GIT_TAG v2.13.6 # Specify a version/tag here 19 | ) 20 | FetchContent_MakeAvailable(pybind11) 21 | 22 | include(cmake/BuildWhispercpp.cmake) 23 | 24 | # Create the extension module 25 | pybind11_add_module(_whisper_cpp src/whisper_wrapper.cpp) 26 | target_link_libraries(_whisper_cpp PRIVATE Whispercpp) 27 | 28 | # Set the output directory for the built module 29 | set_target_properties( 30 | _whisper_cpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY 31 | ${CMAKE_CURRENT_SOURCE_DIR}/simpler_whisper) 32 | 33 | # Copy the DLL to the output directory on Windows 34 | if(WIN32 OR APPLE) 35 | foreach(WHISPER_ADDITIONAL_FILE ${WHISPER_ADDITIONAL_FILES}) 36 | add_custom_command( 37 | TARGET _whisper_cpp 38 | POST_BUILD 39 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 40 | "${WHISPER_ADDITIONAL_FILE}" $) 41 | endforeach() 42 | endif() 43 | 44 | if(APPLE) 45 | # Additional macOS-specific settings for the module 46 | set_target_properties(_whisper_cpp PROPERTIES 47 | INSTALL_RPATH "@loader_path" 48 | BUILD_WITH_INSTALL_RPATH TRUE 49 | ) 50 | endif() 51 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - '[0-9]+.[0-9]+.[0-9]+' 7 | - '[0-9]+.[0-9]+.[0-9]+a[0-9]*' 8 | - '[0-9]+.[0-9]+.[0-9]+b[0-9]*' 9 | - '[0-9]+.[0-9]+.[0-9]+rc[0-9]*' 10 | - '[0-9]+.[0-9]+.[0-9]+.post[0-9]*' 11 | - '[0-9]+.[0-9]+.[0-9]+.dev[0-9]*' 12 | 13 | jobs: 14 | build-project: 15 | name: Build Project 🧱 16 | uses: ./.github/workflows/build.yaml 17 | secrets: inherit 18 | permissions: 19 | contents: read 20 | 21 | create_release: 22 | needs: build-project 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Set up Python 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: '3.12' 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip build twine 35 | 36 | - name: Build source distribution 37 | run: python -m build --sdist 38 | 39 | - name: Download all workflow run artifacts 40 | uses: actions/download-artifact@v4 41 | 42 | - name: Create Release 43 | id: create_release 44 | uses: softprops/action-gh-release@v1 45 | env: 46 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 47 | with: 48 | tag_name: ${{ github.ref }} 49 | name: Release ${{ github.ref_name }} 50 | draft: false 51 | prerelease: false 52 | files: | 53 | dist/*.tar.gz 54 | */*.whl 55 | 56 | - name: Publish to PyPI 57 | env: 58 | TWINE_USERNAME: __token__ 59 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 60 | run: | 61 | python -m twine upload */*.whl 62 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | workflow_call: 9 | 10 | jobs: 11 | build: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: ['windows-latest', 'macos-latest', 'ubuntu-latest'] 16 | acceleration: ['cpu'] # , 'cuda' 17 | exclude: 18 | - os: macos-latest 19 | acceleration: cuda 20 | - os: ubuntu-latest 21 | acceleration: cuda 22 | 23 | runs-on: ${{ matrix.os }} 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | with: 28 | fetch-tags: true 29 | fetch-depth: 0 30 | 31 | - name: Set up Python 32 | uses: actions/setup-python@v5 33 | with: 34 | python-version: '3.12' 35 | 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install numpy cmake wheel setuptools build cibuildwheel 40 | 41 | - name: Get latest tag 42 | id: get_tag 43 | run: | 44 | echo "tag=$(git describe --tags --abbrev=1)" >> $GITHUB_OUTPUT 45 | shell: bash 46 | 47 | - name: Build wheel 48 | env: 49 | CIBW_ENVIRONMENT: "SIMPLER_WHISPER_ACCELERATION='${{ matrix.acceleration }}' MAOSX_DEPLOYMENT_TARGET=10.13" 50 | CIBW_BUILD: "cp310-* cp311-* cp312-*" 51 | CIBW_ARCHS_MACOS: "universal2" 52 | CIBW_ARCHS_WINDOWS: "AMD64" 53 | CIBW_ARCHS_LINUX: "x86_64" 54 | CIBW_SKIP: "*-musllinux_*" 55 | CIBW_BUILD_VERBOSITY: 1 56 | SIMPLER_WHISPER_VERSION: ${{ steps.get_tag.outputs.tag }} 57 | run: | 58 | python -m cibuildwheel --output-dir wheelhouse 59 | 60 | - name: Test import 61 | if: false 62 | run: | 63 | python -c "import sys; sys.path.pop(0); import simpler_whisper; print(simpler_whisper.__file__)" 64 | 65 | - name: Upload wheel 66 | uses: actions/upload-artifact@v4 67 | with: 68 | name: wheels-${{ matrix.os }}-${{ matrix.acceleration }} 69 | path: | 70 | wheelhouse/*.whl 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | db.sqlite3-journal 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyderworkspace 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # pytype static type analyzer 119 | .pytype/ 120 | 121 | # Cython debug symbols 122 | cython_debug/ 123 | 124 | # VS Code 125 | .vscode/ 126 | *.code-workspace 127 | 128 | *.dll -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "cmake>=3.12", "numpy<=1.26.4"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "simpler-whisper" 7 | dynamic = ["version"] 8 | authors = [ 9 | {name = "Roy Shilkrot", email = "roy.shil@gmail.com"}, 10 | ] 11 | description = "A simple Python wrapper for whisper.cpp" 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | classifiers = [ 15 | "Development Status :: 3 - Alpha", 16 | "Intended Audience :: Developers", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | ] 24 | dependencies = [ 25 | "numpy", 26 | ] 27 | 28 | [project.optional-dependencies] 29 | test = [ 30 | "pytest>=7.0", 31 | "requests>=2.0.0", 32 | ] 33 | 34 | [project.urls] 35 | "Homepage" = "https://github.com/locaal-ai/simpler-whisper" 36 | "Bug Tracker" = "https://github.com/locaal-ai/simpler-whisper/issues" 37 | 38 | [tool.setuptools] 39 | packages = ["simpler_whisper"] 40 | 41 | [tool.pytest] 42 | testpaths = ["tests"] 43 | python_files = "test_*.py" 44 | python_classes = "Test*" 45 | python_functions = "test_*" 46 | 47 | [tool.cibuildwheel] 48 | # Environment variables 49 | environment = { PIP_PREFER_BINARY="1" } 50 | 51 | # Build configuration 52 | build-verbosity = 1 53 | 54 | # Test configuration 55 | test-command = """ 56 | python -c " 57 | import os 58 | import sys 59 | import site 60 | site.addsitedir(os.path.abspath('..')) 61 | print('Python path:', sys.path) 62 | print('Working directory:', os.getcwd()) 63 | print('Directory listing:', os.listdir()) 64 | print('Parent directory:', os.listdir('..')) 65 | try: 66 | import simpler_whisper 67 | print('simpler_whisper found at:', simpler_whisper.__file__) 68 | except ImportError as e: 69 | print('Failed to import simpler_whisper:', e) 70 | " 71 | pytest {project}/tests 72 | """ 73 | 74 | test-extras = ["test"] 75 | 76 | # Before test configuration 77 | before-test = """ 78 | pip install pytest requests numpy 79 | python -c " 80 | import os 81 | import requests 82 | model_url = 'https://ggml.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin' 83 | model_path = 'ggml-tiny.en-q5_1.bin' 84 | if not os.path.exists(model_path): 85 | print('Downloading whisper model...') 86 | response = requests.get(model_url) 87 | with open(model_path, 'wb') as f: 88 | f.write(response.content) 89 | print('Model downloaded successfully') 90 | " 91 | """ 92 | 93 | [tool.cibuildwheel.macos] 94 | environment = { MACOSX_DEPLOYMENT_TARGET="10.13" } 95 | repair-wheel-command = """ 96 | MACOSX_DEPLOYMENT_TARGET=10.13 delocate-wheel --require-archs {delocate_archs} -w {dest_dir} {wheel} 97 | """ 98 | test-skip = "*_universal2:*" 99 | 100 | [tool.cibuildwheel.linux] 101 | test-skip = "*" 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simpler Whisper 2 | 3 | ![Build and Test](https://img.shields.io/github/actions/workflow/status/locaal-ai/simpler-whisper/build.yaml) 4 | 5 | A zero-dependency simple Python wrapper for [whisper.cpp](https://github.com/ggerganov/whisper.cpp), providing an easy-to-use interface for speech recognition using the Whisper model. 6 | 7 | Why is it better than [faster-whisper](https://github.com/SYSTRAN/faster-whisper) and [pywhispercpp](https://github.com/abdeladim-s/pywhispercpp): 8 | - Zero-dependency: Everything is shipped with the built wheel, no Python dependency (on `av` or `ctranslate2` etc.) except for `numpy`. 9 | - Dead simple API: call `.transcribe()` and get a result 10 | - Acceleration enabled: supports whatever whisper.cpp supports 11 | - Updated: using precompiled whisper.cpp from https://github.com/locaal-ai/occ-ai-dep-whispercpp 12 | - Build time: builds in 2 minutes because it's using a precompiled binary 13 | 14 | ## Installation 15 | 16 | To install simpler-whisper, you need: 17 | - A C++ compiler (e.g., GCC, Clang, or MSVC) 18 | - CMake (version 3.12 or higher) 19 | - NumPy 20 | 21 | Then you can install using pip: 22 | 23 | ```bash 24 | pip install simpler-whisper 25 | ``` 26 | 27 | ## Usage 28 | 29 | There are three ways to use simpler-whisper: 30 | 31 | ### 1. Basic Usage 32 | ```python 33 | from simpler_whisper.whisper import WhisperModel 34 | 35 | # Load the model (models can be downloaded from https://huggingface.co/ggerganov/whisper.cpp) 36 | model = WhisperModel("path/to/model.bin", use_gpu=True) 37 | 38 | # Load and prepare your audio 39 | # You can use av, librosa, or any method that gives you 16kHz mono float32 samples 40 | import av 41 | container = av.open("audio.mp3") 42 | audio_stream = container.streams.audio[0] 43 | samples = np.concatenate([ 44 | frame.to_ndarray().mean(axis=0) if frame.format.channels == 2 else frame.to_ndarray() 45 | for frame in container.decode(audio_stream) 46 | ]) 47 | 48 | # Transcribe 49 | transcription = model.transcribe(samples) 50 | for segment in transcription: 51 | print(f"{segment.text} ({segment.t0:.2f}s - {segment.t1:.2f}s)") 52 | ``` 53 | 54 | ### 2. Async Processing 55 | 56 | This will create a thread in the backend (not locked by the GIL) to allow for asynchronous transcription. 57 | 58 | ```python 59 | from simpler_whisper.whisper import AsyncWhisperModel 60 | 61 | def handle_result(chunk_id: int, segments: List[WhisperSegment], is_partial: bool): 62 | text = " ".join([seg.text for seg in segments]) 63 | print(f"Chunk {chunk_id}: {text}") 64 | 65 | # Create and start async model 66 | model = AsyncWhisperModel("path/to/model.bin", callback=handle_result, use_gpu=True) 67 | model.start() 68 | 69 | # Queue audio chunks for processing 70 | chunk_id = model.transcribe(audio_samples) 71 | 72 | # When done 73 | model.stop() 74 | ``` 75 | 76 | ### 3. Real-time Threaded Processing 77 | 78 | This method creates a background thread for real-time transcription that will continuously 79 | process the input in e.g. 10 seconds chunks and report on both final or partial results. 80 | 81 | ```python 82 | from simpler_whisper.whisper import ThreadedWhisperModel 83 | 84 | def handle_result(chunk_id: int, segments: List[WhisperSegment], is_partial: bool): 85 | text = " ".join([seg.text for seg in segments]) 86 | print(f"Chunk {chunk_id}: {text}") 87 | 88 | # Create and start threaded model with 10-second chunks 89 | model = ThreadedWhisperModel( 90 | "path/to/model.bin", 91 | callback=handle_result, 92 | use_gpu=True, 93 | max_duration_sec=10.0 94 | ) 95 | model.start() 96 | 97 | # Queue audio chunks as they arrive 98 | chunk_id = model.queue_audio(audio_samples) 99 | 100 | # When done 101 | model.stop() 102 | ``` 103 | 104 | ## Platform-specific notes 105 | 106 | - On Windows, the package uses a DLL (whisper.dll), which is included in the package. 107 | - On Mac and Linux, the package uses static libraries that are linked into the extension. 108 | 109 | ## Building from source 110 | 111 | If you're building from source: 112 | 1. Clone the repository: 113 | ``` 114 | git clone https://github.com/locaal-ai/simpler-whisper.git 115 | cd simpler-whisper 116 | ``` 117 | 2. Install the package in editable mode: 118 | ``` 119 | pip install -e . 120 | ``` 121 | 122 | This will run the CMake build process and compile the extension. 123 | 124 | ## Build Configuration 125 | 126 | Simpler Whisper supports various build configurations to optimize for different hardware and acceleration methods. You can specify the build configuration using environment variables: 127 | 128 | - `SIMPLER_WHISPER_ACCELERATION`: Specifies the acceleration method. Options are: 129 | - `cpu` (default) 130 | - `cuda` (for NVIDIA GPUs) 131 | - `hipblas` (for AMD GPUs) 132 | - `vulkan` (for cross-platform GPU acceleration) 133 | 134 | ### Example: Building for Windows with CUDA acceleration 135 | 136 | ```powershell 137 | $env:SIMPLER_WHISPER_ACCELERATION="cuda" 138 | pip install . 139 | ``` 140 | 141 | ### Example: Building for macOS ARM64 142 | 143 | ```bash 144 | pip install . 145 | ``` 146 | 147 | ## License 148 | 149 | This project is licensed under the MIT License - see the LICENSE file for details. 150 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from setuptools import setup, Extension 3 | from setuptools.command.build_ext import build_ext 4 | from setuptools.command.build_py import build_py 5 | import sys 6 | import os 7 | import subprocess 8 | import platform 9 | import sysconfig 10 | 11 | 12 | class CMakeExtension(Extension): 13 | def __init__(self, name, sourcedir=""): 14 | Extension.__init__(self, name, sources=[]) 15 | self.sourcedir = os.path.abspath(sourcedir) 16 | 17 | 18 | class BuildPyCommand(build_py): 19 | def run(self): 20 | self.run_command("build_ext") 21 | return super().run() 22 | 23 | 24 | class CMakeBuild(build_ext): 25 | def run(self): 26 | try: 27 | out = subprocess.check_output(["cmake", "--version"]) 28 | except OSError: 29 | raise RuntimeError( 30 | "CMake must be installed to build the following extensions: " 31 | + ", ".join(e.name for e in self.extensions) 32 | ) 33 | 34 | for ext in self.extensions: 35 | self.build_extension(ext) 36 | 37 | def build_extension(self, ext): 38 | ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") 39 | extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) 40 | os.makedirs(extdir, exist_ok=True) 41 | 42 | acceleration = os.environ.get("SIMPLER_WHISPER_ACCELERATION", "cpu") 43 | 44 | cmake_args = [ 45 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", 46 | f"-DPYTHON_EXTENSION_SUFFIX={ext_suffix}", 47 | f"-DACCELERATION={acceleration}", 48 | f"-DPYTHON_EXECUTABLE={sys.executable}", 49 | ] 50 | 51 | env = os.environ.copy() 52 | 53 | if platform.system() == "Darwin": 54 | cmake_args += [ 55 | "-DCMAKE_OSX_ARCHITECTURES=arm64;x86_64", 56 | "-DCMAKE_OSX_DEPLOYMENT_TARGET=10.13", 57 | "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON", 58 | "-DCMAKE_INSTALL_RPATH_USE_LINK_PATH=ON", 59 | f"-DCMAKE_INSTALL_NAME_DIR=@rpath", 60 | ] 61 | # Set environment variables for universal build 62 | env["MACOSX_DEPLOYMENT_TARGET"] = "10.13" 63 | env["_PYTHON_HOST_PLATFORM"] = "macosx-10.13-universal2" 64 | 65 | # Remove any existing arch flags that might interfere 66 | if "ARCHFLAGS" in env: 67 | del env["ARCHFLAGS"] 68 | if "MACOS_ARCH" in env: 69 | del env["MACOS_ARCH"] 70 | 71 | cfg = "Debug" if self.debug else "Release" 72 | build_args = ["--config", cfg] 73 | 74 | if platform.system() == "Windows": 75 | cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] 76 | if sys.maxsize > 2**32: 77 | cmake_args += ["-A", "x64"] 78 | build_args += ["--", "/m"] 79 | else: 80 | cmake_args += [f"-DCMAKE_BUILD_TYPE={cfg}"] 81 | build_args += ["--", "-j2"] 82 | 83 | env["CXXFLAGS"] = ( 84 | f'{env.get("CXXFLAGS", "")} -DVERSION_INFO=\\"{self.distribution.get_version()}\\"' 85 | ) 86 | 87 | if not os.path.exists(self.build_temp): 88 | os.makedirs(self.build_temp) 89 | else: 90 | # Remove the existing CMakeCache.txt to ensure a clean build 91 | cache_file = os.path.join(self.build_temp, "CMakeCache.txt") 92 | if os.path.exists(cache_file): 93 | os.remove(cache_file) 94 | 95 | # Configure and build the extension 96 | subprocess.check_call( 97 | ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env 98 | ) 99 | subprocess.check_call( 100 | ["cmake", "--build", "."] + build_args, cwd=self.build_temp 101 | ) 102 | 103 | 104 | def get_latest_git_tag(): 105 | tag = os.environ.get("SIMPLER_WHISPER_VERSION") 106 | if not tag: 107 | try: 108 | tag = subprocess.check_output( 109 | ["git", "describe", "--tags"], encoding="utf-8" 110 | ).strip() 111 | except subprocess.CalledProcessError: 112 | return "0.0.0-dev" 113 | parts = tag.split("-") 114 | if len(parts) == 3: 115 | return f"{parts[0]}-dev{parts[1]}" 116 | return tag 117 | 118 | 119 | setup( 120 | name="simpler-whisper", 121 | version=get_latest_git_tag(), 122 | author="Roy Shilkrot", 123 | author_email="roy.shil@gmail.com", 124 | description="A simple Python wrapper for whisper.cpp", 125 | long_description=open("README.md").read(), 126 | long_description_content_type="text/markdown", 127 | ext_modules=[CMakeExtension("simpler_whisper._whisper_cpp")], 128 | cmdclass={ 129 | "build_ext": CMakeBuild, 130 | "build_py": BuildPyCommand, 131 | }, 132 | zip_safe=False, 133 | packages=["simpler_whisper"], 134 | python_requires=">=3.10", 135 | install_requires=[ 136 | "numpy", 137 | ], 138 | package_data={ 139 | "simpler_whisper": [ 140 | "./*.dll", 141 | "./*.pyd", 142 | "./*.so", 143 | "./*.metal", 144 | "./*.bin", 145 | "./*.dylib", 146 | ], 147 | }, 148 | include_package_data=True, 149 | ) 150 | -------------------------------------------------------------------------------- /tests/test_wrapper.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import unittest 3 | import numpy as np 4 | import threading 5 | import time 6 | import queue 7 | import os 8 | from concurrent.futures import ThreadPoolExecutor 9 | from simpler_whisper import ( 10 | WhisperModel, 11 | AsyncWhisperModel, 12 | ThreadedWhisperModel, 13 | set_log_callback, 14 | LogLevel, 15 | ) 16 | 17 | 18 | class TestWhisperWrapper(unittest.TestCase): 19 | @classmethod 20 | def setUpClass(cls): 21 | # download the model from https://ggml.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin 22 | # and place it in the project root 23 | url = "https://ggml.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin" 24 | model_path = os.path.join( 25 | os.path.dirname(os.path.dirname(__file__)), "ggml-tiny.en-q5_1.bin" 26 | ) 27 | if not os.path.exists(model_path): 28 | import requests 29 | 30 | print(f"Downloading model from {url}...") 31 | response = requests.get(url) 32 | with open(model_path, "wb") as f: 33 | f.write(response.content) 34 | print(f"Model downloaded to {model_path}") 35 | 36 | # Get the model path relative to the project root 37 | cls.model_path = model_path 38 | 39 | # Verify model exists 40 | if not os.path.exists(cls.model_path): 41 | raise FileNotFoundError(f"Model file not found at {cls.model_path}") 42 | 43 | # Create sample audio data (silence) 44 | cls.sample_rate = 16000 45 | duration_sec = 3 46 | cls.test_audio = np.zeros(cls.sample_rate * duration_sec, dtype=np.float32) 47 | 48 | # Create some mock audio with varying amplitudes for better testing 49 | cls.mock_speech = np.sin( 50 | 2 * np.pi * 440 * np.linspace(0, 1, cls.sample_rate) 51 | ).astype(np.float32) 52 | 53 | def test_sync_model_basic(self): 54 | """Test basic synchronous model initialization and transcription""" 55 | try: 56 | model = WhisperModel(self.model_path, False) 57 | result = model.transcribe(self.test_audio) 58 | self.assertIsInstance(result, list) 59 | except Exception as e: 60 | self.fail(f"Basic synchronous model test failed: {str(e)}") 61 | 62 | def test_sync_model_empty_audio(self): 63 | """Test synchronous model with empty audio""" 64 | model = WhisperModel(self.model_path, False) 65 | empty_audio = np.array([], dtype=np.float32) 66 | response = model.transcribe(empty_audio) 67 | self.assertEqual(response, []) 68 | 69 | # def test_sync_model_invalid_audio(self): 70 | # """Test synchronous model with invalid audio data""" 71 | # model = WhisperModel(self.model_path, False) 72 | # invalid_audio = np.array([1.5, -1.5], dtype=np.float64) # Wrong dtype 73 | # with self.assertRaises(Exception): 74 | # model.transcribe(invalid_audio) 75 | 76 | # def test_async_model_basic(self): 77 | # """Test basic async model functionality""" 78 | # results = queue.Queue() 79 | 80 | # def callback(chunk_id, segments, is_partial): 81 | # results.put((chunk_id, segments, is_partial)) 82 | 83 | # model = whisper.AsyncWhisperModel(self.model_path, False) 84 | # try: 85 | # model.start(callback) 86 | # chunk_id = model.transcribe(self.test_audio) 87 | 88 | # # Wait for result with timeout 89 | # try: 90 | # result = results.get(timeout=10) 91 | # self.assertEqual(result[0], chunk_id) # Check if chunk_id matches 92 | # except queue.Empty: 93 | # self.fail("Async transcription timeout") 94 | 95 | # finally: 96 | # model.stop() 97 | 98 | # def test_threaded_model_basic(self): 99 | # """Test basic threaded model functionality""" 100 | # results = queue.Queue() 101 | 102 | # def callback(chunk_id, segments, is_partial): 103 | # results.put((chunk_id, segments, is_partial)) 104 | 105 | # model = whisper.ThreadedWhisperModel( 106 | # self.model_path, 107 | # False, 108 | # max_duration_sec=5.0, 109 | # sample_rate=self.sample_rate, 110 | # ) 111 | 112 | # try: 113 | # model.start(callback) 114 | # chunk_id = model.queue_audio(self.mock_speech) 115 | 116 | # # Wait for result with timeout 117 | # try: 118 | # result = results.get(timeout=10) 119 | # self.assertEqual(result[0], chunk_id) 120 | # except queue.Empty: 121 | # self.fail("Threaded transcription timeout") 122 | # finally: 123 | # model.stop() 124 | 125 | # def test_threaded_model_continuous(self): 126 | # """Test threaded model with continuous audio chunks""" 127 | # results = [] 128 | # result_lock = threading.Lock() 129 | 130 | # def callback(chunk_id, segments, is_partial): 131 | # with result_lock: 132 | # results.append((chunk_id, segments, is_partial)) 133 | 134 | # model = whisper.ThreadedWhisperModel( 135 | # self.model_path, 136 | # False, 137 | # max_duration_sec=1.0, 138 | # sample_rate=self.sample_rate, 139 | # ) 140 | 141 | # try: 142 | # model.start(callback) 143 | 144 | # # Queue multiple chunks of audio 145 | # chunk_size = self.sample_rate # 1 second chunks 146 | # num_chunks = 3 147 | # chunk_ids = [] 148 | 149 | # for i in range(num_chunks): 150 | # chunk = self.mock_speech[i * chunk_size : (i + 1) * chunk_size] 151 | # chunk_id = model.queue_audio(chunk) 152 | # chunk_ids.append(chunk_id) 153 | # time.sleep(0.1) # Small delay between chunks 154 | 155 | # # Wait for all results 156 | # max_wait = 15 # seconds 157 | # start_time = time.time() 158 | # while len(results) < num_chunks and (time.time() - start_time) < max_wait: 159 | # time.sleep(0.1) 160 | 161 | # self.assertGreaterEqual(len(results), num_chunks) 162 | 163 | # finally: 164 | # model.stop() 165 | 166 | def test_log_callback(self): 167 | """Test log callback functionality""" 168 | log_messages = queue.Queue() 169 | 170 | def log_callback(level, message): 171 | log_messages.put((level, message)) 172 | 173 | try: 174 | # Set the log callback 175 | set_log_callback(log_callback) 176 | 177 | # Create a minimal audio sample 178 | sample_rate = 16000 179 | duration_sec = 1 180 | test_audio = np.zeros(sample_rate * duration_sec, dtype=np.float32) 181 | 182 | # Create a model and do minimal transcription 183 | model = WhisperModel(self.model_path, use_gpu=False) 184 | try: 185 | model.transcribe(test_audio) 186 | except Exception as e: 187 | print(f"Transcription failed but continuing: {e}") 188 | 189 | # Check if we received any log messages 190 | try: 191 | log_message = log_messages.get_nowait() 192 | self.assertIsInstance(log_message, tuple) 193 | self.assertIsInstance(log_message[0], int) # level 194 | self.assertIsInstance(log_message[1], str) # message 195 | except queue.Empty: 196 | self.skipTest("No log messages received") 197 | 198 | except Exception as e: 199 | self.skipTest(f"Log callback test failed: {e}") 200 | finally: 201 | # Reset log callback 202 | try: 203 | set_log_callback(None) 204 | except: 205 | pass 206 | 207 | 208 | if __name__ == "__main__": 209 | unittest.main() 210 | -------------------------------------------------------------------------------- /test_simpler_whisper.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import av 3 | import argparse 4 | import sys 5 | import numpy as np 6 | import time 7 | import resampy 8 | import librosa 9 | 10 | # Remove the current directory from sys.path to avoid conflicts with the installed package 11 | sys.path.pop(0) 12 | 13 | from simpler_whisper.whisper import ( 14 | WhisperSegment, 15 | set_log_callback, 16 | LogLevel, 17 | WhisperModel, 18 | ThreadedWhisperModel, 19 | AsyncWhisperModel, 20 | ) 21 | 22 | 23 | log_levels = {LogLevel.ERROR: "ERROR", LogLevel.WARN: "WARN", LogLevel.INFO: "INFO"} 24 | 25 | 26 | def my_log_callback(level, message): 27 | if message is not None and len(message.strip()) > 0: 28 | print(f"whisper.cpp [{log_levels.get(level, 'UNKNOWN')}] {message.strip()}") 29 | 30 | 31 | # Path to your Whisper model file 32 | # Parse command-line arguments 33 | parser = argparse.ArgumentParser(description="Test simpler-whisper model.") 34 | parser.add_argument("model_path", type=str, help="Path to the Whisper model file") 35 | parser.add_argument("audio_file", type=str, help="Path to the audio file") 36 | # non-positoinal required arg for the method to use (regular vs threaded) 37 | parser.add_argument( 38 | "method", 39 | type=str, 40 | choices=["regular", "threaded", "async"], 41 | help="The method to use for testing the model", 42 | ) 43 | args = parser.parse_args() 44 | 45 | model_path = args.model_path 46 | audio_file = args.audio_file 47 | 48 | 49 | def get_samples_from_frame(frame: av.AudioFrame) -> np.ndarray: 50 | """ 51 | Extracts and processes audio samples from an audio frame. 52 | This function reads an audio chunk from the provided audio frame, converts it to mono if it is stereo, 53 | normalizes the audio if it is in int16 format, and resamples it to 16kHz if necessary. 54 | Parameters: 55 | frame (av.AudioFrame): The input audio frame containing the audio data. 56 | Returns: 57 | numpy.ndarray: The processed audio samples, normalized and resampled to 16kHz if needed. 58 | """ 59 | # Read audio chunk 60 | incoming_audio = frame.to_ndarray() 61 | # check if stereo 62 | if incoming_audio.shape[0] == 2: 63 | incoming_audio = incoming_audio.mean(axis=0) 64 | # check if the type is int16 or float32 65 | if incoming_audio.dtype == np.int16: 66 | incoming_audio = incoming_audio / 32768.0 # normalize to [-1, 1] 67 | if incoming_audio.dtype == np.int32: 68 | incoming_audio = incoming_audio / 2147483648.0 # normalize to [-1, 1] 69 | # resample to 16kHz if needed 70 | if frame.rate != 16000: 71 | samples = resampy.resample(incoming_audio, frame.rate, 16000) 72 | else: 73 | samples = incoming_audio 74 | 75 | return samples 76 | 77 | 78 | def test_simpler_whisper(): 79 | set_log_callback(my_log_callback) 80 | 81 | # Load the model 82 | print("Loading the Whisper model...") 83 | model = WhisperModel(model_path, use_gpu=True) 84 | print("Model loaded successfully!") 85 | 86 | # Load audio from file with av 87 | container = av.open(audio_file) 88 | audio = container.streams.audio[0] 89 | print(audio) 90 | 91 | frame_generator = container.decode(audio) 92 | 93 | # Run transcription 94 | print("Running transcription...") 95 | run_times = [] 96 | samples_for_transcription = np.array([]) 97 | for i, frame in enumerate(frame_generator): 98 | samples = get_samples_from_frame(frame) 99 | # append the samples to the samples_for_transcription 100 | samples_for_transcription = np.append(samples_for_transcription, samples) 101 | 102 | # if there are less than 30 seconds of audio, append the samples and continue to the next frame 103 | if len(samples_for_transcription) < 16000 * 30: 104 | continue 105 | 106 | start_time = time.time() 107 | transcription = model.transcribe(samples_for_transcription) 108 | end_time = time.time() 109 | elapsed_time = end_time - start_time 110 | run_times.append(elapsed_time) 111 | print(f"Run {i + 1}: Transcription took {elapsed_time:.3f} seconds.") 112 | for segment in transcription: 113 | for j, tok in enumerate(segment.tokens): 114 | print(f"Token {j}: {tok.text} ({tok.t0:.3f} - {tok.t1:.3f})") 115 | # reset the samples_for_transcription 116 | samples_for_transcription = np.array([]) 117 | 118 | avg_time = np.mean(run_times) 119 | min_time = np.min(run_times) 120 | max_time = np.max(run_times) 121 | 122 | print(f"\nStatistics over runs:") 123 | print(f"Average time: {avg_time:.3f} seconds") 124 | print(f"Minimum time: {min_time:.3f} seconds") 125 | print(f"Maximum time: {max_time:.3f} seconds") 126 | 127 | print("Transcription completed.") 128 | 129 | 130 | def test_async_whisper(): 131 | set_log_callback(my_log_callback) 132 | chunk_ids = [] 133 | 134 | def handle_result(chunk_id: int, segments: List[WhisperSegment], is_partial: bool): 135 | text = " ".join([seg.text for seg in segments]) 136 | print( 137 | f"Chunk {chunk_id} results ({'partial' if is_partial else 'final'}): {text}" 138 | ) 139 | # remove the chunk_id from the list of chunk_ids 140 | chunk_ids.remove(chunk_id) 141 | 142 | # Create model 143 | model = AsyncWhisperModel( 144 | model_path=model_path, callback=handle_result, use_gpu=True 145 | ) 146 | 147 | print("Loading audio from file...") 148 | # Load audio from file with librosa 149 | audio_data, sample_rate = librosa.load(audio_file, sr=16000) 150 | 151 | # Start processing with callback 152 | print("Starting Whisper model") 153 | model.start() 154 | 155 | # create 30-seconds chunks of audio_data 156 | for i in range(0, len(audio_data), 16000 * 30): 157 | try: 158 | samples_for_transcription = audio_data[i : i + 16000 * 30] 159 | 160 | # Queue the chunk for processing 161 | chunk_id = model.transcribe(samples_for_transcription) 162 | chunk_ids.append(chunk_id) 163 | print(f"Queued chunk {chunk_id}") 164 | 165 | # reset 166 | samples_for_transcription = np.array([]) 167 | except: 168 | break 169 | 170 | # wait for all chunks to finish processing 171 | while len(chunk_ids) > 0: 172 | try: 173 | time.sleep(0.1) 174 | except: 175 | break 176 | 177 | # When done 178 | print("Stopping Whisper model") 179 | model.stop() 180 | 181 | 182 | def test_threaded_whisper(): 183 | set_log_callback(my_log_callback) 184 | 185 | def handle_result(chunk_id: int, segments: List[WhisperSegment], is_partial: bool): 186 | text = " ".join([seg.text for seg in segments]) 187 | print( 188 | f"Chunk {chunk_id} results ({'partial' if is_partial else 'final'}): {text}" 189 | ) 190 | 191 | # Create model with 10-second max duration 192 | model = ThreadedWhisperModel( 193 | model_path=model_path, 194 | callback=handle_result, 195 | use_gpu=True, 196 | max_duration_sec=10.0, 197 | ) 198 | 199 | # load audio from file with av 200 | container = av.open(audio_file) 201 | audio = container.streams.audio[0] 202 | print(audio) 203 | frame_generator = container.decode(audio) 204 | 205 | # Start processing with callback 206 | print("Starting threaded Whisper model...") 207 | model.start() 208 | 209 | for i, frame in enumerate(frame_generator): 210 | try: 211 | samples = get_samples_from_frame(frame) 212 | 213 | # Queue some audio (will get partial results until 10 seconds accumulate) 214 | chunk_id = model.queue_audio(samples) 215 | # sleep for the size of the audio chunk 216 | time.sleep(float(len(samples)) / float(16000)) 217 | except: 218 | break 219 | 220 | # close the container 221 | container.close() 222 | 223 | # When done 224 | print("Stopping threaded Whisper model...") 225 | model.stop() # Will process any remaining audio as final 226 | 227 | 228 | if __name__ == "__main__": 229 | if args.method == "regular": 230 | test_simpler_whisper() 231 | elif args.method == "async": 232 | test_async_whisper() 233 | else: 234 | test_threaded_whisper() 235 | -------------------------------------------------------------------------------- /simpler_whisper/whisper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Callable, List, Union 3 | from . import _whisper_cpp 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class WhisperToken: 9 | """A token from the Whisper model with timing and probability information.""" 10 | 11 | id: int 12 | p: float 13 | t0: int # Start time in milliseconds 14 | t1: int # End time in milliseconds 15 | text: str 16 | 17 | 18 | @dataclass 19 | class WhisperSegment: 20 | """A segment of transcribed text with timing information and token details.""" 21 | 22 | text: str 23 | start: int # Start time in milliseconds 24 | end: int # End time in milliseconds 25 | tokens: List[WhisperToken] 26 | 27 | 28 | class WhisperModel: 29 | def __init__(self, model_path: str, use_gpu=False): 30 | self.model = _whisper_cpp.WhisperModel(model_path, use_gpu) 31 | 32 | def transcribe(self, audio: Union[np.ndarray, List[float]]) -> List[WhisperSegment]: 33 | # Ensure audio is a numpy array of float32 34 | audio = np.array(audio, dtype=np.float32) 35 | 36 | # Run inference 37 | transcription = self.model.transcribe(audio) 38 | 39 | return transcription 40 | 41 | def __del__(self): 42 | # Explicitly delete the C++ object 43 | if hasattr(self, "model"): 44 | del self.model 45 | 46 | 47 | class AsyncWhisperModel: 48 | """ 49 | AsyncWhisperModel is a class that provides asynchronous transcription of audio data using a Whisper model. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | model_path: str, 55 | callback: Callable[[int, List[WhisperSegment], bool], None], 56 | use_gpu=False, 57 | ): 58 | self.model = _whisper_cpp.AsyncWhisperModel(model_path, use_gpu) 59 | self._is_running = False 60 | self.callback = callback 61 | 62 | def transcribe(self, audio: Union[np.ndarray, List[float]]) -> int: 63 | """ 64 | Transcribes the given audio input using the model. 65 | Args: 66 | audio (Union[np.ndarray, List[float]]): The audio data to be transcribed. 67 | It can be either a numpy array or a list of floats. 68 | Returns: 69 | int: The queued chunk ID. 70 | """ 71 | # Ensure audio is a numpy array of float32 72 | audio = np.array(audio, dtype=np.float32) 73 | 74 | # Run async inference (no return value) 75 | return self.model.transcribe(audio) 76 | 77 | def handle_result( 78 | self, chunk_id: int, segments: List[WhisperSegment], is_partial: bool 79 | ): 80 | if self.callback is not None: 81 | self.callback(chunk_id, segments, is_partial) 82 | 83 | def start(self, result_check_interval_ms=100): 84 | """ 85 | Start the processing threads with a callback for results. 86 | 87 | Args: 88 | callback: Function that takes three arguments: 89 | - chunk_id (int): Unique identifier for the audio chunk 90 | - segments (WhisperSegment): Transcribed text for the audio chunk 91 | - is_partial (bool): Whether this is a partial result 92 | result_check_interval_ms (int): How often to check for results 93 | """ 94 | if self._is_running: 95 | return 96 | 97 | self.model.start(self.handle_result, result_check_interval_ms) 98 | self._is_running = True 99 | 100 | def stop(self): 101 | """ 102 | Stop processing and clean up resources. 103 | Any remaining audio will be processed as a final segment. 104 | """ 105 | if not self._is_running: 106 | return 107 | 108 | self.model.stop() 109 | self._is_running = False 110 | 111 | def __del__(self): 112 | # Explicitly delete the C++ object 113 | if hasattr(self, "model"): 114 | if self._is_running: 115 | self.stop() 116 | self._is_running = False 117 | del self.model 118 | 119 | 120 | class ThreadedWhisperModel: 121 | def __init__( 122 | self, 123 | model_path: str, 124 | callback: Callable[[int, List[WhisperSegment], bool], None], 125 | use_gpu=False, 126 | max_duration_sec=10.0, 127 | sample_rate=16000, 128 | ): 129 | """ 130 | Initialize a threaded Whisper model for continuous audio processing. 131 | 132 | Args: 133 | model_path (str): Path to the Whisper model file 134 | use_gpu (bool): Whether to use GPU acceleration 135 | max_duration_sec (float): Maximum duration in seconds before finalizing a segment 136 | sample_rate (int): Audio sample rate (default: 16000) 137 | callback: Function that takes three arguments: 138 | - chunk_id (int): Unique identifier for the audio chunk 139 | - segments (List[WhisperSegment]): Transcribed text for the audio chunk 140 | - is_partial (bool): Whether this is a partial result 141 | """ 142 | self.model = _whisper_cpp.ThreadedWhisperModel( 143 | model_path, use_gpu, max_duration_sec, sample_rate 144 | ) 145 | self._is_running = False 146 | self.callback = callback 147 | 148 | def handle_result( 149 | self, chunk_id: int, segments: List[WhisperSegment], is_partial: bool 150 | ): 151 | if self.callback is not None: 152 | self.callback(chunk_id, segments, is_partial) 153 | 154 | def start(self, result_check_interval_ms=100): 155 | """ 156 | Start the processing threads with a callback for results. 157 | 158 | Args: 159 | callback: Function that takes three arguments: 160 | - chunk_id (int): Unique identifier for the audio chunk 161 | - segments (WhisperSegment): Transcribed text for the audio chunk 162 | - is_partial (bool): Whether this is a partial result 163 | result_check_interval_ms (int): How often to check for results 164 | """ 165 | if self._is_running: 166 | return 167 | 168 | self.model.start(self.handle_result, result_check_interval_ms) 169 | self._is_running = True 170 | 171 | def stop(self): 172 | """ 173 | Stop processing and clean up resources. 174 | Any remaining audio will be processed as a final segment. 175 | """ 176 | if not self._is_running: 177 | return 178 | 179 | self.model.stop() 180 | self._is_running = False 181 | 182 | def queue_audio(self, audio): 183 | """ 184 | Queue audio for processing. 185 | 186 | Args: 187 | audio: Audio samples as numpy array or array-like object. 188 | Will be converted to float32. 189 | 190 | Returns: 191 | chunk_id (int): Unique identifier for this audio chunk 192 | """ 193 | # Ensure audio is a numpy array of float32 194 | audio = np.array(audio, dtype=np.float32) 195 | return self.model.queue_audio(audio) 196 | 197 | def set_max_duration(self, max_duration_sec, sample_rate=16000): 198 | """ 199 | Change the maximum duration for partial segments. 200 | 201 | Args: 202 | max_duration_sec (float): New maximum duration in seconds 203 | sample_rate (int): Audio sample rate (default: 16000) 204 | """ 205 | self.model.set_max_duration(max_duration_sec, sample_rate) 206 | 207 | def __del__(self): 208 | # Ensure threads are stopped and resources cleaned up 209 | if hasattr(self, "model"): 210 | if self._is_running: 211 | self.stop() 212 | del self.model 213 | 214 | 215 | def set_log_callback(callback): 216 | """ 217 | Set a custom logging callback function. 218 | 219 | The callback function should accept two arguments: 220 | - level: An integer representing the log level (use LogLevel enum for interpretation) 221 | - message: A string containing the log message 222 | 223 | Example: 224 | def my_log_callback(level, message): 225 | print(f"[{LogLevel(level).name}] {message}") 226 | 227 | set_log_callback(my_log_callback) 228 | """ 229 | _whisper_cpp.set_log_callback(callback) 230 | 231 | 232 | # Expose LogLevel enum from C++ module 233 | LogLevel = _whisper_cpp.LogLevel 234 | -------------------------------------------------------------------------------- /cmake/BuildWhispercpp.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | include(FetchContent) 3 | 4 | if(UNIX AND NOT APPLE) 5 | find_package(OpenMP REQUIRED) 6 | # Set compiler flags for OpenMP 7 | set(WHISPER_EXTRA_CXX_FLAGS "${WHISPER_EXTRA_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 8 | set(WHISPER_EXTRA_C_FLAGS "${WHISPER_EXTRA_CXX_FLAGS} ${OpenMP_C_FLAGS}") 9 | endif() 10 | 11 | set(PREBUILT_WHISPERCPP_VERSION "0.0.7") 12 | set(PREBUILT_WHISPERCPP_URL_BASE 13 | "https://github.com/locaal-ai/occ-ai-dep-whispercpp/releases/download/${PREBUILT_WHISPERCPP_VERSION}" 14 | ) 15 | 16 | if(APPLE) 17 | # Store source directories for each architecture 18 | foreach(MACOS_ARCH IN ITEMS "x86_64" "arm64") 19 | if(${MACOS_ARCH} STREQUAL "x86_64") 20 | set(WHISPER_CPP_HASH 21 | "dc7fd5ff9c7fbb8623f8e14d9ff2872186cab4cd7a52066fcb2fab790d6092fc") 22 | elseif(${MACOS_ARCH} STREQUAL "arm64") 23 | set(WHISPER_CPP_HASH 24 | "ebed595ee431b182261bce41583993b149eed539e15ebf770d98a6bc85d53a92") 25 | endif() 26 | 27 | set(WHISPER_CPP_URL 28 | "${PREBUILT_WHISPERCPP_URL_BASE}/whispercpp-macos-${MACOS_ARCH}-${PREBUILT_WHISPERCPP_VERSION}.tar.gz" 29 | ) 30 | 31 | # Use unique names for each architecture's fetch 32 | FetchContent_Declare( 33 | whispercpp_fetch_${MACOS_ARCH} 34 | DOWNLOAD_EXTRACT_TIMESTAMP TRUE 35 | URL ${WHISPER_CPP_URL} 36 | URL_HASH SHA256=${WHISPER_CPP_HASH}) 37 | FetchContent_MakeAvailable(whispercpp_fetch_${MACOS_ARCH}) 38 | 39 | # Store the source dir for each arch 40 | if(${MACOS_ARCH} STREQUAL "x86_64") 41 | set(WHISPER_X86_64_DIR ${whispercpp_fetch_x86_64_SOURCE_DIR}) 42 | else() 43 | set(WHISPER_ARM64_DIR ${whispercpp_fetch_arm64_SOURCE_DIR}) 44 | endif() 45 | endforeach() 46 | 47 | # Create a directory for the universal binaries 48 | set(UNIVERSAL_LIB_DIR ${CMAKE_BINARY_DIR}/universal/lib) 49 | file(MAKE_DIRECTORY ${UNIVERSAL_LIB_DIR}) 50 | 51 | # Create universal binaries using lipo 52 | execute_process( 53 | COMMAND 54 | lipo -create "${WHISPER_X86_64_DIR}/lib/libwhisper.a" 55 | "${WHISPER_ARM64_DIR}/lib/libwhisper.a" -output 56 | "${UNIVERSAL_LIB_DIR}/libwhisper.a") 57 | 58 | execute_process( 59 | COMMAND 60 | lipo -create "${WHISPER_X86_64_DIR}/lib/libggml.a" 61 | "${WHISPER_ARM64_DIR}/lib/libggml.a" -output 62 | "${UNIVERSAL_LIB_DIR}/libggml.a") 63 | 64 | execute_process( 65 | COMMAND 66 | lipo -create "${WHISPER_X86_64_DIR}/lib/libwhisper.coreml.a" 67 | "${WHISPER_ARM64_DIR}/lib/libwhisper.coreml.a" -output 68 | "${UNIVERSAL_LIB_DIR}/libwhisper.coreml.a") 69 | 70 | # Set up the imported libraries to use the universal binaries 71 | add_library(Whispercpp::Whisper STATIC IMPORTED) 72 | set_target_properties( 73 | Whispercpp::Whisper PROPERTIES IMPORTED_LOCATION 74 | "${UNIVERSAL_LIB_DIR}/libwhisper.a") 75 | set_target_properties( 76 | Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES 77 | ${WHISPER_ARM64_DIR}/include) # Either arch's 78 | # include dir 79 | # is fine 80 | 81 | add_library(Whispercpp::GGML STATIC IMPORTED) 82 | set_target_properties( 83 | Whispercpp::GGML PROPERTIES IMPORTED_LOCATION 84 | "${UNIVERSAL_LIB_DIR}/libggml.a") 85 | 86 | add_library(Whispercpp::CoreML STATIC IMPORTED) 87 | set_target_properties( 88 | Whispercpp::CoreML PROPERTIES IMPORTED_LOCATION 89 | "${UNIVERSAL_LIB_DIR}/libwhisper.coreml.a") 90 | 91 | # Copy the metal file from either architecture (they should be identical) 92 | set(WHISPER_ADDITIONAL_FILES ${WHISPER_ARM64_DIR}/bin/ggml-metal.metal) 93 | elseif(WIN32) 94 | if(NOT DEFINED ACCELERATION) 95 | message( 96 | FATAL_ERROR 97 | "ACCELERATION is not set. Please set it to either `cpu`, `cuda`, `vulkan` or `hipblas`" 98 | ) 99 | endif() 100 | 101 | set(ARCH_PREFIX ${ACCELERATION}) 102 | set(WHISPER_CPP_URL 103 | "${PREBUILT_WHISPERCPP_URL_BASE}/whispercpp-windows-${ARCH_PREFIX}-${PREBUILT_WHISPERCPP_VERSION}.zip" 104 | ) 105 | if(${ACCELERATION} STREQUAL "cpu") 106 | set(WHISPER_CPP_HASH 107 | "c23862b4aac7d8448cf7de4d339a86498f88ecba6fa7d243bbd7fabdb13d4dd4") 108 | add_compile_definitions("LOCALVOCAL_WITH_CPU") 109 | elseif(${ACCELERATION} STREQUAL "cuda") 110 | set(WHISPER_CPP_HASH 111 | "a0adeaccae76fab0678d016a62b79a19661ed34eb810d8bae3b610345ee9a405") 112 | add_compile_definitions("LOCALVOCAL_WITH_CUDA") 113 | elseif(${ACCELERATION} STREQUAL "hipblas") 114 | set(WHISPER_CPP_HASH 115 | "bbad0b4eec01c5a801d384c03745ef5e97061958f8cf8f7724281d433d7d92a1") 116 | add_compile_definitions("LOCALVOCAL_WITH_HIPBLAS") 117 | elseif(${ACCELERATION} STREQUAL "vulkan") 118 | set(WHISPER_CPP_HASH 119 | "12bb34821f9efcd31f04a487569abff2b669221f2706fe0d09c17883635ef58a") 120 | add_compile_definitions("LOCALVOCAL_WITH_VULKAN") 121 | else() 122 | message( 123 | FATAL_ERROR 124 | "The ACCELERATION environment variable is not set to a valid value. Please set it to either `cpu` or `cuda` or `vulkan` or `hipblas`" 125 | ) 126 | endif() 127 | 128 | FetchContent_Declare( 129 | whispercpp_fetch 130 | URL ${WHISPER_CPP_URL} 131 | URL_HASH SHA256=${WHISPER_CPP_HASH} 132 | DOWNLOAD_EXTRACT_TIMESTAMP TRUE) 133 | FetchContent_MakeAvailable(whispercpp_fetch) 134 | 135 | add_library(Whispercpp::Whisper SHARED IMPORTED) 136 | set_target_properties( 137 | Whispercpp::Whisper 138 | PROPERTIES 139 | IMPORTED_LOCATION 140 | ${whispercpp_fetch_SOURCE_DIR}/bin/${CMAKE_SHARED_LIBRARY_PREFIX}whisper${CMAKE_SHARED_LIBRARY_SUFFIX} 141 | ) 142 | set_target_properties( 143 | Whispercpp::Whisper 144 | PROPERTIES 145 | IMPORTED_IMPLIB 146 | ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX} 147 | ) 148 | set_target_properties( 149 | Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES 150 | ${whispercpp_fetch_SOURCE_DIR}/include) 151 | add_library(Whispercpp::GGML SHARED IMPORTED) 152 | set_target_properties( 153 | Whispercpp::GGML 154 | PROPERTIES 155 | IMPORTED_LOCATION 156 | ${whispercpp_fetch_SOURCE_DIR}/bin/${CMAKE_SHARED_LIBRARY_PREFIX}ggml${CMAKE_SHARED_LIBRARY_SUFFIX} 157 | ) 158 | set_target_properties( 159 | Whispercpp::GGML 160 | PROPERTIES 161 | IMPORTED_IMPLIB 162 | ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}ggml${CMAKE_STATIC_LIBRARY_SUFFIX} 163 | ) 164 | 165 | if(${ACCELERATION} STREQUAL "cpu") 166 | # add openblas to the link line 167 | add_library(Whispercpp::OpenBLAS STATIC IMPORTED) 168 | set_target_properties( 169 | Whispercpp::OpenBLAS 170 | PROPERTIES IMPORTED_LOCATION 171 | ${whispercpp_fetch_SOURCE_DIR}/lib/libopenblas.dll.a) 172 | endif() 173 | 174 | # glob all dlls in the bin directory and install them 175 | file(GLOB WHISPER_ADDITIONAL_FILES ${whispercpp_fetch_SOURCE_DIR}/bin/*.dll) 176 | list(FILTER WHISPER_ADDITIONAL_FILES EXCLUDE REGEX "^.*/cu.*\\.dll$") 177 | else() 178 | if(${CMAKE_BUILD_TYPE} STREQUAL Release OR ${CMAKE_BUILD_TYPE} STREQUAL 179 | RelWithDebInfo) 180 | set(Whispercpp_BUILD_TYPE Release) 181 | else() 182 | set(Whispercpp_BUILD_TYPE Debug) 183 | endif() 184 | set(Whispercpp_Build_GIT_TAG "v1.7.1") 185 | set(WHISPER_EXTRA_CXX_FLAGS "-fPIC") 186 | 187 | find_package(OpenMP REQUIRED) 188 | # Set compiler flags for OpenMP 189 | set(WHISPER_EXTRA_CXX_FLAGS "${WHISPER_EXTRA_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 190 | set(WHISPER_EXTRA_C_FLAGS "${WHISPER_EXTRA_CXX_FLAGS} ${OpenMP_C_FLAGS}") 191 | 192 | # On Linux build a static Whisper library 193 | ExternalProject_Add( 194 | Whispercpp_Build 195 | DOWNLOAD_EXTRACT_TIMESTAMP true 196 | GIT_REPOSITORY https://github.com/ggerganov/whisper.cpp.git 197 | GIT_TAG ${Whispercpp_Build_GIT_TAG} 198 | BUILD_COMMAND ${CMAKE_COMMAND} --build --config 199 | ${Whispercpp_BUILD_TYPE} 200 | BUILD_BYPRODUCTS 201 | /lib64/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX} 202 | /lib64/${CMAKE_STATIC_LIBRARY_PREFIX}ggml${CMAKE_STATIC_LIBRARY_SUFFIX} 203 | CMAKE_GENERATOR ${CMAKE_GENERATOR} 204 | INSTALL_COMMAND 205 | ${CMAKE_COMMAND} --install --config ${Whispercpp_BUILD_TYPE} 206 | && ${CMAKE_COMMAND} -E copy /ggml/include/ggml.h 207 | /include 208 | CONFIGURE_COMMAND 209 | ${CMAKE_COMMAND} -E env ${WHISPER_ADDITIONAL_ENV} ${CMAKE_COMMAND} 210 | -B -G ${CMAKE_GENERATOR} 211 | -DCMAKE_INSTALL_PREFIX= 212 | -DCMAKE_BUILD_TYPE=${Whispercpp_BUILD_TYPE} 213 | -DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM} 214 | -DCMAKE_OSX_DEPLOYMENT_TARGET=10.13 215 | -DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES_} 216 | -DCMAKE_CXX_FLAGS=${WHISPER_EXTRA_CXX_FLAGS} 217 | -DCMAKE_C_FLAGS=${WHISPER_EXTRA_C_FLAGS} -DBUILD_SHARED_LIBS=OFF 218 | -DWHISPER_BUILD_TESTS=OFF -DWHISPER_BUILD_EXAMPLES=OFF 219 | -DGGML_OPENMP=ON -WHISPER_BUILD_SERVER=OFF 220 | -DGGML_BLAS=OFF -DGGML_CUDA=OFF -DGGML_VULKAN=OFF -DGGML_HIPBLAS=OFF) 221 | 222 | ExternalProject_Get_Property(Whispercpp_Build INSTALL_DIR) 223 | 224 | # add the static Whisper library to the link line 225 | add_library(Whispercpp::Whisper STATIC IMPORTED) 226 | set_target_properties( 227 | Whispercpp::Whisper 228 | PROPERTIES 229 | IMPORTED_LOCATION 230 | ${INSTALL_DIR}/lib64/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX} 231 | ) 232 | add_library(Whispercpp::GGML STATIC IMPORTED) 233 | set_target_properties( 234 | Whispercpp::GGML 235 | PROPERTIES 236 | IMPORTED_LOCATION 237 | ${INSTALL_DIR}/lib64/${CMAKE_STATIC_LIBRARY_PREFIX}ggml${CMAKE_STATIC_LIBRARY_SUFFIX} 238 | ) 239 | set_target_properties( 240 | Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES 241 | ${INSTALL_DIR}/include) 242 | set_property( 243 | TARGET Whispercpp::Whisper 244 | APPEND 245 | PROPERTY INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_CXX) 246 | 247 | endif() 248 | 249 | add_library(Whispercpp INTERFACE) 250 | add_dependencies(Whispercpp Whispercpp_Build) 251 | target_link_libraries(Whispercpp INTERFACE Whispercpp::Whisper Whispercpp::GGML) 252 | if(WIN32 AND "${ACCELERATION}" STREQUAL "cpu") 253 | target_link_libraries(Whispercpp INTERFACE Whispercpp::OpenBLAS) 254 | endif() 255 | if(APPLE) 256 | target_link_libraries( 257 | Whispercpp 258 | INTERFACE "-framework Accelerate -framework CoreML -framework Metal") 259 | target_link_libraries(Whispercpp INTERFACE Whispercpp::CoreML) 260 | endif(APPLE) 261 | if(UNIX AND NOT APPLE) 262 | target_link_libraries(Whispercpp INTERFACE OpenMP::OpenMP_CXX) 263 | endif() 264 | -------------------------------------------------------------------------------- /src/whisper_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace py = pybind11; 16 | 17 | std::string trim(const std::string &str) 18 | { 19 | size_t start = str.find_first_not_of(" \t\n\r"); 20 | size_t end = str.find_last_not_of(" \t\n\r"); 21 | 22 | if (start == std::string::npos) // handles empty string "" and all-whitespace strings like " " 23 | return ""; 24 | 25 | return str.substr(start, end - start + 1); 26 | } 27 | 28 | // Global variable to store the Python callback function 29 | py::function g_py_log_callback; 30 | 31 | // C++ callback function that will be passed to whisper_log_set 32 | void cpp_log_callback(ggml_log_level level, const char *text, void *) 33 | { 34 | if (!g_py_log_callback.is_none() && text != nullptr && strlen(text) > 0) 35 | { 36 | py::gil_scoped_acquire gil; 37 | g_py_log_callback(level, std::string(text)); 38 | } 39 | } 40 | 41 | // Function to set the log callback 42 | void set_log_callback(py::function callback) 43 | { 44 | g_py_log_callback = callback; 45 | whisper_log_set(cpp_log_callback, nullptr); 46 | ggml_log_set(cpp_log_callback, nullptr); 47 | } 48 | 49 | struct WhisperToken 50 | { 51 | int id; 52 | float p; 53 | int64_t t0; 54 | int64_t t1; 55 | std::string text; 56 | }; 57 | 58 | struct WhisperSegment 59 | { 60 | std::string text; 61 | int64_t start; 62 | int64_t end; 63 | std::vector tokens; 64 | }; 65 | 66 | // Original synchronous implementation 67 | class WhisperModel 68 | { 69 | public: 70 | WhisperModel(const std::string &model_path, bool use_gpu = false) 71 | { 72 | whisper_context_params ctx_params = whisper_context_default_params(); 73 | ctx_params.use_gpu = use_gpu; 74 | ctx = whisper_init_from_file_with_params(model_path.c_str(), ctx_params); 75 | if (!ctx) 76 | { 77 | throw std::runtime_error("Failed to initialize whisper context"); 78 | } 79 | params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); 80 | params.no_timestamps = false; 81 | params.token_timestamps = true; 82 | } 83 | 84 | ~WhisperModel() 85 | { 86 | if (ctx) 87 | { 88 | whisper_free(ctx); 89 | } 90 | } 91 | 92 | py::list transcribe(py::array_t audio) 93 | { 94 | py::list result; 95 | // Check if input is empty 96 | if (audio.is_none() || audio.size() == 0) 97 | { 98 | return result; 99 | } 100 | 101 | auto audio_buffer = audio.request(); 102 | float *audio_data = static_cast(audio_buffer.ptr); 103 | int n_samples = audio_buffer.size; 104 | 105 | std::vector segments = transcribe_raw_audio(audio_data, n_samples); 106 | 107 | for (const auto &segment : segments) 108 | { 109 | result.append(py::cast(segment)); 110 | } 111 | 112 | return result; 113 | } 114 | 115 | std::vector transcribe_raw_audio(const float *audio_data, int n_samples) 116 | { 117 | if (whisper_full(ctx, params, audio_data, n_samples) != 0) 118 | { 119 | throw std::runtime_error("Whisper inference failed"); 120 | } 121 | 122 | int n_segments = whisper_full_n_segments(ctx); 123 | std::vector transcription; 124 | for (int i = 0; i < n_segments; i++) 125 | { 126 | const char *text = whisper_full_get_segment_text(ctx, i); 127 | WhisperSegment segment; 128 | segment.start = whisper_full_get_segment_t0(ctx, i); 129 | segment.end = whisper_full_get_segment_t1(ctx, i); 130 | segment.text = std::string(text); 131 | const int n_tokens = whisper_full_n_tokens(ctx, i); 132 | for (int j = 0; j < n_tokens; ++j) 133 | { 134 | // get token 135 | whisper_token_data token = 136 | whisper_full_get_token_data(ctx, i, j); 137 | WhisperToken wt; 138 | wt.id = token.id; 139 | wt.p = token.p; 140 | wt.t0 = token.t0; 141 | wt.t1 = token.t1; 142 | wt.text = std::string(whisper_token_to_str(ctx, token.id)); 143 | segment.tokens.push_back(wt); 144 | } 145 | 146 | transcription.push_back(segment); 147 | } 148 | 149 | return transcription; 150 | } 151 | 152 | private: 153 | whisper_context *ctx; 154 | whisper_full_params params; 155 | }; 156 | 157 | struct AudioChunk 158 | { 159 | std::vector data; 160 | size_t id; 161 | }; 162 | 163 | struct TranscriptionResult 164 | { 165 | size_t chunk_id; 166 | bool is_partial; 167 | std::vector segments; 168 | }; 169 | 170 | class AsyncWhisperModel 171 | { 172 | public: 173 | AsyncWhisperModel(const std::string &model_path, bool use_gpu = false) : model_path(model_path), use_gpu(use_gpu), 174 | running(false), next_chunk_id(0), current_chunk_id(0) 175 | { 176 | } 177 | 178 | ~AsyncWhisperModel() 179 | { 180 | } 181 | 182 | void start(py::function callback, int result_check_interval_ms = 100) 183 | { 184 | if (running) 185 | return; 186 | 187 | running = true; 188 | result_callback = callback; 189 | 190 | process_thread = std::thread(&AsyncWhisperModel::processThread, this); 191 | result_thread = std::thread(&AsyncWhisperModel::resultThread, this, 192 | result_check_interval_ms); 193 | } 194 | 195 | /** 196 | * @brief Transcribes the given audio data. 197 | * 198 | * This function takes an audio input in the form of a py::array_t and 199 | * processes it by queuing the audio for transcription. 200 | * 201 | * @param audio A py::array_t containing the audio data to be transcribed. 202 | * @return size_t The queued chunk ID. 203 | */ 204 | size_t transcribe(py::array_t audio) 205 | { 206 | // Check if input is empty 207 | if (audio.is_none() || audio.size() == 0) 208 | { 209 | return 0; 210 | } 211 | 212 | return this->queueAudio(audio); 213 | } 214 | 215 | virtual void stop() 216 | { 217 | if (!running) 218 | return; 219 | running = false; 220 | 221 | { 222 | std::lock_guard lock(input_mutex); 223 | input_cv.notify_one(); 224 | } 225 | 226 | { 227 | std::lock_guard lock(result_mutex); 228 | result_cv.notify_one(); 229 | } 230 | 231 | if (process_thread.joinable()) 232 | process_thread.join(); 233 | if (result_thread.joinable()) 234 | result_thread.join(); 235 | } 236 | 237 | size_t queueAudio(py::array_t audio) 238 | { 239 | auto buffer = audio.request(); 240 | float *data = static_cast(buffer.ptr); 241 | size_t n_samples = buffer.size; 242 | 243 | AudioChunk chunk; 244 | chunk.data.assign(data, data + n_samples); 245 | chunk.id = next_chunk_id++; 246 | 247 | { 248 | std::lock_guard lock(input_mutex); 249 | input_queue.push(std::move(chunk)); 250 | input_cv.notify_one(); 251 | } 252 | 253 | return chunk.id; 254 | } 255 | 256 | protected: 257 | virtual void processThread() 258 | { 259 | WhisperModel model(model_path, use_gpu); 260 | 261 | while (running) 262 | { 263 | AudioChunk chunk; 264 | // Get next chunk from input queue 265 | { 266 | std::unique_lock lock(input_mutex); 267 | input_cv.wait_for(lock, 268 | std::chrono::milliseconds(100), 269 | [this] 270 | { return !input_queue.empty() || !running; }); 271 | 272 | if (!running) 273 | break; 274 | 275 | if (input_queue.empty()) 276 | continue; 277 | 278 | chunk = std::move(input_queue.front()); 279 | input_queue.pop(); 280 | } 281 | 282 | // Process audio 283 | TranscriptionResult result; 284 | result.chunk_id = chunk.id; 285 | result.is_partial = false; 286 | try 287 | { 288 | result.segments = model.transcribe_raw_audio(chunk.data.data(), chunk.data.size()); 289 | } 290 | catch (const std::exception &e) 291 | { 292 | std::cerr << "Exception during transcription: " << e.what() << std::endl; 293 | } 294 | catch (...) 295 | { 296 | std::cerr << "Unknown exception during transcription" << std::endl; 297 | } 298 | 299 | // Add result to output queue 300 | { 301 | std::lock_guard lock(result_mutex); 302 | result_queue.push(result); 303 | result_cv.notify_one(); 304 | } 305 | } 306 | } 307 | 308 | void resultThread(int check_interval_ms) 309 | { 310 | while (running) 311 | { 312 | std::vector results; 313 | 314 | { 315 | std::unique_lock lock(result_mutex); 316 | result_cv.wait_for(lock, 317 | std::chrono::milliseconds(check_interval_ms), 318 | [this] 319 | { return !result_queue.empty() || !running; }); 320 | 321 | if (!running && result_queue.empty()) 322 | break; 323 | 324 | while (!result_queue.empty()) 325 | { 326 | results.push_back(std::move(result_queue.front())); 327 | result_queue.pop(); 328 | } 329 | } 330 | 331 | if (!results.empty()) 332 | { 333 | py::gil_scoped_acquire gil; 334 | for (const auto &result : results) 335 | { 336 | if (result.segments.empty()) 337 | continue; 338 | 339 | // concatenate segments into a single string 340 | std::string full_text; 341 | for (const auto &segment : result.segments) 342 | { 343 | full_text += segment.text; 344 | } 345 | full_text = trim(full_text); 346 | if (full_text.empty()) 347 | continue; 348 | 349 | if (result_callback) 350 | { 351 | try 352 | { 353 | result_callback((int)result.chunk_id, result.segments, result.is_partial); 354 | } 355 | catch (const std::exception &e) 356 | { 357 | std::cerr << "Exception in result callback: " << e.what() << std::endl; 358 | } 359 | catch (...) 360 | { 361 | std::cerr << "Unknown exception in result callback" << std::endl; 362 | } 363 | } 364 | } 365 | } 366 | } 367 | } 368 | 369 | std::string model_path; 370 | bool use_gpu; 371 | 372 | std::atomic running; 373 | std::atomic next_chunk_id; 374 | size_t current_chunk_id; 375 | 376 | std::thread process_thread; 377 | std::thread result_thread; 378 | 379 | std::queue input_queue; 380 | std::mutex input_mutex; 381 | std::condition_variable input_cv; 382 | 383 | std::queue result_queue; 384 | std::mutex result_mutex; 385 | std::condition_variable result_cv; 386 | 387 | py::function result_callback; 388 | }; 389 | 390 | class ThreadedWhisperModel : public AsyncWhisperModel 391 | { 392 | public: 393 | ThreadedWhisperModel(const std::string &model_path, bool use_gpu = false, 394 | float max_duration_sec = 10.0f, int sample_rate = 16000) 395 | : AsyncWhisperModel(model_path, use_gpu), 396 | max_samples(static_cast(max_duration_sec * sample_rate)) 397 | { 398 | } 399 | 400 | ~ThreadedWhisperModel() 401 | { 402 | stop(); 403 | } 404 | 405 | void setMaxDuration(float max_duration_sec, int sample_rate = 16000) 406 | { 407 | max_samples = static_cast(max_duration_sec * sample_rate); 408 | } 409 | 410 | void start(py::function callback, int result_check_interval_ms = 100) 411 | { 412 | AsyncWhisperModel::start(callback, result_check_interval_ms); 413 | } 414 | 415 | void stop() override 416 | { 417 | AsyncWhisperModel::stop(); 418 | 419 | // Clear accumulated buffer 420 | { 421 | std::lock_guard lock(buffer_mutex); 422 | accumulated_buffer.clear(); 423 | } 424 | } 425 | 426 | private: 427 | void processAccumulatedAudio(WhisperModel &model, bool force_final = false) 428 | { 429 | std::vector process_buffer; 430 | size_t current_id; 431 | 432 | { 433 | std::lock_guard lock(buffer_mutex); 434 | if (accumulated_buffer.empty() || accumulated_buffer.size() < 16000) 435 | return; 436 | 437 | process_buffer = accumulated_buffer; 438 | current_id = current_chunk_id; 439 | 440 | // Only clear the buffer if we're processing a final result 441 | if (force_final || accumulated_buffer.size() >= max_samples) 442 | { 443 | accumulated_buffer.clear(); 444 | } 445 | } 446 | 447 | // Process audio 448 | std::vector segments; 449 | try 450 | { 451 | segments = model.transcribe_raw_audio(process_buffer.data(), process_buffer.size()); 452 | } 453 | catch (const std::exception &e) 454 | { 455 | std::cerr << "Exception during transcription: " << e.what() << std::endl; 456 | } 457 | catch (...) 458 | { 459 | std::cerr << "Unknown exception during transcription" << std::endl; 460 | } 461 | 462 | if (segments.empty()) 463 | { 464 | return; 465 | } 466 | 467 | TranscriptionResult result; 468 | result.chunk_id = current_id; 469 | for (const auto &segment : segments) 470 | { 471 | result.segments.push_back(segment); 472 | } 473 | // Set partial flag based on whether this is a final result 474 | result.is_partial = !(force_final || process_buffer.size() >= max_samples); 475 | 476 | // Add result to output queue 477 | { 478 | std::lock_guard lock(result_mutex); 479 | result_queue.push(result); 480 | result_cv.notify_one(); 481 | } 482 | } 483 | 484 | void processThread() override 485 | { 486 | WhisperModel model(model_path, use_gpu); 487 | 488 | while (running) 489 | { 490 | AudioChunk all_chunks; 491 | bool has_chunk = false; 492 | 493 | // Get next chunk from input queue 494 | { 495 | std::unique_lock lock(input_mutex); 496 | input_cv.wait(lock, [this] 497 | { return !input_queue.empty() || !running; }); 498 | 499 | if (!running) 500 | { 501 | // Process any remaining audio as final before shutting down 502 | processAccumulatedAudio(model, true); 503 | break; 504 | } 505 | 506 | // take all chunks from the queue and create a single chunk 507 | while (!input_queue.empty()) 508 | { 509 | AudioChunk chunk = std::move(input_queue.front()); 510 | input_queue.pop(); 511 | all_chunks.data.insert(all_chunks.data.end(), chunk.data.begin(), chunk.data.end()); 512 | all_chunks.id = chunk.id; 513 | has_chunk = true; 514 | } 515 | } 516 | 517 | if (has_chunk) 518 | { 519 | // Add to accumulated buffer 520 | { 521 | std::lock_guard lock(buffer_mutex); 522 | size_t old_size = accumulated_buffer.size(); 523 | accumulated_buffer.resize(old_size + all_chunks.data.size()); 524 | std::copy(all_chunks.data.begin(), all_chunks.data.end(), 525 | accumulated_buffer.begin() + old_size); 526 | 527 | current_chunk_id = all_chunks.id; 528 | } 529 | 530 | // Process the accumulated audio 531 | processAccumulatedAudio(model, false); 532 | } 533 | } 534 | } 535 | 536 | // Audio accumulation 537 | std::vector accumulated_buffer; 538 | size_t max_samples; 539 | std::mutex buffer_mutex; 540 | }; 541 | 542 | PYBIND11_MODULE(_whisper_cpp, m) 543 | { 544 | // Bind WhisperToken 545 | py::class_(m, "WhisperToken") 546 | .def(py::init<>()) 547 | .def_readwrite("id", &WhisperToken::id) 548 | .def_readwrite("p", &WhisperToken::p) 549 | .def_readwrite("t0", &WhisperToken::t0) 550 | .def_readwrite("t1", &WhisperToken::t1) 551 | .def_readwrite("text", &WhisperToken::text) 552 | .def("__str__", [](const WhisperToken &t) 553 | { 554 | std::stringstream ss; 555 | ss << t.text << " (id: " << t.id << ", p: " << t.p << ")"; 556 | return ss.str(); }); 557 | 558 | // Bind WhisperSement 559 | py::class_(m, "WhisperSement") 560 | .def(py::init<>()) 561 | .def_readwrite("text", &WhisperSegment::text) 562 | .def_readwrite("start", &WhisperSegment::start) 563 | .def_readwrite("end", &WhisperSegment::end) 564 | .def_readwrite("tokens", &WhisperSegment::tokens) 565 | .def("__str__", [](const WhisperSegment &s) 566 | { return s.text; }) 567 | .def("__repr__", [](const WhisperSegment &s) 568 | { 569 | std::stringstream ss; 570 | ss << "WhisperSegment(text=\"" << s.text << "\", start=" << s.start << ", end=" << s.end << ")"; 571 | return ss.str(); }); 572 | 573 | // Expose synchronous model 574 | py::class_(m, "WhisperModel") 575 | .def(py::init()) 576 | .def("transcribe", &WhisperModel::transcribe); 577 | 578 | // Expose asynchronous model 579 | py::class_(m, "AsyncWhisperModel") 580 | .def(py::init()) 581 | .def("start", &AsyncWhisperModel::start, 582 | py::arg("callback"), 583 | py::arg("result_check_interval_ms") = 100) 584 | .def("stop", &AsyncWhisperModel::stop) 585 | .def("transcribe", &AsyncWhisperModel::transcribe) 586 | .def("queue_audio", &AsyncWhisperModel::queueAudio); 587 | 588 | py::class_(m, "ThreadedWhisperModel") 589 | .def(py::init(), 590 | py::arg("model_path"), 591 | py::arg("use_gpu") = false, 592 | py::arg("max_duration_sec") = 10.0f, 593 | py::arg("sample_rate") = 16000) 594 | .def("start", &ThreadedWhisperModel::start, 595 | py::arg("callback"), 596 | py::arg("result_check_interval_ms") = 100) 597 | .def("stop", &ThreadedWhisperModel::stop) 598 | .def("queue_audio", &ThreadedWhisperModel::queueAudio) 599 | .def("set_max_duration", &ThreadedWhisperModel::setMaxDuration, 600 | py::arg("max_duration_sec"), 601 | py::arg("sample_rate") = 16000); 602 | 603 | // Expose logging functionality 604 | m.def("set_log_callback", &set_log_callback, "Set the log callback function"); 605 | 606 | py::enum_(m, "LogLevel") 607 | .value("NONE", GGML_LOG_LEVEL_NONE) 608 | .value("INFO", GGML_LOG_LEVEL_INFO) 609 | .value("WARN", GGML_LOG_LEVEL_WARN) 610 | .value("ERROR", GGML_LOG_LEVEL_ERROR) 611 | .value("DEBUG", GGML_LOG_LEVEL_DEBUG) 612 | .value("CONT", GGML_LOG_LEVEL_CONT) 613 | .export_values(); 614 | } 615 | --------------------------------------------------------------------------------