├── pywhispercpp ├── __init__.py ├── examples │ ├── __init__.py │ ├── recording.py │ ├── main.py │ ├── assistant.py │ ├── livestream.py │ └── gui.py ├── utils.py ├── constants.py └── model.py ├── requirements.txt ├── .directory ├── .github ├── dependabot.yml └── workflows │ ├── docs.yml │ ├── pip.yml │ └── wheels.yml ├── .gitmodules ├── CMakeLists.txt ├── docs └── index.md ├── MANIFEST.in ├── .appveyor.yml ├── LICENSE ├── tests ├── test_segfault.py ├── test_model.py └── test_c_api.py ├── pyproject.toml ├── .pre-commit-config.yaml ├── mkdocs.yml ├── .gitignore ├── setup.py ├── README.md └── src └── main.cpp /pywhispercpp/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pywhispercpp/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | sounddevice~=0.4.6 3 | webrtcvad~=2.0.10 4 | requests 5 | tqdm 6 | platformdirs -------------------------------------------------------------------------------- /.directory: -------------------------------------------------------------------------------- 1 | [Dolphin] 2 | HeaderColumnWidths=499,145,72 3 | Timestamp=2023,3,9,19,49,1.313 4 | Version=4 5 | ViewMode=1 6 | 7 | [Settings] 8 | HiddenFilesShown=true 9 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pybind11"] 2 | path = pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | branch = master 5 | 6 | [submodule "whisper.cpp"] 7 | path = whisper.cpp 8 | url = https://github.com/ggml-org/whisper.cpp.git 9 | branch = master 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.4...3.18) 2 | project(pywhispercpp) 3 | 4 | add_subdirectory(pybind11) 5 | add_subdirectory(whisper.cpp) 6 | 7 | pybind11_add_module(_pywhispercpp 8 | src/main.cpp 9 | ) 10 | 11 | target_link_libraries (_pywhispercpp PRIVATE whisper) 12 | 13 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # PyWhisperCpp API Reference 2 | 3 | 4 | ::: pywhispercpp.model 5 | 6 | ::: pywhispercpp.constants 7 | options: 8 | show_if_no_docstring: true 9 | 10 | ::: pywhispercpp.utils 11 | 12 | ::: pywhispercpp.examples 13 | options: 14 | show_if_no_docstring: false -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE pybind11/LICENSE version.txt 2 | graft pybind11/include 3 | graft pybind11/tools 4 | graft src 5 | global-include CMakeLists.txt *.cmake 6 | 7 | graft whisper.cpp/cmake 8 | graft whisper.cpp/ggml 9 | graft whisper.cpp/grammars 10 | graft whisper.cpp/include 11 | graft whisper.cpp/spm-headers 12 | graft whisper.cpp/src 13 | exclude whisper.cpp/**/*.o 14 | exclude whisper.cpp/**/*.so 15 | exclude whisper.cpp/**/*.a 16 | exclude whisper.cpp/**/*.dylib 17 | exclude whisper.cpp/**/*.dll 18 | exclude whisper.cpp/**/*.lib -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | with: 15 | python-version: '3.x' 16 | - uses: actions/cache@v4 17 | with: 18 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 19 | path: ~/.cache/pip 20 | - run: | 21 | pip install mkdocs-macros-plugin mkdocs-material mkdocstrings[python] black pywhispercpp 22 | - run: mkdocs gh-deploy --force 23 | -------------------------------------------------------------------------------- /.appveyor.yml: -------------------------------------------------------------------------------- 1 | version: '{build}' 2 | image: Visual Studio 2019 3 | platform: 4 | - x86 5 | - x64 6 | environment: 7 | global: 8 | DISTUTILS_USE_SDK: 1 9 | PYTHONWARNINGS: ignore:DEPRECATION 10 | MSSdk: 1 11 | matrix: 12 | - PYTHON: 37 13 | install: 14 | - cmd: '"%VS140COMNTOOLS%\..\..\VC\vcvarsall.bat" %PLATFORM%' 15 | - ps: | 16 | git submodule update -q --init --recursive 17 | if ($env:PLATFORM -eq "x64") { $env:PYTHON = "$env:PYTHON-x64" } 18 | $env:PATH = "C:\Python$env:PYTHON\;C:\Python$env:PYTHON\Scripts\;$env:PATH" 19 | python -m pip install --disable-pip-version-check --upgrade --no-warn-script-location pip build pytest 20 | build_script: 21 | - ps: | 22 | python -m build -s 23 | cd dist 24 | python -m pip install --verbose cmake_example-0.0.1.tar.gz 25 | cd .. 26 | test_script: 27 | - ps: python -m pytest 28 | -------------------------------------------------------------------------------- /.github/workflows/pip.yml: -------------------------------------------------------------------------------- 1 | name: Pip 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | build: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | platform: [windows-latest, macos-14, ubuntu-latest] 16 | python-version: ["3.8", "3.11"] 17 | 18 | runs-on: ${{ matrix.platform }} 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | with: 23 | submodules: true 24 | 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Add requirements 30 | run: python -m pip install --upgrade wheel setuptools 31 | 32 | - name: Install requirements 33 | run: python -m pip install -r requirements.txt 34 | 35 | - name: Build and install 36 | run: pip install --verbose .[test] 37 | 38 | # - name: Test C-API 39 | # run: python -m unittest ./tests/test_c_api.py 40 | 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Abdeladim Sadiki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_segfault.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import unittest 6 | from pathlib import Path 7 | from unittest import TestCase 8 | 9 | from pywhispercpp.model import Model, Segment 10 | 11 | if __name__ == '__main__': 12 | pass 13 | 14 | WHISPER_CPP_DIR = Path(__file__).parent.parent / 'whisper.cpp' 15 | 16 | class TestSegfault(TestCase): 17 | audio_file = WHISPER_CPP_DIR/ 'samples/jfk.wav' 18 | 19 | def voice_to_text(self, tmp_path): 20 | # n_threads=1 is 3x faster than n_threads=6 when running in Docker. 21 | whisper_model = Model('tiny.en-q5_1', n_threads=1) 22 | segments: list[Segment] = whisper_model.transcribe(tmp_path) 23 | text = next((segment.text for segment in segments if segment.text and '(' not in segment.text and '[' not in segment.text)) 24 | return text 25 | 26 | def test_sample_file(self): 27 | expected_text = "ask not what your country can do for you" 28 | text = self.voice_to_text(str(self.audio_file)) 29 | self.assertIn(expected_text.lower(), text.lower(), 30 | f"Expected text '{expected_text}' not found in transcription: '{text}'") 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | "ninja", 6 | "cmake>=3.12", 7 | "repairwheel", 8 | "setuptools-scm>=8" 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | 12 | [tool.mypy] 13 | files = "setup.py" 14 | python_version = "3.8" 15 | strict = true 16 | show_error_codes = true 17 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] 18 | warn_unreachable = true 19 | 20 | [[tool.mypy.overrides]] 21 | module = ["ninja"] 22 | ignore_missing_imports = true 23 | 24 | 25 | [tool.pytest.ini_options] 26 | minversion = "6.0" 27 | addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] 28 | xfail_strict = true 29 | filterwarnings = ["error"] 30 | testpaths = ["tests"] 31 | 32 | [tool.cibuildwheel] 33 | #test-command = "pytest {project}/tests" 34 | #test-extras = ["test"] 35 | test-skip = ["*universal2:arm64"] 36 | # Setuptools bug causes collision between pypy and cpython artifacts 37 | before-build = "rm -rf {project}/build" 38 | 39 | [tool.ruff] 40 | extend-select = [ 41 | "B", # flake8-bugbear 42 | "B904", 43 | "I", # isort 44 | "PGH", # pygrep-hooks 45 | "RUF", # Ruff-specific 46 | "UP", # pyupgrade 47 | ] 48 | extend-ignore = [ 49 | "E501", # Line too long 50 | ] 51 | target-version = "py39" 52 | 53 | [tool.setuptools_scm] 54 | version_file = "_version.py" -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Test model.py 6 | """ 7 | import unittest 8 | from pathlib import Path 9 | from unittest import TestCase 10 | 11 | from pywhispercpp.model import Model, Segment 12 | 13 | if __name__ == '__main__': 14 | pass 15 | 16 | WHISPER_CPP_DIR = Path(__file__).parent.parent / 'whisper.cpp' 17 | 18 | class TestModel(TestCase): 19 | audio_file = WHISPER_CPP_DIR/ 'samples/jfk.wav' 20 | model = Model("tiny", models_dir=str(WHISPER_CPP_DIR/'models')) 21 | 22 | def test_transcribe(self): 23 | segments = self.model.transcribe(str(self.audio_file)) 24 | return self.assertIsInstance(segments, list) and \ 25 | self.assertIsInstance(segments[0], Segment) if len(segments) > 0 else True 26 | 27 | def test_get_params(self): 28 | params = self.model.get_params() 29 | return self.assertIsInstance(params, dict) 30 | 31 | def test_lang_max_id(self): 32 | n = self.model.lang_max_id() 33 | return self.assertGreater(n, 0) 34 | 35 | def test_available_languages(self): 36 | av_langs = self.model.available_languages() 37 | return self.assertIsInstance(av_langs, list) and self.assertGreater(len(av_langs), 1) 38 | 39 | def test__load_audio(self): 40 | audio_arr = self.model._load_audio(str(self.audio_file)) 41 | return self.assertIsNotNone(audio_arr) 42 | 43 | def test_auto_detect_language(self): 44 | detected_language, probs = self.model.auto_detect_language(str(self.audio_file)) 45 | return self.assertIsInstance(detected_language, tuple) and self.assertEqual(detected_language[0], 'en') 46 | 47 | 48 | if __name__ == '__main__': 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # pre-commit run -a 4 | # 5 | # Or: 6 | # 7 | # pre-commit install # (runs every time you commit in git) 8 | # 9 | # To update this file: 10 | # 11 | # pre-commit autoupdate 12 | # 13 | # See https://github.com/pre-commit/pre-commit 14 | 15 | ci: 16 | autoupdate_commit_msg: "chore: update pre-commit hooks" 17 | autofix_commit_msg: "style: pre-commit fixes" 18 | 19 | repos: 20 | # Standard hooks 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v4.4.0 23 | hooks: 24 | - id: check-added-large-files 25 | - id: check-case-conflict 26 | - id: check-merge-conflict 27 | - id: check-symlinks 28 | - id: check-yaml 29 | exclude: ^conda\.recipe/meta\.yaml$ 30 | - id: debug-statements 31 | - id: end-of-file-fixer 32 | - id: mixed-line-ending 33 | - id: requirements-txt-fixer 34 | - id: trailing-whitespace 35 | 36 | # Black, the code formatter, natively supports pre-commit 37 | - repo: https://github.com/psf/black 38 | rev: 23.1.0 39 | hooks: 40 | - id: black 41 | exclude: ^(docs) 42 | 43 | - repo: https://github.com/charliermarsh/ruff-pre-commit 44 | rev: "v0.0.253" 45 | hooks: 46 | - id: ruff 47 | args: ["--fix", "--show-fixes"] 48 | 49 | # Checking static types 50 | - repo: https://github.com/pre-commit/mirrors-mypy 51 | rev: "v1.0.1" 52 | hooks: 53 | - id: mypy 54 | files: "setup.py" 55 | args: [] 56 | additional_dependencies: [types-setuptools] 57 | 58 | # Changes tabs to spaces 59 | - repo: https://github.com/Lucas-C/pre-commit-hooks 60 | rev: v1.4.2 61 | hooks: 62 | - id: remove-tabs 63 | exclude: ^(docs) 64 | 65 | # CMake formatting 66 | - repo: https://github.com/cheshirekow/cmake-format-precommit 67 | rev: v0.6.13 68 | hooks: 69 | - id: cmake-format 70 | additional_dependencies: [pyyaml] 71 | types: [file] 72 | files: (\.cmake|CMakeLists.txt)(.in)?$ 73 | 74 | # Suggested hook if you add a .clang-format file 75 | # - repo: https://github.com/pre-commit/mirrors-clang-format 76 | # rev: v13.0.0 77 | # hooks: 78 | # - id: clang-format 79 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: PyWhisperCpp 2 | repo_url: https://github.com/absadiki/pywhispercpp 3 | repo_name: absadiki/pywhispercpp 4 | theme: 5 | name: material 6 | language: en 7 | features: 8 | - navigation.tabs 9 | - navigation.sections 10 | - toc.integrate 11 | - navigation.top 12 | - search.suggest 13 | - search.highlight 14 | - content.tabs.link 15 | - content.code.annotation 16 | - content.code.copy 17 | palette: 18 | - scheme: default 19 | toggle: 20 | icon: material/toggle-switch-off-outline 21 | name: Switch to dark mode 22 | - scheme: slate 23 | toggle: 24 | icon: material/toggle-switch 25 | name: Switch to light mode 26 | 27 | 28 | extra: 29 | social: 30 | - icon: fontawesome/brands/github-alt 31 | link: https://github.com/absadiki/pywhispercpp 32 | 33 | markdown_extensions: 34 | - pymdownx.inlinehilite 35 | - pymdownx.snippets 36 | - admonition 37 | - pymdownx.arithmatex: 38 | generic: true 39 | - footnotes 40 | - pymdownx.details 41 | - pymdownx.superfences 42 | - pymdownx.mark 43 | - attr_list 44 | - pymdownx.emoji: 45 | emoji_index: !!python/name:materialx.emoji.twemoji 46 | emoji_generator: !!python/name:materialx.emoji.to_svg 47 | - pymdownx.highlight: 48 | use_pygments: true 49 | pygments_lang_class: true 50 | 51 | copyright: | 52 | © 2023 absadiki 53 | 54 | plugins: 55 | - search 56 | - macros: 57 | include_dir: . 58 | - mkdocstrings: 59 | handlers: 60 | python: 61 | paths: [src] 62 | options: 63 | separate_signature: true 64 | docstring_style: sphinx 65 | docstring_section_style: list 66 | members_order: source 67 | merge_init_into_class: true 68 | show_bases: true 69 | show_if_no_docstring: false 70 | show_root_full_path: true 71 | show_root_heading: true 72 | show_submodules: true 73 | filters: 74 | - "!^_" 75 | watch: 76 | - pywhispercpp/ 77 | 78 | -------------------------------------------------------------------------------- /pywhispercpp/examples/recording.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | A simple example showcasing how to use pywhispercpp to transcribe a recording. 6 | """ 7 | import argparse 8 | import logging 9 | import sounddevice as sd 10 | import pywhispercpp.constants 11 | from pywhispercpp.model import Model 12 | import importlib.metadata 13 | 14 | 15 | __version__ = importlib.metadata.version('pywhispercpp') 16 | 17 | __header__ = f""" 18 | =================================================================== 19 | PyWhisperCpp 20 | A simple example of transcribing a recording, based on whisper.cpp 21 | Version: {__version__} 22 | =================================================================== 23 | """ 24 | 25 | 26 | class Recording: 27 | """ 28 | Recording class 29 | 30 | Example usage 31 | ```python 32 | from pywhispercpp.examples.recording import Recording 33 | 34 | myrec = Recording(5) 35 | myrec.start() 36 | ``` 37 | """ 38 | def __init__(self, 39 | duration: int, 40 | model: str = 'tiny.en', 41 | **model_params): 42 | self.duration = duration 43 | self.sample_rate = pywhispercpp.constants.WHISPER_SAMPLE_RATE 44 | self.channels = 1 45 | self.pwcpp_model = Model(model, print_realtime=True, **model_params) 46 | 47 | def start(self): 48 | logging.info(f"Start recording for {self.duration}s ...") 49 | recording = sd.rec(int(self.duration * self.sample_rate), samplerate=self.sample_rate, channels=self.channels) 50 | sd.wait() 51 | logging.info('Duration finished') 52 | res = self.pwcpp_model.transcribe(recording) 53 | self.pwcpp_model.print_timings() 54 | 55 | 56 | def _main(): 57 | print(__header__) 58 | parser = argparse.ArgumentParser(description="", allow_abbrev=True) 59 | # Positional args 60 | parser.add_argument('duration', type=int, help=f"duration in seconds") 61 | parser.add_argument('-m', '--model', default='tiny.en', type=str, help="Whisper.cpp model, default to %(default)s") 62 | 63 | args = parser.parse_args() 64 | 65 | myrec = Recording(duration=args.duration, model=args.model) 66 | myrec.start() 67 | 68 | 69 | if __name__ == '__main__': 70 | _main() 71 | -------------------------------------------------------------------------------- /tests/test_c_api.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import _pywhispercpp as pw 4 | 5 | import unittest 6 | from unittest import TestCase 7 | 8 | 9 | class TestCAPI(TestCase): 10 | 11 | model_file = './whisper.cpp/models/for-tests-ggml-tiny.en.bin' 12 | 13 | def test_whisper_init_from_file(self): 14 | ctx = pw.whisper_init_from_file(self.model_file) 15 | self.assertIsInstance(ctx, pw.whisper_context) 16 | 17 | def test_whisper_lang_str(self): 18 | return self.assertEqual(pw.whisper_lang_str(0), 'en') 19 | 20 | def test_whisper_lang_id(self): 21 | return self.assertEqual(pw.whisper_lang_id('en'), 0) 22 | 23 | def test_whisper_full_params_language_set_to_de(self): 24 | params = pw.whisper_full_params() 25 | params.language = 'de' 26 | return self.assertEqual(params.language, 'de') 27 | 28 | def test_whisper_full_params_language_set_to_german(self): 29 | params = pw.whisper_full_params() 30 | params.language = 'german' 31 | return self.assertEqual(params.language, 'de') 32 | 33 | def test_whisper_full_params_context(self): 34 | 35 | params = pw.whisper_full_params() 36 | # to ensure that the string is not cached 37 | prompt = str(10120923) + "A" + " test" 38 | params.initial_prompt = prompt 39 | print("Params Prompt: ", params.initial_prompt) 40 | del prompt 41 | import gc 42 | gc.collect() 43 | return self.assertEqual(params.initial_prompt, str(10120923) + "A test") 44 | 45 | def test_whisper_full_params_regex(self): 46 | params = pw.whisper_full_params() 47 | val = str(10120923) + "A" + " test" 48 | params.suppress_regex = val 49 | print("Params Prompt: ", params.suppress_regex) 50 | del val 51 | import gc 52 | gc.collect() 53 | return self.assertEqual(params.suppress_regex, str(10120923) + "A" + " test") 54 | 55 | def test_whisper_full_params_default(self): 56 | params = pw.whisper_full_default_params(pw.whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY) 57 | self.assertIsInstance(params, pw.whisper_full_params) 58 | self.assertEqual(params.suppress_regex, "") 59 | 60 | def test_whisper_lang_id(self): 61 | return self.assertEqual(pw.whisper_lang_id('en'), 0) 62 | 63 | def test_whisper_full_params(self): 64 | params = pw.whisper_full_params() 65 | return self.assertIsInstance(params.n_threads, int) 66 | 67 | 68 | if __name__ == '__main__': 69 | 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | _build/ 4 | _generate/ 5 | *.so 6 | *.dylib 7 | *.py[cod] 8 | *.egg-info 9 | *env* 10 | 11 | 12 | # custom 13 | .idea 14 | _docs 15 | _examples 16 | src/.idea 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | 149 | -------------------------------------------------------------------------------- /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Wheels 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - release 9 | release: 10 | types: 11 | - published 12 | 13 | jobs: 14 | build_sdist: 15 | name: Build SDist 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | submodules: true 21 | 22 | - name: Build SDist 23 | run: pipx run build --sdist 24 | 25 | - name: Check metadata 26 | run: pipx run twine check dist/* 27 | 28 | - uses: actions/upload-artifact@v4 29 | with: 30 | name: artifact-sdist 31 | path: dist/*.tar.gz 32 | 33 | 34 | build_wheels: 35 | name: Wheels on ${{ matrix.os }} 36 | runs-on: ${{ matrix.os }} 37 | strategy: 38 | fail-fast: false 39 | matrix: 40 | os: [ubuntu-24.04-arm, ubuntu-latest, windows-2022, macos-14] 41 | 42 | steps: 43 | - uses: actions/checkout@v4 44 | with: 45 | submodules: true 46 | 47 | # Used to host cibuildwheel 48 | - uses: actions/setup-python@v5 49 | 50 | - name: Install cibuildwheel 51 | run: python -m pip install cibuildwheel 52 | 53 | - name: Build wheels 54 | run: python -m cibuildwheel --output-dir wheelhouse 55 | env: 56 | CIBW_ARCHS: auto 57 | # for windows setup.py repairwheel step should solve it 58 | CIBW_SKIP: pp* cp38-* 59 | # Whisper.cpp tries to use BMI2 on 32 bit Windows, so disable BMI2 when building on Windows to avoid that bug. See https://github.com/ggml-org/whisper.cpp/pull/3543 60 | CIBW_ENVIRONMENT: CMAKE_ARGS="${{ contains(matrix.os, 'arm') && '-DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a' || ''}} ${{ contains(matrix.os, 'windows') && '-DGGML_BMI2=OFF' || '' }}" 61 | 62 | - name: Verify clean directory 63 | run: git diff --exit-code 64 | shell: bash 65 | 66 | - name: List files 67 | run: ls wheelhouse 68 | 69 | - name: Upload wheels 70 | uses: actions/upload-artifact@v4 71 | with: 72 | name: artifact-${{ matrix.os }} 73 | path: wheelhouse/*.whl 74 | 75 | test_wheels: 76 | name: Test wheels on ${{ matrix.os }} (Python ${{ matrix.python-version }}) 77 | runs-on: ${{ matrix.os }} 78 | needs: build_wheels 79 | strategy: 80 | fail-fast: false 81 | matrix: 82 | os: [ubuntu-latest, windows-latest, macos-latest, ubuntu-24.04-arm] 83 | python-version: [3.11, 3.12, 3.13] 84 | 85 | steps: 86 | - uses: actions/checkout@v4 87 | with: 88 | submodules: true 89 | 90 | - uses: actions/download-artifact@v4 91 | with: 92 | pattern: artifact-* 93 | merge-multiple: true 94 | path: wheelhouse 95 | 96 | - name: Verify artifact download 97 | run: | 98 | ls -l wheelhouse 99 | 100 | - name: Set up Python 101 | uses: actions/setup-python@v5 102 | with: 103 | python-version: ${{ matrix.python-version }} 104 | 105 | - name: Install dependencies 106 | run: | 107 | pip install -r requirements.txt 108 | pip install pytest 109 | 110 | - name: Install Wheel 111 | run: | 112 | pip install --no-index --find-links=./wheelhouse pywhispercpp 113 | 114 | - name: Run tests 115 | run: | 116 | pytest tests/ 117 | 118 | 119 | upload_all: 120 | name: Upload if release 121 | needs: [build_wheels, build_sdist] 122 | runs-on: ubuntu-latest 123 | if: github.event_name == 'release' && github.event.action == 'published' 124 | 125 | steps: 126 | - uses: actions/setup-python@v5 127 | with: 128 | python-version: "3.x" 129 | 130 | - uses: actions/download-artifact@v4 131 | with: 132 | pattern: artifact-* 133 | merge-multiple: true 134 | path: dist 135 | 136 | - uses: pypa/gh-action-pypi-publish@release/v1 137 | with: 138 | verbose: true 139 | password: ${{ secrets.PYPI_API_TOKEN }} 140 | -------------------------------------------------------------------------------- /pywhispercpp/examples/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | A simple Command Line Interface to test the package 6 | """ 7 | import argparse 8 | import importlib.metadata 9 | import logging 10 | 11 | import pywhispercpp.constants as constants 12 | 13 | __version__ = importlib.metadata.version('pywhispercpp') 14 | 15 | __header__ = f""" 16 | PyWhisperCpp 17 | A simple Command Line Interface to test the package 18 | Version: {__version__} 19 | ==================================================== 20 | """ 21 | 22 | from pywhispercpp.model import Model 23 | import pywhispercpp.utils as utils 24 | 25 | 26 | def _get_params(args) -> dict: 27 | """ 28 | Helper function to get params from argparse as a `dict` 29 | """ 30 | params = {} 31 | for arg in args.__dict__: 32 | if arg in constants.PARAMS_SCHEMA.keys() and getattr(args, arg) is not None: 33 | if constants.PARAMS_SCHEMA[arg]['type'] is bool: 34 | if getattr(args, arg).lower() == 'false': 35 | params[arg] = False 36 | else: 37 | params[arg] = True 38 | else: 39 | params[arg] = constants.PARAMS_SCHEMA[arg]['type'](getattr(args, arg)) 40 | return params 41 | 42 | 43 | def run(args): 44 | logging.info(f"Running with model `{args.model}`") 45 | params = _get_params(args) 46 | logging.info(f"Running with params {params}") 47 | m = Model(model=args.model, **params) 48 | logging.info(f"System info: n_threads = {m.get_params()['n_threads']} | Processors = {args.processors} " 49 | f"| {m.system_info()}") 50 | for file in args.media_file: 51 | segs = [] 52 | try: 53 | logging.info(f"Processing file {file} ...") 54 | m.transcribe(file, 55 | n_processors=int(args.processors) if args.processors else None, 56 | new_segment_callback=lambda seg: segs.append(seg) 57 | ) 58 | m.print_timings() 59 | except KeyboardInterrupt: 60 | logging.info("Transcription manually stopped") 61 | break 62 | except Exception as e: 63 | logging.error(f"Error while processing file {file}: {e}") 64 | finally: 65 | if segs: 66 | # output stuff 67 | if args.output_txt: 68 | logging.info(f"Saving result as a txt file ...") 69 | txt_file = utils.output_txt(segs, file) 70 | logging.info(f"txt file saved to {txt_file}") 71 | if args.output_vtt: 72 | logging.info(f"Saving results as a vtt file ...") 73 | vtt_file = utils.output_vtt(segs, file) 74 | logging.info(f"vtt file saved to {vtt_file}") 75 | if args.output_srt: 76 | logging.info(f"Saving results as a srt file ...") 77 | srt_file = utils.output_srt(segs, file) 78 | logging.info(f"srt file saved to {srt_file}") 79 | if args.output_csv: 80 | logging.info(f"Saving results as a csv file ...") 81 | csv_file = utils.output_csv(segs, file) 82 | logging.info(f"csv file saved to {csv_file}") 83 | 84 | 85 | def main(): 86 | print(__header__) 87 | parser = argparse.ArgumentParser(description="", allow_abbrev=True) 88 | # Positional args 89 | parser.add_argument('media_file', type=str, nargs='+', help="The path of the media file or a list of files" 90 | "separated by space") 91 | 92 | parser.add_argument('-m', '--model', default='tiny', help="Path to the `ggml` model, or just the model name") 93 | 94 | parser.add_argument('--version', action='version', version=f'%(prog)s {__version__}') 95 | parser.add_argument('--processors', help="number of processors to use during computation") 96 | parser.add_argument('-otxt', '--output-txt', action='store_true', help="output result in a text file") 97 | parser.add_argument('-ovtt', '--output-vtt', action='store_true', help="output result in a vtt file") 98 | parser.add_argument('-osrt', '--output-srt', action='store_true', help="output result in a srt file") 99 | parser.add_argument('-ocsv', '--output-csv', action='store_true', help="output result in a CSV file") 100 | 101 | # add params from PARAMS_SCHEMA 102 | for param in constants.PARAMS_SCHEMA: 103 | param_fields = constants.PARAMS_SCHEMA[param] 104 | parser.add_argument(f'--{param}', 105 | help=f'{param_fields["description"]}') 106 | 107 | args = parser.parse_args() 108 | run(args) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /pywhispercpp/examples/assistant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | A simple example showcasing the use of `pywhispercpp` as an assistant. 6 | The idea is to use a `VAD` to detect speech (in this example we used webrtcvad), and when speech is detected 7 | we run the inference. 8 | """ 9 | import argparse 10 | import importlib.metadata 11 | import queue 12 | import time 13 | from typing import Callable 14 | import numpy as np 15 | import sounddevice as sd 16 | import pywhispercpp.constants as constants 17 | import webrtcvad 18 | import logging 19 | from pywhispercpp.model import Model 20 | 21 | __version__ = importlib.metadata.version('pywhispercpp') 22 | 23 | __header__ = f""" 24 | ===================================== 25 | PyWhisperCpp 26 | A simple assistant using Whisper.cpp 27 | Version: {__version__} 28 | ===================================== 29 | """ 30 | 31 | 32 | class Assistant: 33 | """ 34 | Assistant class 35 | 36 | Example usage 37 | ```python 38 | from pywhispercpp.examples.assistant import Assistant 39 | 40 | my_assistant = Assistant(commands_callback=print, n_threads=8) 41 | my_assistant.start() 42 | ``` 43 | """ 44 | 45 | def __init__(self, 46 | model='tiny', 47 | input_device: int = None, 48 | silence_threshold: int = 8, 49 | q_threshold: int = 16, 50 | block_duration: int = 30, 51 | commands_callback: Callable[[str], None] = None, 52 | **model_params): 53 | 54 | """ 55 | :param model: whisper.cpp model name or a direct path to a`ggml` model 56 | :param input_device: The input device (aka microphone), keep it None to take the default 57 | :param silence_threshold: The duration of silence after which the inference will be running 58 | :param q_threshold: The inference won't be running until the data queue is having at least `q_threshold` elements 59 | :param block_duration: minimum time audio updates in ms 60 | :param commands_callback: The callback to run when a command is received 61 | :param model_log_level: Logging level 62 | :param model_params: any other parameter to pass to the whsiper.cpp model see ::: pywhispercpp.constants.PARAMS_SCHEMA 63 | """ 64 | 65 | self.input_device = input_device 66 | self.sample_rate = constants.WHISPER_SAMPLE_RATE # same as whisper.cpp 67 | self.channels = 1 # same as whisper.cpp 68 | self.block_duration = block_duration 69 | self.block_size = int(self.sample_rate * self.block_duration / 1000) 70 | self.q = queue.Queue() 71 | 72 | self.vad = webrtcvad.Vad() 73 | self.silence_threshold = silence_threshold 74 | self.q_threshold = q_threshold 75 | self._silence_counter = 0 76 | 77 | self.pwccp_model = Model(model, 78 | print_realtime=False, 79 | print_progress=False, 80 | print_timestamps=False, 81 | single_segment=True, 82 | no_context=True, 83 | **model_params) 84 | self.commands_callback = commands_callback 85 | 86 | def _audio_callback(self, indata, frames, time, status): 87 | """ 88 | This is called (from a separate thread) for each audio block. 89 | """ 90 | if status: 91 | logging.warning(F"underlying audio stack warning:{status}") 92 | 93 | assert frames == self.block_size 94 | audio_data = map(lambda x: (x + 1) / 2, indata) # normalize from [-1,+1] to [0,1] 95 | audio_data = np.fromiter(audio_data, np.float16) 96 | audio_data = audio_data.tobytes() 97 | detection = self.vad.is_speech(audio_data, self.sample_rate) 98 | if detection: 99 | self.q.put(indata.copy()) 100 | self._silence_counter = 0 101 | else: 102 | if self._silence_counter >= self.silence_threshold: 103 | if self.q.qsize() > self.q_threshold: 104 | self._transcribe_speech() 105 | self._silence_counter = 0 106 | else: 107 | self._silence_counter += 1 108 | 109 | def _transcribe_speech(self): 110 | logging.info(f"Speech detected ...") 111 | audio_data = np.array([]) 112 | while self.q.qsize() > 0: 113 | # get all the data from the q 114 | audio_data = np.append(audio_data, self.q.get()) 115 | # Appending zeros to the audio data as a workaround for small audio packets (small commands) 116 | audio_data = np.concatenate([audio_data, np.zeros((int(self.sample_rate) + 10))]) 117 | # running the inference 118 | self.pwccp_model.transcribe(audio_data, 119 | new_segment_callback=self._new_segment_callback) 120 | 121 | def _new_segment_callback(self, seg): 122 | if self.commands_callback: 123 | self.commands_callback(seg.text) 124 | 125 | def start(self) -> None: 126 | """ 127 | Use this function to start the assistant 128 | :return: None 129 | """ 130 | logging.info(f"Starting Assistant ...") 131 | with sd.InputStream( 132 | device=self.input_device, # the default input device 133 | channels=self.channels, 134 | samplerate=constants.WHISPER_SAMPLE_RATE, 135 | blocksize=self.block_size, 136 | callback=self._audio_callback): 137 | 138 | try: 139 | logging.info(f"Assistant is listening ... (CTRL+C to stop)") 140 | while True: 141 | time.sleep(0.1) 142 | except KeyboardInterrupt: 143 | logging.info("Assistant stopped") 144 | 145 | @staticmethod 146 | def available_devices(): 147 | return sd.query_devices() 148 | 149 | 150 | def _main(): 151 | parser = argparse.ArgumentParser(description="", allow_abbrev=True) 152 | # Positional args 153 | parser.add_argument('-m', '--model', default='tiny.en', type=str, help="Whisper.cpp model, default to %(default)s") 154 | parser.add_argument('-ind', '--input_device', type=int, default=None, 155 | help=f'Id of The input device (aka microphone)\n' 156 | f'available devices {Assistant.available_devices()}') 157 | parser.add_argument('-st', '--silence_threshold', default=16, type=int, 158 | help=f"he duration of silence after which the inference will be running, default to %(default)s") 159 | parser.add_argument('-bd', '--block_duration', default=30, 160 | help=f"minimum time audio updates in ms, default to %(default)s") 161 | 162 | args = parser.parse_args() 163 | 164 | my_assistant = Assistant(model=args.model, 165 | input_device=args.input_device, 166 | silence_threshold=args.silence_threshold, 167 | block_duration=args.block_duration, 168 | commands_callback=print) 169 | my_assistant.start() 170 | 171 | 172 | if __name__ == '__main__': 173 | _main() 174 | -------------------------------------------------------------------------------- /pywhispercpp/examples/livestream.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Quick and dirty realtime livestream transcription. 6 | 7 | Not fully satisfying though :) 8 | You are welcome to make it better. 9 | """ 10 | import argparse 11 | import logging 12 | import queue 13 | from multiprocessing import Process 14 | import ffmpeg 15 | import numpy as np 16 | import pywhispercpp.constants as constants 17 | import sounddevice as sd 18 | from pywhispercpp.model import Model 19 | import importlib.metadata 20 | 21 | 22 | __version__ = importlib.metadata.version('pywhispercpp') 23 | 24 | __header__ = f""" 25 | ======================================================== 26 | PyWhisperCpp 27 | A simple Livestream transcription, based on whisper.cpp 28 | Version: {__version__} 29 | ======================================================== 30 | """ 31 | 32 | class LiveStream: 33 | """ 34 | LiveStream class 35 | 36 | ???+ note 37 | 38 | It heavily depends on the machine power, the processor will jump quickly to 100% with the wrong parameters. 39 | 40 | Example usage 41 | ```python 42 | from pywhispercpp.examples.livestream import LiveStream 43 | 44 | url = "" # Make sure it is a direct stream URL 45 | ls = LiveStream(url=url, n_threads=4) 46 | ls.start() 47 | ``` 48 | """ 49 | 50 | def __init__(self, 51 | url, 52 | model='tiny.en', 53 | block_size: int = 1024, 54 | buffer_size: int = 20, 55 | sample_size: int = 4, 56 | output_device: int = None, 57 | model_log_level=logging.CRITICAL, 58 | **model_params): 59 | 60 | """ 61 | :param url: Live stream url 62 | :param model: whisper.cpp model 63 | :param block_size: block size, default to 1024 64 | :param buffer_size: number of blocks used for buffering, default to 20 65 | :param sample_size: sample size 66 | :param output_device: the output device, aka the speaker, leave it None to take the default 67 | :param model_log_level: logging level 68 | :param model_params: any other whisper.cpp params 69 | """ 70 | self.url = url 71 | self.block_size = block_size 72 | self.buffer_size = buffer_size 73 | self.sample_size = sample_size 74 | self.output_device = output_device 75 | 76 | self.channels = 1 77 | self.samplerate = constants.WHISPER_SAMPLE_RATE 78 | 79 | self.q = queue.Queue(maxsize=buffer_size) 80 | self.audio_data = np.array([]) 81 | 82 | self.pwccp_model = Model(model, 83 | log_level=model_log_level, 84 | print_realtime=True, 85 | print_progress=False, 86 | print_timestamps=False, 87 | single_segment=True, 88 | **model_params) 89 | 90 | def _transcribe_process(self): 91 | self.pwccp_model.transcribe(self.audio_data, n_processors=None) 92 | 93 | def _audio_callback(self, outdata, frames, time, status): 94 | assert frames == self.block_size 95 | if status.output_underflow: 96 | logging.error('Output underflow: increase blocksize?') 97 | raise sd.CallbackAbort 98 | assert not status 99 | try: 100 | data = self.q.get_nowait() 101 | except queue.Empty as e: 102 | logging.error('Buffer is empty: increase buffer_size?') 103 | raise sd.CallbackAbort from e 104 | assert len(data) == len(outdata) 105 | outdata[:] = data 106 | audio = np.frombuffer(data[:], np.float32) 107 | audio = audio.reshape((audio.size, 1)) / 2 ** 5 108 | self.audio_data = np.append(self.audio_data, audio) 109 | if self.audio_data.size > self.samplerate: 110 | # Create a separate process for transcription 111 | p1 = Process(target=self._transcribe_process,) 112 | p1.start() 113 | self.audio_data = np.array([]) 114 | 115 | def start(self): 116 | process = ffmpeg.input(self.url).output( 117 | 'pipe:', 118 | format='f32le', 119 | acodec='pcm_f32le', 120 | ac=self.channels, 121 | ar=self.samplerate, 122 | loglevel='quiet', 123 | ).run_async(pipe_stdout=True) 124 | 125 | out_stream = sd.RawOutputStream( 126 | device=self.output_device, 127 | samplerate=self.samplerate, 128 | blocksize=self.block_size, 129 | channels=self.channels, 130 | dtype='float32', 131 | callback=self._audio_callback) 132 | 133 | read_size = self.block_size * self.channels * self.sample_size 134 | 135 | logging.info('Buffering ...') 136 | for _ in range(self.buffer_size): 137 | self.q.put_nowait(process.stdout.read(read_size)) 138 | 139 | with out_stream: 140 | logging.info('Starting Playback ... (CTRL+C) to stop') 141 | try: 142 | timeout = self.block_size * self.buffer_size / self.samplerate 143 | while True: 144 | buffer_data = process.stdout.read(read_size) 145 | self.q.put(buffer_data, timeout=timeout) 146 | except KeyboardInterrupt: 147 | logging.info("Interrupted!") 148 | 149 | @staticmethod 150 | def available_devices(): 151 | return sd.query_devices() 152 | 153 | 154 | def _main(): 155 | print(__header__) 156 | parser = argparse.ArgumentParser(description="", allow_abbrev=True) 157 | # Positional args 158 | parser.add_argument('url', type=str, help=f"Stream URL") 159 | 160 | parser.add_argument('-nt', '--n_threads', type=int, default=3, 161 | help="number of threads, default to %(default)s") 162 | parser.add_argument('-m', '--model', default='tiny.en', type=str, help="Whisper.cpp model, default to %(default)s") 163 | parser.add_argument('-od', '--output_device', type=int, default=None, 164 | help=f'the output device, aka the speaker, leave it None to take the default\n' 165 | f'available devices {LiveStream.available_devices()}') 166 | parser.add_argument('-bls', '--block_size', type=int, default=1024, 167 | help=f"block size, default to %(default)s") 168 | parser.add_argument('-bus', '--buffer_size', type=int, default=20, 169 | help=f"number of blocks used for buffering, default to %(default)s") 170 | parser.add_argument('-ss', '--sample_size', type=int, default=4, 171 | help=f"Sample size, default to %(default)s") 172 | args = parser.parse_args() 173 | 174 | 175 | # url = "http://n03.radiojar.com/t2n88q0st5quv?rj-ttl=5&rj-tok=AAABhsR2u6MAYFxz69dJ6eQnww" # VOA english 176 | ls = LiveStream(url=args.url, 177 | model=args.model, 178 | block_size=args.block_size, 179 | buffer_size=args.buffer_size, 180 | sample_size=args.sample_size, 181 | output_device=args.output_device, 182 | n_threads=args.n_threads) 183 | ls.start() 184 | 185 | 186 | if __name__ == '__main__': 187 | _main() 188 | -------------------------------------------------------------------------------- /pywhispercpp/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Helper functions 5 | """ 6 | 7 | import contextlib 8 | import logging 9 | import os 10 | import sys 11 | from pathlib import Path 12 | from typing import TextIO 13 | 14 | import requests 15 | from tqdm import tqdm 16 | 17 | from pywhispercpp.constants import ( 18 | AVAILABLE_MODELS, 19 | MODELS_BASE_URL, 20 | MODELS_DIR, 21 | MODELS_PREFIX_URL, 22 | ) 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def _get_model_url(model_name: str) -> str: 28 | """ 29 | Returns the url of the `ggml` model 30 | :param model_name: name of the model 31 | :return: URL of the model 32 | """ 33 | return f"{MODELS_BASE_URL}/{MODELS_PREFIX_URL}-{model_name}.bin" 34 | 35 | 36 | def download_model(model_name: str, download_dir=None, chunk_size=1024) -> str: 37 | """ 38 | Helper function to download the `ggml` models 39 | :param model_name: name of the model, one of ::: constants.AVAILABLE_MODELS 40 | :param download_dir: Where to store the models 41 | :param chunk_size: size of the download chunk 42 | 43 | :return: Absolute path of the downloaded model 44 | """ 45 | if model_name not in AVAILABLE_MODELS: 46 | logger.error(f"Invalid model name `{model_name}`, available models are: {AVAILABLE_MODELS}") 47 | return 48 | if download_dir is None: 49 | download_dir = MODELS_DIR 50 | logger.info(f"No download directory was provided, models will be downloaded to {download_dir}") 51 | 52 | os.makedirs(download_dir, exist_ok=True) 53 | 54 | url = _get_model_url(model_name=model_name) 55 | file_path = Path(download_dir) / os.path.basename(url) 56 | # check if the file is already there 57 | if file_path.exists(): 58 | logger.info(f"Model {model_name} already exists in {download_dir}") 59 | else: 60 | # download it from huggingface 61 | resp = requests.get(url, stream=True) 62 | total = int(resp.headers.get('content-length', 0)) 63 | 64 | progress_bar = tqdm(desc=f"Downloading Model {model_name} ...", 65 | total=total, 66 | unit='iB', 67 | unit_scale=True, 68 | unit_divisor=1024) 69 | 70 | try: 71 | with open(file_path, 'wb') as file, progress_bar: 72 | for data in resp.iter_content(chunk_size=chunk_size): 73 | size = file.write(data) 74 | progress_bar.update(size) 75 | logger.info(f"Model downloaded to {file_path.absolute()}") 76 | except Exception as e: 77 | # error download, just remove the file 78 | os.remove(file_path) 79 | raise e 80 | return str(file_path.absolute()) 81 | 82 | 83 | def to_timestamp(t: int, separator=',') -> str: 84 | """ 85 | 376 -> 00:00:03,760 86 | 1344 -> 00:00:13,440 87 | 88 | Implementation from `whisper.cpp/examples/main` 89 | 90 | :param t: input time from whisper timestamps 91 | :param separator: seprator between seconds and milliseconds 92 | :return: time representation in hh: mm: ss[separator]ms 93 | """ 94 | # logic exactly from whisper.cpp 95 | 96 | msec = t * 10 97 | hr = msec // (1000 * 60 * 60) 98 | msec = msec - hr * (1000 * 60 * 60) 99 | min = msec // (1000 * 60) 100 | msec = msec - min * (1000 * 60) 101 | sec = msec // 1000 102 | msec = msec - sec * 1000 103 | return f"{int(hr):02,.0f}:{int(min):02,.0f}:{int(sec):02,.0f}{separator}{int(msec):03,.0f}" 104 | 105 | 106 | def output_txt(segments: list, output_file_path: str) -> str: 107 | """ 108 | Creates a raw text from a list of segments 109 | 110 | Implementation from `whisper.cpp/examples/main` 111 | 112 | :param segments: list of segments 113 | :return: path of the file 114 | """ 115 | if not output_file_path.endswith('.txt'): 116 | output_file_path = output_file_path + '.txt' 117 | 118 | absolute_path = Path(output_file_path).absolute() 119 | 120 | with open(str(absolute_path), 'w') as file: 121 | for seg in segments: 122 | file.write(seg.text) 123 | file.write('\n') 124 | return absolute_path 125 | 126 | 127 | def output_vtt(segments: list, output_file_path: str) -> str: 128 | """ 129 | Creates a vtt file from a list of segments 130 | 131 | Implementation from `whisper.cpp/examples/main` 132 | 133 | :param segments: list of segments 134 | :return: path of the file 135 | 136 | :return: Absolute path of the file 137 | """ 138 | if not output_file_path.endswith('.vtt'): 139 | output_file_path = output_file_path + '.vtt' 140 | 141 | absolute_path = Path(output_file_path).absolute() 142 | 143 | with open(absolute_path, 'w') as file: 144 | file.write("WEBVTT\n\n") 145 | for seg in segments: 146 | file.write(f"{to_timestamp(seg.t0, separator='.')} --> {to_timestamp(seg.t1, separator='.')}\n") 147 | file.write(f"{seg.text}\n\n") 148 | return absolute_path 149 | 150 | 151 | def output_srt(segments: list, output_file_path: str) -> str: 152 | """ 153 | Creates a srt file from a list of segments 154 | 155 | :param segments: list of segments 156 | :return: path of the file 157 | 158 | :return: Absolute path of the file 159 | """ 160 | if not output_file_path.endswith('.srt'): 161 | output_file_path = output_file_path + '.srt' 162 | 163 | absolute_path = Path(output_file_path).absolute() 164 | 165 | with open(absolute_path, 'w') as file: 166 | for i in range(len(segments)): 167 | seg = segments[i] 168 | file.write(f"{i+1}\n") 169 | file.write(f"{to_timestamp(seg.t0, separator=',')} --> {to_timestamp(seg.t1, separator=',')}\n") 170 | file.write(f"{seg.text}\n\n") 171 | return absolute_path 172 | 173 | 174 | def output_csv(segments: list, output_file_path: str) -> str: 175 | """ 176 | Creates a srt file from a list of segments 177 | 178 | :param segments: list of segments 179 | :return: path of the file 180 | 181 | :return: Absolute path of the file 182 | """ 183 | if not output_file_path.endswith('.csv'): 184 | output_file_path = output_file_path + '.csv' 185 | 186 | absolute_path = Path(output_file_path).absolute() 187 | 188 | with open(absolute_path, 'w') as file: 189 | for seg in segments: 190 | file.write(f"{10 * seg.t0}, {10 * seg.t1}, \"{seg.text}\"\n") 191 | return absolute_path 192 | 193 | 194 | @contextlib.contextmanager 195 | def redirect_stderr(to: bool | TextIO | str | None = False) -> None: 196 | """ 197 | Redirect stderr to the specified target. 198 | 199 | :param to: 200 | - None to suppress output (redirect to devnull), 201 | - sys.stdout to redirect to stdout, 202 | - A file path (str) to redirect to a file, 203 | - False to do nothing (no redirection). 204 | """ 205 | 206 | if to is False: 207 | # do nothing 208 | yield 209 | return 210 | 211 | def _resolve_target(target): 212 | opened_stream = None 213 | if target is None: 214 | opened_stream = open(os.devnull, "w") 215 | return opened_stream, True 216 | if isinstance(target, str): 217 | opened_stream = open(target, "w") 218 | return opened_stream, True 219 | if hasattr(target, "write"): 220 | return target, False 221 | raise ValueError( 222 | "Invalid `to` parameter; expected None, a filepath string, or a file-like object." 223 | ) 224 | 225 | sys.stderr.flush() 226 | try: 227 | original_fd = sys.stderr.fileno() 228 | except (AttributeError, OSError): 229 | # Jupyter or non-standard stderr implementations 230 | original_fd = None 231 | 232 | stream, should_close = _resolve_target(to) 233 | 234 | if original_fd is not None and hasattr(stream, "fileno"): 235 | saved_fd = os.dup(original_fd) 236 | try: 237 | os.dup2(stream.fileno(), original_fd) 238 | yield 239 | finally: 240 | os.dup2(saved_fd, original_fd) 241 | os.close(saved_fd) 242 | if should_close: 243 | stream.close() 244 | return 245 | 246 | # Fallback: Python-level redirect 247 | try: 248 | with contextlib.redirect_stderr(stream): 249 | yield 250 | finally: 251 | if should_close: 252 | stream.close() 253 | -------------------------------------------------------------------------------- /pywhispercpp/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Constants 6 | """ 7 | from pathlib import Path 8 | from typing import Tuple 9 | 10 | import _pywhispercpp as _pwcpp 11 | from platformdirs import user_data_dir 12 | 13 | 14 | WHISPER_SAMPLE_RATE = _pwcpp.WHISPER_SAMPLE_RATE 15 | # MODELS URL MODELS_BASE_URL+ '/' + MODELS_PREFIX_URL+'-'+MODEL_NAME+'.bin' 16 | # example = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin" 17 | MODELS_BASE_URL = "https://huggingface.co/ggerganov/whisper.cpp" 18 | MODELS_PREFIX_URL = "resolve/main/ggml" 19 | 20 | 21 | PACKAGE_NAME = 'pywhispercpp' 22 | 23 | 24 | MODELS_DIR = Path(user_data_dir(PACKAGE_NAME)) / 'models' 25 | 26 | 27 | AVAILABLE_MODELS = [ 28 | "base", 29 | "base-q5_1", 30 | "base-q8_0", 31 | "base.en", 32 | "base.en-q5_1", 33 | "base.en-q8_0", 34 | "large-v1", 35 | "large-v2", 36 | "large-v2-q5_0", 37 | "large-v2-q8_0", 38 | "large-v3", 39 | "large-v3-q5_0", 40 | "large-v3-turbo", 41 | "large-v3-turbo-q5_0", 42 | "large-v3-turbo-q8_0", 43 | "medium", 44 | "medium-q5_0", 45 | "medium-q8_0", 46 | "medium.en", 47 | "medium.en-q5_0", 48 | "medium.en-q8_0", 49 | "small", 50 | "small-q5_1", 51 | "small-q8_0", 52 | "small.en", 53 | "small.en-q5_1", 54 | "small.en-q8_0", 55 | "tiny", 56 | "tiny-q5_1", 57 | "tiny-q8_0", 58 | "tiny.en", 59 | "tiny.en-q5_1", 60 | "tiny.en-q8_0", 61 | ] 62 | PARAMS_SCHEMA = { # as exactly presented in whisper.cpp 63 | 'n_threads': { 64 | 'type': int, 65 | 'description': "Number of threads to allocate for the inference" 66 | "default to min(4, available hardware_concurrency)", 67 | 'options': None, 68 | 'default': None 69 | }, 70 | 'n_max_text_ctx': { 71 | 'type': int, 72 | 'description': "max tokens to use from past text as prompt for the decoder", 73 | 'options': None, 74 | 'default': 16384 75 | }, 76 | 'offset_ms': { 77 | 'type': int, 78 | 'description': "start offset in ms", 79 | 'options': None, 80 | 'default': 0 81 | }, 82 | 'duration_ms': { 83 | 'type': int, 84 | 'description': "audio duration to process in ms", 85 | 'options': None, 86 | 'default': 0 87 | }, 88 | 'translate': { 89 | 'type': bool, 90 | 'description': "whether to translate the audio to English", 91 | 'options': None, 92 | 'default': False 93 | }, 94 | 'no_context': { 95 | 'type': bool, 96 | 'description': "do not use past transcription (if any) as initial prompt for the decoder", 97 | 'options': None, 98 | 'default': False 99 | }, 100 | 'single_segment': { 101 | 'type': bool, 102 | 'description': "force single segment output (useful for streaming)", 103 | 'options': None, 104 | 'default': False 105 | }, 106 | 'print_special': { 107 | 'type': bool, 108 | 'description': "print special tokens (e.g. , , , etc.)", 109 | 'options': None, 110 | 'default': False 111 | }, 112 | 'print_progress': { 113 | 'type': bool, 114 | 'description': "print progress information", 115 | 'options': None, 116 | 'default': True 117 | }, 118 | 'print_realtime': { 119 | 'type': bool, 120 | 'description': "print results from within whisper.cpp (avoid it, use callback instead)", 121 | 'options': None, 122 | 'default': False 123 | }, 124 | 'print_timestamps': { 125 | 'type': bool, 126 | 'description': "print timestamps for each text segment when printing realtime", 127 | 'options': None, 128 | 'default': True 129 | }, 130 | # [EXPERIMENTAL] token-level timestamps 131 | 'token_timestamps': { 132 | 'type': bool, 133 | 'description': "enable token-level timestamps", 134 | 'options': None, 135 | 'default': False 136 | }, 137 | 'thold_pt': { 138 | 'type': float, 139 | 'description': "timestamp token probability threshold (~0.01)", 140 | 'options': None, 141 | 'default': 0.01 142 | }, 143 | 'thold_ptsum': { 144 | 'type': float, 145 | 'description': "timestamp token sum probability threshold (~0.01)", 146 | 'options': None, 147 | 'default': 0.01 148 | }, 149 | 'max_len': { 150 | 'type': int, 151 | 'description': "max segment length in characters, note: token_timestamps needs to be set to True for this to work", 152 | 'options': None, 153 | 'default': 0 154 | }, 155 | 'split_on_word': { 156 | 'type': bool, 157 | 'description': "split on word rather than on token (when used with max_len)", 158 | 'options': None, 159 | 'default': False 160 | }, 161 | 'max_tokens': { 162 | 'type': int, 163 | 'description': "max tokens per segment (0 = no limit)", 164 | 'options': None, 165 | 'default': 0 166 | }, 167 | 'audio_ctx': { 168 | 'type': int, 169 | 'description': "overwrite the audio context size (0 = use default)", 170 | 'options': None, 171 | 'default': 0 172 | }, 173 | 'initial_prompt': { 174 | 'type': str, 175 | 'description': "Initial prompt, these are prepended to any existing text context from a previous call", 176 | 'options': None, 177 | 'default': None 178 | }, 179 | 'prompt_tokens': { 180 | 'type': Tuple, 181 | 'description': "tokens to provide to the whisper decoder as initial prompt", 182 | 'options': None, 183 | 'default': None 184 | }, 185 | 'prompt_n_tokens': { 186 | 'type': int, 187 | 'description': "tokens to provide to the whisper decoder as initial prompt", 188 | 'options': None, 189 | 'default': 0 190 | }, 191 | 'language': { 192 | 'type': str, 193 | 'description': 'for auto-detection, set to None, "" or "auto"', 194 | 'options': None, 195 | 'default': "" 196 | }, 197 | 'suppress_blank': { 198 | 'type': bool, 199 | 'description': 'common decoding parameters', 200 | 'options': None, 201 | 'default': True 202 | }, 203 | 'suppress_non_speech_tokens': { 204 | 'type': bool, 205 | 'description': 'common decoding parameters', 206 | 'options': None, 207 | 'default': False 208 | }, 209 | 'temperature': { 210 | 'type': float, 211 | 'description': 'initial decoding temperature', 212 | 'options': None, 213 | 'default': 0.0 214 | }, 215 | 'max_initial_ts': { 216 | 'type': float, 217 | 'description': 'max_initial_ts', 218 | 'options': None, 219 | 'default': 1.0 220 | }, 221 | 'length_penalty': { 222 | 'type': float, 223 | 'description': 'length_penalty', 224 | 'options': None, 225 | 'default': -1.0 226 | }, 227 | 'temperature_inc': { 228 | 'type': float, 229 | 'description': 'temperature_inc', 230 | 'options': None, 231 | 'default': 0.2 232 | }, 233 | 'entropy_thold': { 234 | 'type': float, 235 | 'description': 'similar to OpenAI\'s "compression_ratio_threshold"', 236 | 'options': None, 237 | 'default': 2.4 238 | }, 239 | 'logprob_thold': { 240 | 'type': float, 241 | 'description': 'logprob_thold', 242 | 'options': None, 243 | 'default': -1.0 244 | }, 245 | 'no_speech_thold': { # not implemented 246 | 'type': float, 247 | 'description': 'no_speech_thold', 248 | 'options': None, 249 | 'default': 0.6 250 | }, 251 | 'greedy': { 252 | 'type': dict, 253 | 'description': 'greedy', 254 | 'options': None, 255 | 'default': {"best_of": -1} 256 | }, 257 | 'beam_search': { 258 | 'type': dict, 259 | 'description': 'beam_search', 260 | 'options': None, 261 | 'default': {"beam_size": -1, "patience": -1.0} 262 | }, 263 | 'extract_probability': { 264 | 'type': bool, 265 | 'description': 'calculate the geometric mean of token probabilities for each segment.', 266 | 'options': None, 267 | 'default': True 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # This setup.py is used to build the pywhispercpp package. 2 | # The environment variables you may find interesting are: 3 | # 4 | # PYWHISPERCPP_VERSION 5 | # if set, it will be used as the version number. 6 | # 7 | # GGML_VULKAN=1 8 | # if set, whisper.cpp will be build with vulkan support. 9 | # 10 | # WHISPER_COREML=1 11 | # WHISPER_COREML_ALLOW_FALLBACK=1 12 | # if set, whisper.cpp will be build with coreml support which requires special models 13 | # It is best used with WHISPER_COREML_ALLOW_FALLBACK=1 14 | 15 | 16 | import os 17 | import re 18 | import subprocess 19 | import sys 20 | from pathlib import Path 21 | import subprocess 22 | 23 | from setuptools import Extension, setup, find_packages 24 | from setuptools.command.build_ext import build_ext 25 | from setuptools.command.bdist_wheel import bdist_wheel 26 | 27 | # Convert distutils Windows platform specifiers to CMake -A arguments 28 | PLAT_TO_CMAKE = { 29 | "win32": "Win32", 30 | "win-amd64": "x64", 31 | "win-arm32": "ARM", 32 | "win-arm64": "ARM64", 33 | } 34 | 35 | 36 | # A CMakeExtension needs a sourcedir instead of a file list. 37 | # The name must be the _single_ output extension from the CMake build. 38 | # If you need multiple extensions, see scikit-build. 39 | class CMakeExtension(Extension): 40 | def __init__(self, name: str, sourcedir: str = "") -> None: 41 | super().__init__(name, sources=[]) 42 | self.sourcedir = os.fspath(Path(sourcedir).resolve()) 43 | 44 | 45 | dll_folder = 'unset' 46 | 47 | 48 | class CMakeBuild(build_ext): 49 | def build_extension(self, ext: CMakeExtension) -> None: 50 | # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ 51 | ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call] 52 | extdir = ext_fullpath.parent.resolve() 53 | 54 | # Using this requires trailing slash for auto-detection & inclusion of 55 | # auxiliary "native" libs 56 | 57 | debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug 58 | cfg = "Debug" if debug else "Release" 59 | 60 | # CMake lets you override the generator - we need to check this. 61 | # Can be set with Conda-Build, for example. 62 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "") 63 | 64 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON 65 | # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code 66 | # from Python. 67 | cmake_args = [ 68 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", 69 | f"-DPYTHON_EXECUTABLE={sys.executable}", 70 | f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm 71 | ] 72 | if self.editable_mode: 73 | # Platform-specific rpath settings 74 | if sys.platform.startswith('darwin'): 75 | # macOS-specific settings 76 | cmake_args += [ 77 | "-DCMAKE_INSTALL_RPATH=@loader_path", 78 | "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON" 79 | ] 80 | elif sys.platform.startswith('linux'): 81 | # Linux-specific settings 82 | cmake_args += [ 83 | "-DCMAKE_INSTALL_RPATH=$ORIGIN", 84 | "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON" 85 | ] 86 | 87 | build_args = [] 88 | # Adding CMake arguments set as environment variable 89 | # (needed e.g. to build for ARM OSx on conda-forge) 90 | if "CMAKE_ARGS" in os.environ: 91 | cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] 92 | 93 | # In this example, we pass in the version to C++. You might not need to. 94 | cmake_args += [f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] # type: ignore[attr-defined] 95 | 96 | if self.compiler.compiler_type != "msvc": 97 | # Using Ninja-build since it a) is available as a wheel and b) 98 | # multithreads automatically. MSVC would require all variables be 99 | # exported for Ninja to pick it up, which is a little tricky to do. 100 | # Users can override the generator with CMAKE_GENERATOR in CMake 101 | # 3.15+. 102 | if not cmake_generator or cmake_generator == "Ninja": 103 | try: 104 | import ninja 105 | 106 | ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" 107 | cmake_args += [ 108 | "-GNinja", 109 | f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", 110 | ] 111 | except ImportError: 112 | pass 113 | 114 | else: 115 | # Single config generators are handled "normally" 116 | single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) 117 | 118 | # CMake allows an arch-in-generator style for backward compatibility 119 | contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) 120 | 121 | # Specify the arch if using MSVC generator, but only if it doesn't 122 | # contain a backward-compatibility arch spec already in the 123 | # generator name. 124 | if not single_config and not contains_arch: 125 | cmake_args += ["-A", PLAT_TO_CMAKE[self.plat_name]] 126 | 127 | # Multi-config generators have a different way to specify configs 128 | if not single_config: 129 | cmake_args += [ 130 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}" 131 | ] 132 | build_args += ["--config", cfg] 133 | 134 | if sys.platform.startswith("darwin"): 135 | # Cross-compile support for macOS - respect ARCHFLAGS if set 136 | archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) 137 | if archs: 138 | cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] 139 | 140 | # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level 141 | # across all generators. 142 | if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: 143 | # self.parallel is a Python 3 only way to set parallel jobs by hand 144 | # using -j in the build_ext call, not supported by pip or PyPA-build. 145 | if hasattr(self, "parallel") and self.parallel: 146 | # CMake 3.12+ only. 147 | build_args += [f"-j{self.parallel}"] 148 | 149 | build_temp = Path(self.build_temp) / ext.name 150 | if not build_temp.exists(): 151 | build_temp.mkdir(parents=True) 152 | 153 | for key, value in os.environ.items(): 154 | cmake_args.append(f'-D{key}={value}') 155 | 156 | subprocess.run( 157 | ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True 158 | ) 159 | subprocess.run( 160 | ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True 161 | ) 162 | 163 | # store the dll folder in a global variable to use in repairwheel 164 | global dll_folder 165 | cfg = "Debug" if self.debug else "Release" 166 | dll_folder = os.path.join(self.build_temp, '_pywhispercpp', 'bin', cfg) 167 | print("dll_folder in build_extension", dll_folder) 168 | # self.copy_extensions_to_source() 169 | 170 | def copy_extensions_to_source(self): 171 | super().copy_extensions_to_source() 172 | 173 | if self.editable_mode: 174 | build_lib = Path(self.build_lib) 175 | for ext in self.extensions: 176 | extdir = Path(self.get_ext_fullpath(ext.name)).parent.resolve() 177 | # Assuming all shared libraries are in the same directory 178 | shared_lib_files = [*build_lib.glob('**/*.dylib'), *build_lib.glob('**/*.so*')] 179 | for shared_lib in shared_lib_files: 180 | self.copy_file(shared_lib, extdir) 181 | 182 | 183 | # read the contents of your README file 184 | this_directory = Path(__file__).parent 185 | long_description = (this_directory / "README.md").read_text() 186 | 187 | 188 | class RepairWheel(bdist_wheel): 189 | def run(self): 190 | super().run() 191 | if os.environ.get('NO_REPAIR', '0') == '1': 192 | print("Skipping wheel repair") 193 | return 194 | if os.environ.get('CIBUILDWHEEL', '0') == '0' or sys.platform.startswith('win'): 195 | # for linux and macos we use the default wheel repair command from cibuildwheel, for windows we need to do it manually as there is no repair command 196 | self.repair_wheel() 197 | 198 | def repair_wheel(self): 199 | # on windows the dlls are in D:\a\pywhispercpp\pywhispercpp\build\temp.win-amd64-cpython-311\Release\_pywhispercpp\bin\Release\whisper.dll 200 | global dll_folder 201 | print("dll_folder in repairwheel", dll_folder) 202 | print("Files in dll_folder:", *Path(dll_folder).glob('*')) 203 | # build\temp.win-amd64-cpython-311\Release\_pywhispercpp\bin\Release\whisper.dll 204 | 205 | wheel_path = next(Path(self.dist_dir).glob(f"{self.distribution.get_name()}*.whl")) 206 | # Create a temporary directory for the repaired wheel 207 | import tempfile 208 | with tempfile.TemporaryDirectory(prefix='repaired_wheel_') as tmp_dir: 209 | tmp_dir = Path(tmp_dir) 210 | subprocess.call(['repairwheel', wheel_path, '-o', tmp_dir, '-l', dll_folder]) 211 | print("Repaired wheel: ", *tmp_dir.glob('*.whl')) 212 | # We need to glob as repairwheel may change the name of the wheel 213 | # on linux from pywhispercpp-1.2.0-cp312-cp312-linux_aarch64.whl 214 | # to pywhispercpp-1.2.0-cp312-cp312-manylinux_2_34_aarch64.whl 215 | repaired_wheel = next(tmp_dir.glob("*.whl")) 216 | self.copy_file(repaired_wheel, wheel_path) 217 | print(f"Copied repaired wheel to: {wheel_path}") 218 | 219 | 220 | def get_local_version() -> str: 221 | try: 222 | git_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8") 223 | return f"+git{git_sha[:7]}" 224 | except (FileNotFoundError, subprocess.CalledProcessError): 225 | return "" 226 | 227 | 228 | def get_version() -> str: 229 | try: 230 | return os.environ['PYWHISPERCPP_VERSION'] 231 | except KeyError: 232 | pass 233 | with open("version.txt") as f: 234 | version = f.read().strip() 235 | return f"{version}{get_local_version()}" 236 | 237 | 238 | # The information here can also be placed in setup.cfg - better separation of 239 | # logic and declaration, and simpler if you include description/version in a file. 240 | setup( 241 | name="pywhispercpp", 242 | author="absadiki", 243 | description="Python bindings for whisper.cpp", 244 | long_description=long_description, 245 | ext_modules=[CMakeExtension("_pywhispercpp")], 246 | cmdclass={"build_ext": CMakeBuild, 247 | 'bdist_wheel': RepairWheel, }, 248 | zip_safe=False, 249 | # extras_require={"test": ["pytest>=6.0"]}, 250 | python_requires=">=3.8", 251 | packages=find_packages('.'), 252 | package_dir={'': '.'}, 253 | include_package_data=True, 254 | package_data={'pywhispercpp': []}, 255 | long_description_content_type="text/markdown", 256 | license='MIT', 257 | entry_points={ 258 | 'console_scripts': ['pwcpp=pywhispercpp.examples.main:main', 259 | 'pwcpp-assistant=pywhispercpp.examples.assistant:_main', 260 | 'pwcpp-livestream=pywhispercpp.examples.livestream:_main', 261 | 'pwcpp-recording=pywhispercpp.examples.recording:_main', 262 | 'pwcpp-gui=pywhispercpp.examples.gui:_main', ] 263 | }, 264 | project_urls={ 265 | 'Documentation': 'https://absadiki.github.io/pywhispercpp/', 266 | 'Source': 'https://github.com/absadiki/pywhispercpp', 267 | 'Tracker': 'https://github.com/absadiki/pywhispercpp/issues', 268 | }, 269 | install_requires=['numpy', "requests", "tqdm", "platformdirs"], 270 | extras_require={"examples": ["sounddevice", "webrtcvad"], 271 | "gui": ["pyqt5"]}, 272 | ) 273 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pywhispercpp 2 | Python bindings for [whisper.cpp](https://github.com/ggerganov/whisper.cpp) with a simple Pythonic API on top of it. 3 | 4 | [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) 5 | [![Wheels](https://github.com/absadiki/pywhispercpp/actions/workflows/wheels.yml/badge.svg?branch=main&event=push)](https://github.com/absadiki/pywhispercpp/actions/workflows/wheels.yml) 6 | [![PyPi version](https://badgen.net/pypi/v/pywhispercpp)](https://pypi.org/project/pywhispercpp/) 7 | [![Downloads](https://static.pepy.tech/badge/pywhispercpp)](https://pepy.tech/project/pywhispercpp) 8 | 9 | # Table of contents 10 | 11 | * [Installation](#installation) 12 | * [From source](#from-source) 13 | * [Pre-built wheels](#pre-built-wheels) 14 | * [NVIDIA GPU support](#nvidia-gpu-support) 15 | * [CoreML support](#coreml-support) 16 | * [Vulkan support](#vulkan-support) 17 | * [Quick start](#quick-start) 18 | * [Examples](#examples) 19 | * [CLI](#cli) 20 | * [GUI](#gui) 21 | * [Assistant](#assistant) 22 | * [Advanced usage](#advanced-usage) 23 | * [Discussions and contributions](#discussions-and-contributions) 24 | * [License](#license) 25 | 26 | 27 | # Installation 28 | 29 | ### From source 30 | * For the best performance, you need to install the package from source: 31 | ```shell 32 | pip install git+https://github.com/absadiki/pywhispercpp 33 | ``` 34 | ### Pre-built wheels 35 | * Otherwise, Basic Pre-built CPU wheels are available on PYPI 36 | 37 | ```shell 38 | pip install pywhispercpp # or pywhispercpp[examples] to install the extra dependencies needed for the examples 39 | ``` 40 | 41 | [Optional] To transcribe files other than wav, you need to install ffmpeg: 42 | ```shell 43 | # on Ubuntu or Debian 44 | sudo apt update && sudo apt install ffmpeg 45 | 46 | # on Arch Linux 47 | sudo pacman -S ffmpeg 48 | 49 | # on MacOS using Homebrew (https://brew.sh/) 50 | brew install ffmpeg 51 | 52 | # on Windows using Chocolatey (https://chocolatey.org/) 53 | choco install ffmpeg 54 | 55 | # on Windows using Scoop (https://scoop.sh/) 56 | scoop install ffmpeg 57 | ``` 58 | 59 | ### NVIDIA GPU support 60 | To Install the package with CUDA support, make sure you have [cuda](https://developer.nvidia.com/cuda-downloads) installed and use `GGML_CUDA=1`: 61 | 62 | ```shell 63 | GGML_CUDA=1 pip install git+https://github.com/absadiki/pywhispercpp 64 | ``` 65 | ### CoreML support 66 | 67 | Install the package with `WHISPER_COREML=1`: 68 | 69 | ```shell 70 | WHISPER_COREML=1 pip install git+https://github.com/absadiki/pywhispercpp 71 | ``` 72 | 73 | ### Vulkan support 74 | 75 | Install the package with `GGML_VULKAN=1`: 76 | 77 | ```shell 78 | GGML_VULKAN=1 pip install git+https://github.com/absadiki/pywhispercpp 79 | ``` 80 | 81 | ### OpenBLAS support 82 | 83 | If OpenBLAS is installed, you can use `GGML_BLAS=1`. The other flags ensure you're installing fresh with the correct flags, and printing output for sanity checking. 84 | ```shell 85 | GGML_BLAS=1 pip install git+https://github.com/absadiki/pywhispercpp --no-cache --force-reinstall -v 86 | ``` 87 | 88 | ### OpenVINO support 89 | 90 | Follow the the steps to download correct OpenVINO package (https://github.com/ggerganov/whisper.cpp?tab=readme-ov-file#openvino-support). 91 | 92 | Then init the OpenVINO environment and build. 93 | ``` 94 | source ~/l_openvino_toolkit_ubuntu22_2023.0.0.10926.b4452d56304_x86_64/setupvars.sh 95 | WHISPER_OPENVINO=1 pip install git+https://github.com/absadiki/pywhispercpp --no-cache --force-reinstall 96 | ``` 97 | 98 | Note that the toolkit for Ubuntu22 works on Ubuntu24 99 | 100 | 101 | ** __Feel free to update this list and submit a PR if you tested the package on other backends.__ 102 | 103 | 104 | # Quick start 105 | 106 | ```python 107 | from pywhispercpp.model import Model 108 | 109 | model = Model('base.en') 110 | segments = model.transcribe('file.wav') 111 | for segment in segments: 112 | print(segment.text) 113 | ``` 114 | 115 | You can also assign a custom `new_segment_callback` 116 | 117 | ```python 118 | from pywhispercpp.model import Model 119 | 120 | model = Model('base.en', print_realtime=False, print_progress=False) 121 | segments = model.transcribe('file.mp3', new_segment_callback=print) 122 | ``` 123 | 124 | 125 | * The model will be downloaded automatically, or you can use the path to a local model. 126 | * You can pass any `whisper.cpp` [parameter](https://absadiki.github.io/pywhispercpp/#pywhispercpp.constants.PARAMS_SCHEMA) as a keyword argument to the `Model` class or to the `transcribe` function. 127 | * Check the [Model](https://absadiki.github.io/pywhispercpp/#pywhispercpp.model.Model) class documentation for more details. 128 | 129 | # Examples 130 | 131 | ## CLI 132 | Just a straightforward example Command Line Interface. 133 | You can use it as follows: 134 | 135 | ```shell 136 | pwcpp file.wav -m base --output-srt --print_realtime true 137 | ``` 138 | Run ```pwcpp --help``` to get the help message 139 | 140 | ```shell 141 | usage: pwcpp [-h] [-m MODEL] [--version] [--processors PROCESSORS] [-otxt] [-ovtt] [-osrt] [-ocsv] [--strategy STRATEGY] 142 | [--n_threads N_THREADS] [--n_max_text_ctx N_MAX_TEXT_CTX] [--offset_ms OFFSET_MS] [--duration_ms DURATION_MS] 143 | [--translate TRANSLATE] [--no_context NO_CONTEXT] [--single_segment SINGLE_SEGMENT] [--print_special PRINT_SPECIAL] 144 | [--print_progress PRINT_PROGRESS] [--print_realtime PRINT_REALTIME] [--print_timestamps PRINT_TIMESTAMPS] 145 | [--token_timestamps TOKEN_TIMESTAMPS] [--thold_pt THOLD_PT] [--thold_ptsum THOLD_PTSUM] [--max_len MAX_LEN] 146 | [--split_on_word SPLIT_ON_WORD] [--max_tokens MAX_TOKENS] [--audio_ctx AUDIO_CTX] 147 | [--prompt_tokens PROMPT_TOKENS] [--prompt_n_tokens PROMPT_N_TOKENS] [--language LANGUAGE] [--suppress_blank SUPPRESS_BLANK] 148 | [--suppress_non_speech_tokens SUPPRESS_NON_SPEECH_TOKENS] [--temperature TEMPERATURE] [--max_initial_ts MAX_INITIAL_TS] 149 | [--length_penalty LENGTH_PENALTY] [--temperature_inc TEMPERATURE_INC] [--entropy_thold ENTROPY_THOLD] 150 | [--logprob_thold LOGPROB_THOLD] [--no_speech_thold NO_SPEECH_THOLD] [--greedy GREEDY] [--beam_search BEAM_SEARCH] 151 | media_file [media_file ...] 152 | 153 | positional arguments: 154 | media_file The path of the media file or a list of filesseparated by space 155 | 156 | options: 157 | -h, --help show this help message and exit 158 | -m MODEL, --model MODEL 159 | Path to the `ggml` model, or just the model name 160 | --version show program's version number and exit 161 | --processors PROCESSORS 162 | number of processors to use during computation 163 | -otxt, --output-txt output result in a text file 164 | -ovtt, --output-vtt output result in a vtt file 165 | -osrt, --output-srt output result in a srt file 166 | -ocsv, --output-csv output result in a CSV file 167 | --strategy STRATEGY Available sampling strategiesGreefyDecoder -> 0BeamSearchDecoder -> 1 168 | --n_threads N_THREADS 169 | Number of threads to allocate for the inferencedefault to min(4, available hardware_concurrency) 170 | --n_max_text_ctx N_MAX_TEXT_CTX 171 | max tokens to use from past text as prompt for the decoder 172 | --offset_ms OFFSET_MS 173 | start offset in ms 174 | --duration_ms DURATION_MS 175 | audio duration to process in ms 176 | --translate TRANSLATE 177 | whether to translate the audio to English 178 | --no_context NO_CONTEXT 179 | do not use past transcription (if any) as initial prompt for the decoder 180 | --single_segment SINGLE_SEGMENT 181 | force single segment output (useful for streaming) 182 | --print_special PRINT_SPECIAL 183 | print special tokens (e.g. , , , etc.) 184 | --print_progress PRINT_PROGRESS 185 | print progress information 186 | --print_realtime PRINT_REALTIME 187 | print results from within whisper.cpp (avoid it, use callback instead) 188 | --print_timestamps PRINT_TIMESTAMPS 189 | print timestamps for each text segment when printing realtime 190 | --token_timestamps TOKEN_TIMESTAMPS 191 | enable token-level timestamps 192 | --thold_pt THOLD_PT timestamp token probability threshold (~0.01) 193 | --thold_ptsum THOLD_PTSUM 194 | timestamp token sum probability threshold (~0.01) 195 | --max_len MAX_LEN max segment length in characters 196 | --split_on_word SPLIT_ON_WORD 197 | split on word rather than on token (when used with max_len) 198 | --max_tokens MAX_TOKENS 199 | max tokens per segment (0 = no limit) 200 | --audio_ctx AUDIO_CTX 201 | overwrite the audio context size (0 = use default) 202 | --prompt_tokens PROMPT_TOKENS 203 | tokens to provide to the whisper decoder as initial prompt 204 | --prompt_n_tokens PROMPT_N_TOKENS 205 | tokens to provide to the whisper decoder as initial prompt 206 | --language LANGUAGE for auto-detection, set to None, "" or "auto" 207 | --suppress_blank SUPPRESS_BLANK 208 | common decoding parameters 209 | --suppress_non_speech_tokens SUPPRESS_NON_SPEECH_TOKENS 210 | common decoding parameters 211 | --temperature TEMPERATURE 212 | initial decoding temperature 213 | --max_initial_ts MAX_INITIAL_TS 214 | max_initial_ts 215 | --length_penalty LENGTH_PENALTY 216 | length_penalty 217 | --temperature_inc TEMPERATURE_INC 218 | temperature_inc 219 | --entropy_thold ENTROPY_THOLD 220 | similar to OpenAI's "compression_ratio_threshold" 221 | --logprob_thold LOGPROB_THOLD 222 | logprob_thold 223 | --no_speech_thold NO_SPEECH_THOLD 224 | no_speech_thold 225 | --greedy GREEDY greedy 226 | --beam_search BEAM_SEARCH 227 | beam_search 228 | 229 | ``` 230 | 231 | ## GUI 232 | If you prefer a Graphical User Interface, you can use the `pwcpp-gui` command which will launch A simple graphical interface built with PyQt5. 233 | * First you need to install the GUI dependencies: 234 | ```bash 235 | pip install pywhispercpp[gui] 236 | ``` 237 | 238 | * Then you can run the GUI with: 239 | ```bash 240 | pwcpp-gui 241 | ``` 242 | 243 | The GUI provides a user-friendly way to: 244 | 245 | - Select audio files 246 | - Choose models 247 | - Adjust basic transcription settings 248 | - View and export transcription results 249 | 250 | ## Assistant 251 | This is a simple example showcasing the use of `pywhispercpp` to create an assistant like example. 252 | The idea is to use a Voice Activity Detector (VAD) to detect speech (in this example, we used webrtcvad), and when some speech is detected, we run the transcription. 253 | It is inspired from the [whisper.cpp/examples/command](https://github.com/ggerganov/whisper.cpp/tree/master/examples/command) example. 254 | 255 | You can check the source code [here](https://github.com/absadiki/pywhispercpp/blob/main/pywhispercpp/examples/assistant.py) 256 | or you can use the class directly to create your own assistant: 257 | 258 | 259 | ```python 260 | from pywhispercpp.examples.assistant import Assistant 261 | 262 | my_assistant = Assistant(commands_callback=print, n_threads=8) 263 | my_assistant.start() 264 | ``` 265 | Here, we set the `commands_callback` to a simple print function, so the commands will just get printed on the screen. 266 | 267 | You can also run this example from the command line. 268 | ```shell 269 | $ pwcpp-assistant --help 270 | 271 | usage: pwcpp-assistant [-h] [-m MODEL] [-ind INPUT_DEVICE] [-st SILENCE_THRESHOLD] [-bd BLOCK_DURATION] 272 | 273 | options: 274 | -h, --help show this help message and exit 275 | -m MODEL, --model MODEL 276 | Whisper.cpp model, default to tiny.en 277 | -ind INPUT_DEVICE, --input_device INPUT_DEVICE 278 | Id of The input device (aka microphone) 279 | -st SILENCE_THRESHOLD, --silence_threshold SILENCE_THRESHOLD 280 | he duration of silence after which the inference will be running, default to 16 281 | -bd BLOCK_DURATION, --block_duration BLOCK_DURATION 282 | minimum time audio updates in ms, default to 30 283 | ``` 284 | ------------- 285 | 286 | * Check the [examples folder](https://github.com/absadiki/pywhispercpp/tree/main/pywhispercpp/examples) for more examples. 287 | 288 | # Advanced usage 289 | * First check the [API documentation](https://absadiki.github.io/pywhispercpp/) for more advanced usage. 290 | * If you are a more experienced user, you can access the exposed C-APIs directly from the binding module `_pywhispercpp`. 291 | 292 | ```python 293 | import _pywhispercpp as pwcpp 294 | 295 | ctx = pwcpp.whisper_init_from_file('path/to/ggml/model') 296 | ``` 297 | 298 | # Discussions and contributions 299 | If you find any bug, please open an [issue](https://github.com/absadiki/pywhispercpp/issues). 300 | 301 | If you have any feedback, or you want to share how you are using this project, feel free to use the [Discussions](https://github.com/absadiki/pywhispercpp/discussions) and open a new topic. 302 | 303 | # License 304 | 305 | This project is licensed under the same license as [whisper.cpp](https://github.com/ggerganov/whisper.cpp/blob/master/LICENSE) (MIT [License](./LICENSE)). 306 | -------------------------------------------------------------------------------- /pywhispercpp/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | This module contains a simple Python API on-top of the C-style 6 | [whisper.cpp](https://github.com/ggerganov/whisper.cpp) API. 7 | """ 8 | import importlib.metadata 9 | import logging 10 | import shutil 11 | import sys 12 | from pathlib import Path 13 | from time import time 14 | from typing import Union, Callable, List, TextIO, Tuple, Optional 15 | import _pywhispercpp as pw 16 | import numpy as np 17 | import pywhispercpp.utils as utils 18 | import pywhispercpp.constants as constants 19 | import subprocess 20 | import os 21 | import tempfile 22 | import wave 23 | 24 | __author__ = "absadiki" 25 | __copyright__ = "Copyright 2023, " 26 | __license__ = "MIT" 27 | __version__ = importlib.metadata.version('pywhispercpp') 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class Segment: 33 | """ 34 | A small class representing a transcription segment 35 | """ 36 | 37 | def __init__(self, t0: int, t1: int, text: str, probability: float = np.nan): 38 | """ 39 | :param t0: start time 40 | :param t1: end time 41 | :param text: text 42 | :param probability: Confidence score for the segment, computed as the geometric mean of 43 | the token probabilities for the segment (NaN if not calculated). 44 | This makes it interpretable as a probability in [0, 1]. 45 | """ 46 | self.t0 = t0 47 | self.t1 = t1 48 | self.text = text 49 | self.probability = probability 50 | 51 | def __str__(self): 52 | return f"t0={self.t0}, t1={self.t1}, text={self.text}, probability={self.probability}" 53 | 54 | def __repr__(self): 55 | return str(self) 56 | 57 | 58 | class Model: 59 | """ 60 | This classes defines a Whisper.cpp model. 61 | 62 | Example usage. 63 | ```python 64 | model = Model('base.en', n_threads=6) 65 | segments = model.transcribe('file.mp3') 66 | for segment in segments: 67 | print(segment.text) 68 | ``` 69 | """ 70 | 71 | _new_segment_callback = None 72 | 73 | def __init__(self, 74 | model: str = 'tiny', 75 | models_dir: str = None, 76 | params_sampling_strategy: int = 0, 77 | redirect_whispercpp_logs_to: Union[bool, TextIO, str, None] = False, 78 | use_openvino: bool = False, 79 | openvino_model_path: str = None, 80 | openvino_device: str = 'CPU', 81 | openvino_cache_dir: str = None, 82 | **params): 83 | """ 84 | :param model: The name of the model, one of the [AVAILABLE_MODELS](/pywhispercpp/#pywhispercpp.constants.AVAILABLE_MODELS), 85 | (default to `tiny`), or a direct path to a `ggml` model. 86 | :param models_dir: The directory where the models are stored, or where they will be downloaded if they don't 87 | exist, default to [MODELS_DIR](/pywhispercpp/#pywhispercpp.constants.MODELS_DIR) 88 | :param params_sampling_strategy: 0 -> GREEDY, else BEAM_SEARCH 89 | :param redirect_whispercpp_logs_to: where to redirect the whisper.cpp logs, default to False (no redirection), accepts str file path, sys.stdout, sys.stderr, or use None to redirect to devnull 90 | :param use_openvino: whether to use OpenVINO or not 91 | :param openvino_model_path: path to the OpenVINO model 92 | :param openvino_device: OpenVINO device, default to CPU 93 | :param openvino_cache_dir: OpenVINO cache directory 94 | :param params: keyword arguments for different whisper.cpp parameters, 95 | see [PARAMS_SCHEMA](/pywhispercpp/#pywhispercpp.constants.PARAMS_SCHEMA) 96 | """ 97 | if Path(model).is_file(): 98 | self.model_path = model 99 | else: 100 | self.model_path = utils.download_model(model, models_dir) 101 | self._ctx = None 102 | self._sampling_strategy = pw.whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY if params_sampling_strategy == 0 else \ 103 | pw.whisper_sampling_strategy.WHISPER_SAMPLING_BEAM_SEARCH 104 | self._params = pw.whisper_full_default_params(self._sampling_strategy) 105 | # assign params 106 | self.params = params 107 | self._set_params(params) 108 | self.redirect_whispercpp_logs_to = redirect_whispercpp_logs_to 109 | self.use_openvino = use_openvino 110 | self.openvino_model_path = openvino_model_path 111 | self.openvino_device = openvino_device 112 | self.openvino_cache_dir = openvino_cache_dir 113 | # init the model 114 | self._init_model() 115 | 116 | def transcribe(self, 117 | media: Union[str, np.ndarray], 118 | n_processors: int = None, 119 | new_segment_callback: Callable[[Segment], None] = None, 120 | **params) -> List[Segment]: 121 | """ 122 | Transcribes the media provided as input and returns list of `Segment` objects. 123 | Accepts a media_file path (audio/video) or a raw numpy array. 124 | 125 | :param media: Media file path or a numpy array 126 | :param n_processors: if not None, it will run the transcription on multiple processes 127 | binding to whisper.cpp/whisper_full_parallel 128 | > Split the input audio in chunks and process each chunk separately using whisper_full() 129 | :param new_segment_callback: callback function that will be called when a new segment is generated 130 | :param params: keyword arguments for different whisper.cpp parameters, see ::: constants.PARAMS_SCHEMA 131 | :param extract_probability: If True, calculates the geometric mean of token probabilities for each segment, 132 | providing a confidence score interpretable as a probability in [0, 1]. 133 | :return: List of transcription segments 134 | """ 135 | if type(media) is np.ndarray: 136 | audio = media 137 | else: 138 | if not Path(media).exists(): 139 | raise FileNotFoundError(media) 140 | audio = self._load_audio(media) 141 | 142 | # Handle extract_probability parameter 143 | self.extract_probability = params.pop('extract_probability', False) 144 | 145 | # update params if any 146 | self._set_params(params) 147 | 148 | # setting up callback 149 | if new_segment_callback: 150 | Model._new_segment_callback = new_segment_callback 151 | pw.assign_new_segment_callback(self._params, Model.__call_new_segment_callback) 152 | 153 | # run inference 154 | start_time = time() 155 | logger.info("Transcribing ...") 156 | res = self._transcribe(audio, n_processors=n_processors) 157 | end_time = time() 158 | logger.info(f"Inference time: {end_time - start_time:.3f} s") 159 | return res 160 | 161 | @staticmethod 162 | def _get_segments(ctx, start: int, end: int, extract_probability: bool = False) -> List[Segment]: 163 | """ 164 | Helper function to get generated segments between `start` and `end` 165 | 166 | :param ctx: whisper context 167 | :param start: start index 168 | :param end: end index 169 | :param extract_probability: whether to calculate token probabilities 170 | 171 | :return: list of segments 172 | """ 173 | n = pw.whisper_full_n_segments(ctx) 174 | assert end <= n, f"{end} > {n}: `End` index must be less or equal than the total number of segments" 175 | res = [] 176 | for i in range(start, end): 177 | t0 = pw.whisper_full_get_segment_t0(ctx, i) 178 | t1 = pw.whisper_full_get_segment_t1(ctx, i) 179 | bytes = pw.whisper_full_get_segment_text(ctx, i) 180 | text = bytes.decode('utf-8', errors='replace') 181 | 182 | avg_prob = np.nan 183 | 184 | # Only calculate probabilities if requested 185 | if extract_probability: 186 | n_tokens = pw.whisper_full_n_tokens(ctx, i) 187 | if n_tokens == 1: 188 | avg_prob = pw.whisper_full_get_token_p(ctx, i, 0) 189 | elif n_tokens > 1: 190 | total_logprob = 0.0 191 | for j in range(n_tokens): 192 | total_logprob += np.log(pw.whisper_full_get_token_p(ctx, i, j)) 193 | avg_prob = np.exp(total_logprob / n_tokens) 194 | else: 195 | avg_prob = np.nan 196 | 197 | res.append(Segment(t0, t1, text.strip(), probability=np.float32(avg_prob))) 198 | return res 199 | 200 | def get_params(self) -> dict: 201 | """ 202 | Returns a `dict` representation of the actual params 203 | 204 | :return: params dict 205 | """ 206 | res = {} 207 | for param in dir(self._params): 208 | if param.startswith('__'): 209 | continue 210 | try: 211 | res[param] = getattr(self._params, param) 212 | except Exception: 213 | # ignore callback functions 214 | continue 215 | return res 216 | 217 | @staticmethod 218 | def get_params_schema() -> dict: 219 | """ 220 | A simple link to ::: constants.PARAMS_SCHEMA 221 | :return: dict of params schema 222 | """ 223 | return constants.PARAMS_SCHEMA 224 | 225 | @staticmethod 226 | def lang_max_id() -> int: 227 | """ 228 | Returns number of supported languages. 229 | Direct binding to whisper.cpp/lang_max_id 230 | :return: 231 | """ 232 | return pw.whisper_lang_max_id() 233 | 234 | def print_timings(self) -> None: 235 | """ 236 | Direct binding to whisper.cpp/whisper_print_timings 237 | 238 | :return: None 239 | """ 240 | pw.whisper_print_timings(self._ctx) 241 | 242 | @staticmethod 243 | def system_info() -> None: 244 | """ 245 | Direct binding to whisper.cpp/whisper_print_system_info 246 | 247 | :return: None 248 | """ 249 | return pw.whisper_print_system_info() 250 | 251 | @staticmethod 252 | def available_languages() -> list[str]: 253 | """ 254 | Returns a list of supported language codes 255 | 256 | :return: list of supported language codes 257 | """ 258 | n = pw.whisper_lang_max_id() 259 | res = [] 260 | for i in range(n+1): 261 | res.append(pw.whisper_lang_str(i)) 262 | return res 263 | 264 | def _init_model(self) -> None: 265 | """ 266 | Private method to initialize the method from the bindings, it will be called automatically from the __init__ 267 | :return: 268 | """ 269 | logger.info("Initializing the model ...") 270 | with utils.redirect_stderr(to=self.redirect_whispercpp_logs_to): 271 | self._ctx = pw.whisper_init_from_file(self.model_path) 272 | if self.use_openvino: 273 | pw.whisper_ctx_init_openvino_encoder(self._ctx, self.openvino_model_path, self.openvino_device, self.openvino_cache_dir) 274 | 275 | 276 | 277 | def _set_params(self, kwargs: dict) -> None: 278 | """ 279 | Private method to set the kwargs params to the `Params` class 280 | :param kwargs: dict like object for the different params 281 | :return: None 282 | """ 283 | for param in kwargs: 284 | setattr(self._params, param, kwargs[param]) 285 | 286 | def _transcribe(self, audio: np.ndarray, n_processors: int = None): 287 | """ 288 | Private method to call the whisper.cpp/whisper_full function 289 | 290 | :param audio: numpy array of audio data 291 | :param n_processors: if not None, it will run whisper.cpp/whisper_full_parallel with n_processors 292 | :return: 293 | """ 294 | 295 | if n_processors: 296 | pw.whisper_full_parallel(self._ctx, self._params, audio, audio.size, n_processors) 297 | else: 298 | pw.whisper_full(self._ctx, self._params, audio, audio.size) 299 | n = pw.whisper_full_n_segments(self._ctx) 300 | res = Model._get_segments(self._ctx, 0, n, self.extract_probability) 301 | return res 302 | 303 | @staticmethod 304 | def __call_new_segment_callback(ctx, n_new, user_data) -> None: 305 | """ 306 | Internal new_segment_callback, it just calls the user's callback with the `Segment` object 307 | :param ctx: whisper.cpp ctx param 308 | :param n_new: whisper.cpp n_new param 309 | :param user_data: whisper.cpp user_data param 310 | :return: None 311 | """ 312 | n = pw.whisper_full_n_segments(ctx) 313 | start = n - n_new 314 | res = Model._get_segments(ctx, start, n, False) 315 | for segment in res: 316 | Model._new_segment_callback(segment) 317 | 318 | @staticmethod 319 | def _load_audio(media_file_path: str) -> np.array: 320 | """ 321 | Helper method to return a `np.array` object from a media file 322 | If the media file is not a WAV file, it will try to convert it using ffmpeg 323 | 324 | :param media_file_path: Path of the media file 325 | :return: Numpy array 326 | """ 327 | 328 | def wav_to_np(file_path): 329 | with wave.open(file_path, 'rb') as wf: 330 | num_channels = wf.getnchannels() 331 | sample_width = wf.getsampwidth() 332 | sample_rate = wf.getframerate() 333 | num_frames = wf.getnframes() 334 | 335 | if num_channels not in (1, 2): 336 | raise Exception(f"WAV file must be mono or stereo") 337 | 338 | if sample_rate != pw.WHISPER_SAMPLE_RATE: 339 | raise Exception(f"WAV file must be {pw.WHISPER_SAMPLE_RATE} Hz") 340 | 341 | if sample_width != 2: 342 | raise Exception(f"WAV file must be 16-bit") 343 | 344 | raw = wf.readframes(num_frames) 345 | wf.close() 346 | audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) 347 | n = num_frames 348 | if num_channels == 1: 349 | pcmf32 = audio / 32768.0 350 | else: 351 | audio = audio.reshape(-1, 2) 352 | # Averaging the two channels 353 | pcmf32 = (audio[:, 0] + audio[:, 1]) / 65536.0 354 | return pcmf32 355 | 356 | if media_file_path.endswith('.wav'): 357 | return wav_to_np(media_file_path) 358 | else: 359 | if shutil.which('ffmpeg') is None: 360 | raise Exception( 361 | "FFMPEG is not installed or not in PATH. Please install it, or provide a WAV file or a NumPy array instead!") 362 | 363 | temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) 364 | temp_file_path = temp_file.name 365 | temp_file.close() 366 | try: 367 | subprocess.run([ 368 | 'ffmpeg', '-i', media_file_path, '-ac', '1', '-ar', '16000', 369 | temp_file_path, '-y' 370 | ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 371 | return wav_to_np(temp_file_path) 372 | finally: 373 | os.remove(temp_file_path) 374 | 375 | def auto_detect_language(self, media: Union[str, np.ndarray], offset_ms: int = 0, n_threads: int = 4) -> Tuple[Tuple[str, np.float32], dict[str, np.float32]]: 376 | """ 377 | Automatic language detection using whisper.cpp/whisper_pcm_to_mel and whisper.cpp/whisper_lang_auto_detect 378 | 379 | :param media: Media file path or a numpy array 380 | :param offset_ms: offset in milliseconds 381 | :param n_threads: number of threads to use 382 | :return: ((detected_language, probability), probabilities for all languages) 383 | """ 384 | if type(media) is np.ndarray: 385 | audio = media 386 | else: 387 | if not Path(media).exists(): 388 | raise FileNotFoundError(media) 389 | audio = self._load_audio(media) 390 | 391 | pw.whisper_pcm_to_mel(self._ctx, audio, len(audio), n_threads) 392 | lang_max_id = self.lang_max_id() 393 | probs = np.zeros(lang_max_id, dtype=np.float32) 394 | auto_detect = pw.whisper_lang_auto_detect(self._ctx, offset_ms, n_threads, probs) 395 | langs = self.available_languages() 396 | lang_probs = {langs[i]: probs[i] for i in range(lang_max_id)} 397 | return (langs[auto_detect], probs[auto_detect]), lang_probs 398 | 399 | def __del__(self): 400 | """ 401 | Free up resources 402 | :return: None 403 | """ 404 | pw.whisper_free(self._ctx) -------------------------------------------------------------------------------- /pywhispercpp/examples/gui.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import threading 3 | from datetime import datetime 4 | 5 | from PyQt5.QtWidgets import ( 6 | QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, 7 | QFileDialog, QProgressBar, QLabel, QFrame, 8 | QSizePolicy, QTableWidget, QTableWidgetItem, QHeaderView, 9 | QGroupBox, QFormLayout, QComboBox, QLineEdit, QCheckBox, 10 | QSpinBox, QDoubleSpinBox, QToolButton, QDialog, QMenu # Import QMenu 11 | ) 12 | from PyQt5.QtCore import Qt, pyqtSignal, QObject 13 | import os 14 | import importlib.metadata 15 | 16 | __version__ = importlib.metadata.version('pywhispercpp') 17 | 18 | from pywhispercpp.model import Model, Segment 19 | from pywhispercpp.utils import output_txt, output_srt, output_vtt, output_csv # Import utilities 20 | 21 | 22 | # --- Available Models --- 23 | # Define a custom order for model sizes 24 | MODEL_SIZE_ORDER = {"tiny": 0, "base": 1, "small": 2, "medium": 3, "large": 4} 25 | 26 | UNSORTED_MODELS = [ 27 | "base", "base-q5_1", "base-q8_0", "base.en", "base.en-q5_1", 28 | "base.en-q8_0", "large-v1", "large-v2", "large-v2-q5_0", 29 | "large-v2-q8_0", "large-v3", "large-v3-q5_0", "large-v3-turbo", 30 | "large-v3-turbo-q5_0", "large-v3-turbo-q8_0", "medium", 31 | "medium-q5_0", "medium-q8_0", "medium.en", "medium.en-q5_0", 32 | "medium.en-q8_0", "small", "small-q5_1", "small-q8_0", "small.en", 33 | "small.en-q5_1", "small.en-q8_0", "tiny", "tiny-q5_1", "tiny-q8_0", 34 | "tiny.en", "tiny.en-q5_1", "tiny.en-q8_0", 35 | ] 36 | 37 | 38 | # Custom sort key function 39 | def get_model_sort_key(model_name): 40 | # Extract the base model name (e.g., "tiny.en" -> "tiny", "large-v3-turbo" -> "large") 41 | base_name = model_name.split('.')[0].split('-')[0] 42 | # Return a tuple for multi-level sorting: (size_order, full_model_name_for_secondary_sort) 43 | return (MODEL_SIZE_ORDER.get(base_name, 99), model_name) # 99 for any unexpected names 44 | 45 | 46 | # Sort the models 47 | AVAILABLE_MODELS = sorted(UNSORTED_MODELS, key=get_model_sort_key) 48 | 49 | # --- Retouched Minimal Stylesheet --- 50 | STYLESHEET = """ 51 | /* General Application Styles */ 52 | QWidget { 53 | background-color: #f5f5f5; /* Light gray background */ 54 | color: #333333; /* Dark text */ 55 | font-family: Arial, sans-serif; 56 | font-size: 14px; 57 | } 58 | 59 | /* Buttons */ 60 | QPushButton { 61 | background-color: #e0e0e0; /* Slightly darker gray for buttons */ 62 | color: #333333; 63 | border: 1px solid #c0c0c0; /* Light gray border */ 64 | padding: 6px 12px; 65 | border-radius: 3px; /* Slightly rounded corners */ 66 | outline: none; /* Remove focus outline for a cleaner look */ 67 | } 68 | 69 | QPushButton:hover { 70 | background-color: #d0d0d0; /* Darker on hover */ 71 | } 72 | 73 | QPushButton:pressed { 74 | background-color: #c0c0c0; /* Even darker when pressed */ 75 | } 76 | 77 | QPushButton:disabled { 78 | background-color: #f0f0f0; 79 | color: #aaaaaa; 80 | border-color: #e0e0e0; 81 | } 82 | 83 | /* Specific styling for the Transcribe button (Call to Action) */ 84 | QPushButton#TranscribeButton { 85 | background-color: #007bff; /* Vibrant blue */ 86 | color: #ffffff; /* White text */ 87 | font-weight: bold; 88 | font-size: 15px; /* Slightly larger font */ 89 | padding: 8px 18px; /* More padding */ 90 | border: 1px solid #007bff; /* Matching border */ 91 | border-radius: 4px; 92 | } 93 | 94 | QPushButton#TranscribeButton:hover { 95 | background-color: #0056b3; /* Darker blue on hover */ 96 | border-color: #0056b3; 97 | } 98 | 99 | QPushButton#TranscribeButton:pressed { 100 | background-color: #004085; /* Even darker blue when pressed */ 101 | border-color: #004085; 102 | } 103 | 104 | QPushButton#TranscribeButton:disabled { 105 | background-color: #cccccc; /* Light gray for disabled state */ 106 | color: #888888; 107 | border-color: #cccccc; 108 | } 109 | 110 | /* Stop button */ 111 | QPushButton#StopButton { 112 | background-color: #ff0000; /* Red for stop */ 113 | color: #ffffff; /* White text */ 114 | font-weight: bold; 115 | font-size: 15px; /* Slightly larger font */ 116 | padding: 8px 18px; /* More padding */ 117 | border: 1px solid #ff0000; /* Matching border */ 118 | border-radius: 4px; 119 | } 120 | 121 | QPushButton#StopButton:hover { 122 | background-color: #cc0000; /* Darker red on hover */ 123 | border-color: #cc0000; 124 | } 125 | 126 | QPushButton#StopButton:pressed { 127 | background-color: #990000; /* Even darker red when pressed */ 128 | border-color: #990000; 129 | } 130 | 131 | QPushButton#StopButton:disabled { 132 | background-color: #cccccc; /* Light gray for disabled state */ 133 | color: #888888; 134 | border-color: #cccccc; 135 | } 136 | 137 | /* ToolButton for accordion header */ 138 | QToolButton { 139 | background-color: transparent; /* Keep it transparent */ 140 | border: none; 141 | padding: 5px 0; /* Some padding for clickable area */ 142 | font-weight: bold; 143 | color: #444444; /* Darker text for emphasis */ 144 | text-align: left; 145 | outline: none; /* Remove focus outline */ 146 | } 147 | 148 | QToolButton::menu-indicator { 149 | image: none; /* Hide default menu indicator */ 150 | } 151 | 152 | /* Table Widget */ 153 | QTableWidget { 154 | background-color: #ffffff; 155 | border: 1px solid #d0d0d0; /* Slightly softer border */ 156 | gridline-color: #e8e8e8; /* Very light grid lines */ 157 | border-radius: 3px; 158 | } 159 | 160 | QTableWidget::item { 161 | padding: 4px; 162 | border-bottom: 1px solid #f5f5f5; /* Match general background for subtle row separation */ 163 | } 164 | 165 | /* Table Header */ 166 | QHeaderView::section { 167 | background-color: #e8e8e8; /* Light header background */ 168 | color: #333333; 169 | padding: 5px; 170 | border: 1px solid #d0d0d0; 171 | font-weight: bold; 172 | } 173 | 174 | /* Labels */ 175 | QLabel { 176 | color: #333333; 177 | } 178 | 179 | /* Specific styling for the main title label */ 180 | QLabel#TitleLabel { 181 | font-size: 18px; /* Slightly smaller than previous '20px' to integrate better */ 182 | font-weight: bold; 183 | color: #0056b3; /* A slightly darker blue for main title */ 184 | padding-bottom: 3px; 185 | margin-bottom: 5px; 186 | border-bottom: 1px solid #e0e0e0; 187 | } 188 | 189 | /* Progress Bar */ 190 | QProgressBar { 191 | border: 1px solid #bbbbbb; 192 | border-radius: 3px; 193 | text-align: center; 194 | background-color: #e6e6e6; 195 | color: #333333; 196 | } 197 | 198 | QProgressBar::chunk { 199 | background-color: #4CAF50; /* A pleasant green */ 200 | border-radius: 2px; 201 | } 202 | 203 | /* Input Widgets (ComboBox, LineEdit, SpinBox, DoubleSpinBox) */ 204 | QComboBox, QLineEdit, QSpinBox, QDoubleSpinBox { 205 | border: 1px solid #cccccc; 206 | padding: 3px; 207 | background-color: #ffffff; 208 | color: #333333; 209 | border-radius: 3px; /* Apply rounded corners */ 210 | } 211 | 212 | /* File label should visually match text inputs */ 213 | QLabel#file_label { 214 | border: 1px solid #cccccc; 215 | padding: 3px; 216 | background-color: #ffffff; 217 | border-radius: 3px; 218 | } 219 | 220 | QFrame { 221 | border: none; /* Keep frames invisible unless needed for structure */ 222 | } 223 | 224 | /* Status Bar Label */ 225 | QLabel#status_bar_label { 226 | background-color: #e0e0e0; /* Light grey background */ 227 | border-top: 1px solid #cccccc; /* Separator */ 228 | color: #444444; /* Darker text */ 229 | padding: 3px 5px; /* Add internal padding */ 230 | font-size: 13px; /* Slightly smaller font */ 231 | } 232 | """ 233 | 234 | 235 | # --- Communication Object for Threading --- 236 | class WorkerSignals(QObject): 237 | """ 238 | Defines signals available from a running worker thread. 239 | Supported signals are: 240 | - finished: No data 241 | - error: tuple (exctype, value, traceback.format_exc()) 242 | - result: list (the transcribed segments) 243 | - progress: int (0-100) 244 | - status_update: str 245 | """ 246 | finished = pyqtSignal() 247 | error = pyqtSignal(tuple) 248 | segment = pyqtSignal(Segment) 249 | result = pyqtSignal(list) 250 | progress = pyqtSignal(int) 251 | status_update = pyqtSignal(str) 252 | 253 | 254 | # --- Worker Thread for Transcription --- 255 | class PyWhisperCppWorker(threading.Thread): 256 | 257 | def __init__(self, audio_file_path, model_name, **transcribe_params): 258 | super().__init__() 259 | self.audio_file_path = audio_file_path 260 | self.model_name = model_name 261 | self.transcribe_params = transcribe_params 262 | self.signals = WorkerSignals() 263 | self._is_running = False 264 | 265 | def run(self): 266 | """ 267 | Executes the transcription process. 268 | """ 269 | try: 270 | self._is_running = True 271 | self.signals.status_update.emit(f"Loading model: {self.model_name}...") 272 | 273 | # pywhispercpp will download the specified model if not found 274 | model_init_params = {} 275 | if 'n_threads' in self.transcribe_params and self.transcribe_params['n_threads'] is not None: 276 | model_init_params['n_threads'] = self.transcribe_params['n_threads'] 277 | # Remove from transcribe_params as it's a model init param 278 | del self.transcribe_params['n_threads'] 279 | 280 | model = Model(self.model_name, **model_init_params) 281 | 282 | self.signals.status_update.emit("Model loaded. Starting transcription...") 283 | 284 | def new_segment_callback(segment): 285 | if not self._is_running: 286 | raise RuntimeError("Transcription manually stopped") 287 | self.signals.segment.emit(segment) 288 | 289 | segments = model.transcribe(self.audio_file_path, 290 | new_segment_callback=new_segment_callback, 291 | progress_callback=lambda progress: self.signals.progress.emit(progress), 292 | **self.transcribe_params) 293 | 294 | self.signals.status_update.emit("Transcription complete!") 295 | self.signals.result.emit(segments) 296 | 297 | except Exception as e: 298 | print(e) 299 | self.signals.status_update.emit(f"Error: {str(e)}") 300 | self.signals.error.emit((type(e), e, str(e))) 301 | finally: 302 | self._is_running = False 303 | self.signals.finished.emit() 304 | 305 | def stop(self): 306 | self._is_running = False 307 | 308 | 309 | # --- Main Application Window --- 310 | class TranscriptionApp(QWidget): 311 | def __init__(self): 312 | super().__init__() 313 | self.selected_file_path = None 314 | self.whisper_thread = None 315 | # Settings widgets 316 | self.model_combo = None 317 | self.language_input = None 318 | self.translate_checkbox = None 319 | self.n_threads_spinbox = None 320 | self.no_context_checkbox = None 321 | self.temperature_spinbox = None 322 | self.settings_content_frame = None # Frame to hold collapsible settings 323 | self.toggle_settings_button = None # Button to toggle settings 324 | self.status_bar_label = None # New label for the status bar 325 | self.about_button = None # About button 326 | self.segments = [] # Store segments for export 327 | self.copy_text_button = None # New button for copy text 328 | 329 | self.initUI() 330 | 331 | def initUI(self): 332 | """ 333 | Initializes the user interface of the application. 334 | """ 335 | self.setWindowTitle('PyWhisperCpp Simple GUI') 336 | self.setGeometry(100, 100, 450, 500) 337 | # Apply the updated stylesheet 338 | self.setStyleSheet(STYLESHEET) 339 | 340 | # Main vertical layout 341 | main_layout = QVBoxLayout() 342 | # Set bottom margin to 0 for the main layout to ensure status bar is flush 343 | main_layout.setContentsMargins(4, 4, 4, 0) 344 | main_layout.setSpacing(10) 345 | 346 | # --- Header (Title + About Button) --- 347 | header_layout = QHBoxLayout() 348 | title_label = QLabel("PyWhisperCpp Simple GUI") # Updated main title label 349 | title_label.setObjectName("TitleLabel") # Add objectName for styling 350 | title_label.setAlignment(Qt.AlignLeft) # Keep title centered within its allocated space 351 | 352 | # Adding stretch before and after title to center it 353 | # header_layout.addStretch() 354 | header_layout.addWidget(title_label) 355 | header_layout.addStretch() 356 | 357 | # About button 358 | self.about_button = QPushButton("About") 359 | self.about_button.clicked.connect(self.show_about_dialog) 360 | # Removed setFixedSize to allow text to fit, or adjust as needed 361 | # self.about_button.setFixedSize(50, 25) 362 | header_layout.addWidget(self.about_button) # Add it to the header layout 363 | 364 | main_layout.addLayout(header_layout) # Add the combined header to main layout 365 | 366 | # --- File Selection Area --- 367 | file_frame = QFrame() 368 | file_layout = QHBoxLayout(file_frame) 369 | file_layout.setContentsMargins(0, 0, 0, 0) 370 | file_layout.setSpacing(10) 371 | 372 | self.select_button = QPushButton("Select Audio File") 373 | self.select_button.clicked.connect(self.select_file) 374 | 375 | self.file_label = QLabel("No file selected.") 376 | self.file_label.setObjectName("file_label") # Added objectName for styling 377 | self.file_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred) 378 | 379 | file_layout.addWidget(self.select_button) 380 | file_layout.addWidget(self.file_label) 381 | main_layout.addWidget(file_frame) 382 | 383 | # --- Collapsible Settings Section --- 384 | settings_group = QGroupBox() # No title here, using QToolButton for title 385 | settings_group_layout = QVBoxLayout(settings_group) 386 | settings_group_layout.setContentsMargins(5, 5, 5, 5) 387 | 388 | # Custom title bar for the collapsible group box 389 | header_layout_settings = QHBoxLayout() # Renamed to avoid clash 390 | self.toggle_settings_button = QToolButton(settings_group) 391 | self.toggle_settings_button.setText("Transcription Settings") 392 | self.toggle_settings_button.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) 393 | self.toggle_settings_button.setArrowType(Qt.RightArrow) 394 | self.toggle_settings_button.setCheckable(True) 395 | self.toggle_settings_button.setChecked(False) # Start collapsed 396 | self.toggle_settings_button.clicked.connect(self.toggle_settings_visibility) 397 | 398 | header_layout_settings.addWidget(self.toggle_settings_button) 399 | header_layout_settings.addStretch() # Push button to left 400 | 401 | settings_group_layout.addLayout(header_layout_settings) 402 | 403 | # Frame to hold the actual settings form (this will be hidden/shown) 404 | self.settings_content_frame = QFrame() 405 | settings_form_layout = QFormLayout(self.settings_content_frame) 406 | settings_form_layout.setContentsMargins(15, 5, 10, 10) 407 | settings_form_layout.setSpacing(8) 408 | 409 | # Model Selection 410 | self.model_combo = QComboBox() 411 | self.model_combo.addItems(AVAILABLE_MODELS) 412 | self.model_combo.setCurrentText("tiny") # Default to 'tiny' as requested 413 | settings_form_layout.addRow("Model:", self.model_combo) 414 | 415 | # Language Input 416 | self.language_input = QLineEdit() 417 | self.language_input.setPlaceholderText('e.g., "en", "es", or leave empty for auto-detect') 418 | self.language_input.setText("") # Default to auto-detect 419 | settings_form_layout.addRow("Language:", self.language_input) 420 | 421 | # Translate Checkbox 422 | self.translate_checkbox = QCheckBox("Translate to English") 423 | self.translate_checkbox.setChecked(False) # Default 424 | settings_form_layout.addRow("Translate:", self.translate_checkbox) 425 | 426 | # N Threads Spinbox 427 | self.n_threads_spinbox = QSpinBox() 428 | self.n_threads_spinbox.setRange(1, os.cpu_count() if os.cpu_count() else 8) # Max threads based on CPU cores 429 | self.n_threads_spinbox.setValue(4) # Sensible default 430 | settings_form_layout.addRow("Number of Threads:", self.n_threads_spinbox) 431 | 432 | # No Context Checkbox 433 | self.no_context_checkbox = QCheckBox("No Context (do not use past transcription)") 434 | self.no_context_checkbox.setChecked(False) # Default 435 | settings_form_layout.addRow("No Context:", self.no_context_checkbox) 436 | 437 | # Temperature Spinbox 438 | self.temperature_spinbox = QDoubleSpinBox() 439 | self.temperature_spinbox.setRange(0.0, 1.0) 440 | self.temperature_spinbox.setSingleStep(0.1) 441 | self.temperature_spinbox.setValue(0.0) # Default 442 | settings_form_layout.addRow("Temperature:", self.temperature_spinbox) 443 | 444 | settings_group_layout.addWidget(self.settings_content_frame) 445 | self.settings_content_frame.setVisible(False) # Initially hidden 446 | 447 | main_layout.addWidget(settings_group) 448 | 449 | # --- Transcription Button --- 450 | self.transcribe_button = QPushButton("Transcribe") 451 | self.transcribe_button.setObjectName("TranscribeButton") # Add objectName for styling 452 | self.transcribe_button.setEnabled(False) 453 | self.transcribe_button.clicked.connect(self.start_transcription) 454 | main_layout.addWidget(self.transcribe_button) 455 | 456 | # --- Stop Button --- 457 | self.stop_button = QPushButton("Stop") 458 | self.stop_button.setObjectName("StopButton") # Add objectName for styling 459 | self.stop_button.setEnabled(True) 460 | self.stop_button.setVisible(False) 461 | self.stop_button.clicked.connect(self.stop_transcription) 462 | main_layout.addWidget(self.stop_button) 463 | 464 | # --- Progress Bar --- 465 | progress_frame = QFrame() 466 | progress_layout = QVBoxLayout(progress_frame) 467 | progress_layout.setContentsMargins(0, 5, 0, 5) 468 | progress_layout.setSpacing(5) 469 | 470 | self.progress_bar = QProgressBar() 471 | self.progress_bar.setVisible(False) 472 | 473 | progress_layout.addWidget(self.progress_bar) 474 | main_layout.addWidget(progress_frame) 475 | 476 | # --- Transcription Output Table --- 477 | output_label = QLabel("Transcription Output:") 478 | main_layout.addWidget(output_label) 479 | 480 | self.results_table = QTableWidget() 481 | self.results_table.setColumnCount(3) 482 | self.results_table.setHorizontalHeaderLabels(["Start Time", "End Time", "Text"]) 483 | header = self.results_table.horizontalHeader() 484 | header.setSectionResizeMode(0, QHeaderView.ResizeToContents) 485 | header.setSectionResizeMode(1, QHeaderView.ResizeToContents) 486 | header.setSectionResizeMode(2, QHeaderView.Stretch) 487 | self.results_table.verticalHeader().setVisible(False) 488 | main_layout.addWidget(self.results_table) 489 | 490 | # --- Output Buttons (Export and Copy) --- 491 | output_buttons_layout = QHBoxLayout() 492 | output_buttons_layout.addStretch() # Pushes buttons to the right 493 | 494 | # Export Button with Menu 495 | self.export_button = QPushButton("Export as...") 496 | self.export_button.setEnabled(False) 497 | self.export_menu = QMenu(self) 498 | 499 | self.export_action_txt = self.export_menu.addAction("Plain Text (.txt)") 500 | self.export_action_srt = self.export_menu.addAction("SRT Subtitle (.srt)") 501 | self.export_action_vtt = self.export_menu.addAction("VTT Subtitle (.vtt)") 502 | self.export_action_csv = self.export_menu.addAction("CSV (.csv)") 503 | 504 | self.export_action_txt.triggered.connect(lambda: self.export_transcription("txt")) 505 | self.export_action_srt.triggered.connect(lambda: self.export_transcription("srt")) 506 | self.export_action_vtt.triggered.connect(lambda: self.export_transcription("vtt")) 507 | self.export_action_csv.triggered.connect(lambda: self.export_transcription("csv")) 508 | 509 | self.export_button.setMenu(self.export_menu) 510 | output_buttons_layout.addWidget(self.export_button) 511 | 512 | # Copy Text Button 513 | self.copy_text_button = QPushButton("Copy Text") 514 | self.copy_text_button.setEnabled(False) # Initially disabled 515 | self.copy_text_button.clicked.connect(self.copy_all_text_to_clipboard) # Connect to new method 516 | output_buttons_layout.addWidget(self.copy_text_button) 517 | 518 | main_layout.addLayout(output_buttons_layout) 519 | 520 | # --- Status Bar at the very bottom --- 521 | self.status_bar_label = QLabel("Ready.") 522 | self.status_bar_label.setObjectName("status_bar_label") # Add objectName for styling 523 | self.status_bar_label.setAlignment(Qt.AlignLeft | Qt.AlignVCenter) 524 | self.status_bar_label.setContentsMargins(5, 2, 5, 2) 525 | main_layout.addWidget(self.status_bar_label) 526 | 527 | self.setLayout(main_layout) 528 | 529 | def toggle_settings_visibility(self): 530 | """Toggles the visibility of the settings content frame and updates the arrow.""" 531 | is_visible = self.settings_content_frame.isVisible() 532 | self.settings_content_frame.setVisible(not is_visible) 533 | if not is_visible: 534 | self.toggle_settings_button.setArrowType(Qt.DownArrow) 535 | else: 536 | self.toggle_settings_button.setArrowType(Qt.RightArrow) 537 | 538 | def select_file(self): 539 | """ 540 | Opens a file dialog to select an audio file. 541 | """ 542 | options = QFileDialog.Options() 543 | file_path, _ = QFileDialog.getOpenFileName( 544 | self, "Select a Media File", "", 545 | "All Files (*)", 546 | options=options 547 | ) 548 | if file_path: 549 | self.selected_file_path = file_path 550 | self.file_label.setText(f"Selected: {os.path.basename(file_path)}") 551 | self.transcribe_button.setEnabled(True) 552 | self.results_table.setRowCount(0) 553 | self.export_button.setEnabled(False) # Disable export until transcription 554 | self.copy_text_button.setEnabled(False) # Disable copy until transcription 555 | self.update_status("File selected: " + os.path.basename(file_path)) # Update new status bar 556 | 557 | def start_transcription(self): 558 | """ 559 | Starts the transcription process in a separate thread, passing selected settings. 560 | """ 561 | if self.selected_file_path: 562 | self.transcribe_button.setVisible(False) 563 | self.stop_button.setVisible(True) 564 | self.select_button.setEnabled(False) 565 | self.progress_bar.setVisible(True) 566 | self.progress_bar.setValue(0) 567 | self.results_table.setRowCount(0) 568 | self.export_button.setEnabled(False) # Disable export during transcription 569 | self.copy_text_button.setEnabled(False) # Disable copy during transcription 570 | self.update_status("Starting transcription...") 571 | self.segments = [] # Clear segments for new transcription 572 | 573 | # Gather settings from GUI widgets 574 | selected_model = self.model_combo.currentText() 575 | transcribe_params = { 576 | "language": self.language_input.text() if self.language_input.text() else None, 577 | "translate": self.translate_checkbox.isChecked(), 578 | "n_threads": self.n_threads_spinbox.value(), 579 | "no_context": self.no_context_checkbox.isChecked(), 580 | "temperature": self.temperature_spinbox.value(), 581 | } 582 | # Remove None values to use pywhispercpp defaults where applicable 583 | transcribe_params = {k: v for k, v in transcribe_params.items() if v is not None} 584 | 585 | # Create and start the worker thread 586 | self.whisper_thread = PyWhisperCppWorker( 587 | self.selected_file_path, 588 | selected_model, 589 | **transcribe_params 590 | ) 591 | self.whisper_thread.signals.result.connect(self.on_transcription_result) 592 | self.whisper_thread.signals.segment.connect(self.on_new_segment) 593 | self.whisper_thread.signals.finished.connect(self.on_transcription_finished) 594 | self.whisper_thread.signals.error.connect(self.on_transcription_error) 595 | self.whisper_thread.signals.progress.connect(self.update_progress) 596 | self.whisper_thread.signals.status_update.connect(self.update_status) 597 | self.whisper_thread.start() 598 | 599 | def stop_transcription(self): 600 | if self.whisper_thread: 601 | self.whisper_thread.stop() 602 | # self.transcribe_button.setVisible(True) 603 | # self.stop_button.setVisible(False) 604 | # self.select_button.setEnabled(True) 605 | # self.progress_bar.setVisible(False) 606 | # self.on_transcription_finished() 607 | 608 | def update_progress(self, value): 609 | self.progress_bar.setValue(value) 610 | # Update status bar with progress if not already showing a specific message 611 | if not self.status_bar_label.text().startswith("Error:") and \ 612 | not self.status_bar_label.text().startswith("Finished.") and \ 613 | not self.status_bar_label.text().startswith("Text exported") and \ 614 | not self.status_bar_label.text().startswith("Text copied"): # Updated check 615 | self.update_status(f"Progress: {value}%") 616 | 617 | def update_status(self, status_text): 618 | # Update the new status bar label directly 619 | if self.status_bar_label: 620 | self.status_bar_label.setText(status_text) 621 | # Polish stylesheet for status_bar_label to ensure updates are reflected 622 | self.status_bar_label.style().unpolish(self.status_bar_label) 623 | self.status_bar_label.style().polish(self.status_bar_label) 624 | 625 | def format_time(self, milliseconds): 626 | """Converts milliseconds to HH:MM:SS.ms format.""" 627 | seconds_total = milliseconds / 1000 628 | minutes, seconds = divmod(seconds_total, 60) 629 | hours, minutes = divmod(minutes, 60) 630 | return f"{int(hours):02d}:{int(minutes):02d}:{seconds:06.3f}" 631 | 632 | def on_new_segment(self, segment): 633 | row_position = self.results_table.rowCount() 634 | self.results_table.insertRow(row_position) 635 | start_time_str = self.format_time(segment.t0) 636 | end_time_str = self.format_time(segment.t1) 637 | 638 | start_item = QTableWidgetItem(start_time_str) 639 | end_item = QTableWidgetItem(end_time_str) 640 | text_item = QTableWidgetItem(segment.text.strip()) 641 | 642 | self.results_table.setItem(row_position, 0, start_item) 643 | self.results_table.setItem(row_position, 1, end_item) 644 | self.results_table.setItem(row_position, 2, text_item) 645 | 646 | def on_transcription_result(self, segments): 647 | """ 648 | Populates the results table with the transcription segments. 649 | Stores segments for export. 650 | """ 651 | self.segments = segments # Store segments 652 | self.export_button.setEnabled(True if segments else False) # Enable export if segments exist 653 | self.copy_text_button.setEnabled(True if segments else False) # Enable copy if segments exist 654 | 655 | def on_transcription_finished(self): 656 | """ 657 | Cleans up after the transcription thread is finished. 658 | """ 659 | self.transcribe_button.setVisible(True) 660 | self.transcribe_button.setEnabled(True) 661 | self.stop_button.setVisible(False) 662 | self.select_button.setEnabled(True) 663 | self.progress_bar.setVisible(False) 664 | if self.results_table.rowCount() == 0: 665 | self.update_status("Finished. No transcription data.") 666 | else: 667 | self.update_status("Transcription finished successfully!") 668 | self.whisper_thread = None 669 | 670 | def on_transcription_error(self, err): 671 | """ 672 | Displays an error message if transcription fails. 673 | """ 674 | exctype, value, tb = err 675 | error_message = f"Error: {value}" 676 | self.update_status(error_message) # Update new status bar 677 | self.on_transcription_finished() 678 | 679 | def export_transcription(self, format_type): 680 | """ 681 | Handles exporting the transcription to a chosen file format. 682 | """ 683 | if not self.segments: 684 | self.update_status("No transcription data to export.") 685 | return 686 | 687 | file_dialog_filter = { 688 | "txt": "Plain Text Files (*.txt)", 689 | "srt": "SRT Subtitle Files (*.srt)", 690 | "vtt": "VTT Subtitle Files (*.vtt)", 691 | "csv": "CSV (Comma Separated Values) Files (*.csv)", 692 | } 693 | 694 | default_file_name = os.path.basename(self.selected_file_path).rsplit('.', 1)[ 695 | 0] + f".{format_type}" if self.selected_file_path else f"transcription.{format_type}" 696 | 697 | options = QFileDialog.Options() 698 | file_path, _ = QFileDialog.getSaveFileName( 699 | self, f"Save Transcription as {format_type.upper()}", 700 | default_file_name, 701 | file_dialog_filter.get(format_type, "All Files (*)"), 702 | options=options 703 | ) 704 | 705 | if file_path: 706 | try: 707 | # Use pywhispercpp.utils functions based on format_type 708 | if format_type == "txt": 709 | # For TXT, we'll re-use the text from the table or segments 710 | all_text = [] 711 | for segment in self.segments: 712 | all_text.append(segment.text.strip()) 713 | output_txt_content = "\n".join(all_text) 714 | with open(file_path, 'w', encoding='utf-8') as f: 715 | f.write(output_txt_content) 716 | 717 | elif format_type == "srt": 718 | if output_srt: 719 | output_srt(self.segments, file_path) 720 | else: 721 | raise ImportError("pywhispercpp.utils.output_srt not available.") 722 | elif format_type == "vtt": 723 | if output_vtt: 724 | output_vtt(self.segments, file_path) 725 | else: 726 | raise ImportError("pywhispercpp.utils.output_vtt not available.") 727 | elif format_type == "csv": 728 | if output_csv: 729 | # For CSV, we need to pass a list of lists/tuples representing rows 730 | # pywhispercpp.utils.output_csv expects a list of segments and a file path 731 | output_csv(self.segments, file_path) 732 | else: 733 | raise ImportError("pywhispercpp.utils.output_csv not available.") 734 | 735 | self.update_status(f"Transcription successfully exported to {os.path.basename(file_path)}") 736 | except Exception as e: 737 | self.update_status(f"Error exporting to {format_type.upper()}: {e}") 738 | else: 739 | self.update_status("Export cancelled.") 740 | 741 | def copy_all_text_to_clipboard(self): 742 | """ 743 | Concatenates all text from segments and copies it to the clipboard. 744 | """ 745 | if not self.segments: 746 | self.update_status("No transcription data to copy.") 747 | return 748 | 749 | all_text = [] 750 | for segment in self.segments: 751 | all_text.append(segment.text.strip()) 752 | 753 | QApplication.clipboard().setText("\n".join(all_text)) 754 | self.update_status("Text copied to clipboard!") 755 | 756 | def show_about_dialog(self): 757 | """Opens a small dialog with About information.""" 758 | about_dialog = QDialog(self) 759 | about_dialog.setWindowTitle("About PyWhisperCPP Simple GUI") 760 | about_dialog.setFixedSize(400, 220) 761 | 762 | dialog_layout = QVBoxLayout(about_dialog) 763 | dialog_layout.setContentsMargins(20, 20, 20, 20) 764 | 765 | info_text = QLabel() 766 | info_text.setTextFormat(Qt.RichText) 767 | info_text.setText( 768 | "PyWhisperCPP Simple GUI
" 769 | f"Version {__version__}
" 770 | "
" 771 | "A simple graphical user interface for PyWhisperCpp Using PyQt.

" 772 | "
PyWhisperCpp GitHub repository
" 773 | "
" 774 | f"Copyright © {datetime.now().year}" 775 | ) 776 | info_text.setOpenExternalLinks(True) 777 | 778 | dialog_layout.addWidget(info_text) 779 | 780 | close_button = QPushButton("Close") 781 | close_button.clicked.connect(about_dialog.accept) 782 | dialog_layout.addWidget(close_button, alignment=Qt.AlignCenter) 783 | 784 | about_dialog.exec_() 785 | 786 | 787 | def _main(): 788 | """Main function to run the application.""" 789 | if Model is None: 790 | print("pywhispercpp is not installed.") 791 | print("Please install it by running: pip install pywhispercpp") 792 | print("You also need ffmpeg installed on your system.") 793 | return 794 | 795 | app = QApplication(sys.argv) 796 | ex = TranscriptionApp() 797 | ex.show() 798 | sys.exit(app.exec_()) 799 | 800 | 801 | if __name__ == '__main__': 802 | _main() 803 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | ******************************************************************************** 3 | * @file main.cpp 4 | * @author [absadiki](https://github.com/absadiki) 5 | * @date 2023 6 | * @brief Python bindings for [whisper.cpp](https://github.com/ggerganov/whisper.cpp) using Pybind11 7 | * 8 | * @par 9 | * COPYRIGHT NOTICE: (c) 2023. All rights reserved. 10 | ******************************************************************************** 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "whisper.h" 19 | 20 | 21 | #define STRINGIFY(x) #x 22 | #define MACRO_STRINGIFY(x) STRINGIFY(x) 23 | 24 | #define DEF_RELEASE_GIL(name, fn, doc) \ 25 | m.def(name, fn, doc, py::call_guard()) 26 | 27 | 28 | namespace py = pybind11; 29 | using namespace pybind11::literals; // to bring in the `_a` literal 30 | 31 | 32 | py::function py_new_segment_callback; 33 | py::function py_encoder_begin_callback; 34 | py::function py_logits_filter_callback; 35 | 36 | 37 | // whisper context wrapper, to solve the incomplete type issue 38 | // Thanks to https://github.com/pybind/pybind11/issues/2770 39 | struct whisper_context_wrapper { 40 | whisper_context* ptr; 41 | 42 | }; 43 | 44 | 45 | // struct inside params 46 | struct greedy{ 47 | int best_of; 48 | }; 49 | 50 | struct beam_search{ 51 | int beam_size; 52 | float patience; 53 | }; 54 | 55 | 56 | struct whisper_model_loader_wrapper { 57 | whisper_model_loader* ptr; 58 | 59 | }; 60 | 61 | struct whisper_context_wrapper whisper_init_from_file_wrapper(const char * path_model){ 62 | struct whisper_context * ctx = whisper_init_from_file(path_model); 63 | struct whisper_context_wrapper ctw_w; 64 | ctw_w.ptr = ctx; 65 | return ctw_w; 66 | } 67 | 68 | struct whisper_context_wrapper whisper_init_from_buffer_wrapper(void * buffer, size_t buffer_size){ 69 | struct whisper_context * ctx = whisper_init_from_buffer(buffer, buffer_size); 70 | struct whisper_context_wrapper ctw_w; 71 | ctw_w.ptr = ctx; 72 | return ctw_w; 73 | } 74 | 75 | struct whisper_context_wrapper whisper_init_wrapper(struct whisper_model_loader_wrapper * loader){ 76 | struct whisper_context * ctx = whisper_init(loader->ptr); 77 | struct whisper_context_wrapper ctw_w; 78 | ctw_w.ptr = ctx; 79 | return ctw_w; 80 | }; 81 | 82 | void whisper_free_wrapper(struct whisper_context_wrapper * ctx_w){ 83 | whisper_free(ctx_w->ptr); 84 | }; 85 | 86 | int whisper_pcm_to_mel_wrapper( 87 | struct whisper_context_wrapper * ctx, 88 | py::array_t samples, 89 | int n_samples, 90 | int n_threads){ 91 | py::buffer_info buf = samples.request(); 92 | float *samples_ptr = static_cast(buf.ptr); 93 | return whisper_pcm_to_mel(ctx->ptr, samples_ptr, n_samples, n_threads); 94 | }; 95 | 96 | int whisper_set_mel_wrapper( 97 | struct whisper_context_wrapper * ctx, 98 | py::array_t data, 99 | int n_len, 100 | int n_mel){ 101 | py::buffer_info buf = data.request(); 102 | float *data_ptr = static_cast(buf.ptr); 103 | return whisper_set_mel(ctx->ptr, data_ptr, n_len, n_mel); 104 | 105 | }; 106 | 107 | int whisper_n_len_wrapper(struct whisper_context_wrapper * ctx_w){ 108 | return whisper_n_len(ctx_w->ptr); 109 | }; 110 | 111 | int whisper_n_vocab_wrapper(struct whisper_context_wrapper * ctx_w){ 112 | return whisper_n_vocab(ctx_w->ptr); 113 | }; 114 | 115 | int whisper_n_text_ctx_wrapper(struct whisper_context_wrapper * ctx_w){ 116 | return whisper_n_text_ctx(ctx_w->ptr); 117 | }; 118 | 119 | int whisper_n_audio_ctx_wrapper(struct whisper_context_wrapper * ctx_w){ 120 | return whisper_n_audio_ctx(ctx_w->ptr); 121 | } 122 | 123 | int whisper_is_multilingual_wrapper(struct whisper_context_wrapper * ctx_w){ 124 | return whisper_is_multilingual(ctx_w->ptr); 125 | } 126 | 127 | 128 | float * whisper_get_logits_wrapper(struct whisper_context_wrapper * ctx_w){ 129 | return whisper_get_logits(ctx_w->ptr); 130 | }; 131 | 132 | const char * whisper_token_to_str_wrapper(struct whisper_context_wrapper * ctx_w, whisper_token token){ 133 | return whisper_token_to_str(ctx_w->ptr, token); 134 | }; 135 | 136 | py::bytes whisper_token_to_bytes_wrapper(struct whisper_context_wrapper * ctx_w, whisper_token token){ 137 | const char* str = whisper_token_to_str(ctx_w->ptr, token); 138 | size_t l = strlen(str); 139 | return py::bytes(str, l); 140 | } 141 | 142 | whisper_token whisper_token_eot_wrapper(struct whisper_context_wrapper * ctx_w){ 143 | return whisper_token_eot(ctx_w->ptr); 144 | } 145 | 146 | whisper_token whisper_token_sot_wrapper(struct whisper_context_wrapper * ctx_w){ 147 | return whisper_token_sot(ctx_w->ptr); 148 | } 149 | 150 | whisper_token whisper_token_prev_wrapper(struct whisper_context_wrapper * ctx_w){ 151 | return whisper_token_prev(ctx_w->ptr); 152 | } 153 | 154 | whisper_token whisper_token_solm_wrapper(struct whisper_context_wrapper * ctx_w){ 155 | return whisper_token_solm(ctx_w->ptr); 156 | } 157 | 158 | whisper_token whisper_token_not_wrapper(struct whisper_context_wrapper * ctx_w){ 159 | return whisper_token_not(ctx_w->ptr); 160 | } 161 | 162 | whisper_token whisper_token_beg_wrapper(struct whisper_context_wrapper * ctx_w){ 163 | return whisper_token_beg(ctx_w->ptr); 164 | } 165 | 166 | whisper_token whisper_token_lang_wrapper(struct whisper_context_wrapper * ctx_w, int lang_id){ 167 | return whisper_token_lang(ctx_w->ptr, lang_id); 168 | } 169 | 170 | whisper_token whisper_token_translate_wrapper(struct whisper_context_wrapper * ctx_w){ 171 | return whisper_token_translate(ctx_w->ptr); 172 | } 173 | 174 | whisper_token whisper_token_transcribe_wrapper(struct whisper_context_wrapper * ctx_w){ 175 | return whisper_token_transcribe(ctx_w->ptr); 176 | } 177 | 178 | void whisper_print_timings_wrapper(struct whisper_context_wrapper * ctx_w){ 179 | return whisper_print_timings(ctx_w->ptr); 180 | } 181 | 182 | void whisper_reset_timings_wrapper(struct whisper_context_wrapper * ctx_w){ 183 | return whisper_reset_timings(ctx_w->ptr); 184 | } 185 | 186 | int whisper_encode_wrapper( 187 | struct whisper_context_wrapper * ctx, 188 | int offset, 189 | int n_threads){ 190 | return whisper_encode(ctx->ptr, offset, n_threads); 191 | } 192 | 193 | 194 | int whisper_decode_wrapper( 195 | struct whisper_context_wrapper * ctx, 196 | const whisper_token * tokens, 197 | int n_tokens, 198 | int n_past, 199 | int n_threads){ 200 | return whisper_decode(ctx->ptr, tokens, n_tokens, n_past, n_threads); 201 | }; 202 | 203 | int whisper_tokenize_wrapper( 204 | struct whisper_context_wrapper * ctx, 205 | const char * text, 206 | whisper_token * tokens, 207 | int n_max_tokens){ 208 | return whisper_tokenize(ctx->ptr, text, tokens, n_max_tokens); 209 | }; 210 | 211 | int whisper_lang_auto_detect_wrapper( 212 | struct whisper_context_wrapper * ctx, 213 | int offset_ms, 214 | int n_threads, 215 | py::array_t lang_probs){ 216 | 217 | py::buffer_info buf = lang_probs.request(); 218 | float *lang_probs_ptr = static_cast(buf.ptr); 219 | return whisper_lang_auto_detect(ctx->ptr, offset_ms, n_threads, lang_probs_ptr); 220 | 221 | } 222 | 223 | int whisper_full_wrapper( 224 | struct whisper_context_wrapper * ctx_w, 225 | struct whisper_full_params params, 226 | py::array_t samples, 227 | int n_samples){ 228 | py::buffer_info buf = samples.request(); 229 | float *samples_ptr = static_cast(buf.ptr); 230 | 231 | py::gil_scoped_release release; 232 | return whisper_full(ctx_w->ptr, params, samples_ptr, n_samples); 233 | } 234 | 235 | int whisper_full_parallel_wrapper( 236 | struct whisper_context_wrapper * ctx_w, 237 | struct whisper_full_params params, 238 | py::array_t samples, 239 | int n_samples, 240 | int n_processors){ 241 | py::buffer_info buf = samples.request(); 242 | float *samples_ptr = static_cast(buf.ptr); 243 | 244 | py::gil_scoped_release release; 245 | return whisper_full_parallel(ctx_w->ptr, params, samples_ptr, n_samples, n_processors); 246 | } 247 | 248 | 249 | int whisper_full_n_segments_wrapper(struct whisper_context_wrapper * ctx){ 250 | py::gil_scoped_release release; 251 | return whisper_full_n_segments(ctx->ptr); 252 | } 253 | 254 | int whisper_full_lang_id_wrapper(struct whisper_context_wrapper * ctx){ 255 | return whisper_full_lang_id(ctx->ptr); 256 | } 257 | 258 | int64_t whisper_full_get_segment_t0_wrapper(struct whisper_context_wrapper * ctx, int i_segment){ 259 | return whisper_full_get_segment_t0(ctx->ptr, i_segment); 260 | } 261 | 262 | int64_t whisper_full_get_segment_t1_wrapper(struct whisper_context_wrapper * ctx, int i_segment){ 263 | return whisper_full_get_segment_t1(ctx->ptr, i_segment); 264 | } 265 | 266 | // https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html 267 | const py::bytes whisper_full_get_segment_text_wrapper(struct whisper_context_wrapper * ctx, int i_segment){ 268 | const char * c_array = whisper_full_get_segment_text(ctx->ptr, i_segment); 269 | size_t length = strlen(c_array); // Determine the length of the array 270 | return py::bytes(c_array, length); // Return the data without transcoding 271 | }; 272 | 273 | int whisper_full_n_tokens_wrapper(struct whisper_context_wrapper * ctx, int i_segment){ 274 | return whisper_full_n_tokens(ctx->ptr, i_segment); 275 | } 276 | 277 | const char * whisper_full_get_token_text_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){ 278 | return whisper_full_get_token_text(ctx->ptr, i_segment, i_token); 279 | } 280 | 281 | whisper_token whisper_full_get_token_id_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){ 282 | return whisper_full_get_token_id(ctx->ptr, i_segment, i_token); 283 | } 284 | 285 | whisper_token_data whisper_full_get_token_data_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){ 286 | return whisper_full_get_token_data(ctx->ptr, i_segment, i_token); 287 | } 288 | 289 | float whisper_full_get_token_p_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){ 290 | return whisper_full_get_token_p(ctx->ptr, i_segment, i_token); 291 | } 292 | 293 | int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * ctx, const char * model_path, 294 | const char * device, 295 | const char * cache_dir){ 296 | return whisper_ctx_init_openvino_encoder(ctx->ptr, model_path, device, cache_dir); 297 | } 298 | 299 | class WhisperFullParamsWrapper : public whisper_full_params { 300 | std::string initial_prompt_str; 301 | std::string suppress_regex_str; 302 | public: 303 | py::function py_progress_callback; 304 | WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params()) 305 | : whisper_full_params(params), 306 | initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""), 307 | suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") { 308 | initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); 309 | suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); 310 | // progress callback 311 | progress_callback_user_data = this; 312 | progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { 313 | auto* self = static_cast(user_data); 314 | if(self && self->print_progress){ 315 | if (self->py_progress_callback) { 316 | // call the python callback 317 | py::gil_scoped_acquire gil; 318 | self->py_progress_callback(progress); // Call Python callback 319 | } 320 | else { 321 | fprintf(stderr, "Progress: %3d%%\n", progress); 322 | } // Default message 323 | } 324 | } ; 325 | } 326 | 327 | WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other) 328 | : WhisperFullParamsWrapper(static_cast(other)) {} 329 | 330 | void set_initial_prompt(const std::string& prompt) { 331 | initial_prompt_str = prompt; 332 | initial_prompt = initial_prompt_str.c_str(); 333 | } 334 | 335 | void set_suppress_regex(const std::string& regex) { 336 | suppress_regex_str = regex; 337 | suppress_regex = suppress_regex_str.c_str(); 338 | } 339 | }; 340 | 341 | WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) { 342 | return WhisperFullParamsWrapper(whisper_full_default_params(strategy)); 343 | } 344 | 345 | // callbacks mechanism 346 | 347 | void _new_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data){ 348 | struct whisper_context_wrapper ctx_w; 349 | ctx_w.ptr = ctx; 350 | // call the python callback 351 | py::gil_scoped_acquire gil; // Acquire the GIL while in this scope. 352 | py_new_segment_callback(ctx_w, n_new, user_data); 353 | }; 354 | 355 | void assign_new_segment_callback(struct whisper_full_params *params, py::function f){ 356 | params->new_segment_callback = _new_segment_callback; 357 | py_new_segment_callback = f; 358 | }; 359 | 360 | bool _encoder_begin_callback(struct whisper_context * ctx, struct whisper_state * state, void * user_data){ 361 | struct whisper_context_wrapper ctx_w; 362 | ctx_w.ptr = ctx; 363 | // call the python callback 364 | py::object result_py = py_encoder_begin_callback(ctx_w, user_data); 365 | bool res = result_py.cast(); 366 | return res; 367 | } 368 | 369 | void assign_encoder_begin_callback(struct whisper_full_params *params, py::function f){ 370 | params->encoder_begin_callback = _encoder_begin_callback; 371 | py_encoder_begin_callback = f; 372 | } 373 | 374 | void _logits_filter_callback( 375 | struct whisper_context * ctx, 376 | struct whisper_state * state, 377 | const whisper_token_data * tokens, 378 | int n_tokens, 379 | float * logits, 380 | void * user_data){ 381 | struct whisper_context_wrapper ctx_w; 382 | ctx_w.ptr = ctx; 383 | // call the python callback 384 | py_logits_filter_callback(ctx_w, n_tokens, logits, user_data); 385 | } 386 | 387 | void assign_logits_filter_callback(struct whisper_full_params *params, py::function f){ 388 | params->logits_filter_callback = _logits_filter_callback; 389 | py_logits_filter_callback = f; 390 | } 391 | 392 | py::dict get_greedy(whisper_full_params * params){ 393 | py::dict d("best_of"_a=params->greedy.best_of); 394 | return d; 395 | } 396 | 397 | PYBIND11_MODULE(_pywhispercpp, m) { 398 | m.doc() = R"pbdoc( 399 | Pywhispercpp: Python binding to whisper.cpp 400 | ----------------------- 401 | 402 | .. currentmodule:: _whispercpp 403 | 404 | .. autosummary:: 405 | :toctree: _generate 406 | 407 | )pbdoc"; 408 | 409 | m.attr("WHISPER_SAMPLE_RATE") = WHISPER_SAMPLE_RATE; 410 | m.attr("WHISPER_N_FFT") = WHISPER_N_FFT; 411 | m.attr("WHISPER_HOP_LENGTH") = WHISPER_HOP_LENGTH; 412 | m.attr("WHISPER_CHUNK_SIZE") = WHISPER_CHUNK_SIZE; 413 | 414 | py::class_(m, "whisper_context"); 415 | py::class_(m, "whisper_token") 416 | .def(py::init<>()); 417 | py::class_(m,"whisper_token_data") 418 | .def(py::init<>()) 419 | .def_readwrite("id", &whisper_token_data::id) 420 | .def_readwrite("tid", &whisper_token_data::tid) 421 | .def_readwrite("p", &whisper_token_data::p) 422 | .def_readwrite("plog", &whisper_token_data::plog) 423 | .def_readwrite("pt", &whisper_token_data::pt) 424 | .def_readwrite("ptsum", &whisper_token_data::ptsum) 425 | .def_readwrite("t0", &whisper_token_data::t0) 426 | .def_readwrite("t1", &whisper_token_data::t1) 427 | .def_readwrite("vlen", &whisper_token_data::vlen); 428 | 429 | py::class_(m,"whisper_model_loader") 430 | .def(py::init<>()); 431 | 432 | DEF_RELEASE_GIL("whisper_init_from_file", &whisper_init_from_file_wrapper, "Various functions for loading a ggml whisper model.\n" 433 | "Allocate (almost) all memory needed for the model.\n" 434 | "Return NULL on failure"); 435 | DEF_RELEASE_GIL("whisper_init_from_buffer", &whisper_init_from_buffer_wrapper, "Various functions for loading a ggml whisper model.\n" 436 | "Allocate (almost) all memory needed for the model.\n" 437 | "Return NULL on failure"); 438 | DEF_RELEASE_GIL("whisper_init", &whisper_init_wrapper, "Various functions for loading a ggml whisper model.\n" 439 | "Allocate (almost) all memory needed for the model.\n" 440 | "Return NULL on failure"); 441 | 442 | 443 | m.def("whisper_free", &whisper_free_wrapper, "Frees all memory allocated by the model."); 444 | 445 | m.def("whisper_pcm_to_mel", &whisper_pcm_to_mel_wrapper, "Convert RAW PCM audio to log mel spectrogram.\n" 446 | "The resulting spectrogram is stored inside the provided whisper context.\n" 447 | "Returns 0 on success"); 448 | 449 | m.def("whisper_set_mel", &whisper_set_mel_wrapper, " This can be used to set a custom log mel spectrogram inside the provided whisper context.\n" 450 | "Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.\n" 451 | "n_mel must be 80\n" 452 | "Returns 0 on success"); 453 | 454 | m.def("whisper_encode", &whisper_encode_wrapper, "Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.\n" 455 | "Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.\n" 456 | "offset can be used to specify the offset of the first frame in the spectrogram.\n" 457 | "Returns 0 on success"); 458 | 459 | m.def("whisper_decode", &whisper_decode_wrapper, "Run the Whisper decoder to obtain the logits and probabilities for the next token.\n" 460 | "Make sure to call whisper_encode() first.\n" 461 | "tokens + n_tokens is the provided context for the decoder.\n" 462 | "n_past is the number of tokens to use from previous decoder calls.\n" 463 | "Returns 0 on success\n" 464 | "TODO: add support for multiple decoders"); 465 | 466 | m.def("whisper_tokenize", &whisper_tokenize_wrapper, "Convert the provided text into tokens.\n" 467 | "The tokens pointer must be large enough to hold the resulting tokens.\n" 468 | "Returns the number of tokens on success, no more than n_max_tokens\n" 469 | "Returns -1 on failure\n" 470 | "TODO: not sure if correct"); 471 | 472 | m.def("whisper_lang_max_id", &whisper_lang_max_id, "Largest language id (i.e. number of available languages - 1)"); 473 | m.def("whisper_lang_id", &whisper_lang_id, "Return the id of the specified language, returns -1 if not found\n" 474 | "Examples:\n" 475 | "\"de\" -> 2\n" 476 | "\"german\" -> 2"); 477 | m.def("whisper_lang_str", &whisper_lang_str, "Return the short string of the specified language id (e.g. 2 -> \"de\"), returns nullptr if not found"); 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | m.def("whisper_lang_auto_detect", &whisper_lang_auto_detect_wrapper, "Use mel data at offset_ms to try and auto-detect the spoken language\n" 486 | "Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first\n" 487 | "Returns the top language id or negative on failure\n" 488 | "If not null, fills the lang_probs array with the probabilities of all languages\n" 489 | "The array must be whispe_lang_max_id() + 1 in size\n" 490 | "ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69\n"); 491 | m.def("whisper_n_len", &whisper_n_len_wrapper, "whisper_n_len"); 492 | m.def("whisper_n_vocab", &whisper_n_vocab_wrapper, "wrapper_whisper_n_vocab"); 493 | m.def("whisper_n_text_ctx", &whisper_n_text_ctx_wrapper, "whisper_n_text_ctx"); 494 | m.def("whisper_n_audio_ctx", &whisper_n_audio_ctx_wrapper, "whisper_n_audio_ctx"); 495 | m.def("whisper_is_multilingual", &whisper_is_multilingual_wrapper, "whisper_is_multilingual"); 496 | m.def("whisper_get_logits", &whisper_get_logits_wrapper, "Token logits obtained from the last call to whisper_decode()\n" 497 | "The logits for the last token are stored in the last row\n" 498 | "Rows: n_tokens\n" 499 | "Cols: n_vocab"); 500 | 501 | 502 | m.def("whisper_token_to_str", &whisper_token_to_str_wrapper, "whisper_token_to_str"); 503 | m.def("whisper_token_to_bytes", &whisper_token_to_bytes_wrapper, "whisper_token_to_bytes"); 504 | m.def("whisper_token_eot", &whisper_token_eot_wrapper, "whisper_token_eot"); 505 | m.def("whisper_token_sot", &whisper_token_sot_wrapper, "whisper_token_sot"); 506 | m.def("whisper_token_prev", &whisper_token_prev_wrapper); 507 | m.def("whisper_token_solm", &whisper_token_solm_wrapper); 508 | m.def("whisper_token_not", &whisper_token_not_wrapper); 509 | m.def("whisper_token_beg", &whisper_token_beg_wrapper); 510 | m.def("whisper_token_lang", &whisper_token_lang_wrapper); 511 | 512 | m.def("whisper_token_translate", &whisper_token_translate_wrapper); 513 | m.def("whisper_token_transcribe", &whisper_token_transcribe_wrapper); 514 | 515 | m.def("whisper_print_timings", &whisper_print_timings_wrapper); 516 | m.def("whisper_reset_timings", &whisper_reset_timings_wrapper); 517 | 518 | m.def("whisper_print_system_info", &whisper_print_system_info); 519 | 520 | 521 | 522 | ////////////////////// 523 | 524 | py::enum_(m, "whisper_sampling_strategy") 525 | .value("WHISPER_SAMPLING_GREEDY", whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY) 526 | .value("WHISPER_SAMPLING_BEAM_SEARCH", whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) 527 | .export_values(); 528 | 529 | py::class_(m, "__whisper_full_params__internal") 530 | .def(py::init<>()) 531 | .def("__repr__", [](const whisper_full_params& self) { 532 | std::ostringstream oss; 533 | oss << "whisper_full_params(" 534 | << "strategy=" << self.strategy << ", " 535 | << "n_threads=" << self.n_threads << ", " 536 | << "n_max_text_ctx=" << self.n_max_text_ctx << ", " 537 | << "offset_ms=" << self.offset_ms << ", " 538 | << "duration_ms=" << self.duration_ms << ", " 539 | << "translate=" << (self.translate ? "True" : "False") << ", " 540 | << "no_context=" << (self.no_context ? "True" : "False") << ", " 541 | << "no_timestamps=" << (self.no_timestamps ? "True" : "False") << ", " 542 | << "single_segment=" << (self.single_segment ? "True" : "False") << ", " 543 | << "print_special=" << (self.print_special ? "True" : "False") << ", " 544 | << "print_progress=" << (self.print_progress ? "True" : "False") << ", " 545 | << "print_realtime=" << (self.print_realtime ? "True" : "False") << ", " 546 | << "print_timestamps=" << (self.print_timestamps ? "True" : "False") << ", " 547 | << "token_timestamps=" << (self.token_timestamps ? "True" : "False") << ", " 548 | << "thold_pt=" << self.thold_pt << ", " 549 | << "thold_ptsum=" << self.thold_ptsum << ", " 550 | << "max_len=" << self.max_len << ", " 551 | << "split_on_word=" << (self.split_on_word ? "True" : "False") << ", " 552 | << "max_tokens=" << self.max_tokens << ", " 553 | << "debug_mode=" << (self.debug_mode ? "True" : "False") << ", " 554 | << "audio_ctx=" << self.audio_ctx << ", " 555 | << "tdrz_enable=" << (self.tdrz_enable ? "True" : "False") << ", " 556 | << "suppress_regex=" << (self.suppress_regex ? self.suppress_regex : "None") << ", " 557 | << "initial_prompt=" << (self.initial_prompt ? self.initial_prompt : "None") << ", " 558 | << "prompt_tokens=" << (self.prompt_tokens ? "(whisper_token *)" : "None") << ", " 559 | << "prompt_n_tokens=" << self.prompt_n_tokens << ", " 560 | << "language=" << (self.language ? self.language : "None") << ", " 561 | << "detect_language=" << (self.detect_language ? "True" : "False") << ", " 562 | << "suppress_blank=" << (self.suppress_blank ? "True" : "False") << ", " 563 | << "temperature=" << self.temperature << ", " 564 | << "max_initial_ts=" << self.max_initial_ts << ", " 565 | << "length_penalty=" << self.length_penalty << ", " 566 | << "temperature_inc=" << self.temperature_inc << ", " 567 | << "entropy_thold=" << self.entropy_thold << ", " 568 | << "logprob_thold=" << self.logprob_thold << ", " 569 | << "no_speech_thold=" << self.no_speech_thold << ", " 570 | << "greedy={best_of=" << self.greedy.best_of << "}, " 571 | << "beam_search={beam_size=" << self.beam_search.beam_size << ", patience=" << self.beam_search.patience << "}, " 572 | << "new_segment_callback=" << (self.new_segment_callback ? "(function pointer)" : "None") << ", " 573 | << "progress_callback=" << (self.progress_callback ? "(function pointer)" : "None") << ", " 574 | << "encoder_begin_callback=" << (self.encoder_begin_callback ? "(function pointer)" : "None") << ", " 575 | << "abort_callback=" << (self.abort_callback ? "(function pointer)" : "None") << ", " 576 | << "logits_filter_callback=" << (self.logits_filter_callback ? "(function pointer)" : "None") << ", " 577 | << "grammar_rules=" << (self.grammar_rules ? "(whisper_grammar_element **)" : "None") << ", " 578 | << "n_grammar_rules=" << self.n_grammar_rules << ", " 579 | << "i_start_rule=" << self.i_start_rule << ", " 580 | << "grammar_penalty=" << self.grammar_penalty 581 | << ")"; 582 | return oss.str(); 583 | }); 584 | 585 | py::class_(m, "whisper_full_params") 586 | .def(py::init<>()) 587 | .def_readwrite("strategy", &WhisperFullParamsWrapper::strategy) 588 | .def_readwrite("n_threads", &WhisperFullParamsWrapper::n_threads) 589 | .def_readwrite("n_max_text_ctx", &WhisperFullParamsWrapper::n_max_text_ctx) 590 | .def_readwrite("offset_ms", &WhisperFullParamsWrapper::offset_ms) 591 | .def_readwrite("duration_ms", &WhisperFullParamsWrapper::duration_ms) 592 | .def_readwrite("translate", &WhisperFullParamsWrapper::translate) 593 | .def_readwrite("no_context", &WhisperFullParamsWrapper::no_context) 594 | .def_readwrite("single_segment", &WhisperFullParamsWrapper::single_segment) 595 | .def_readwrite("print_special", &WhisperFullParamsWrapper::print_special) 596 | .def_readwrite("print_progress", &WhisperFullParamsWrapper::print_progress) 597 | .def_readwrite("progress_callback", &WhisperFullParamsWrapper::py_progress_callback) 598 | .def_readwrite("print_realtime", &WhisperFullParamsWrapper::print_realtime) 599 | .def_readwrite("print_timestamps", &WhisperFullParamsWrapper::print_timestamps) 600 | .def_readwrite("token_timestamps", &WhisperFullParamsWrapper::token_timestamps) 601 | .def_readwrite("thold_pt", &WhisperFullParamsWrapper::thold_pt) 602 | .def_readwrite("thold_ptsum", &WhisperFullParamsWrapper::thold_ptsum) 603 | .def_readwrite("max_len", &WhisperFullParamsWrapper::max_len) 604 | .def_readwrite("split_on_word", &WhisperFullParamsWrapper::split_on_word) 605 | .def_readwrite("max_tokens", &WhisperFullParamsWrapper::max_tokens) 606 | .def_readwrite("audio_ctx", &WhisperFullParamsWrapper::audio_ctx) 607 | .def_property("suppress_regex", 608 | [](WhisperFullParamsWrapper &self) { 609 | return py::str(self.suppress_regex ? self.suppress_regex : ""); 610 | }, 611 | [](WhisperFullParamsWrapper &self, const std::string &new_c) { 612 | self.set_suppress_regex(new_c); 613 | }) 614 | .def_property("initial_prompt", 615 | [](WhisperFullParamsWrapper &self) { 616 | return py::str(self.initial_prompt ? self.initial_prompt : ""); 617 | }, 618 | [](WhisperFullParamsWrapper &self, const std::string &initial_prompt) { 619 | self.set_initial_prompt(initial_prompt); 620 | } 621 | ) 622 | .def_readwrite("prompt_tokens", &WhisperFullParamsWrapper::prompt_tokens) 623 | .def_readwrite("prompt_n_tokens", &WhisperFullParamsWrapper::prompt_n_tokens) 624 | .def_property("language", 625 | [](WhisperFullParamsWrapper &self) { 626 | return py::str(self.language); 627 | }, 628 | [](WhisperFullParamsWrapper &self, const char *new_c) {// using lang_id let us avoid issues with memory management 629 | const int lang_id = (new_c && strlen(new_c) > 0) ? whisper_lang_id(new_c) : -1; 630 | if (lang_id != -1) { 631 | self.language = whisper_lang_str(lang_id); 632 | } else { 633 | self.language = ""; //defaults to auto-detect 634 | } 635 | }) 636 | .def_readwrite("suppress_blank", &WhisperFullParamsWrapper::suppress_blank) 637 | .def_readwrite("temperature", &WhisperFullParamsWrapper::temperature) 638 | .def_readwrite("max_initial_ts", &WhisperFullParamsWrapper::max_initial_ts) 639 | .def_readwrite("length_penalty", &WhisperFullParamsWrapper::length_penalty) 640 | .def_readwrite("temperature_inc", &WhisperFullParamsWrapper::temperature_inc) 641 | .def_readwrite("entropy_thold", &WhisperFullParamsWrapper::entropy_thold) 642 | .def_readwrite("logprob_thold", &WhisperFullParamsWrapper::logprob_thold) 643 | .def_readwrite("no_speech_thold", &WhisperFullParamsWrapper::no_speech_thold) 644 | // little hack for the internal stuct 645 | .def_property("greedy", [](WhisperFullParamsWrapper &self) {return py::dict("best_of"_a=self.greedy.best_of);}, 646 | [](WhisperFullParamsWrapper &self, py::dict dict) {self.greedy.best_of = dict["best_of"].cast();}) 647 | .def_property("beam_search", [](WhisperFullParamsWrapper &self) {return py::dict("beam_size"_a=self.beam_search.beam_size, "patience"_a=self.beam_search.patience);}, 648 | [](WhisperFullParamsWrapper &self, py::dict dict) {self.beam_search.beam_size = dict["beam_size"].cast(); self.beam_search.patience = dict["patience"].cast();}) 649 | .def_readwrite("new_segment_callback_user_data", &WhisperFullParamsWrapper::new_segment_callback_user_data) 650 | .def_readwrite("encoder_begin_callback_user_data", &WhisperFullParamsWrapper::encoder_begin_callback_user_data) 651 | .def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data); 652 | 653 | 654 | py::implicitly_convertible(); 655 | 656 | m.def("whisper_full_default_params", &whisper_full_default_params_wrapper); 657 | 658 | m.def("whisper_full", &whisper_full_wrapper, "Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text\n" 659 | "Uses the specified decoding strategy to obtain the text.\n"); 660 | 661 | m.def("whisper_full_parallel", &whisper_full_parallel_wrapper, "Split the input audio in chunks and process each chunk separately using whisper_full()\n" 662 | "It seems this approach can offer some speedup in some cases.\n" 663 | "However, the transcription accuracy can be worse at the beginning and end of each chunk."); 664 | 665 | m.def("whisper_full_n_segments", &whisper_full_n_segments_wrapper, "Number of generated text segments.\n" 666 | "A segment can be a few words, a sentence, or even a paragraph.\n"); 667 | 668 | m.def("whisper_full_lang_id", &whisper_full_lang_id_wrapper, "Language id associated with the current context"); 669 | m.def("whisper_full_get_segment_t0", &whisper_full_get_segment_t0_wrapper, "Get the start time of the specified segment"); 670 | m.def("whisper_full_get_segment_t1", &whisper_full_get_segment_t1_wrapper, "Get the end time of the specified segment"); 671 | 672 | m.def("whisper_full_get_segment_text", &whisper_full_get_segment_text_wrapper, "Get the text of the specified segment"); 673 | m.def("whisper_full_n_tokens", &whisper_full_n_tokens_wrapper, "Get number of tokens in the specified segment."); 674 | 675 | m.def("whisper_full_get_token_text", &whisper_full_get_token_text_wrapper, "Get the token text of the specified token in the specified segment."); 676 | m.def("whisper_full_get_token_id", &whisper_full_get_token_id_wrapper, "Get the token text of the specified token in the specified segment."); 677 | 678 | m.def("whisper_full_get_token_data", &whisper_full_get_token_data_wrapper, "Get token data for the specified token in the specified segment.\n" 679 | "This contains probabilities, timestamps, etc."); 680 | 681 | m.def("whisper_full_get_token_p", &whisper_full_get_token_p_wrapper, "Get the probability of the specified token in the specified segment."); 682 | 683 | m.def("whisper_ctx_init_openvino_encoder", &whisper_ctx_init_openvino_encoder_wrapper, "Given a context, enable use of OpenVINO for encode inference."); 684 | 685 | 686 | //////////////////////////////////////////////////////////////////////////// 687 | 688 | m.def("whisper_bench_memcpy", &whisper_bench_memcpy, "Temporary helpers needed for exposing ggml interface"); 689 | m.def("whisper_bench_ggml_mul_mat", &whisper_bench_ggml_mul_mat, "Temporary helpers needed for exposing ggml interface"); 690 | 691 | //////////////////////////////////////////////////////////////////////////// 692 | // Helper mechanism to set callbacks from python 693 | // The only difference from the C-Style API 694 | 695 | m.def("assign_new_segment_callback", &assign_new_segment_callback, "Assigns a new_segment_callback, takes instance and a callable function with the same parameters which are defined in the interface", 696 | py::arg("params"), py::arg("callback")); 697 | 698 | m.def("assign_encoder_begin_callback", &assign_encoder_begin_callback, "Assigns an encoder_begin_callback, takes instance and a callable function with the same parameters which are defined in the interface", 699 | py::arg("params"), py::arg("callback")); 700 | 701 | m.def("assign_logits_filter_callback", &assign_logits_filter_callback, "Assigns a logits_filter_callback, takes instance and a callable function with the same parameters which are defined in the interface", 702 | py::arg("params"), py::arg("callback")); 703 | 704 | 705 | #ifdef VERSION_INFO 706 | m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); 707 | #else 708 | m.attr("__version__") = "dev"; 709 | #endif 710 | } 711 | --------------------------------------------------------------------------------