├── pywhispercpp
├── __init__.py
├── examples
│ ├── __init__.py
│ ├── recording.py
│ ├── main.py
│ ├── assistant.py
│ ├── livestream.py
│ └── gui.py
├── utils.py
├── constants.py
└── model.py
├── requirements.txt
├── .directory
├── .github
├── dependabot.yml
└── workflows
│ ├── docs.yml
│ ├── pip.yml
│ └── wheels.yml
├── .gitmodules
├── CMakeLists.txt
├── docs
└── index.md
├── MANIFEST.in
├── .appveyor.yml
├── LICENSE
├── tests
├── test_segfault.py
├── test_model.py
└── test_c_api.py
├── pyproject.toml
├── .pre-commit-config.yaml
├── mkdocs.yml
├── .gitignore
├── setup.py
├── README.md
└── src
└── main.cpp
/pywhispercpp/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pywhispercpp/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | sounddevice~=0.4.6
3 | webrtcvad~=2.0.10
4 | requests
5 | tqdm
6 | platformdirs
--------------------------------------------------------------------------------
/.directory:
--------------------------------------------------------------------------------
1 | [Dolphin]
2 | HeaderColumnWidths=499,145,72
3 | Timestamp=2023,3,9,19,49,1.313
4 | Version=4
5 | ViewMode=1
6 |
7 | [Settings]
8 | HiddenFilesShown=true
9 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | # Maintain dependencies for GitHub Actions
4 | - package-ecosystem: "github-actions"
5 | directory: "/"
6 | schedule:
7 | interval: "daily"
8 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "pybind11"]
2 | path = pybind11
3 | url = https://github.com/pybind/pybind11.git
4 | branch = master
5 |
6 | [submodule "whisper.cpp"]
7 | path = whisper.cpp
8 | url = https://github.com/ggml-org/whisper.cpp.git
9 | branch = master
10 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.4...3.18)
2 | project(pywhispercpp)
3 |
4 | add_subdirectory(pybind11)
5 | add_subdirectory(whisper.cpp)
6 |
7 | pybind11_add_module(_pywhispercpp
8 | src/main.cpp
9 | )
10 |
11 | target_link_libraries (_pywhispercpp PRIVATE whisper)
12 |
13 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # PyWhisperCpp API Reference
2 |
3 |
4 | ::: pywhispercpp.model
5 |
6 | ::: pywhispercpp.constants
7 | options:
8 | show_if_no_docstring: true
9 |
10 | ::: pywhispercpp.utils
11 |
12 | ::: pywhispercpp.examples
13 | options:
14 | show_if_no_docstring: false
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md LICENSE pybind11/LICENSE version.txt
2 | graft pybind11/include
3 | graft pybind11/tools
4 | graft src
5 | global-include CMakeLists.txt *.cmake
6 |
7 | graft whisper.cpp/cmake
8 | graft whisper.cpp/ggml
9 | graft whisper.cpp/grammars
10 | graft whisper.cpp/include
11 | graft whisper.cpp/spm-headers
12 | graft whisper.cpp/src
13 | exclude whisper.cpp/**/*.o
14 | exclude whisper.cpp/**/*.so
15 | exclude whisper.cpp/**/*.a
16 | exclude whisper.cpp/**/*.dylib
17 | exclude whisper.cpp/**/*.dll
18 | exclude whisper.cpp/**/*.lib
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | permissions:
7 | contents: write
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: actions/setup-python@v5
14 | with:
15 | python-version: '3.x'
16 | - uses: actions/cache@v4
17 | with:
18 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
19 | path: ~/.cache/pip
20 | - run: |
21 | pip install mkdocs-macros-plugin mkdocs-material mkdocstrings[python] black pywhispercpp
22 | - run: mkdocs gh-deploy --force
23 |
--------------------------------------------------------------------------------
/.appveyor.yml:
--------------------------------------------------------------------------------
1 | version: '{build}'
2 | image: Visual Studio 2019
3 | platform:
4 | - x86
5 | - x64
6 | environment:
7 | global:
8 | DISTUTILS_USE_SDK: 1
9 | PYTHONWARNINGS: ignore:DEPRECATION
10 | MSSdk: 1
11 | matrix:
12 | - PYTHON: 37
13 | install:
14 | - cmd: '"%VS140COMNTOOLS%\..\..\VC\vcvarsall.bat" %PLATFORM%'
15 | - ps: |
16 | git submodule update -q --init --recursive
17 | if ($env:PLATFORM -eq "x64") { $env:PYTHON = "$env:PYTHON-x64" }
18 | $env:PATH = "C:\Python$env:PYTHON\;C:\Python$env:PYTHON\Scripts\;$env:PATH"
19 | python -m pip install --disable-pip-version-check --upgrade --no-warn-script-location pip build pytest
20 | build_script:
21 | - ps: |
22 | python -m build -s
23 | cd dist
24 | python -m pip install --verbose cmake_example-0.0.1.tar.gz
25 | cd ..
26 | test_script:
27 | - ps: python -m pytest
28 |
--------------------------------------------------------------------------------
/.github/workflows/pip.yml:
--------------------------------------------------------------------------------
1 | name: Pip
2 |
3 | on:
4 | workflow_dispatch:
5 | pull_request:
6 | push:
7 | branches:
8 | - main
9 |
10 | jobs:
11 | build:
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | platform: [windows-latest, macos-14, ubuntu-latest]
16 | python-version: ["3.8", "3.11"]
17 |
18 | runs-on: ${{ matrix.platform }}
19 |
20 | steps:
21 | - uses: actions/checkout@v4
22 | with:
23 | submodules: true
24 |
25 | - uses: actions/setup-python@v5
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 |
29 | - name: Add requirements
30 | run: python -m pip install --upgrade wheel setuptools
31 |
32 | - name: Install requirements
33 | run: python -m pip install -r requirements.txt
34 |
35 | - name: Build and install
36 | run: pip install --verbose .[test]
37 |
38 | # - name: Test C-API
39 | # run: python -m unittest ./tests/test_c_api.py
40 |
41 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Abdeladim Sadiki
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/test_segfault.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import unittest
6 | from pathlib import Path
7 | from unittest import TestCase
8 |
9 | from pywhispercpp.model import Model, Segment
10 |
11 | if __name__ == '__main__':
12 | pass
13 |
14 | WHISPER_CPP_DIR = Path(__file__).parent.parent / 'whisper.cpp'
15 |
16 | class TestSegfault(TestCase):
17 | audio_file = WHISPER_CPP_DIR/ 'samples/jfk.wav'
18 |
19 | def voice_to_text(self, tmp_path):
20 | # n_threads=1 is 3x faster than n_threads=6 when running in Docker.
21 | whisper_model = Model('tiny.en-q5_1', n_threads=1)
22 | segments: list[Segment] = whisper_model.transcribe(tmp_path)
23 | text = next((segment.text for segment in segments if segment.text and '(' not in segment.text and '[' not in segment.text))
24 | return text
25 |
26 | def test_sample_file(self):
27 | expected_text = "ask not what your country can do for you"
28 | text = self.voice_to_text(str(self.audio_file))
29 | self.assertIn(expected_text.lower(), text.lower(),
30 | f"Expected text '{expected_text}' not found in transcription: '{text}'")
31 |
32 |
33 | if __name__ == '__main__':
34 | unittest.main()
35 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel",
5 | "ninja",
6 | "cmake>=3.12",
7 | "repairwheel",
8 | "setuptools-scm>=8"
9 | ]
10 | build-backend = "setuptools.build_meta"
11 |
12 | [tool.mypy]
13 | files = "setup.py"
14 | python_version = "3.8"
15 | strict = true
16 | show_error_codes = true
17 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
18 | warn_unreachable = true
19 |
20 | [[tool.mypy.overrides]]
21 | module = ["ninja"]
22 | ignore_missing_imports = true
23 |
24 |
25 | [tool.pytest.ini_options]
26 | minversion = "6.0"
27 | addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
28 | xfail_strict = true
29 | filterwarnings = ["error"]
30 | testpaths = ["tests"]
31 |
32 | [tool.cibuildwheel]
33 | #test-command = "pytest {project}/tests"
34 | #test-extras = ["test"]
35 | test-skip = ["*universal2:arm64"]
36 | # Setuptools bug causes collision between pypy and cpython artifacts
37 | before-build = "rm -rf {project}/build"
38 |
39 | [tool.ruff]
40 | extend-select = [
41 | "B", # flake8-bugbear
42 | "B904",
43 | "I", # isort
44 | "PGH", # pygrep-hooks
45 | "RUF", # Ruff-specific
46 | "UP", # pyupgrade
47 | ]
48 | extend-ignore = [
49 | "E501", # Line too long
50 | ]
51 | target-version = "py39"
52 |
53 | [tool.setuptools_scm]
54 | version_file = "_version.py"
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Test model.py
6 | """
7 | import unittest
8 | from pathlib import Path
9 | from unittest import TestCase
10 |
11 | from pywhispercpp.model import Model, Segment
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
16 | WHISPER_CPP_DIR = Path(__file__).parent.parent / 'whisper.cpp'
17 |
18 | class TestModel(TestCase):
19 | audio_file = WHISPER_CPP_DIR/ 'samples/jfk.wav'
20 | model = Model("tiny", models_dir=str(WHISPER_CPP_DIR/'models'))
21 |
22 | def test_transcribe(self):
23 | segments = self.model.transcribe(str(self.audio_file))
24 | return self.assertIsInstance(segments, list) and \
25 | self.assertIsInstance(segments[0], Segment) if len(segments) > 0 else True
26 |
27 | def test_get_params(self):
28 | params = self.model.get_params()
29 | return self.assertIsInstance(params, dict)
30 |
31 | def test_lang_max_id(self):
32 | n = self.model.lang_max_id()
33 | return self.assertGreater(n, 0)
34 |
35 | def test_available_languages(self):
36 | av_langs = self.model.available_languages()
37 | return self.assertIsInstance(av_langs, list) and self.assertGreater(len(av_langs), 1)
38 |
39 | def test__load_audio(self):
40 | audio_arr = self.model._load_audio(str(self.audio_file))
41 | return self.assertIsNotNone(audio_arr)
42 |
43 | def test_auto_detect_language(self):
44 | detected_language, probs = self.model.auto_detect_language(str(self.audio_file))
45 | return self.assertIsInstance(detected_language, tuple) and self.assertEqual(detected_language[0], 'en')
46 |
47 |
48 | if __name__ == '__main__':
49 | unittest.main()
50 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # To use:
2 | #
3 | # pre-commit run -a
4 | #
5 | # Or:
6 | #
7 | # pre-commit install # (runs every time you commit in git)
8 | #
9 | # To update this file:
10 | #
11 | # pre-commit autoupdate
12 | #
13 | # See https://github.com/pre-commit/pre-commit
14 |
15 | ci:
16 | autoupdate_commit_msg: "chore: update pre-commit hooks"
17 | autofix_commit_msg: "style: pre-commit fixes"
18 |
19 | repos:
20 | # Standard hooks
21 | - repo: https://github.com/pre-commit/pre-commit-hooks
22 | rev: v4.4.0
23 | hooks:
24 | - id: check-added-large-files
25 | - id: check-case-conflict
26 | - id: check-merge-conflict
27 | - id: check-symlinks
28 | - id: check-yaml
29 | exclude: ^conda\.recipe/meta\.yaml$
30 | - id: debug-statements
31 | - id: end-of-file-fixer
32 | - id: mixed-line-ending
33 | - id: requirements-txt-fixer
34 | - id: trailing-whitespace
35 |
36 | # Black, the code formatter, natively supports pre-commit
37 | - repo: https://github.com/psf/black
38 | rev: 23.1.0
39 | hooks:
40 | - id: black
41 | exclude: ^(docs)
42 |
43 | - repo: https://github.com/charliermarsh/ruff-pre-commit
44 | rev: "v0.0.253"
45 | hooks:
46 | - id: ruff
47 | args: ["--fix", "--show-fixes"]
48 |
49 | # Checking static types
50 | - repo: https://github.com/pre-commit/mirrors-mypy
51 | rev: "v1.0.1"
52 | hooks:
53 | - id: mypy
54 | files: "setup.py"
55 | args: []
56 | additional_dependencies: [types-setuptools]
57 |
58 | # Changes tabs to spaces
59 | - repo: https://github.com/Lucas-C/pre-commit-hooks
60 | rev: v1.4.2
61 | hooks:
62 | - id: remove-tabs
63 | exclude: ^(docs)
64 |
65 | # CMake formatting
66 | - repo: https://github.com/cheshirekow/cmake-format-precommit
67 | rev: v0.6.13
68 | hooks:
69 | - id: cmake-format
70 | additional_dependencies: [pyyaml]
71 | types: [file]
72 | files: (\.cmake|CMakeLists.txt)(.in)?$
73 |
74 | # Suggested hook if you add a .clang-format file
75 | # - repo: https://github.com/pre-commit/mirrors-clang-format
76 | # rev: v13.0.0
77 | # hooks:
78 | # - id: clang-format
79 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: PyWhisperCpp
2 | repo_url: https://github.com/absadiki/pywhispercpp
3 | repo_name: absadiki/pywhispercpp
4 | theme:
5 | name: material
6 | language: en
7 | features:
8 | - navigation.tabs
9 | - navigation.sections
10 | - toc.integrate
11 | - navigation.top
12 | - search.suggest
13 | - search.highlight
14 | - content.tabs.link
15 | - content.code.annotation
16 | - content.code.copy
17 | palette:
18 | - scheme: default
19 | toggle:
20 | icon: material/toggle-switch-off-outline
21 | name: Switch to dark mode
22 | - scheme: slate
23 | toggle:
24 | icon: material/toggle-switch
25 | name: Switch to light mode
26 |
27 |
28 | extra:
29 | social:
30 | - icon: fontawesome/brands/github-alt
31 | link: https://github.com/absadiki/pywhispercpp
32 |
33 | markdown_extensions:
34 | - pymdownx.inlinehilite
35 | - pymdownx.snippets
36 | - admonition
37 | - pymdownx.arithmatex:
38 | generic: true
39 | - footnotes
40 | - pymdownx.details
41 | - pymdownx.superfences
42 | - pymdownx.mark
43 | - attr_list
44 | - pymdownx.emoji:
45 | emoji_index: !!python/name:materialx.emoji.twemoji
46 | emoji_generator: !!python/name:materialx.emoji.to_svg
47 | - pymdownx.highlight:
48 | use_pygments: true
49 | pygments_lang_class: true
50 |
51 | copyright: |
52 | © 2023 absadiki
53 |
54 | plugins:
55 | - search
56 | - macros:
57 | include_dir: .
58 | - mkdocstrings:
59 | handlers:
60 | python:
61 | paths: [src]
62 | options:
63 | separate_signature: true
64 | docstring_style: sphinx
65 | docstring_section_style: list
66 | members_order: source
67 | merge_init_into_class: true
68 | show_bases: true
69 | show_if_no_docstring: false
70 | show_root_full_path: true
71 | show_root_heading: true
72 | show_submodules: true
73 | filters:
74 | - "!^_"
75 | watch:
76 | - pywhispercpp/
77 |
78 |
--------------------------------------------------------------------------------
/pywhispercpp/examples/recording.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | A simple example showcasing how to use pywhispercpp to transcribe a recording.
6 | """
7 | import argparse
8 | import logging
9 | import sounddevice as sd
10 | import pywhispercpp.constants
11 | from pywhispercpp.model import Model
12 | import importlib.metadata
13 |
14 |
15 | __version__ = importlib.metadata.version('pywhispercpp')
16 |
17 | __header__ = f"""
18 | ===================================================================
19 | PyWhisperCpp
20 | A simple example of transcribing a recording, based on whisper.cpp
21 | Version: {__version__}
22 | ===================================================================
23 | """
24 |
25 |
26 | class Recording:
27 | """
28 | Recording class
29 |
30 | Example usage
31 | ```python
32 | from pywhispercpp.examples.recording import Recording
33 |
34 | myrec = Recording(5)
35 | myrec.start()
36 | ```
37 | """
38 | def __init__(self,
39 | duration: int,
40 | model: str = 'tiny.en',
41 | **model_params):
42 | self.duration = duration
43 | self.sample_rate = pywhispercpp.constants.WHISPER_SAMPLE_RATE
44 | self.channels = 1
45 | self.pwcpp_model = Model(model, print_realtime=True, **model_params)
46 |
47 | def start(self):
48 | logging.info(f"Start recording for {self.duration}s ...")
49 | recording = sd.rec(int(self.duration * self.sample_rate), samplerate=self.sample_rate, channels=self.channels)
50 | sd.wait()
51 | logging.info('Duration finished')
52 | res = self.pwcpp_model.transcribe(recording)
53 | self.pwcpp_model.print_timings()
54 |
55 |
56 | def _main():
57 | print(__header__)
58 | parser = argparse.ArgumentParser(description="", allow_abbrev=True)
59 | # Positional args
60 | parser.add_argument('duration', type=int, help=f"duration in seconds")
61 | parser.add_argument('-m', '--model', default='tiny.en', type=str, help="Whisper.cpp model, default to %(default)s")
62 |
63 | args = parser.parse_args()
64 |
65 | myrec = Recording(duration=args.duration, model=args.model)
66 | myrec.start()
67 |
68 |
69 | if __name__ == '__main__':
70 | _main()
71 |
--------------------------------------------------------------------------------
/tests/test_c_api.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import _pywhispercpp as pw
4 |
5 | import unittest
6 | from unittest import TestCase
7 |
8 |
9 | class TestCAPI(TestCase):
10 |
11 | model_file = './whisper.cpp/models/for-tests-ggml-tiny.en.bin'
12 |
13 | def test_whisper_init_from_file(self):
14 | ctx = pw.whisper_init_from_file(self.model_file)
15 | self.assertIsInstance(ctx, pw.whisper_context)
16 |
17 | def test_whisper_lang_str(self):
18 | return self.assertEqual(pw.whisper_lang_str(0), 'en')
19 |
20 | def test_whisper_lang_id(self):
21 | return self.assertEqual(pw.whisper_lang_id('en'), 0)
22 |
23 | def test_whisper_full_params_language_set_to_de(self):
24 | params = pw.whisper_full_params()
25 | params.language = 'de'
26 | return self.assertEqual(params.language, 'de')
27 |
28 | def test_whisper_full_params_language_set_to_german(self):
29 | params = pw.whisper_full_params()
30 | params.language = 'german'
31 | return self.assertEqual(params.language, 'de')
32 |
33 | def test_whisper_full_params_context(self):
34 |
35 | params = pw.whisper_full_params()
36 | # to ensure that the string is not cached
37 | prompt = str(10120923) + "A" + " test"
38 | params.initial_prompt = prompt
39 | print("Params Prompt: ", params.initial_prompt)
40 | del prompt
41 | import gc
42 | gc.collect()
43 | return self.assertEqual(params.initial_prompt, str(10120923) + "A test")
44 |
45 | def test_whisper_full_params_regex(self):
46 | params = pw.whisper_full_params()
47 | val = str(10120923) + "A" + " test"
48 | params.suppress_regex = val
49 | print("Params Prompt: ", params.suppress_regex)
50 | del val
51 | import gc
52 | gc.collect()
53 | return self.assertEqual(params.suppress_regex, str(10120923) + "A" + " test")
54 |
55 | def test_whisper_full_params_default(self):
56 | params = pw.whisper_full_default_params(pw.whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY)
57 | self.assertIsInstance(params, pw.whisper_full_params)
58 | self.assertEqual(params.suppress_regex, "")
59 |
60 | def test_whisper_lang_id(self):
61 | return self.assertEqual(pw.whisper_lang_id('en'), 0)
62 |
63 | def test_whisper_full_params(self):
64 | params = pw.whisper_full_params()
65 | return self.assertIsInstance(params.n_threads, int)
66 |
67 |
68 | if __name__ == '__main__':
69 |
70 | unittest.main()
71 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | dist/
3 | _build/
4 | _generate/
5 | *.so
6 | *.dylib
7 | *.py[cod]
8 | *.egg-info
9 | *env*
10 |
11 |
12 | # custom
13 | .idea
14 | _docs
15 | _examples
16 | src/.idea
17 |
18 | # Byte-compiled / optimized / DLL files
19 | __pycache__/
20 | *.py[cod]
21 | *$py.class
22 |
23 | # C extensions
24 | *.so
25 |
26 | # Distribution / packaging
27 | .Python
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | pip-wheel-metadata/
41 | share/python-wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 | MANIFEST
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | htmlcov/
59 | .tox/
60 | .nox/
61 | .coverage
62 | .coverage.*
63 | .cache
64 | nosetests.xml
65 | coverage.xml
66 | *.cover
67 | *.py,cover
68 | .hypothesis/
69 | .pytest_cache/
70 |
71 | # Translations
72 | *.mo
73 | *.pot
74 |
75 | # Django stuff:
76 | *.log
77 | local_settings.py
78 | db.sqlite3
79 | db.sqlite3-journal
80 |
81 | # Flask stuff:
82 | instance/
83 | .webassets-cache
84 |
85 | # Scrapy stuff:
86 | .scrapy
87 |
88 | # Sphinx documentation
89 | docs/_build/
90 |
91 | # PyBuilder
92 | target/
93 |
94 | # Jupyter Notebook
95 | .ipynb_checkpoints
96 |
97 | # IPython
98 | profile_default/
99 | ipython_config.py
100 |
101 | # pyenv
102 | .python-version
103 |
104 | # pipenv
105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
108 | # install all needed dependencies.
109 | #Pipfile.lock
110 |
111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
112 | __pypackages__/
113 |
114 | # Celery stuff
115 | celerybeat-schedule
116 | celerybeat.pid
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # Environments
122 | .env
123 | .venv
124 | env/
125 | venv/
126 | ENV/
127 | env.bak/
128 | venv.bak/
129 |
130 | # Spyder project settings
131 | .spyderproject
132 | .spyproject
133 |
134 | # Rope project settings
135 | .ropeproject
136 |
137 | # mkdocs documentation
138 | /site
139 |
140 | # mypy
141 | .mypy_cache/
142 | .dmypy.json
143 | dmypy.json
144 |
145 | # Pyre type checker
146 | .pyre/
147 |
148 |
149 |
--------------------------------------------------------------------------------
/.github/workflows/wheels.yml:
--------------------------------------------------------------------------------
1 | name: Wheels
2 |
3 | on:
4 | workflow_dispatch:
5 | pull_request:
6 | push:
7 | branches:
8 | - release
9 | release:
10 | types:
11 | - published
12 |
13 | jobs:
14 | build_sdist:
15 | name: Build SDist
16 | runs-on: ubuntu-latest
17 | steps:
18 | - uses: actions/checkout@v4
19 | with:
20 | submodules: true
21 |
22 | - name: Build SDist
23 | run: pipx run build --sdist
24 |
25 | - name: Check metadata
26 | run: pipx run twine check dist/*
27 |
28 | - uses: actions/upload-artifact@v4
29 | with:
30 | name: artifact-sdist
31 | path: dist/*.tar.gz
32 |
33 |
34 | build_wheels:
35 | name: Wheels on ${{ matrix.os }}
36 | runs-on: ${{ matrix.os }}
37 | strategy:
38 | fail-fast: false
39 | matrix:
40 | os: [ubuntu-24.04-arm, ubuntu-latest, windows-2022, macos-14]
41 |
42 | steps:
43 | - uses: actions/checkout@v4
44 | with:
45 | submodules: true
46 |
47 | # Used to host cibuildwheel
48 | - uses: actions/setup-python@v5
49 |
50 | - name: Install cibuildwheel
51 | run: python -m pip install cibuildwheel
52 |
53 | - name: Build wheels
54 | run: python -m cibuildwheel --output-dir wheelhouse
55 | env:
56 | CIBW_ARCHS: auto
57 | # for windows setup.py repairwheel step should solve it
58 | CIBW_SKIP: pp* cp38-*
59 | # Whisper.cpp tries to use BMI2 on 32 bit Windows, so disable BMI2 when building on Windows to avoid that bug. See https://github.com/ggml-org/whisper.cpp/pull/3543
60 | CIBW_ENVIRONMENT: CMAKE_ARGS="${{ contains(matrix.os, 'arm') && '-DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a' || ''}} ${{ contains(matrix.os, 'windows') && '-DGGML_BMI2=OFF' || '' }}"
61 |
62 | - name: Verify clean directory
63 | run: git diff --exit-code
64 | shell: bash
65 |
66 | - name: List files
67 | run: ls wheelhouse
68 |
69 | - name: Upload wheels
70 | uses: actions/upload-artifact@v4
71 | with:
72 | name: artifact-${{ matrix.os }}
73 | path: wheelhouse/*.whl
74 |
75 | test_wheels:
76 | name: Test wheels on ${{ matrix.os }} (Python ${{ matrix.python-version }})
77 | runs-on: ${{ matrix.os }}
78 | needs: build_wheels
79 | strategy:
80 | fail-fast: false
81 | matrix:
82 | os: [ubuntu-latest, windows-latest, macos-latest, ubuntu-24.04-arm]
83 | python-version: [3.11, 3.12, 3.13]
84 |
85 | steps:
86 | - uses: actions/checkout@v4
87 | with:
88 | submodules: true
89 |
90 | - uses: actions/download-artifact@v4
91 | with:
92 | pattern: artifact-*
93 | merge-multiple: true
94 | path: wheelhouse
95 |
96 | - name: Verify artifact download
97 | run: |
98 | ls -l wheelhouse
99 |
100 | - name: Set up Python
101 | uses: actions/setup-python@v5
102 | with:
103 | python-version: ${{ matrix.python-version }}
104 |
105 | - name: Install dependencies
106 | run: |
107 | pip install -r requirements.txt
108 | pip install pytest
109 |
110 | - name: Install Wheel
111 | run: |
112 | pip install --no-index --find-links=./wheelhouse pywhispercpp
113 |
114 | - name: Run tests
115 | run: |
116 | pytest tests/
117 |
118 |
119 | upload_all:
120 | name: Upload if release
121 | needs: [build_wheels, build_sdist]
122 | runs-on: ubuntu-latest
123 | if: github.event_name == 'release' && github.event.action == 'published'
124 |
125 | steps:
126 | - uses: actions/setup-python@v5
127 | with:
128 | python-version: "3.x"
129 |
130 | - uses: actions/download-artifact@v4
131 | with:
132 | pattern: artifact-*
133 | merge-multiple: true
134 | path: dist
135 |
136 | - uses: pypa/gh-action-pypi-publish@release/v1
137 | with:
138 | verbose: true
139 | password: ${{ secrets.PYPI_API_TOKEN }}
140 |
--------------------------------------------------------------------------------
/pywhispercpp/examples/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | A simple Command Line Interface to test the package
6 | """
7 | import argparse
8 | import importlib.metadata
9 | import logging
10 |
11 | import pywhispercpp.constants as constants
12 |
13 | __version__ = importlib.metadata.version('pywhispercpp')
14 |
15 | __header__ = f"""
16 | PyWhisperCpp
17 | A simple Command Line Interface to test the package
18 | Version: {__version__}
19 | ====================================================
20 | """
21 |
22 | from pywhispercpp.model import Model
23 | import pywhispercpp.utils as utils
24 |
25 |
26 | def _get_params(args) -> dict:
27 | """
28 | Helper function to get params from argparse as a `dict`
29 | """
30 | params = {}
31 | for arg in args.__dict__:
32 | if arg in constants.PARAMS_SCHEMA.keys() and getattr(args, arg) is not None:
33 | if constants.PARAMS_SCHEMA[arg]['type'] is bool:
34 | if getattr(args, arg).lower() == 'false':
35 | params[arg] = False
36 | else:
37 | params[arg] = True
38 | else:
39 | params[arg] = constants.PARAMS_SCHEMA[arg]['type'](getattr(args, arg))
40 | return params
41 |
42 |
43 | def run(args):
44 | logging.info(f"Running with model `{args.model}`")
45 | params = _get_params(args)
46 | logging.info(f"Running with params {params}")
47 | m = Model(model=args.model, **params)
48 | logging.info(f"System info: n_threads = {m.get_params()['n_threads']} | Processors = {args.processors} "
49 | f"| {m.system_info()}")
50 | for file in args.media_file:
51 | segs = []
52 | try:
53 | logging.info(f"Processing file {file} ...")
54 | m.transcribe(file,
55 | n_processors=int(args.processors) if args.processors else None,
56 | new_segment_callback=lambda seg: segs.append(seg)
57 | )
58 | m.print_timings()
59 | except KeyboardInterrupt:
60 | logging.info("Transcription manually stopped")
61 | break
62 | except Exception as e:
63 | logging.error(f"Error while processing file {file}: {e}")
64 | finally:
65 | if segs:
66 | # output stuff
67 | if args.output_txt:
68 | logging.info(f"Saving result as a txt file ...")
69 | txt_file = utils.output_txt(segs, file)
70 | logging.info(f"txt file saved to {txt_file}")
71 | if args.output_vtt:
72 | logging.info(f"Saving results as a vtt file ...")
73 | vtt_file = utils.output_vtt(segs, file)
74 | logging.info(f"vtt file saved to {vtt_file}")
75 | if args.output_srt:
76 | logging.info(f"Saving results as a srt file ...")
77 | srt_file = utils.output_srt(segs, file)
78 | logging.info(f"srt file saved to {srt_file}")
79 | if args.output_csv:
80 | logging.info(f"Saving results as a csv file ...")
81 | csv_file = utils.output_csv(segs, file)
82 | logging.info(f"csv file saved to {csv_file}")
83 |
84 |
85 | def main():
86 | print(__header__)
87 | parser = argparse.ArgumentParser(description="", allow_abbrev=True)
88 | # Positional args
89 | parser.add_argument('media_file', type=str, nargs='+', help="The path of the media file or a list of files"
90 | "separated by space")
91 |
92 | parser.add_argument('-m', '--model', default='tiny', help="Path to the `ggml` model, or just the model name")
93 |
94 | parser.add_argument('--version', action='version', version=f'%(prog)s {__version__}')
95 | parser.add_argument('--processors', help="number of processors to use during computation")
96 | parser.add_argument('-otxt', '--output-txt', action='store_true', help="output result in a text file")
97 | parser.add_argument('-ovtt', '--output-vtt', action='store_true', help="output result in a vtt file")
98 | parser.add_argument('-osrt', '--output-srt', action='store_true', help="output result in a srt file")
99 | parser.add_argument('-ocsv', '--output-csv', action='store_true', help="output result in a CSV file")
100 |
101 | # add params from PARAMS_SCHEMA
102 | for param in constants.PARAMS_SCHEMA:
103 | param_fields = constants.PARAMS_SCHEMA[param]
104 | parser.add_argument(f'--{param}',
105 | help=f'{param_fields["description"]}')
106 |
107 | args = parser.parse_args()
108 | run(args)
109 |
110 |
111 | if __name__ == '__main__':
112 | main()
113 |
--------------------------------------------------------------------------------
/pywhispercpp/examples/assistant.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | A simple example showcasing the use of `pywhispercpp` as an assistant.
6 | The idea is to use a `VAD` to detect speech (in this example we used webrtcvad), and when speech is detected
7 | we run the inference.
8 | """
9 | import argparse
10 | import importlib.metadata
11 | import queue
12 | import time
13 | from typing import Callable
14 | import numpy as np
15 | import sounddevice as sd
16 | import pywhispercpp.constants as constants
17 | import webrtcvad
18 | import logging
19 | from pywhispercpp.model import Model
20 |
21 | __version__ = importlib.metadata.version('pywhispercpp')
22 |
23 | __header__ = f"""
24 | =====================================
25 | PyWhisperCpp
26 | A simple assistant using Whisper.cpp
27 | Version: {__version__}
28 | =====================================
29 | """
30 |
31 |
32 | class Assistant:
33 | """
34 | Assistant class
35 |
36 | Example usage
37 | ```python
38 | from pywhispercpp.examples.assistant import Assistant
39 |
40 | my_assistant = Assistant(commands_callback=print, n_threads=8)
41 | my_assistant.start()
42 | ```
43 | """
44 |
45 | def __init__(self,
46 | model='tiny',
47 | input_device: int = None,
48 | silence_threshold: int = 8,
49 | q_threshold: int = 16,
50 | block_duration: int = 30,
51 | commands_callback: Callable[[str], None] = None,
52 | **model_params):
53 |
54 | """
55 | :param model: whisper.cpp model name or a direct path to a`ggml` model
56 | :param input_device: The input device (aka microphone), keep it None to take the default
57 | :param silence_threshold: The duration of silence after which the inference will be running
58 | :param q_threshold: The inference won't be running until the data queue is having at least `q_threshold` elements
59 | :param block_duration: minimum time audio updates in ms
60 | :param commands_callback: The callback to run when a command is received
61 | :param model_log_level: Logging level
62 | :param model_params: any other parameter to pass to the whsiper.cpp model see ::: pywhispercpp.constants.PARAMS_SCHEMA
63 | """
64 |
65 | self.input_device = input_device
66 | self.sample_rate = constants.WHISPER_SAMPLE_RATE # same as whisper.cpp
67 | self.channels = 1 # same as whisper.cpp
68 | self.block_duration = block_duration
69 | self.block_size = int(self.sample_rate * self.block_duration / 1000)
70 | self.q = queue.Queue()
71 |
72 | self.vad = webrtcvad.Vad()
73 | self.silence_threshold = silence_threshold
74 | self.q_threshold = q_threshold
75 | self._silence_counter = 0
76 |
77 | self.pwccp_model = Model(model,
78 | print_realtime=False,
79 | print_progress=False,
80 | print_timestamps=False,
81 | single_segment=True,
82 | no_context=True,
83 | **model_params)
84 | self.commands_callback = commands_callback
85 |
86 | def _audio_callback(self, indata, frames, time, status):
87 | """
88 | This is called (from a separate thread) for each audio block.
89 | """
90 | if status:
91 | logging.warning(F"underlying audio stack warning:{status}")
92 |
93 | assert frames == self.block_size
94 | audio_data = map(lambda x: (x + 1) / 2, indata) # normalize from [-1,+1] to [0,1]
95 | audio_data = np.fromiter(audio_data, np.float16)
96 | audio_data = audio_data.tobytes()
97 | detection = self.vad.is_speech(audio_data, self.sample_rate)
98 | if detection:
99 | self.q.put(indata.copy())
100 | self._silence_counter = 0
101 | else:
102 | if self._silence_counter >= self.silence_threshold:
103 | if self.q.qsize() > self.q_threshold:
104 | self._transcribe_speech()
105 | self._silence_counter = 0
106 | else:
107 | self._silence_counter += 1
108 |
109 | def _transcribe_speech(self):
110 | logging.info(f"Speech detected ...")
111 | audio_data = np.array([])
112 | while self.q.qsize() > 0:
113 | # get all the data from the q
114 | audio_data = np.append(audio_data, self.q.get())
115 | # Appending zeros to the audio data as a workaround for small audio packets (small commands)
116 | audio_data = np.concatenate([audio_data, np.zeros((int(self.sample_rate) + 10))])
117 | # running the inference
118 | self.pwccp_model.transcribe(audio_data,
119 | new_segment_callback=self._new_segment_callback)
120 |
121 | def _new_segment_callback(self, seg):
122 | if self.commands_callback:
123 | self.commands_callback(seg.text)
124 |
125 | def start(self) -> None:
126 | """
127 | Use this function to start the assistant
128 | :return: None
129 | """
130 | logging.info(f"Starting Assistant ...")
131 | with sd.InputStream(
132 | device=self.input_device, # the default input device
133 | channels=self.channels,
134 | samplerate=constants.WHISPER_SAMPLE_RATE,
135 | blocksize=self.block_size,
136 | callback=self._audio_callback):
137 |
138 | try:
139 | logging.info(f"Assistant is listening ... (CTRL+C to stop)")
140 | while True:
141 | time.sleep(0.1)
142 | except KeyboardInterrupt:
143 | logging.info("Assistant stopped")
144 |
145 | @staticmethod
146 | def available_devices():
147 | return sd.query_devices()
148 |
149 |
150 | def _main():
151 | parser = argparse.ArgumentParser(description="", allow_abbrev=True)
152 | # Positional args
153 | parser.add_argument('-m', '--model', default='tiny.en', type=str, help="Whisper.cpp model, default to %(default)s")
154 | parser.add_argument('-ind', '--input_device', type=int, default=None,
155 | help=f'Id of The input device (aka microphone)\n'
156 | f'available devices {Assistant.available_devices()}')
157 | parser.add_argument('-st', '--silence_threshold', default=16, type=int,
158 | help=f"he duration of silence after which the inference will be running, default to %(default)s")
159 | parser.add_argument('-bd', '--block_duration', default=30,
160 | help=f"minimum time audio updates in ms, default to %(default)s")
161 |
162 | args = parser.parse_args()
163 |
164 | my_assistant = Assistant(model=args.model,
165 | input_device=args.input_device,
166 | silence_threshold=args.silence_threshold,
167 | block_duration=args.block_duration,
168 | commands_callback=print)
169 | my_assistant.start()
170 |
171 |
172 | if __name__ == '__main__':
173 | _main()
174 |
--------------------------------------------------------------------------------
/pywhispercpp/examples/livestream.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Quick and dirty realtime livestream transcription.
6 |
7 | Not fully satisfying though :)
8 | You are welcome to make it better.
9 | """
10 | import argparse
11 | import logging
12 | import queue
13 | from multiprocessing import Process
14 | import ffmpeg
15 | import numpy as np
16 | import pywhispercpp.constants as constants
17 | import sounddevice as sd
18 | from pywhispercpp.model import Model
19 | import importlib.metadata
20 |
21 |
22 | __version__ = importlib.metadata.version('pywhispercpp')
23 |
24 | __header__ = f"""
25 | ========================================================
26 | PyWhisperCpp
27 | A simple Livestream transcription, based on whisper.cpp
28 | Version: {__version__}
29 | ========================================================
30 | """
31 |
32 | class LiveStream:
33 | """
34 | LiveStream class
35 |
36 | ???+ note
37 |
38 | It heavily depends on the machine power, the processor will jump quickly to 100% with the wrong parameters.
39 |
40 | Example usage
41 | ```python
42 | from pywhispercpp.examples.livestream import LiveStream
43 |
44 | url = "" # Make sure it is a direct stream URL
45 | ls = LiveStream(url=url, n_threads=4)
46 | ls.start()
47 | ```
48 | """
49 |
50 | def __init__(self,
51 | url,
52 | model='tiny.en',
53 | block_size: int = 1024,
54 | buffer_size: int = 20,
55 | sample_size: int = 4,
56 | output_device: int = None,
57 | model_log_level=logging.CRITICAL,
58 | **model_params):
59 |
60 | """
61 | :param url: Live stream url
62 | :param model: whisper.cpp model
63 | :param block_size: block size, default to 1024
64 | :param buffer_size: number of blocks used for buffering, default to 20
65 | :param sample_size: sample size
66 | :param output_device: the output device, aka the speaker, leave it None to take the default
67 | :param model_log_level: logging level
68 | :param model_params: any other whisper.cpp params
69 | """
70 | self.url = url
71 | self.block_size = block_size
72 | self.buffer_size = buffer_size
73 | self.sample_size = sample_size
74 | self.output_device = output_device
75 |
76 | self.channels = 1
77 | self.samplerate = constants.WHISPER_SAMPLE_RATE
78 |
79 | self.q = queue.Queue(maxsize=buffer_size)
80 | self.audio_data = np.array([])
81 |
82 | self.pwccp_model = Model(model,
83 | log_level=model_log_level,
84 | print_realtime=True,
85 | print_progress=False,
86 | print_timestamps=False,
87 | single_segment=True,
88 | **model_params)
89 |
90 | def _transcribe_process(self):
91 | self.pwccp_model.transcribe(self.audio_data, n_processors=None)
92 |
93 | def _audio_callback(self, outdata, frames, time, status):
94 | assert frames == self.block_size
95 | if status.output_underflow:
96 | logging.error('Output underflow: increase blocksize?')
97 | raise sd.CallbackAbort
98 | assert not status
99 | try:
100 | data = self.q.get_nowait()
101 | except queue.Empty as e:
102 | logging.error('Buffer is empty: increase buffer_size?')
103 | raise sd.CallbackAbort from e
104 | assert len(data) == len(outdata)
105 | outdata[:] = data
106 | audio = np.frombuffer(data[:], np.float32)
107 | audio = audio.reshape((audio.size, 1)) / 2 ** 5
108 | self.audio_data = np.append(self.audio_data, audio)
109 | if self.audio_data.size > self.samplerate:
110 | # Create a separate process for transcription
111 | p1 = Process(target=self._transcribe_process,)
112 | p1.start()
113 | self.audio_data = np.array([])
114 |
115 | def start(self):
116 | process = ffmpeg.input(self.url).output(
117 | 'pipe:',
118 | format='f32le',
119 | acodec='pcm_f32le',
120 | ac=self.channels,
121 | ar=self.samplerate,
122 | loglevel='quiet',
123 | ).run_async(pipe_stdout=True)
124 |
125 | out_stream = sd.RawOutputStream(
126 | device=self.output_device,
127 | samplerate=self.samplerate,
128 | blocksize=self.block_size,
129 | channels=self.channels,
130 | dtype='float32',
131 | callback=self._audio_callback)
132 |
133 | read_size = self.block_size * self.channels * self.sample_size
134 |
135 | logging.info('Buffering ...')
136 | for _ in range(self.buffer_size):
137 | self.q.put_nowait(process.stdout.read(read_size))
138 |
139 | with out_stream:
140 | logging.info('Starting Playback ... (CTRL+C) to stop')
141 | try:
142 | timeout = self.block_size * self.buffer_size / self.samplerate
143 | while True:
144 | buffer_data = process.stdout.read(read_size)
145 | self.q.put(buffer_data, timeout=timeout)
146 | except KeyboardInterrupt:
147 | logging.info("Interrupted!")
148 |
149 | @staticmethod
150 | def available_devices():
151 | return sd.query_devices()
152 |
153 |
154 | def _main():
155 | print(__header__)
156 | parser = argparse.ArgumentParser(description="", allow_abbrev=True)
157 | # Positional args
158 | parser.add_argument('url', type=str, help=f"Stream URL")
159 |
160 | parser.add_argument('-nt', '--n_threads', type=int, default=3,
161 | help="number of threads, default to %(default)s")
162 | parser.add_argument('-m', '--model', default='tiny.en', type=str, help="Whisper.cpp model, default to %(default)s")
163 | parser.add_argument('-od', '--output_device', type=int, default=None,
164 | help=f'the output device, aka the speaker, leave it None to take the default\n'
165 | f'available devices {LiveStream.available_devices()}')
166 | parser.add_argument('-bls', '--block_size', type=int, default=1024,
167 | help=f"block size, default to %(default)s")
168 | parser.add_argument('-bus', '--buffer_size', type=int, default=20,
169 | help=f"number of blocks used for buffering, default to %(default)s")
170 | parser.add_argument('-ss', '--sample_size', type=int, default=4,
171 | help=f"Sample size, default to %(default)s")
172 | args = parser.parse_args()
173 |
174 |
175 | # url = "http://n03.radiojar.com/t2n88q0st5quv?rj-ttl=5&rj-tok=AAABhsR2u6MAYFxz69dJ6eQnww" # VOA english
176 | ls = LiveStream(url=args.url,
177 | model=args.model,
178 | block_size=args.block_size,
179 | buffer_size=args.buffer_size,
180 | sample_size=args.sample_size,
181 | output_device=args.output_device,
182 | n_threads=args.n_threads)
183 | ls.start()
184 |
185 |
186 | if __name__ == '__main__':
187 | _main()
188 |
--------------------------------------------------------------------------------
/pywhispercpp/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """
4 | Helper functions
5 | """
6 |
7 | import contextlib
8 | import logging
9 | import os
10 | import sys
11 | from pathlib import Path
12 | from typing import TextIO
13 |
14 | import requests
15 | from tqdm import tqdm
16 |
17 | from pywhispercpp.constants import (
18 | AVAILABLE_MODELS,
19 | MODELS_BASE_URL,
20 | MODELS_DIR,
21 | MODELS_PREFIX_URL,
22 | )
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def _get_model_url(model_name: str) -> str:
28 | """
29 | Returns the url of the `ggml` model
30 | :param model_name: name of the model
31 | :return: URL of the model
32 | """
33 | return f"{MODELS_BASE_URL}/{MODELS_PREFIX_URL}-{model_name}.bin"
34 |
35 |
36 | def download_model(model_name: str, download_dir=None, chunk_size=1024) -> str:
37 | """
38 | Helper function to download the `ggml` models
39 | :param model_name: name of the model, one of ::: constants.AVAILABLE_MODELS
40 | :param download_dir: Where to store the models
41 | :param chunk_size: size of the download chunk
42 |
43 | :return: Absolute path of the downloaded model
44 | """
45 | if model_name not in AVAILABLE_MODELS:
46 | logger.error(f"Invalid model name `{model_name}`, available models are: {AVAILABLE_MODELS}")
47 | return
48 | if download_dir is None:
49 | download_dir = MODELS_DIR
50 | logger.info(f"No download directory was provided, models will be downloaded to {download_dir}")
51 |
52 | os.makedirs(download_dir, exist_ok=True)
53 |
54 | url = _get_model_url(model_name=model_name)
55 | file_path = Path(download_dir) / os.path.basename(url)
56 | # check if the file is already there
57 | if file_path.exists():
58 | logger.info(f"Model {model_name} already exists in {download_dir}")
59 | else:
60 | # download it from huggingface
61 | resp = requests.get(url, stream=True)
62 | total = int(resp.headers.get('content-length', 0))
63 |
64 | progress_bar = tqdm(desc=f"Downloading Model {model_name} ...",
65 | total=total,
66 | unit='iB',
67 | unit_scale=True,
68 | unit_divisor=1024)
69 |
70 | try:
71 | with open(file_path, 'wb') as file, progress_bar:
72 | for data in resp.iter_content(chunk_size=chunk_size):
73 | size = file.write(data)
74 | progress_bar.update(size)
75 | logger.info(f"Model downloaded to {file_path.absolute()}")
76 | except Exception as e:
77 | # error download, just remove the file
78 | os.remove(file_path)
79 | raise e
80 | return str(file_path.absolute())
81 |
82 |
83 | def to_timestamp(t: int, separator=',') -> str:
84 | """
85 | 376 -> 00:00:03,760
86 | 1344 -> 00:00:13,440
87 |
88 | Implementation from `whisper.cpp/examples/main`
89 |
90 | :param t: input time from whisper timestamps
91 | :param separator: seprator between seconds and milliseconds
92 | :return: time representation in hh: mm: ss[separator]ms
93 | """
94 | # logic exactly from whisper.cpp
95 |
96 | msec = t * 10
97 | hr = msec // (1000 * 60 * 60)
98 | msec = msec - hr * (1000 * 60 * 60)
99 | min = msec // (1000 * 60)
100 | msec = msec - min * (1000 * 60)
101 | sec = msec // 1000
102 | msec = msec - sec * 1000
103 | return f"{int(hr):02,.0f}:{int(min):02,.0f}:{int(sec):02,.0f}{separator}{int(msec):03,.0f}"
104 |
105 |
106 | def output_txt(segments: list, output_file_path: str) -> str:
107 | """
108 | Creates a raw text from a list of segments
109 |
110 | Implementation from `whisper.cpp/examples/main`
111 |
112 | :param segments: list of segments
113 | :return: path of the file
114 | """
115 | if not output_file_path.endswith('.txt'):
116 | output_file_path = output_file_path + '.txt'
117 |
118 | absolute_path = Path(output_file_path).absolute()
119 |
120 | with open(str(absolute_path), 'w') as file:
121 | for seg in segments:
122 | file.write(seg.text)
123 | file.write('\n')
124 | return absolute_path
125 |
126 |
127 | def output_vtt(segments: list, output_file_path: str) -> str:
128 | """
129 | Creates a vtt file from a list of segments
130 |
131 | Implementation from `whisper.cpp/examples/main`
132 |
133 | :param segments: list of segments
134 | :return: path of the file
135 |
136 | :return: Absolute path of the file
137 | """
138 | if not output_file_path.endswith('.vtt'):
139 | output_file_path = output_file_path + '.vtt'
140 |
141 | absolute_path = Path(output_file_path).absolute()
142 |
143 | with open(absolute_path, 'w') as file:
144 | file.write("WEBVTT\n\n")
145 | for seg in segments:
146 | file.write(f"{to_timestamp(seg.t0, separator='.')} --> {to_timestamp(seg.t1, separator='.')}\n")
147 | file.write(f"{seg.text}\n\n")
148 | return absolute_path
149 |
150 |
151 | def output_srt(segments: list, output_file_path: str) -> str:
152 | """
153 | Creates a srt file from a list of segments
154 |
155 | :param segments: list of segments
156 | :return: path of the file
157 |
158 | :return: Absolute path of the file
159 | """
160 | if not output_file_path.endswith('.srt'):
161 | output_file_path = output_file_path + '.srt'
162 |
163 | absolute_path = Path(output_file_path).absolute()
164 |
165 | with open(absolute_path, 'w') as file:
166 | for i in range(len(segments)):
167 | seg = segments[i]
168 | file.write(f"{i+1}\n")
169 | file.write(f"{to_timestamp(seg.t0, separator=',')} --> {to_timestamp(seg.t1, separator=',')}\n")
170 | file.write(f"{seg.text}\n\n")
171 | return absolute_path
172 |
173 |
174 | def output_csv(segments: list, output_file_path: str) -> str:
175 | """
176 | Creates a srt file from a list of segments
177 |
178 | :param segments: list of segments
179 | :return: path of the file
180 |
181 | :return: Absolute path of the file
182 | """
183 | if not output_file_path.endswith('.csv'):
184 | output_file_path = output_file_path + '.csv'
185 |
186 | absolute_path = Path(output_file_path).absolute()
187 |
188 | with open(absolute_path, 'w') as file:
189 | for seg in segments:
190 | file.write(f"{10 * seg.t0}, {10 * seg.t1}, \"{seg.text}\"\n")
191 | return absolute_path
192 |
193 |
194 | @contextlib.contextmanager
195 | def redirect_stderr(to: bool | TextIO | str | None = False) -> None:
196 | """
197 | Redirect stderr to the specified target.
198 |
199 | :param to:
200 | - None to suppress output (redirect to devnull),
201 | - sys.stdout to redirect to stdout,
202 | - A file path (str) to redirect to a file,
203 | - False to do nothing (no redirection).
204 | """
205 |
206 | if to is False:
207 | # do nothing
208 | yield
209 | return
210 |
211 | def _resolve_target(target):
212 | opened_stream = None
213 | if target is None:
214 | opened_stream = open(os.devnull, "w")
215 | return opened_stream, True
216 | if isinstance(target, str):
217 | opened_stream = open(target, "w")
218 | return opened_stream, True
219 | if hasattr(target, "write"):
220 | return target, False
221 | raise ValueError(
222 | "Invalid `to` parameter; expected None, a filepath string, or a file-like object."
223 | )
224 |
225 | sys.stderr.flush()
226 | try:
227 | original_fd = sys.stderr.fileno()
228 | except (AttributeError, OSError):
229 | # Jupyter or non-standard stderr implementations
230 | original_fd = None
231 |
232 | stream, should_close = _resolve_target(to)
233 |
234 | if original_fd is not None and hasattr(stream, "fileno"):
235 | saved_fd = os.dup(original_fd)
236 | try:
237 | os.dup2(stream.fileno(), original_fd)
238 | yield
239 | finally:
240 | os.dup2(saved_fd, original_fd)
241 | os.close(saved_fd)
242 | if should_close:
243 | stream.close()
244 | return
245 |
246 | # Fallback: Python-level redirect
247 | try:
248 | with contextlib.redirect_stderr(stream):
249 | yield
250 | finally:
251 | if should_close:
252 | stream.close()
253 |
--------------------------------------------------------------------------------
/pywhispercpp/constants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Constants
6 | """
7 | from pathlib import Path
8 | from typing import Tuple
9 |
10 | import _pywhispercpp as _pwcpp
11 | from platformdirs import user_data_dir
12 |
13 |
14 | WHISPER_SAMPLE_RATE = _pwcpp.WHISPER_SAMPLE_RATE
15 | # MODELS URL MODELS_BASE_URL+ '/' + MODELS_PREFIX_URL+'-'+MODEL_NAME+'.bin'
16 | # example = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin"
17 | MODELS_BASE_URL = "https://huggingface.co/ggerganov/whisper.cpp"
18 | MODELS_PREFIX_URL = "resolve/main/ggml"
19 |
20 |
21 | PACKAGE_NAME = 'pywhispercpp'
22 |
23 |
24 | MODELS_DIR = Path(user_data_dir(PACKAGE_NAME)) / 'models'
25 |
26 |
27 | AVAILABLE_MODELS = [
28 | "base",
29 | "base-q5_1",
30 | "base-q8_0",
31 | "base.en",
32 | "base.en-q5_1",
33 | "base.en-q8_0",
34 | "large-v1",
35 | "large-v2",
36 | "large-v2-q5_0",
37 | "large-v2-q8_0",
38 | "large-v3",
39 | "large-v3-q5_0",
40 | "large-v3-turbo",
41 | "large-v3-turbo-q5_0",
42 | "large-v3-turbo-q8_0",
43 | "medium",
44 | "medium-q5_0",
45 | "medium-q8_0",
46 | "medium.en",
47 | "medium.en-q5_0",
48 | "medium.en-q8_0",
49 | "small",
50 | "small-q5_1",
51 | "small-q8_0",
52 | "small.en",
53 | "small.en-q5_1",
54 | "small.en-q8_0",
55 | "tiny",
56 | "tiny-q5_1",
57 | "tiny-q8_0",
58 | "tiny.en",
59 | "tiny.en-q5_1",
60 | "tiny.en-q8_0",
61 | ]
62 | PARAMS_SCHEMA = { # as exactly presented in whisper.cpp
63 | 'n_threads': {
64 | 'type': int,
65 | 'description': "Number of threads to allocate for the inference"
66 | "default to min(4, available hardware_concurrency)",
67 | 'options': None,
68 | 'default': None
69 | },
70 | 'n_max_text_ctx': {
71 | 'type': int,
72 | 'description': "max tokens to use from past text as prompt for the decoder",
73 | 'options': None,
74 | 'default': 16384
75 | },
76 | 'offset_ms': {
77 | 'type': int,
78 | 'description': "start offset in ms",
79 | 'options': None,
80 | 'default': 0
81 | },
82 | 'duration_ms': {
83 | 'type': int,
84 | 'description': "audio duration to process in ms",
85 | 'options': None,
86 | 'default': 0
87 | },
88 | 'translate': {
89 | 'type': bool,
90 | 'description': "whether to translate the audio to English",
91 | 'options': None,
92 | 'default': False
93 | },
94 | 'no_context': {
95 | 'type': bool,
96 | 'description': "do not use past transcription (if any) as initial prompt for the decoder",
97 | 'options': None,
98 | 'default': False
99 | },
100 | 'single_segment': {
101 | 'type': bool,
102 | 'description': "force single segment output (useful for streaming)",
103 | 'options': None,
104 | 'default': False
105 | },
106 | 'print_special': {
107 | 'type': bool,
108 | 'description': "print special tokens (e.g. , , , etc.)",
109 | 'options': None,
110 | 'default': False
111 | },
112 | 'print_progress': {
113 | 'type': bool,
114 | 'description': "print progress information",
115 | 'options': None,
116 | 'default': True
117 | },
118 | 'print_realtime': {
119 | 'type': bool,
120 | 'description': "print results from within whisper.cpp (avoid it, use callback instead)",
121 | 'options': None,
122 | 'default': False
123 | },
124 | 'print_timestamps': {
125 | 'type': bool,
126 | 'description': "print timestamps for each text segment when printing realtime",
127 | 'options': None,
128 | 'default': True
129 | },
130 | # [EXPERIMENTAL] token-level timestamps
131 | 'token_timestamps': {
132 | 'type': bool,
133 | 'description': "enable token-level timestamps",
134 | 'options': None,
135 | 'default': False
136 | },
137 | 'thold_pt': {
138 | 'type': float,
139 | 'description': "timestamp token probability threshold (~0.01)",
140 | 'options': None,
141 | 'default': 0.01
142 | },
143 | 'thold_ptsum': {
144 | 'type': float,
145 | 'description': "timestamp token sum probability threshold (~0.01)",
146 | 'options': None,
147 | 'default': 0.01
148 | },
149 | 'max_len': {
150 | 'type': int,
151 | 'description': "max segment length in characters, note: token_timestamps needs to be set to True for this to work",
152 | 'options': None,
153 | 'default': 0
154 | },
155 | 'split_on_word': {
156 | 'type': bool,
157 | 'description': "split on word rather than on token (when used with max_len)",
158 | 'options': None,
159 | 'default': False
160 | },
161 | 'max_tokens': {
162 | 'type': int,
163 | 'description': "max tokens per segment (0 = no limit)",
164 | 'options': None,
165 | 'default': 0
166 | },
167 | 'audio_ctx': {
168 | 'type': int,
169 | 'description': "overwrite the audio context size (0 = use default)",
170 | 'options': None,
171 | 'default': 0
172 | },
173 | 'initial_prompt': {
174 | 'type': str,
175 | 'description': "Initial prompt, these are prepended to any existing text context from a previous call",
176 | 'options': None,
177 | 'default': None
178 | },
179 | 'prompt_tokens': {
180 | 'type': Tuple,
181 | 'description': "tokens to provide to the whisper decoder as initial prompt",
182 | 'options': None,
183 | 'default': None
184 | },
185 | 'prompt_n_tokens': {
186 | 'type': int,
187 | 'description': "tokens to provide to the whisper decoder as initial prompt",
188 | 'options': None,
189 | 'default': 0
190 | },
191 | 'language': {
192 | 'type': str,
193 | 'description': 'for auto-detection, set to None, "" or "auto"',
194 | 'options': None,
195 | 'default': ""
196 | },
197 | 'suppress_blank': {
198 | 'type': bool,
199 | 'description': 'common decoding parameters',
200 | 'options': None,
201 | 'default': True
202 | },
203 | 'suppress_non_speech_tokens': {
204 | 'type': bool,
205 | 'description': 'common decoding parameters',
206 | 'options': None,
207 | 'default': False
208 | },
209 | 'temperature': {
210 | 'type': float,
211 | 'description': 'initial decoding temperature',
212 | 'options': None,
213 | 'default': 0.0
214 | },
215 | 'max_initial_ts': {
216 | 'type': float,
217 | 'description': 'max_initial_ts',
218 | 'options': None,
219 | 'default': 1.0
220 | },
221 | 'length_penalty': {
222 | 'type': float,
223 | 'description': 'length_penalty',
224 | 'options': None,
225 | 'default': -1.0
226 | },
227 | 'temperature_inc': {
228 | 'type': float,
229 | 'description': 'temperature_inc',
230 | 'options': None,
231 | 'default': 0.2
232 | },
233 | 'entropy_thold': {
234 | 'type': float,
235 | 'description': 'similar to OpenAI\'s "compression_ratio_threshold"',
236 | 'options': None,
237 | 'default': 2.4
238 | },
239 | 'logprob_thold': {
240 | 'type': float,
241 | 'description': 'logprob_thold',
242 | 'options': None,
243 | 'default': -1.0
244 | },
245 | 'no_speech_thold': { # not implemented
246 | 'type': float,
247 | 'description': 'no_speech_thold',
248 | 'options': None,
249 | 'default': 0.6
250 | },
251 | 'greedy': {
252 | 'type': dict,
253 | 'description': 'greedy',
254 | 'options': None,
255 | 'default': {"best_of": -1}
256 | },
257 | 'beam_search': {
258 | 'type': dict,
259 | 'description': 'beam_search',
260 | 'options': None,
261 | 'default': {"beam_size": -1, "patience": -1.0}
262 | },
263 | 'extract_probability': {
264 | 'type': bool,
265 | 'description': 'calculate the geometric mean of token probabilities for each segment.',
266 | 'options': None,
267 | 'default': True
268 | }
269 | }
270 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # This setup.py is used to build the pywhispercpp package.
2 | # The environment variables you may find interesting are:
3 | #
4 | # PYWHISPERCPP_VERSION
5 | # if set, it will be used as the version number.
6 | #
7 | # GGML_VULKAN=1
8 | # if set, whisper.cpp will be build with vulkan support.
9 | #
10 | # WHISPER_COREML=1
11 | # WHISPER_COREML_ALLOW_FALLBACK=1
12 | # if set, whisper.cpp will be build with coreml support which requires special models
13 | # It is best used with WHISPER_COREML_ALLOW_FALLBACK=1
14 |
15 |
16 | import os
17 | import re
18 | import subprocess
19 | import sys
20 | from pathlib import Path
21 | import subprocess
22 |
23 | from setuptools import Extension, setup, find_packages
24 | from setuptools.command.build_ext import build_ext
25 | from setuptools.command.bdist_wheel import bdist_wheel
26 |
27 | # Convert distutils Windows platform specifiers to CMake -A arguments
28 | PLAT_TO_CMAKE = {
29 | "win32": "Win32",
30 | "win-amd64": "x64",
31 | "win-arm32": "ARM",
32 | "win-arm64": "ARM64",
33 | }
34 |
35 |
36 | # A CMakeExtension needs a sourcedir instead of a file list.
37 | # The name must be the _single_ output extension from the CMake build.
38 | # If you need multiple extensions, see scikit-build.
39 | class CMakeExtension(Extension):
40 | def __init__(self, name: str, sourcedir: str = "") -> None:
41 | super().__init__(name, sources=[])
42 | self.sourcedir = os.fspath(Path(sourcedir).resolve())
43 |
44 |
45 | dll_folder = 'unset'
46 |
47 |
48 | class CMakeBuild(build_ext):
49 | def build_extension(self, ext: CMakeExtension) -> None:
50 | # Must be in this form due to bug in .resolve() only fixed in Python 3.10+
51 | ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
52 | extdir = ext_fullpath.parent.resolve()
53 |
54 | # Using this requires trailing slash for auto-detection & inclusion of
55 | # auxiliary "native" libs
56 |
57 | debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
58 | cfg = "Debug" if debug else "Release"
59 |
60 | # CMake lets you override the generator - we need to check this.
61 | # Can be set with Conda-Build, for example.
62 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
63 |
64 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
65 | # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
66 | # from Python.
67 | cmake_args = [
68 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
69 | f"-DPYTHON_EXECUTABLE={sys.executable}",
70 | f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
71 | ]
72 | if self.editable_mode:
73 | # Platform-specific rpath settings
74 | if sys.platform.startswith('darwin'):
75 | # macOS-specific settings
76 | cmake_args += [
77 | "-DCMAKE_INSTALL_RPATH=@loader_path",
78 | "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON"
79 | ]
80 | elif sys.platform.startswith('linux'):
81 | # Linux-specific settings
82 | cmake_args += [
83 | "-DCMAKE_INSTALL_RPATH=$ORIGIN",
84 | "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON"
85 | ]
86 |
87 | build_args = []
88 | # Adding CMake arguments set as environment variable
89 | # (needed e.g. to build for ARM OSx on conda-forge)
90 | if "CMAKE_ARGS" in os.environ:
91 | cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
92 |
93 | # In this example, we pass in the version to C++. You might not need to.
94 | cmake_args += [f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] # type: ignore[attr-defined]
95 |
96 | if self.compiler.compiler_type != "msvc":
97 | # Using Ninja-build since it a) is available as a wheel and b)
98 | # multithreads automatically. MSVC would require all variables be
99 | # exported for Ninja to pick it up, which is a little tricky to do.
100 | # Users can override the generator with CMAKE_GENERATOR in CMake
101 | # 3.15+.
102 | if not cmake_generator or cmake_generator == "Ninja":
103 | try:
104 | import ninja
105 |
106 | ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
107 | cmake_args += [
108 | "-GNinja",
109 | f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
110 | ]
111 | except ImportError:
112 | pass
113 |
114 | else:
115 | # Single config generators are handled "normally"
116 | single_config = any(x in cmake_generator for x in {"NMake", "Ninja"})
117 |
118 | # CMake allows an arch-in-generator style for backward compatibility
119 | contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"})
120 |
121 | # Specify the arch if using MSVC generator, but only if it doesn't
122 | # contain a backward-compatibility arch spec already in the
123 | # generator name.
124 | if not single_config and not contains_arch:
125 | cmake_args += ["-A", PLAT_TO_CMAKE[self.plat_name]]
126 |
127 | # Multi-config generators have a different way to specify configs
128 | if not single_config:
129 | cmake_args += [
130 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"
131 | ]
132 | build_args += ["--config", cfg]
133 |
134 | if sys.platform.startswith("darwin"):
135 | # Cross-compile support for macOS - respect ARCHFLAGS if set
136 | archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
137 | if archs:
138 | cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
139 |
140 | # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
141 | # across all generators.
142 | if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
143 | # self.parallel is a Python 3 only way to set parallel jobs by hand
144 | # using -j in the build_ext call, not supported by pip or PyPA-build.
145 | if hasattr(self, "parallel") and self.parallel:
146 | # CMake 3.12+ only.
147 | build_args += [f"-j{self.parallel}"]
148 |
149 | build_temp = Path(self.build_temp) / ext.name
150 | if not build_temp.exists():
151 | build_temp.mkdir(parents=True)
152 |
153 | for key, value in os.environ.items():
154 | cmake_args.append(f'-D{key}={value}')
155 |
156 | subprocess.run(
157 | ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
158 | )
159 | subprocess.run(
160 | ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
161 | )
162 |
163 | # store the dll folder in a global variable to use in repairwheel
164 | global dll_folder
165 | cfg = "Debug" if self.debug else "Release"
166 | dll_folder = os.path.join(self.build_temp, '_pywhispercpp', 'bin', cfg)
167 | print("dll_folder in build_extension", dll_folder)
168 | # self.copy_extensions_to_source()
169 |
170 | def copy_extensions_to_source(self):
171 | super().copy_extensions_to_source()
172 |
173 | if self.editable_mode:
174 | build_lib = Path(self.build_lib)
175 | for ext in self.extensions:
176 | extdir = Path(self.get_ext_fullpath(ext.name)).parent.resolve()
177 | # Assuming all shared libraries are in the same directory
178 | shared_lib_files = [*build_lib.glob('**/*.dylib'), *build_lib.glob('**/*.so*')]
179 | for shared_lib in shared_lib_files:
180 | self.copy_file(shared_lib, extdir)
181 |
182 |
183 | # read the contents of your README file
184 | this_directory = Path(__file__).parent
185 | long_description = (this_directory / "README.md").read_text()
186 |
187 |
188 | class RepairWheel(bdist_wheel):
189 | def run(self):
190 | super().run()
191 | if os.environ.get('NO_REPAIR', '0') == '1':
192 | print("Skipping wheel repair")
193 | return
194 | if os.environ.get('CIBUILDWHEEL', '0') == '0' or sys.platform.startswith('win'):
195 | # for linux and macos we use the default wheel repair command from cibuildwheel, for windows we need to do it manually as there is no repair command
196 | self.repair_wheel()
197 |
198 | def repair_wheel(self):
199 | # on windows the dlls are in D:\a\pywhispercpp\pywhispercpp\build\temp.win-amd64-cpython-311\Release\_pywhispercpp\bin\Release\whisper.dll
200 | global dll_folder
201 | print("dll_folder in repairwheel", dll_folder)
202 | print("Files in dll_folder:", *Path(dll_folder).glob('*'))
203 | # build\temp.win-amd64-cpython-311\Release\_pywhispercpp\bin\Release\whisper.dll
204 |
205 | wheel_path = next(Path(self.dist_dir).glob(f"{self.distribution.get_name()}*.whl"))
206 | # Create a temporary directory for the repaired wheel
207 | import tempfile
208 | with tempfile.TemporaryDirectory(prefix='repaired_wheel_') as tmp_dir:
209 | tmp_dir = Path(tmp_dir)
210 | subprocess.call(['repairwheel', wheel_path, '-o', tmp_dir, '-l', dll_folder])
211 | print("Repaired wheel: ", *tmp_dir.glob('*.whl'))
212 | # We need to glob as repairwheel may change the name of the wheel
213 | # on linux from pywhispercpp-1.2.0-cp312-cp312-linux_aarch64.whl
214 | # to pywhispercpp-1.2.0-cp312-cp312-manylinux_2_34_aarch64.whl
215 | repaired_wheel = next(tmp_dir.glob("*.whl"))
216 | self.copy_file(repaired_wheel, wheel_path)
217 | print(f"Copied repaired wheel to: {wheel_path}")
218 |
219 |
220 | def get_local_version() -> str:
221 | try:
222 | git_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
223 | return f"+git{git_sha[:7]}"
224 | except (FileNotFoundError, subprocess.CalledProcessError):
225 | return ""
226 |
227 |
228 | def get_version() -> str:
229 | try:
230 | return os.environ['PYWHISPERCPP_VERSION']
231 | except KeyError:
232 | pass
233 | with open("version.txt") as f:
234 | version = f.read().strip()
235 | return f"{version}{get_local_version()}"
236 |
237 |
238 | # The information here can also be placed in setup.cfg - better separation of
239 | # logic and declaration, and simpler if you include description/version in a file.
240 | setup(
241 | name="pywhispercpp",
242 | author="absadiki",
243 | description="Python bindings for whisper.cpp",
244 | long_description=long_description,
245 | ext_modules=[CMakeExtension("_pywhispercpp")],
246 | cmdclass={"build_ext": CMakeBuild,
247 | 'bdist_wheel': RepairWheel, },
248 | zip_safe=False,
249 | # extras_require={"test": ["pytest>=6.0"]},
250 | python_requires=">=3.8",
251 | packages=find_packages('.'),
252 | package_dir={'': '.'},
253 | include_package_data=True,
254 | package_data={'pywhispercpp': []},
255 | long_description_content_type="text/markdown",
256 | license='MIT',
257 | entry_points={
258 | 'console_scripts': ['pwcpp=pywhispercpp.examples.main:main',
259 | 'pwcpp-assistant=pywhispercpp.examples.assistant:_main',
260 | 'pwcpp-livestream=pywhispercpp.examples.livestream:_main',
261 | 'pwcpp-recording=pywhispercpp.examples.recording:_main',
262 | 'pwcpp-gui=pywhispercpp.examples.gui:_main', ]
263 | },
264 | project_urls={
265 | 'Documentation': 'https://absadiki.github.io/pywhispercpp/',
266 | 'Source': 'https://github.com/absadiki/pywhispercpp',
267 | 'Tracker': 'https://github.com/absadiki/pywhispercpp/issues',
268 | },
269 | install_requires=['numpy', "requests", "tqdm", "platformdirs"],
270 | extras_require={"examples": ["sounddevice", "webrtcvad"],
271 | "gui": ["pyqt5"]},
272 | )
273 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pywhispercpp
2 | Python bindings for [whisper.cpp](https://github.com/ggerganov/whisper.cpp) with a simple Pythonic API on top of it.
3 |
4 | [](https://opensource.org/licenses/MIT)
5 | [](https://github.com/absadiki/pywhispercpp/actions/workflows/wheels.yml)
6 | [](https://pypi.org/project/pywhispercpp/)
7 | [](https://pepy.tech/project/pywhispercpp)
8 |
9 | # Table of contents
10 |
11 | * [Installation](#installation)
12 | * [From source](#from-source)
13 | * [Pre-built wheels](#pre-built-wheels)
14 | * [NVIDIA GPU support](#nvidia-gpu-support)
15 | * [CoreML support](#coreml-support)
16 | * [Vulkan support](#vulkan-support)
17 | * [Quick start](#quick-start)
18 | * [Examples](#examples)
19 | * [CLI](#cli)
20 | * [GUI](#gui)
21 | * [Assistant](#assistant)
22 | * [Advanced usage](#advanced-usage)
23 | * [Discussions and contributions](#discussions-and-contributions)
24 | * [License](#license)
25 |
26 |
27 | # Installation
28 |
29 | ### From source
30 | * For the best performance, you need to install the package from source:
31 | ```shell
32 | pip install git+https://github.com/absadiki/pywhispercpp
33 | ```
34 | ### Pre-built wheels
35 | * Otherwise, Basic Pre-built CPU wheels are available on PYPI
36 |
37 | ```shell
38 | pip install pywhispercpp # or pywhispercpp[examples] to install the extra dependencies needed for the examples
39 | ```
40 |
41 | [Optional] To transcribe files other than wav, you need to install ffmpeg:
42 | ```shell
43 | # on Ubuntu or Debian
44 | sudo apt update && sudo apt install ffmpeg
45 |
46 | # on Arch Linux
47 | sudo pacman -S ffmpeg
48 |
49 | # on MacOS using Homebrew (https://brew.sh/)
50 | brew install ffmpeg
51 |
52 | # on Windows using Chocolatey (https://chocolatey.org/)
53 | choco install ffmpeg
54 |
55 | # on Windows using Scoop (https://scoop.sh/)
56 | scoop install ffmpeg
57 | ```
58 |
59 | ### NVIDIA GPU support
60 | To Install the package with CUDA support, make sure you have [cuda](https://developer.nvidia.com/cuda-downloads) installed and use `GGML_CUDA=1`:
61 |
62 | ```shell
63 | GGML_CUDA=1 pip install git+https://github.com/absadiki/pywhispercpp
64 | ```
65 | ### CoreML support
66 |
67 | Install the package with `WHISPER_COREML=1`:
68 |
69 | ```shell
70 | WHISPER_COREML=1 pip install git+https://github.com/absadiki/pywhispercpp
71 | ```
72 |
73 | ### Vulkan support
74 |
75 | Install the package with `GGML_VULKAN=1`:
76 |
77 | ```shell
78 | GGML_VULKAN=1 pip install git+https://github.com/absadiki/pywhispercpp
79 | ```
80 |
81 | ### OpenBLAS support
82 |
83 | If OpenBLAS is installed, you can use `GGML_BLAS=1`. The other flags ensure you're installing fresh with the correct flags, and printing output for sanity checking.
84 | ```shell
85 | GGML_BLAS=1 pip install git+https://github.com/absadiki/pywhispercpp --no-cache --force-reinstall -v
86 | ```
87 |
88 | ### OpenVINO support
89 |
90 | Follow the the steps to download correct OpenVINO package (https://github.com/ggerganov/whisper.cpp?tab=readme-ov-file#openvino-support).
91 |
92 | Then init the OpenVINO environment and build.
93 | ```
94 | source ~/l_openvino_toolkit_ubuntu22_2023.0.0.10926.b4452d56304_x86_64/setupvars.sh
95 | WHISPER_OPENVINO=1 pip install git+https://github.com/absadiki/pywhispercpp --no-cache --force-reinstall
96 | ```
97 |
98 | Note that the toolkit for Ubuntu22 works on Ubuntu24
99 |
100 |
101 | ** __Feel free to update this list and submit a PR if you tested the package on other backends.__
102 |
103 |
104 | # Quick start
105 |
106 | ```python
107 | from pywhispercpp.model import Model
108 |
109 | model = Model('base.en')
110 | segments = model.transcribe('file.wav')
111 | for segment in segments:
112 | print(segment.text)
113 | ```
114 |
115 | You can also assign a custom `new_segment_callback`
116 |
117 | ```python
118 | from pywhispercpp.model import Model
119 |
120 | model = Model('base.en', print_realtime=False, print_progress=False)
121 | segments = model.transcribe('file.mp3', new_segment_callback=print)
122 | ```
123 |
124 |
125 | * The model will be downloaded automatically, or you can use the path to a local model.
126 | * You can pass any `whisper.cpp` [parameter](https://absadiki.github.io/pywhispercpp/#pywhispercpp.constants.PARAMS_SCHEMA) as a keyword argument to the `Model` class or to the `transcribe` function.
127 | * Check the [Model](https://absadiki.github.io/pywhispercpp/#pywhispercpp.model.Model) class documentation for more details.
128 |
129 | # Examples
130 |
131 | ## CLI
132 | Just a straightforward example Command Line Interface.
133 | You can use it as follows:
134 |
135 | ```shell
136 | pwcpp file.wav -m base --output-srt --print_realtime true
137 | ```
138 | Run ```pwcpp --help``` to get the help message
139 |
140 | ```shell
141 | usage: pwcpp [-h] [-m MODEL] [--version] [--processors PROCESSORS] [-otxt] [-ovtt] [-osrt] [-ocsv] [--strategy STRATEGY]
142 | [--n_threads N_THREADS] [--n_max_text_ctx N_MAX_TEXT_CTX] [--offset_ms OFFSET_MS] [--duration_ms DURATION_MS]
143 | [--translate TRANSLATE] [--no_context NO_CONTEXT] [--single_segment SINGLE_SEGMENT] [--print_special PRINT_SPECIAL]
144 | [--print_progress PRINT_PROGRESS] [--print_realtime PRINT_REALTIME] [--print_timestamps PRINT_TIMESTAMPS]
145 | [--token_timestamps TOKEN_TIMESTAMPS] [--thold_pt THOLD_PT] [--thold_ptsum THOLD_PTSUM] [--max_len MAX_LEN]
146 | [--split_on_word SPLIT_ON_WORD] [--max_tokens MAX_TOKENS] [--audio_ctx AUDIO_CTX]
147 | [--prompt_tokens PROMPT_TOKENS] [--prompt_n_tokens PROMPT_N_TOKENS] [--language LANGUAGE] [--suppress_blank SUPPRESS_BLANK]
148 | [--suppress_non_speech_tokens SUPPRESS_NON_SPEECH_TOKENS] [--temperature TEMPERATURE] [--max_initial_ts MAX_INITIAL_TS]
149 | [--length_penalty LENGTH_PENALTY] [--temperature_inc TEMPERATURE_INC] [--entropy_thold ENTROPY_THOLD]
150 | [--logprob_thold LOGPROB_THOLD] [--no_speech_thold NO_SPEECH_THOLD] [--greedy GREEDY] [--beam_search BEAM_SEARCH]
151 | media_file [media_file ...]
152 |
153 | positional arguments:
154 | media_file The path of the media file or a list of filesseparated by space
155 |
156 | options:
157 | -h, --help show this help message and exit
158 | -m MODEL, --model MODEL
159 | Path to the `ggml` model, or just the model name
160 | --version show program's version number and exit
161 | --processors PROCESSORS
162 | number of processors to use during computation
163 | -otxt, --output-txt output result in a text file
164 | -ovtt, --output-vtt output result in a vtt file
165 | -osrt, --output-srt output result in a srt file
166 | -ocsv, --output-csv output result in a CSV file
167 | --strategy STRATEGY Available sampling strategiesGreefyDecoder -> 0BeamSearchDecoder -> 1
168 | --n_threads N_THREADS
169 | Number of threads to allocate for the inferencedefault to min(4, available hardware_concurrency)
170 | --n_max_text_ctx N_MAX_TEXT_CTX
171 | max tokens to use from past text as prompt for the decoder
172 | --offset_ms OFFSET_MS
173 | start offset in ms
174 | --duration_ms DURATION_MS
175 | audio duration to process in ms
176 | --translate TRANSLATE
177 | whether to translate the audio to English
178 | --no_context NO_CONTEXT
179 | do not use past transcription (if any) as initial prompt for the decoder
180 | --single_segment SINGLE_SEGMENT
181 | force single segment output (useful for streaming)
182 | --print_special PRINT_SPECIAL
183 | print special tokens (e.g. , , , etc.)
184 | --print_progress PRINT_PROGRESS
185 | print progress information
186 | --print_realtime PRINT_REALTIME
187 | print results from within whisper.cpp (avoid it, use callback instead)
188 | --print_timestamps PRINT_TIMESTAMPS
189 | print timestamps for each text segment when printing realtime
190 | --token_timestamps TOKEN_TIMESTAMPS
191 | enable token-level timestamps
192 | --thold_pt THOLD_PT timestamp token probability threshold (~0.01)
193 | --thold_ptsum THOLD_PTSUM
194 | timestamp token sum probability threshold (~0.01)
195 | --max_len MAX_LEN max segment length in characters
196 | --split_on_word SPLIT_ON_WORD
197 | split on word rather than on token (when used with max_len)
198 | --max_tokens MAX_TOKENS
199 | max tokens per segment (0 = no limit)
200 | --audio_ctx AUDIO_CTX
201 | overwrite the audio context size (0 = use default)
202 | --prompt_tokens PROMPT_TOKENS
203 | tokens to provide to the whisper decoder as initial prompt
204 | --prompt_n_tokens PROMPT_N_TOKENS
205 | tokens to provide to the whisper decoder as initial prompt
206 | --language LANGUAGE for auto-detection, set to None, "" or "auto"
207 | --suppress_blank SUPPRESS_BLANK
208 | common decoding parameters
209 | --suppress_non_speech_tokens SUPPRESS_NON_SPEECH_TOKENS
210 | common decoding parameters
211 | --temperature TEMPERATURE
212 | initial decoding temperature
213 | --max_initial_ts MAX_INITIAL_TS
214 | max_initial_ts
215 | --length_penalty LENGTH_PENALTY
216 | length_penalty
217 | --temperature_inc TEMPERATURE_INC
218 | temperature_inc
219 | --entropy_thold ENTROPY_THOLD
220 | similar to OpenAI's "compression_ratio_threshold"
221 | --logprob_thold LOGPROB_THOLD
222 | logprob_thold
223 | --no_speech_thold NO_SPEECH_THOLD
224 | no_speech_thold
225 | --greedy GREEDY greedy
226 | --beam_search BEAM_SEARCH
227 | beam_search
228 |
229 | ```
230 |
231 | ## GUI
232 | If you prefer a Graphical User Interface, you can use the `pwcpp-gui` command which will launch A simple graphical interface built with PyQt5.
233 | * First you need to install the GUI dependencies:
234 | ```bash
235 | pip install pywhispercpp[gui]
236 | ```
237 |
238 | * Then you can run the GUI with:
239 | ```bash
240 | pwcpp-gui
241 | ```
242 |
243 | The GUI provides a user-friendly way to:
244 |
245 | - Select audio files
246 | - Choose models
247 | - Adjust basic transcription settings
248 | - View and export transcription results
249 |
250 | ## Assistant
251 | This is a simple example showcasing the use of `pywhispercpp` to create an assistant like example.
252 | The idea is to use a Voice Activity Detector (VAD) to detect speech (in this example, we used webrtcvad), and when some speech is detected, we run the transcription.
253 | It is inspired from the [whisper.cpp/examples/command](https://github.com/ggerganov/whisper.cpp/tree/master/examples/command) example.
254 |
255 | You can check the source code [here](https://github.com/absadiki/pywhispercpp/blob/main/pywhispercpp/examples/assistant.py)
256 | or you can use the class directly to create your own assistant:
257 |
258 |
259 | ```python
260 | from pywhispercpp.examples.assistant import Assistant
261 |
262 | my_assistant = Assistant(commands_callback=print, n_threads=8)
263 | my_assistant.start()
264 | ```
265 | Here, we set the `commands_callback` to a simple print function, so the commands will just get printed on the screen.
266 |
267 | You can also run this example from the command line.
268 | ```shell
269 | $ pwcpp-assistant --help
270 |
271 | usage: pwcpp-assistant [-h] [-m MODEL] [-ind INPUT_DEVICE] [-st SILENCE_THRESHOLD] [-bd BLOCK_DURATION]
272 |
273 | options:
274 | -h, --help show this help message and exit
275 | -m MODEL, --model MODEL
276 | Whisper.cpp model, default to tiny.en
277 | -ind INPUT_DEVICE, --input_device INPUT_DEVICE
278 | Id of The input device (aka microphone)
279 | -st SILENCE_THRESHOLD, --silence_threshold SILENCE_THRESHOLD
280 | he duration of silence after which the inference will be running, default to 16
281 | -bd BLOCK_DURATION, --block_duration BLOCK_DURATION
282 | minimum time audio updates in ms, default to 30
283 | ```
284 | -------------
285 |
286 | * Check the [examples folder](https://github.com/absadiki/pywhispercpp/tree/main/pywhispercpp/examples) for more examples.
287 |
288 | # Advanced usage
289 | * First check the [API documentation](https://absadiki.github.io/pywhispercpp/) for more advanced usage.
290 | * If you are a more experienced user, you can access the exposed C-APIs directly from the binding module `_pywhispercpp`.
291 |
292 | ```python
293 | import _pywhispercpp as pwcpp
294 |
295 | ctx = pwcpp.whisper_init_from_file('path/to/ggml/model')
296 | ```
297 |
298 | # Discussions and contributions
299 | If you find any bug, please open an [issue](https://github.com/absadiki/pywhispercpp/issues).
300 |
301 | If you have any feedback, or you want to share how you are using this project, feel free to use the [Discussions](https://github.com/absadiki/pywhispercpp/discussions) and open a new topic.
302 |
303 | # License
304 |
305 | This project is licensed under the same license as [whisper.cpp](https://github.com/ggerganov/whisper.cpp/blob/master/LICENSE) (MIT [License](./LICENSE)).
306 |
--------------------------------------------------------------------------------
/pywhispercpp/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | This module contains a simple Python API on-top of the C-style
6 | [whisper.cpp](https://github.com/ggerganov/whisper.cpp) API.
7 | """
8 | import importlib.metadata
9 | import logging
10 | import shutil
11 | import sys
12 | from pathlib import Path
13 | from time import time
14 | from typing import Union, Callable, List, TextIO, Tuple, Optional
15 | import _pywhispercpp as pw
16 | import numpy as np
17 | import pywhispercpp.utils as utils
18 | import pywhispercpp.constants as constants
19 | import subprocess
20 | import os
21 | import tempfile
22 | import wave
23 |
24 | __author__ = "absadiki"
25 | __copyright__ = "Copyright 2023, "
26 | __license__ = "MIT"
27 | __version__ = importlib.metadata.version('pywhispercpp')
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class Segment:
33 | """
34 | A small class representing a transcription segment
35 | """
36 |
37 | def __init__(self, t0: int, t1: int, text: str, probability: float = np.nan):
38 | """
39 | :param t0: start time
40 | :param t1: end time
41 | :param text: text
42 | :param probability: Confidence score for the segment, computed as the geometric mean of
43 | the token probabilities for the segment (NaN if not calculated).
44 | This makes it interpretable as a probability in [0, 1].
45 | """
46 | self.t0 = t0
47 | self.t1 = t1
48 | self.text = text
49 | self.probability = probability
50 |
51 | def __str__(self):
52 | return f"t0={self.t0}, t1={self.t1}, text={self.text}, probability={self.probability}"
53 |
54 | def __repr__(self):
55 | return str(self)
56 |
57 |
58 | class Model:
59 | """
60 | This classes defines a Whisper.cpp model.
61 |
62 | Example usage.
63 | ```python
64 | model = Model('base.en', n_threads=6)
65 | segments = model.transcribe('file.mp3')
66 | for segment in segments:
67 | print(segment.text)
68 | ```
69 | """
70 |
71 | _new_segment_callback = None
72 |
73 | def __init__(self,
74 | model: str = 'tiny',
75 | models_dir: str = None,
76 | params_sampling_strategy: int = 0,
77 | redirect_whispercpp_logs_to: Union[bool, TextIO, str, None] = False,
78 | use_openvino: bool = False,
79 | openvino_model_path: str = None,
80 | openvino_device: str = 'CPU',
81 | openvino_cache_dir: str = None,
82 | **params):
83 | """
84 | :param model: The name of the model, one of the [AVAILABLE_MODELS](/pywhispercpp/#pywhispercpp.constants.AVAILABLE_MODELS),
85 | (default to `tiny`), or a direct path to a `ggml` model.
86 | :param models_dir: The directory where the models are stored, or where they will be downloaded if they don't
87 | exist, default to [MODELS_DIR](/pywhispercpp/#pywhispercpp.constants.MODELS_DIR)
88 | :param params_sampling_strategy: 0 -> GREEDY, else BEAM_SEARCH
89 | :param redirect_whispercpp_logs_to: where to redirect the whisper.cpp logs, default to False (no redirection), accepts str file path, sys.stdout, sys.stderr, or use None to redirect to devnull
90 | :param use_openvino: whether to use OpenVINO or not
91 | :param openvino_model_path: path to the OpenVINO model
92 | :param openvino_device: OpenVINO device, default to CPU
93 | :param openvino_cache_dir: OpenVINO cache directory
94 | :param params: keyword arguments for different whisper.cpp parameters,
95 | see [PARAMS_SCHEMA](/pywhispercpp/#pywhispercpp.constants.PARAMS_SCHEMA)
96 | """
97 | if Path(model).is_file():
98 | self.model_path = model
99 | else:
100 | self.model_path = utils.download_model(model, models_dir)
101 | self._ctx = None
102 | self._sampling_strategy = pw.whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY if params_sampling_strategy == 0 else \
103 | pw.whisper_sampling_strategy.WHISPER_SAMPLING_BEAM_SEARCH
104 | self._params = pw.whisper_full_default_params(self._sampling_strategy)
105 | # assign params
106 | self.params = params
107 | self._set_params(params)
108 | self.redirect_whispercpp_logs_to = redirect_whispercpp_logs_to
109 | self.use_openvino = use_openvino
110 | self.openvino_model_path = openvino_model_path
111 | self.openvino_device = openvino_device
112 | self.openvino_cache_dir = openvino_cache_dir
113 | # init the model
114 | self._init_model()
115 |
116 | def transcribe(self,
117 | media: Union[str, np.ndarray],
118 | n_processors: int = None,
119 | new_segment_callback: Callable[[Segment], None] = None,
120 | **params) -> List[Segment]:
121 | """
122 | Transcribes the media provided as input and returns list of `Segment` objects.
123 | Accepts a media_file path (audio/video) or a raw numpy array.
124 |
125 | :param media: Media file path or a numpy array
126 | :param n_processors: if not None, it will run the transcription on multiple processes
127 | binding to whisper.cpp/whisper_full_parallel
128 | > Split the input audio in chunks and process each chunk separately using whisper_full()
129 | :param new_segment_callback: callback function that will be called when a new segment is generated
130 | :param params: keyword arguments for different whisper.cpp parameters, see ::: constants.PARAMS_SCHEMA
131 | :param extract_probability: If True, calculates the geometric mean of token probabilities for each segment,
132 | providing a confidence score interpretable as a probability in [0, 1].
133 | :return: List of transcription segments
134 | """
135 | if type(media) is np.ndarray:
136 | audio = media
137 | else:
138 | if not Path(media).exists():
139 | raise FileNotFoundError(media)
140 | audio = self._load_audio(media)
141 |
142 | # Handle extract_probability parameter
143 | self.extract_probability = params.pop('extract_probability', False)
144 |
145 | # update params if any
146 | self._set_params(params)
147 |
148 | # setting up callback
149 | if new_segment_callback:
150 | Model._new_segment_callback = new_segment_callback
151 | pw.assign_new_segment_callback(self._params, Model.__call_new_segment_callback)
152 |
153 | # run inference
154 | start_time = time()
155 | logger.info("Transcribing ...")
156 | res = self._transcribe(audio, n_processors=n_processors)
157 | end_time = time()
158 | logger.info(f"Inference time: {end_time - start_time:.3f} s")
159 | return res
160 |
161 | @staticmethod
162 | def _get_segments(ctx, start: int, end: int, extract_probability: bool = False) -> List[Segment]:
163 | """
164 | Helper function to get generated segments between `start` and `end`
165 |
166 | :param ctx: whisper context
167 | :param start: start index
168 | :param end: end index
169 | :param extract_probability: whether to calculate token probabilities
170 |
171 | :return: list of segments
172 | """
173 | n = pw.whisper_full_n_segments(ctx)
174 | assert end <= n, f"{end} > {n}: `End` index must be less or equal than the total number of segments"
175 | res = []
176 | for i in range(start, end):
177 | t0 = pw.whisper_full_get_segment_t0(ctx, i)
178 | t1 = pw.whisper_full_get_segment_t1(ctx, i)
179 | bytes = pw.whisper_full_get_segment_text(ctx, i)
180 | text = bytes.decode('utf-8', errors='replace')
181 |
182 | avg_prob = np.nan
183 |
184 | # Only calculate probabilities if requested
185 | if extract_probability:
186 | n_tokens = pw.whisper_full_n_tokens(ctx, i)
187 | if n_tokens == 1:
188 | avg_prob = pw.whisper_full_get_token_p(ctx, i, 0)
189 | elif n_tokens > 1:
190 | total_logprob = 0.0
191 | for j in range(n_tokens):
192 | total_logprob += np.log(pw.whisper_full_get_token_p(ctx, i, j))
193 | avg_prob = np.exp(total_logprob / n_tokens)
194 | else:
195 | avg_prob = np.nan
196 |
197 | res.append(Segment(t0, t1, text.strip(), probability=np.float32(avg_prob)))
198 | return res
199 |
200 | def get_params(self) -> dict:
201 | """
202 | Returns a `dict` representation of the actual params
203 |
204 | :return: params dict
205 | """
206 | res = {}
207 | for param in dir(self._params):
208 | if param.startswith('__'):
209 | continue
210 | try:
211 | res[param] = getattr(self._params, param)
212 | except Exception:
213 | # ignore callback functions
214 | continue
215 | return res
216 |
217 | @staticmethod
218 | def get_params_schema() -> dict:
219 | """
220 | A simple link to ::: constants.PARAMS_SCHEMA
221 | :return: dict of params schema
222 | """
223 | return constants.PARAMS_SCHEMA
224 |
225 | @staticmethod
226 | def lang_max_id() -> int:
227 | """
228 | Returns number of supported languages.
229 | Direct binding to whisper.cpp/lang_max_id
230 | :return:
231 | """
232 | return pw.whisper_lang_max_id()
233 |
234 | def print_timings(self) -> None:
235 | """
236 | Direct binding to whisper.cpp/whisper_print_timings
237 |
238 | :return: None
239 | """
240 | pw.whisper_print_timings(self._ctx)
241 |
242 | @staticmethod
243 | def system_info() -> None:
244 | """
245 | Direct binding to whisper.cpp/whisper_print_system_info
246 |
247 | :return: None
248 | """
249 | return pw.whisper_print_system_info()
250 |
251 | @staticmethod
252 | def available_languages() -> list[str]:
253 | """
254 | Returns a list of supported language codes
255 |
256 | :return: list of supported language codes
257 | """
258 | n = pw.whisper_lang_max_id()
259 | res = []
260 | for i in range(n+1):
261 | res.append(pw.whisper_lang_str(i))
262 | return res
263 |
264 | def _init_model(self) -> None:
265 | """
266 | Private method to initialize the method from the bindings, it will be called automatically from the __init__
267 | :return:
268 | """
269 | logger.info("Initializing the model ...")
270 | with utils.redirect_stderr(to=self.redirect_whispercpp_logs_to):
271 | self._ctx = pw.whisper_init_from_file(self.model_path)
272 | if self.use_openvino:
273 | pw.whisper_ctx_init_openvino_encoder(self._ctx, self.openvino_model_path, self.openvino_device, self.openvino_cache_dir)
274 |
275 |
276 |
277 | def _set_params(self, kwargs: dict) -> None:
278 | """
279 | Private method to set the kwargs params to the `Params` class
280 | :param kwargs: dict like object for the different params
281 | :return: None
282 | """
283 | for param in kwargs:
284 | setattr(self._params, param, kwargs[param])
285 |
286 | def _transcribe(self, audio: np.ndarray, n_processors: int = None):
287 | """
288 | Private method to call the whisper.cpp/whisper_full function
289 |
290 | :param audio: numpy array of audio data
291 | :param n_processors: if not None, it will run whisper.cpp/whisper_full_parallel with n_processors
292 | :return:
293 | """
294 |
295 | if n_processors:
296 | pw.whisper_full_parallel(self._ctx, self._params, audio, audio.size, n_processors)
297 | else:
298 | pw.whisper_full(self._ctx, self._params, audio, audio.size)
299 | n = pw.whisper_full_n_segments(self._ctx)
300 | res = Model._get_segments(self._ctx, 0, n, self.extract_probability)
301 | return res
302 |
303 | @staticmethod
304 | def __call_new_segment_callback(ctx, n_new, user_data) -> None:
305 | """
306 | Internal new_segment_callback, it just calls the user's callback with the `Segment` object
307 | :param ctx: whisper.cpp ctx param
308 | :param n_new: whisper.cpp n_new param
309 | :param user_data: whisper.cpp user_data param
310 | :return: None
311 | """
312 | n = pw.whisper_full_n_segments(ctx)
313 | start = n - n_new
314 | res = Model._get_segments(ctx, start, n, False)
315 | for segment in res:
316 | Model._new_segment_callback(segment)
317 |
318 | @staticmethod
319 | def _load_audio(media_file_path: str) -> np.array:
320 | """
321 | Helper method to return a `np.array` object from a media file
322 | If the media file is not a WAV file, it will try to convert it using ffmpeg
323 |
324 | :param media_file_path: Path of the media file
325 | :return: Numpy array
326 | """
327 |
328 | def wav_to_np(file_path):
329 | with wave.open(file_path, 'rb') as wf:
330 | num_channels = wf.getnchannels()
331 | sample_width = wf.getsampwidth()
332 | sample_rate = wf.getframerate()
333 | num_frames = wf.getnframes()
334 |
335 | if num_channels not in (1, 2):
336 | raise Exception(f"WAV file must be mono or stereo")
337 |
338 | if sample_rate != pw.WHISPER_SAMPLE_RATE:
339 | raise Exception(f"WAV file must be {pw.WHISPER_SAMPLE_RATE} Hz")
340 |
341 | if sample_width != 2:
342 | raise Exception(f"WAV file must be 16-bit")
343 |
344 | raw = wf.readframes(num_frames)
345 | wf.close()
346 | audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32)
347 | n = num_frames
348 | if num_channels == 1:
349 | pcmf32 = audio / 32768.0
350 | else:
351 | audio = audio.reshape(-1, 2)
352 | # Averaging the two channels
353 | pcmf32 = (audio[:, 0] + audio[:, 1]) / 65536.0
354 | return pcmf32
355 |
356 | if media_file_path.endswith('.wav'):
357 | return wav_to_np(media_file_path)
358 | else:
359 | if shutil.which('ffmpeg') is None:
360 | raise Exception(
361 | "FFMPEG is not installed or not in PATH. Please install it, or provide a WAV file or a NumPy array instead!")
362 |
363 | temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
364 | temp_file_path = temp_file.name
365 | temp_file.close()
366 | try:
367 | subprocess.run([
368 | 'ffmpeg', '-i', media_file_path, '-ac', '1', '-ar', '16000',
369 | temp_file_path, '-y'
370 | ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
371 | return wav_to_np(temp_file_path)
372 | finally:
373 | os.remove(temp_file_path)
374 |
375 | def auto_detect_language(self, media: Union[str, np.ndarray], offset_ms: int = 0, n_threads: int = 4) -> Tuple[Tuple[str, np.float32], dict[str, np.float32]]:
376 | """
377 | Automatic language detection using whisper.cpp/whisper_pcm_to_mel and whisper.cpp/whisper_lang_auto_detect
378 |
379 | :param media: Media file path or a numpy array
380 | :param offset_ms: offset in milliseconds
381 | :param n_threads: number of threads to use
382 | :return: ((detected_language, probability), probabilities for all languages)
383 | """
384 | if type(media) is np.ndarray:
385 | audio = media
386 | else:
387 | if not Path(media).exists():
388 | raise FileNotFoundError(media)
389 | audio = self._load_audio(media)
390 |
391 | pw.whisper_pcm_to_mel(self._ctx, audio, len(audio), n_threads)
392 | lang_max_id = self.lang_max_id()
393 | probs = np.zeros(lang_max_id, dtype=np.float32)
394 | auto_detect = pw.whisper_lang_auto_detect(self._ctx, offset_ms, n_threads, probs)
395 | langs = self.available_languages()
396 | lang_probs = {langs[i]: probs[i] for i in range(lang_max_id)}
397 | return (langs[auto_detect], probs[auto_detect]), lang_probs
398 |
399 | def __del__(self):
400 | """
401 | Free up resources
402 | :return: None
403 | """
404 | pw.whisper_free(self._ctx)
--------------------------------------------------------------------------------
/pywhispercpp/examples/gui.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import threading
3 | from datetime import datetime
4 |
5 | from PyQt5.QtWidgets import (
6 | QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
7 | QFileDialog, QProgressBar, QLabel, QFrame,
8 | QSizePolicy, QTableWidget, QTableWidgetItem, QHeaderView,
9 | QGroupBox, QFormLayout, QComboBox, QLineEdit, QCheckBox,
10 | QSpinBox, QDoubleSpinBox, QToolButton, QDialog, QMenu # Import QMenu
11 | )
12 | from PyQt5.QtCore import Qt, pyqtSignal, QObject
13 | import os
14 | import importlib.metadata
15 |
16 | __version__ = importlib.metadata.version('pywhispercpp')
17 |
18 | from pywhispercpp.model import Model, Segment
19 | from pywhispercpp.utils import output_txt, output_srt, output_vtt, output_csv # Import utilities
20 |
21 |
22 | # --- Available Models ---
23 | # Define a custom order for model sizes
24 | MODEL_SIZE_ORDER = {"tiny": 0, "base": 1, "small": 2, "medium": 3, "large": 4}
25 |
26 | UNSORTED_MODELS = [
27 | "base", "base-q5_1", "base-q8_0", "base.en", "base.en-q5_1",
28 | "base.en-q8_0", "large-v1", "large-v2", "large-v2-q5_0",
29 | "large-v2-q8_0", "large-v3", "large-v3-q5_0", "large-v3-turbo",
30 | "large-v3-turbo-q5_0", "large-v3-turbo-q8_0", "medium",
31 | "medium-q5_0", "medium-q8_0", "medium.en", "medium.en-q5_0",
32 | "medium.en-q8_0", "small", "small-q5_1", "small-q8_0", "small.en",
33 | "small.en-q5_1", "small.en-q8_0", "tiny", "tiny-q5_1", "tiny-q8_0",
34 | "tiny.en", "tiny.en-q5_1", "tiny.en-q8_0",
35 | ]
36 |
37 |
38 | # Custom sort key function
39 | def get_model_sort_key(model_name):
40 | # Extract the base model name (e.g., "tiny.en" -> "tiny", "large-v3-turbo" -> "large")
41 | base_name = model_name.split('.')[0].split('-')[0]
42 | # Return a tuple for multi-level sorting: (size_order, full_model_name_for_secondary_sort)
43 | return (MODEL_SIZE_ORDER.get(base_name, 99), model_name) # 99 for any unexpected names
44 |
45 |
46 | # Sort the models
47 | AVAILABLE_MODELS = sorted(UNSORTED_MODELS, key=get_model_sort_key)
48 |
49 | # --- Retouched Minimal Stylesheet ---
50 | STYLESHEET = """
51 | /* General Application Styles */
52 | QWidget {
53 | background-color: #f5f5f5; /* Light gray background */
54 | color: #333333; /* Dark text */
55 | font-family: Arial, sans-serif;
56 | font-size: 14px;
57 | }
58 |
59 | /* Buttons */
60 | QPushButton {
61 | background-color: #e0e0e0; /* Slightly darker gray for buttons */
62 | color: #333333;
63 | border: 1px solid #c0c0c0; /* Light gray border */
64 | padding: 6px 12px;
65 | border-radius: 3px; /* Slightly rounded corners */
66 | outline: none; /* Remove focus outline for a cleaner look */
67 | }
68 |
69 | QPushButton:hover {
70 | background-color: #d0d0d0; /* Darker on hover */
71 | }
72 |
73 | QPushButton:pressed {
74 | background-color: #c0c0c0; /* Even darker when pressed */
75 | }
76 |
77 | QPushButton:disabled {
78 | background-color: #f0f0f0;
79 | color: #aaaaaa;
80 | border-color: #e0e0e0;
81 | }
82 |
83 | /* Specific styling for the Transcribe button (Call to Action) */
84 | QPushButton#TranscribeButton {
85 | background-color: #007bff; /* Vibrant blue */
86 | color: #ffffff; /* White text */
87 | font-weight: bold;
88 | font-size: 15px; /* Slightly larger font */
89 | padding: 8px 18px; /* More padding */
90 | border: 1px solid #007bff; /* Matching border */
91 | border-radius: 4px;
92 | }
93 |
94 | QPushButton#TranscribeButton:hover {
95 | background-color: #0056b3; /* Darker blue on hover */
96 | border-color: #0056b3;
97 | }
98 |
99 | QPushButton#TranscribeButton:pressed {
100 | background-color: #004085; /* Even darker blue when pressed */
101 | border-color: #004085;
102 | }
103 |
104 | QPushButton#TranscribeButton:disabled {
105 | background-color: #cccccc; /* Light gray for disabled state */
106 | color: #888888;
107 | border-color: #cccccc;
108 | }
109 |
110 | /* Stop button */
111 | QPushButton#StopButton {
112 | background-color: #ff0000; /* Red for stop */
113 | color: #ffffff; /* White text */
114 | font-weight: bold;
115 | font-size: 15px; /* Slightly larger font */
116 | padding: 8px 18px; /* More padding */
117 | border: 1px solid #ff0000; /* Matching border */
118 | border-radius: 4px;
119 | }
120 |
121 | QPushButton#StopButton:hover {
122 | background-color: #cc0000; /* Darker red on hover */
123 | border-color: #cc0000;
124 | }
125 |
126 | QPushButton#StopButton:pressed {
127 | background-color: #990000; /* Even darker red when pressed */
128 | border-color: #990000;
129 | }
130 |
131 | QPushButton#StopButton:disabled {
132 | background-color: #cccccc; /* Light gray for disabled state */
133 | color: #888888;
134 | border-color: #cccccc;
135 | }
136 |
137 | /* ToolButton for accordion header */
138 | QToolButton {
139 | background-color: transparent; /* Keep it transparent */
140 | border: none;
141 | padding: 5px 0; /* Some padding for clickable area */
142 | font-weight: bold;
143 | color: #444444; /* Darker text for emphasis */
144 | text-align: left;
145 | outline: none; /* Remove focus outline */
146 | }
147 |
148 | QToolButton::menu-indicator {
149 | image: none; /* Hide default menu indicator */
150 | }
151 |
152 | /* Table Widget */
153 | QTableWidget {
154 | background-color: #ffffff;
155 | border: 1px solid #d0d0d0; /* Slightly softer border */
156 | gridline-color: #e8e8e8; /* Very light grid lines */
157 | border-radius: 3px;
158 | }
159 |
160 | QTableWidget::item {
161 | padding: 4px;
162 | border-bottom: 1px solid #f5f5f5; /* Match general background for subtle row separation */
163 | }
164 |
165 | /* Table Header */
166 | QHeaderView::section {
167 | background-color: #e8e8e8; /* Light header background */
168 | color: #333333;
169 | padding: 5px;
170 | border: 1px solid #d0d0d0;
171 | font-weight: bold;
172 | }
173 |
174 | /* Labels */
175 | QLabel {
176 | color: #333333;
177 | }
178 |
179 | /* Specific styling for the main title label */
180 | QLabel#TitleLabel {
181 | font-size: 18px; /* Slightly smaller than previous '20px' to integrate better */
182 | font-weight: bold;
183 | color: #0056b3; /* A slightly darker blue for main title */
184 | padding-bottom: 3px;
185 | margin-bottom: 5px;
186 | border-bottom: 1px solid #e0e0e0;
187 | }
188 |
189 | /* Progress Bar */
190 | QProgressBar {
191 | border: 1px solid #bbbbbb;
192 | border-radius: 3px;
193 | text-align: center;
194 | background-color: #e6e6e6;
195 | color: #333333;
196 | }
197 |
198 | QProgressBar::chunk {
199 | background-color: #4CAF50; /* A pleasant green */
200 | border-radius: 2px;
201 | }
202 |
203 | /* Input Widgets (ComboBox, LineEdit, SpinBox, DoubleSpinBox) */
204 | QComboBox, QLineEdit, QSpinBox, QDoubleSpinBox {
205 | border: 1px solid #cccccc;
206 | padding: 3px;
207 | background-color: #ffffff;
208 | color: #333333;
209 | border-radius: 3px; /* Apply rounded corners */
210 | }
211 |
212 | /* File label should visually match text inputs */
213 | QLabel#file_label {
214 | border: 1px solid #cccccc;
215 | padding: 3px;
216 | background-color: #ffffff;
217 | border-radius: 3px;
218 | }
219 |
220 | QFrame {
221 | border: none; /* Keep frames invisible unless needed for structure */
222 | }
223 |
224 | /* Status Bar Label */
225 | QLabel#status_bar_label {
226 | background-color: #e0e0e0; /* Light grey background */
227 | border-top: 1px solid #cccccc; /* Separator */
228 | color: #444444; /* Darker text */
229 | padding: 3px 5px; /* Add internal padding */
230 | font-size: 13px; /* Slightly smaller font */
231 | }
232 | """
233 |
234 |
235 | # --- Communication Object for Threading ---
236 | class WorkerSignals(QObject):
237 | """
238 | Defines signals available from a running worker thread.
239 | Supported signals are:
240 | - finished: No data
241 | - error: tuple (exctype, value, traceback.format_exc())
242 | - result: list (the transcribed segments)
243 | - progress: int (0-100)
244 | - status_update: str
245 | """
246 | finished = pyqtSignal()
247 | error = pyqtSignal(tuple)
248 | segment = pyqtSignal(Segment)
249 | result = pyqtSignal(list)
250 | progress = pyqtSignal(int)
251 | status_update = pyqtSignal(str)
252 |
253 |
254 | # --- Worker Thread for Transcription ---
255 | class PyWhisperCppWorker(threading.Thread):
256 |
257 | def __init__(self, audio_file_path, model_name, **transcribe_params):
258 | super().__init__()
259 | self.audio_file_path = audio_file_path
260 | self.model_name = model_name
261 | self.transcribe_params = transcribe_params
262 | self.signals = WorkerSignals()
263 | self._is_running = False
264 |
265 | def run(self):
266 | """
267 | Executes the transcription process.
268 | """
269 | try:
270 | self._is_running = True
271 | self.signals.status_update.emit(f"Loading model: {self.model_name}...")
272 |
273 | # pywhispercpp will download the specified model if not found
274 | model_init_params = {}
275 | if 'n_threads' in self.transcribe_params and self.transcribe_params['n_threads'] is not None:
276 | model_init_params['n_threads'] = self.transcribe_params['n_threads']
277 | # Remove from transcribe_params as it's a model init param
278 | del self.transcribe_params['n_threads']
279 |
280 | model = Model(self.model_name, **model_init_params)
281 |
282 | self.signals.status_update.emit("Model loaded. Starting transcription...")
283 |
284 | def new_segment_callback(segment):
285 | if not self._is_running:
286 | raise RuntimeError("Transcription manually stopped")
287 | self.signals.segment.emit(segment)
288 |
289 | segments = model.transcribe(self.audio_file_path,
290 | new_segment_callback=new_segment_callback,
291 | progress_callback=lambda progress: self.signals.progress.emit(progress),
292 | **self.transcribe_params)
293 |
294 | self.signals.status_update.emit("Transcription complete!")
295 | self.signals.result.emit(segments)
296 |
297 | except Exception as e:
298 | print(e)
299 | self.signals.status_update.emit(f"Error: {str(e)}")
300 | self.signals.error.emit((type(e), e, str(e)))
301 | finally:
302 | self._is_running = False
303 | self.signals.finished.emit()
304 |
305 | def stop(self):
306 | self._is_running = False
307 |
308 |
309 | # --- Main Application Window ---
310 | class TranscriptionApp(QWidget):
311 | def __init__(self):
312 | super().__init__()
313 | self.selected_file_path = None
314 | self.whisper_thread = None
315 | # Settings widgets
316 | self.model_combo = None
317 | self.language_input = None
318 | self.translate_checkbox = None
319 | self.n_threads_spinbox = None
320 | self.no_context_checkbox = None
321 | self.temperature_spinbox = None
322 | self.settings_content_frame = None # Frame to hold collapsible settings
323 | self.toggle_settings_button = None # Button to toggle settings
324 | self.status_bar_label = None # New label for the status bar
325 | self.about_button = None # About button
326 | self.segments = [] # Store segments for export
327 | self.copy_text_button = None # New button for copy text
328 |
329 | self.initUI()
330 |
331 | def initUI(self):
332 | """
333 | Initializes the user interface of the application.
334 | """
335 | self.setWindowTitle('PyWhisperCpp Simple GUI')
336 | self.setGeometry(100, 100, 450, 500)
337 | # Apply the updated stylesheet
338 | self.setStyleSheet(STYLESHEET)
339 |
340 | # Main vertical layout
341 | main_layout = QVBoxLayout()
342 | # Set bottom margin to 0 for the main layout to ensure status bar is flush
343 | main_layout.setContentsMargins(4, 4, 4, 0)
344 | main_layout.setSpacing(10)
345 |
346 | # --- Header (Title + About Button) ---
347 | header_layout = QHBoxLayout()
348 | title_label = QLabel("PyWhisperCpp Simple GUI") # Updated main title label
349 | title_label.setObjectName("TitleLabel") # Add objectName for styling
350 | title_label.setAlignment(Qt.AlignLeft) # Keep title centered within its allocated space
351 |
352 | # Adding stretch before and after title to center it
353 | # header_layout.addStretch()
354 | header_layout.addWidget(title_label)
355 | header_layout.addStretch()
356 |
357 | # About button
358 | self.about_button = QPushButton("About")
359 | self.about_button.clicked.connect(self.show_about_dialog)
360 | # Removed setFixedSize to allow text to fit, or adjust as needed
361 | # self.about_button.setFixedSize(50, 25)
362 | header_layout.addWidget(self.about_button) # Add it to the header layout
363 |
364 | main_layout.addLayout(header_layout) # Add the combined header to main layout
365 |
366 | # --- File Selection Area ---
367 | file_frame = QFrame()
368 | file_layout = QHBoxLayout(file_frame)
369 | file_layout.setContentsMargins(0, 0, 0, 0)
370 | file_layout.setSpacing(10)
371 |
372 | self.select_button = QPushButton("Select Audio File")
373 | self.select_button.clicked.connect(self.select_file)
374 |
375 | self.file_label = QLabel("No file selected.")
376 | self.file_label.setObjectName("file_label") # Added objectName for styling
377 | self.file_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)
378 |
379 | file_layout.addWidget(self.select_button)
380 | file_layout.addWidget(self.file_label)
381 | main_layout.addWidget(file_frame)
382 |
383 | # --- Collapsible Settings Section ---
384 | settings_group = QGroupBox() # No title here, using QToolButton for title
385 | settings_group_layout = QVBoxLayout(settings_group)
386 | settings_group_layout.setContentsMargins(5, 5, 5, 5)
387 |
388 | # Custom title bar for the collapsible group box
389 | header_layout_settings = QHBoxLayout() # Renamed to avoid clash
390 | self.toggle_settings_button = QToolButton(settings_group)
391 | self.toggle_settings_button.setText("Transcription Settings")
392 | self.toggle_settings_button.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
393 | self.toggle_settings_button.setArrowType(Qt.RightArrow)
394 | self.toggle_settings_button.setCheckable(True)
395 | self.toggle_settings_button.setChecked(False) # Start collapsed
396 | self.toggle_settings_button.clicked.connect(self.toggle_settings_visibility)
397 |
398 | header_layout_settings.addWidget(self.toggle_settings_button)
399 | header_layout_settings.addStretch() # Push button to left
400 |
401 | settings_group_layout.addLayout(header_layout_settings)
402 |
403 | # Frame to hold the actual settings form (this will be hidden/shown)
404 | self.settings_content_frame = QFrame()
405 | settings_form_layout = QFormLayout(self.settings_content_frame)
406 | settings_form_layout.setContentsMargins(15, 5, 10, 10)
407 | settings_form_layout.setSpacing(8)
408 |
409 | # Model Selection
410 | self.model_combo = QComboBox()
411 | self.model_combo.addItems(AVAILABLE_MODELS)
412 | self.model_combo.setCurrentText("tiny") # Default to 'tiny' as requested
413 | settings_form_layout.addRow("Model:", self.model_combo)
414 |
415 | # Language Input
416 | self.language_input = QLineEdit()
417 | self.language_input.setPlaceholderText('e.g., "en", "es", or leave empty for auto-detect')
418 | self.language_input.setText("") # Default to auto-detect
419 | settings_form_layout.addRow("Language:", self.language_input)
420 |
421 | # Translate Checkbox
422 | self.translate_checkbox = QCheckBox("Translate to English")
423 | self.translate_checkbox.setChecked(False) # Default
424 | settings_form_layout.addRow("Translate:", self.translate_checkbox)
425 |
426 | # N Threads Spinbox
427 | self.n_threads_spinbox = QSpinBox()
428 | self.n_threads_spinbox.setRange(1, os.cpu_count() if os.cpu_count() else 8) # Max threads based on CPU cores
429 | self.n_threads_spinbox.setValue(4) # Sensible default
430 | settings_form_layout.addRow("Number of Threads:", self.n_threads_spinbox)
431 |
432 | # No Context Checkbox
433 | self.no_context_checkbox = QCheckBox("No Context (do not use past transcription)")
434 | self.no_context_checkbox.setChecked(False) # Default
435 | settings_form_layout.addRow("No Context:", self.no_context_checkbox)
436 |
437 | # Temperature Spinbox
438 | self.temperature_spinbox = QDoubleSpinBox()
439 | self.temperature_spinbox.setRange(0.0, 1.0)
440 | self.temperature_spinbox.setSingleStep(0.1)
441 | self.temperature_spinbox.setValue(0.0) # Default
442 | settings_form_layout.addRow("Temperature:", self.temperature_spinbox)
443 |
444 | settings_group_layout.addWidget(self.settings_content_frame)
445 | self.settings_content_frame.setVisible(False) # Initially hidden
446 |
447 | main_layout.addWidget(settings_group)
448 |
449 | # --- Transcription Button ---
450 | self.transcribe_button = QPushButton("Transcribe")
451 | self.transcribe_button.setObjectName("TranscribeButton") # Add objectName for styling
452 | self.transcribe_button.setEnabled(False)
453 | self.transcribe_button.clicked.connect(self.start_transcription)
454 | main_layout.addWidget(self.transcribe_button)
455 |
456 | # --- Stop Button ---
457 | self.stop_button = QPushButton("Stop")
458 | self.stop_button.setObjectName("StopButton") # Add objectName for styling
459 | self.stop_button.setEnabled(True)
460 | self.stop_button.setVisible(False)
461 | self.stop_button.clicked.connect(self.stop_transcription)
462 | main_layout.addWidget(self.stop_button)
463 |
464 | # --- Progress Bar ---
465 | progress_frame = QFrame()
466 | progress_layout = QVBoxLayout(progress_frame)
467 | progress_layout.setContentsMargins(0, 5, 0, 5)
468 | progress_layout.setSpacing(5)
469 |
470 | self.progress_bar = QProgressBar()
471 | self.progress_bar.setVisible(False)
472 |
473 | progress_layout.addWidget(self.progress_bar)
474 | main_layout.addWidget(progress_frame)
475 |
476 | # --- Transcription Output Table ---
477 | output_label = QLabel("Transcription Output:")
478 | main_layout.addWidget(output_label)
479 |
480 | self.results_table = QTableWidget()
481 | self.results_table.setColumnCount(3)
482 | self.results_table.setHorizontalHeaderLabels(["Start Time", "End Time", "Text"])
483 | header = self.results_table.horizontalHeader()
484 | header.setSectionResizeMode(0, QHeaderView.ResizeToContents)
485 | header.setSectionResizeMode(1, QHeaderView.ResizeToContents)
486 | header.setSectionResizeMode(2, QHeaderView.Stretch)
487 | self.results_table.verticalHeader().setVisible(False)
488 | main_layout.addWidget(self.results_table)
489 |
490 | # --- Output Buttons (Export and Copy) ---
491 | output_buttons_layout = QHBoxLayout()
492 | output_buttons_layout.addStretch() # Pushes buttons to the right
493 |
494 | # Export Button with Menu
495 | self.export_button = QPushButton("Export as...")
496 | self.export_button.setEnabled(False)
497 | self.export_menu = QMenu(self)
498 |
499 | self.export_action_txt = self.export_menu.addAction("Plain Text (.txt)")
500 | self.export_action_srt = self.export_menu.addAction("SRT Subtitle (.srt)")
501 | self.export_action_vtt = self.export_menu.addAction("VTT Subtitle (.vtt)")
502 | self.export_action_csv = self.export_menu.addAction("CSV (.csv)")
503 |
504 | self.export_action_txt.triggered.connect(lambda: self.export_transcription("txt"))
505 | self.export_action_srt.triggered.connect(lambda: self.export_transcription("srt"))
506 | self.export_action_vtt.triggered.connect(lambda: self.export_transcription("vtt"))
507 | self.export_action_csv.triggered.connect(lambda: self.export_transcription("csv"))
508 |
509 | self.export_button.setMenu(self.export_menu)
510 | output_buttons_layout.addWidget(self.export_button)
511 |
512 | # Copy Text Button
513 | self.copy_text_button = QPushButton("Copy Text")
514 | self.copy_text_button.setEnabled(False) # Initially disabled
515 | self.copy_text_button.clicked.connect(self.copy_all_text_to_clipboard) # Connect to new method
516 | output_buttons_layout.addWidget(self.copy_text_button)
517 |
518 | main_layout.addLayout(output_buttons_layout)
519 |
520 | # --- Status Bar at the very bottom ---
521 | self.status_bar_label = QLabel("Ready.")
522 | self.status_bar_label.setObjectName("status_bar_label") # Add objectName for styling
523 | self.status_bar_label.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
524 | self.status_bar_label.setContentsMargins(5, 2, 5, 2)
525 | main_layout.addWidget(self.status_bar_label)
526 |
527 | self.setLayout(main_layout)
528 |
529 | def toggle_settings_visibility(self):
530 | """Toggles the visibility of the settings content frame and updates the arrow."""
531 | is_visible = self.settings_content_frame.isVisible()
532 | self.settings_content_frame.setVisible(not is_visible)
533 | if not is_visible:
534 | self.toggle_settings_button.setArrowType(Qt.DownArrow)
535 | else:
536 | self.toggle_settings_button.setArrowType(Qt.RightArrow)
537 |
538 | def select_file(self):
539 | """
540 | Opens a file dialog to select an audio file.
541 | """
542 | options = QFileDialog.Options()
543 | file_path, _ = QFileDialog.getOpenFileName(
544 | self, "Select a Media File", "",
545 | "All Files (*)",
546 | options=options
547 | )
548 | if file_path:
549 | self.selected_file_path = file_path
550 | self.file_label.setText(f"Selected: {os.path.basename(file_path)}")
551 | self.transcribe_button.setEnabled(True)
552 | self.results_table.setRowCount(0)
553 | self.export_button.setEnabled(False) # Disable export until transcription
554 | self.copy_text_button.setEnabled(False) # Disable copy until transcription
555 | self.update_status("File selected: " + os.path.basename(file_path)) # Update new status bar
556 |
557 | def start_transcription(self):
558 | """
559 | Starts the transcription process in a separate thread, passing selected settings.
560 | """
561 | if self.selected_file_path:
562 | self.transcribe_button.setVisible(False)
563 | self.stop_button.setVisible(True)
564 | self.select_button.setEnabled(False)
565 | self.progress_bar.setVisible(True)
566 | self.progress_bar.setValue(0)
567 | self.results_table.setRowCount(0)
568 | self.export_button.setEnabled(False) # Disable export during transcription
569 | self.copy_text_button.setEnabled(False) # Disable copy during transcription
570 | self.update_status("Starting transcription...")
571 | self.segments = [] # Clear segments for new transcription
572 |
573 | # Gather settings from GUI widgets
574 | selected_model = self.model_combo.currentText()
575 | transcribe_params = {
576 | "language": self.language_input.text() if self.language_input.text() else None,
577 | "translate": self.translate_checkbox.isChecked(),
578 | "n_threads": self.n_threads_spinbox.value(),
579 | "no_context": self.no_context_checkbox.isChecked(),
580 | "temperature": self.temperature_spinbox.value(),
581 | }
582 | # Remove None values to use pywhispercpp defaults where applicable
583 | transcribe_params = {k: v for k, v in transcribe_params.items() if v is not None}
584 |
585 | # Create and start the worker thread
586 | self.whisper_thread = PyWhisperCppWorker(
587 | self.selected_file_path,
588 | selected_model,
589 | **transcribe_params
590 | )
591 | self.whisper_thread.signals.result.connect(self.on_transcription_result)
592 | self.whisper_thread.signals.segment.connect(self.on_new_segment)
593 | self.whisper_thread.signals.finished.connect(self.on_transcription_finished)
594 | self.whisper_thread.signals.error.connect(self.on_transcription_error)
595 | self.whisper_thread.signals.progress.connect(self.update_progress)
596 | self.whisper_thread.signals.status_update.connect(self.update_status)
597 | self.whisper_thread.start()
598 |
599 | def stop_transcription(self):
600 | if self.whisper_thread:
601 | self.whisper_thread.stop()
602 | # self.transcribe_button.setVisible(True)
603 | # self.stop_button.setVisible(False)
604 | # self.select_button.setEnabled(True)
605 | # self.progress_bar.setVisible(False)
606 | # self.on_transcription_finished()
607 |
608 | def update_progress(self, value):
609 | self.progress_bar.setValue(value)
610 | # Update status bar with progress if not already showing a specific message
611 | if not self.status_bar_label.text().startswith("Error:") and \
612 | not self.status_bar_label.text().startswith("Finished.") and \
613 | not self.status_bar_label.text().startswith("Text exported") and \
614 | not self.status_bar_label.text().startswith("Text copied"): # Updated check
615 | self.update_status(f"Progress: {value}%")
616 |
617 | def update_status(self, status_text):
618 | # Update the new status bar label directly
619 | if self.status_bar_label:
620 | self.status_bar_label.setText(status_text)
621 | # Polish stylesheet for status_bar_label to ensure updates are reflected
622 | self.status_bar_label.style().unpolish(self.status_bar_label)
623 | self.status_bar_label.style().polish(self.status_bar_label)
624 |
625 | def format_time(self, milliseconds):
626 | """Converts milliseconds to HH:MM:SS.ms format."""
627 | seconds_total = milliseconds / 1000
628 | minutes, seconds = divmod(seconds_total, 60)
629 | hours, minutes = divmod(minutes, 60)
630 | return f"{int(hours):02d}:{int(minutes):02d}:{seconds:06.3f}"
631 |
632 | def on_new_segment(self, segment):
633 | row_position = self.results_table.rowCount()
634 | self.results_table.insertRow(row_position)
635 | start_time_str = self.format_time(segment.t0)
636 | end_time_str = self.format_time(segment.t1)
637 |
638 | start_item = QTableWidgetItem(start_time_str)
639 | end_item = QTableWidgetItem(end_time_str)
640 | text_item = QTableWidgetItem(segment.text.strip())
641 |
642 | self.results_table.setItem(row_position, 0, start_item)
643 | self.results_table.setItem(row_position, 1, end_item)
644 | self.results_table.setItem(row_position, 2, text_item)
645 |
646 | def on_transcription_result(self, segments):
647 | """
648 | Populates the results table with the transcription segments.
649 | Stores segments for export.
650 | """
651 | self.segments = segments # Store segments
652 | self.export_button.setEnabled(True if segments else False) # Enable export if segments exist
653 | self.copy_text_button.setEnabled(True if segments else False) # Enable copy if segments exist
654 |
655 | def on_transcription_finished(self):
656 | """
657 | Cleans up after the transcription thread is finished.
658 | """
659 | self.transcribe_button.setVisible(True)
660 | self.transcribe_button.setEnabled(True)
661 | self.stop_button.setVisible(False)
662 | self.select_button.setEnabled(True)
663 | self.progress_bar.setVisible(False)
664 | if self.results_table.rowCount() == 0:
665 | self.update_status("Finished. No transcription data.")
666 | else:
667 | self.update_status("Transcription finished successfully!")
668 | self.whisper_thread = None
669 |
670 | def on_transcription_error(self, err):
671 | """
672 | Displays an error message if transcription fails.
673 | """
674 | exctype, value, tb = err
675 | error_message = f"Error: {value}"
676 | self.update_status(error_message) # Update new status bar
677 | self.on_transcription_finished()
678 |
679 | def export_transcription(self, format_type):
680 | """
681 | Handles exporting the transcription to a chosen file format.
682 | """
683 | if not self.segments:
684 | self.update_status("No transcription data to export.")
685 | return
686 |
687 | file_dialog_filter = {
688 | "txt": "Plain Text Files (*.txt)",
689 | "srt": "SRT Subtitle Files (*.srt)",
690 | "vtt": "VTT Subtitle Files (*.vtt)",
691 | "csv": "CSV (Comma Separated Values) Files (*.csv)",
692 | }
693 |
694 | default_file_name = os.path.basename(self.selected_file_path).rsplit('.', 1)[
695 | 0] + f".{format_type}" if self.selected_file_path else f"transcription.{format_type}"
696 |
697 | options = QFileDialog.Options()
698 | file_path, _ = QFileDialog.getSaveFileName(
699 | self, f"Save Transcription as {format_type.upper()}",
700 | default_file_name,
701 | file_dialog_filter.get(format_type, "All Files (*)"),
702 | options=options
703 | )
704 |
705 | if file_path:
706 | try:
707 | # Use pywhispercpp.utils functions based on format_type
708 | if format_type == "txt":
709 | # For TXT, we'll re-use the text from the table or segments
710 | all_text = []
711 | for segment in self.segments:
712 | all_text.append(segment.text.strip())
713 | output_txt_content = "\n".join(all_text)
714 | with open(file_path, 'w', encoding='utf-8') as f:
715 | f.write(output_txt_content)
716 |
717 | elif format_type == "srt":
718 | if output_srt:
719 | output_srt(self.segments, file_path)
720 | else:
721 | raise ImportError("pywhispercpp.utils.output_srt not available.")
722 | elif format_type == "vtt":
723 | if output_vtt:
724 | output_vtt(self.segments, file_path)
725 | else:
726 | raise ImportError("pywhispercpp.utils.output_vtt not available.")
727 | elif format_type == "csv":
728 | if output_csv:
729 | # For CSV, we need to pass a list of lists/tuples representing rows
730 | # pywhispercpp.utils.output_csv expects a list of segments and a file path
731 | output_csv(self.segments, file_path)
732 | else:
733 | raise ImportError("pywhispercpp.utils.output_csv not available.")
734 |
735 | self.update_status(f"Transcription successfully exported to {os.path.basename(file_path)}")
736 | except Exception as e:
737 | self.update_status(f"Error exporting to {format_type.upper()}: {e}")
738 | else:
739 | self.update_status("Export cancelled.")
740 |
741 | def copy_all_text_to_clipboard(self):
742 | """
743 | Concatenates all text from segments and copies it to the clipboard.
744 | """
745 | if not self.segments:
746 | self.update_status("No transcription data to copy.")
747 | return
748 |
749 | all_text = []
750 | for segment in self.segments:
751 | all_text.append(segment.text.strip())
752 |
753 | QApplication.clipboard().setText("\n".join(all_text))
754 | self.update_status("Text copied to clipboard!")
755 |
756 | def show_about_dialog(self):
757 | """Opens a small dialog with About information."""
758 | about_dialog = QDialog(self)
759 | about_dialog.setWindowTitle("About PyWhisperCPP Simple GUI")
760 | about_dialog.setFixedSize(400, 220)
761 |
762 | dialog_layout = QVBoxLayout(about_dialog)
763 | dialog_layout.setContentsMargins(20, 20, 20, 20)
764 |
765 | info_text = QLabel()
766 | info_text.setTextFormat(Qt.RichText)
767 | info_text.setText(
768 | "PyWhisperCPP Simple GUI
"
769 | f"Version {__version__}
"
770 | "
"
771 | "A simple graphical user interface for PyWhisperCpp Using PyQt.
"
772 | "PyWhisperCpp GitHub repository
"
773 | "
"
774 | f"Copyright © {datetime.now().year}"
775 | )
776 | info_text.setOpenExternalLinks(True)
777 |
778 | dialog_layout.addWidget(info_text)
779 |
780 | close_button = QPushButton("Close")
781 | close_button.clicked.connect(about_dialog.accept)
782 | dialog_layout.addWidget(close_button, alignment=Qt.AlignCenter)
783 |
784 | about_dialog.exec_()
785 |
786 |
787 | def _main():
788 | """Main function to run the application."""
789 | if Model is None:
790 | print("pywhispercpp is not installed.")
791 | print("Please install it by running: pip install pywhispercpp")
792 | print("You also need ffmpeg installed on your system.")
793 | return
794 |
795 | app = QApplication(sys.argv)
796 | ex = TranscriptionApp()
797 | ex.show()
798 | sys.exit(app.exec_())
799 |
800 |
801 | if __name__ == '__main__':
802 | _main()
803 |
--------------------------------------------------------------------------------
/src/main.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | ********************************************************************************
3 | * @file main.cpp
4 | * @author [absadiki](https://github.com/absadiki)
5 | * @date 2023
6 | * @brief Python bindings for [whisper.cpp](https://github.com/ggerganov/whisper.cpp) using Pybind11
7 | *
8 | * @par
9 | * COPYRIGHT NOTICE: (c) 2023. All rights reserved.
10 | ********************************************************************************
11 | */
12 |
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | #include "whisper.h"
19 |
20 |
21 | #define STRINGIFY(x) #x
22 | #define MACRO_STRINGIFY(x) STRINGIFY(x)
23 |
24 | #define DEF_RELEASE_GIL(name, fn, doc) \
25 | m.def(name, fn, doc, py::call_guard())
26 |
27 |
28 | namespace py = pybind11;
29 | using namespace pybind11::literals; // to bring in the `_a` literal
30 |
31 |
32 | py::function py_new_segment_callback;
33 | py::function py_encoder_begin_callback;
34 | py::function py_logits_filter_callback;
35 |
36 |
37 | // whisper context wrapper, to solve the incomplete type issue
38 | // Thanks to https://github.com/pybind/pybind11/issues/2770
39 | struct whisper_context_wrapper {
40 | whisper_context* ptr;
41 |
42 | };
43 |
44 |
45 | // struct inside params
46 | struct greedy{
47 | int best_of;
48 | };
49 |
50 | struct beam_search{
51 | int beam_size;
52 | float patience;
53 | };
54 |
55 |
56 | struct whisper_model_loader_wrapper {
57 | whisper_model_loader* ptr;
58 |
59 | };
60 |
61 | struct whisper_context_wrapper whisper_init_from_file_wrapper(const char * path_model){
62 | struct whisper_context * ctx = whisper_init_from_file(path_model);
63 | struct whisper_context_wrapper ctw_w;
64 | ctw_w.ptr = ctx;
65 | return ctw_w;
66 | }
67 |
68 | struct whisper_context_wrapper whisper_init_from_buffer_wrapper(void * buffer, size_t buffer_size){
69 | struct whisper_context * ctx = whisper_init_from_buffer(buffer, buffer_size);
70 | struct whisper_context_wrapper ctw_w;
71 | ctw_w.ptr = ctx;
72 | return ctw_w;
73 | }
74 |
75 | struct whisper_context_wrapper whisper_init_wrapper(struct whisper_model_loader_wrapper * loader){
76 | struct whisper_context * ctx = whisper_init(loader->ptr);
77 | struct whisper_context_wrapper ctw_w;
78 | ctw_w.ptr = ctx;
79 | return ctw_w;
80 | };
81 |
82 | void whisper_free_wrapper(struct whisper_context_wrapper * ctx_w){
83 | whisper_free(ctx_w->ptr);
84 | };
85 |
86 | int whisper_pcm_to_mel_wrapper(
87 | struct whisper_context_wrapper * ctx,
88 | py::array_t samples,
89 | int n_samples,
90 | int n_threads){
91 | py::buffer_info buf = samples.request();
92 | float *samples_ptr = static_cast(buf.ptr);
93 | return whisper_pcm_to_mel(ctx->ptr, samples_ptr, n_samples, n_threads);
94 | };
95 |
96 | int whisper_set_mel_wrapper(
97 | struct whisper_context_wrapper * ctx,
98 | py::array_t data,
99 | int n_len,
100 | int n_mel){
101 | py::buffer_info buf = data.request();
102 | float *data_ptr = static_cast(buf.ptr);
103 | return whisper_set_mel(ctx->ptr, data_ptr, n_len, n_mel);
104 |
105 | };
106 |
107 | int whisper_n_len_wrapper(struct whisper_context_wrapper * ctx_w){
108 | return whisper_n_len(ctx_w->ptr);
109 | };
110 |
111 | int whisper_n_vocab_wrapper(struct whisper_context_wrapper * ctx_w){
112 | return whisper_n_vocab(ctx_w->ptr);
113 | };
114 |
115 | int whisper_n_text_ctx_wrapper(struct whisper_context_wrapper * ctx_w){
116 | return whisper_n_text_ctx(ctx_w->ptr);
117 | };
118 |
119 | int whisper_n_audio_ctx_wrapper(struct whisper_context_wrapper * ctx_w){
120 | return whisper_n_audio_ctx(ctx_w->ptr);
121 | }
122 |
123 | int whisper_is_multilingual_wrapper(struct whisper_context_wrapper * ctx_w){
124 | return whisper_is_multilingual(ctx_w->ptr);
125 | }
126 |
127 |
128 | float * whisper_get_logits_wrapper(struct whisper_context_wrapper * ctx_w){
129 | return whisper_get_logits(ctx_w->ptr);
130 | };
131 |
132 | const char * whisper_token_to_str_wrapper(struct whisper_context_wrapper * ctx_w, whisper_token token){
133 | return whisper_token_to_str(ctx_w->ptr, token);
134 | };
135 |
136 | py::bytes whisper_token_to_bytes_wrapper(struct whisper_context_wrapper * ctx_w, whisper_token token){
137 | const char* str = whisper_token_to_str(ctx_w->ptr, token);
138 | size_t l = strlen(str);
139 | return py::bytes(str, l);
140 | }
141 |
142 | whisper_token whisper_token_eot_wrapper(struct whisper_context_wrapper * ctx_w){
143 | return whisper_token_eot(ctx_w->ptr);
144 | }
145 |
146 | whisper_token whisper_token_sot_wrapper(struct whisper_context_wrapper * ctx_w){
147 | return whisper_token_sot(ctx_w->ptr);
148 | }
149 |
150 | whisper_token whisper_token_prev_wrapper(struct whisper_context_wrapper * ctx_w){
151 | return whisper_token_prev(ctx_w->ptr);
152 | }
153 |
154 | whisper_token whisper_token_solm_wrapper(struct whisper_context_wrapper * ctx_w){
155 | return whisper_token_solm(ctx_w->ptr);
156 | }
157 |
158 | whisper_token whisper_token_not_wrapper(struct whisper_context_wrapper * ctx_w){
159 | return whisper_token_not(ctx_w->ptr);
160 | }
161 |
162 | whisper_token whisper_token_beg_wrapper(struct whisper_context_wrapper * ctx_w){
163 | return whisper_token_beg(ctx_w->ptr);
164 | }
165 |
166 | whisper_token whisper_token_lang_wrapper(struct whisper_context_wrapper * ctx_w, int lang_id){
167 | return whisper_token_lang(ctx_w->ptr, lang_id);
168 | }
169 |
170 | whisper_token whisper_token_translate_wrapper(struct whisper_context_wrapper * ctx_w){
171 | return whisper_token_translate(ctx_w->ptr);
172 | }
173 |
174 | whisper_token whisper_token_transcribe_wrapper(struct whisper_context_wrapper * ctx_w){
175 | return whisper_token_transcribe(ctx_w->ptr);
176 | }
177 |
178 | void whisper_print_timings_wrapper(struct whisper_context_wrapper * ctx_w){
179 | return whisper_print_timings(ctx_w->ptr);
180 | }
181 |
182 | void whisper_reset_timings_wrapper(struct whisper_context_wrapper * ctx_w){
183 | return whisper_reset_timings(ctx_w->ptr);
184 | }
185 |
186 | int whisper_encode_wrapper(
187 | struct whisper_context_wrapper * ctx,
188 | int offset,
189 | int n_threads){
190 | return whisper_encode(ctx->ptr, offset, n_threads);
191 | }
192 |
193 |
194 | int whisper_decode_wrapper(
195 | struct whisper_context_wrapper * ctx,
196 | const whisper_token * tokens,
197 | int n_tokens,
198 | int n_past,
199 | int n_threads){
200 | return whisper_decode(ctx->ptr, tokens, n_tokens, n_past, n_threads);
201 | };
202 |
203 | int whisper_tokenize_wrapper(
204 | struct whisper_context_wrapper * ctx,
205 | const char * text,
206 | whisper_token * tokens,
207 | int n_max_tokens){
208 | return whisper_tokenize(ctx->ptr, text, tokens, n_max_tokens);
209 | };
210 |
211 | int whisper_lang_auto_detect_wrapper(
212 | struct whisper_context_wrapper * ctx,
213 | int offset_ms,
214 | int n_threads,
215 | py::array_t lang_probs){
216 |
217 | py::buffer_info buf = lang_probs.request();
218 | float *lang_probs_ptr = static_cast(buf.ptr);
219 | return whisper_lang_auto_detect(ctx->ptr, offset_ms, n_threads, lang_probs_ptr);
220 |
221 | }
222 |
223 | int whisper_full_wrapper(
224 | struct whisper_context_wrapper * ctx_w,
225 | struct whisper_full_params params,
226 | py::array_t samples,
227 | int n_samples){
228 | py::buffer_info buf = samples.request();
229 | float *samples_ptr = static_cast(buf.ptr);
230 |
231 | py::gil_scoped_release release;
232 | return whisper_full(ctx_w->ptr, params, samples_ptr, n_samples);
233 | }
234 |
235 | int whisper_full_parallel_wrapper(
236 | struct whisper_context_wrapper * ctx_w,
237 | struct whisper_full_params params,
238 | py::array_t samples,
239 | int n_samples,
240 | int n_processors){
241 | py::buffer_info buf = samples.request();
242 | float *samples_ptr = static_cast(buf.ptr);
243 |
244 | py::gil_scoped_release release;
245 | return whisper_full_parallel(ctx_w->ptr, params, samples_ptr, n_samples, n_processors);
246 | }
247 |
248 |
249 | int whisper_full_n_segments_wrapper(struct whisper_context_wrapper * ctx){
250 | py::gil_scoped_release release;
251 | return whisper_full_n_segments(ctx->ptr);
252 | }
253 |
254 | int whisper_full_lang_id_wrapper(struct whisper_context_wrapper * ctx){
255 | return whisper_full_lang_id(ctx->ptr);
256 | }
257 |
258 | int64_t whisper_full_get_segment_t0_wrapper(struct whisper_context_wrapper * ctx, int i_segment){
259 | return whisper_full_get_segment_t0(ctx->ptr, i_segment);
260 | }
261 |
262 | int64_t whisper_full_get_segment_t1_wrapper(struct whisper_context_wrapper * ctx, int i_segment){
263 | return whisper_full_get_segment_t1(ctx->ptr, i_segment);
264 | }
265 |
266 | // https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html
267 | const py::bytes whisper_full_get_segment_text_wrapper(struct whisper_context_wrapper * ctx, int i_segment){
268 | const char * c_array = whisper_full_get_segment_text(ctx->ptr, i_segment);
269 | size_t length = strlen(c_array); // Determine the length of the array
270 | return py::bytes(c_array, length); // Return the data without transcoding
271 | };
272 |
273 | int whisper_full_n_tokens_wrapper(struct whisper_context_wrapper * ctx, int i_segment){
274 | return whisper_full_n_tokens(ctx->ptr, i_segment);
275 | }
276 |
277 | const char * whisper_full_get_token_text_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){
278 | return whisper_full_get_token_text(ctx->ptr, i_segment, i_token);
279 | }
280 |
281 | whisper_token whisper_full_get_token_id_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){
282 | return whisper_full_get_token_id(ctx->ptr, i_segment, i_token);
283 | }
284 |
285 | whisper_token_data whisper_full_get_token_data_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){
286 | return whisper_full_get_token_data(ctx->ptr, i_segment, i_token);
287 | }
288 |
289 | float whisper_full_get_token_p_wrapper(struct whisper_context_wrapper * ctx, int i_segment, int i_token){
290 | return whisper_full_get_token_p(ctx->ptr, i_segment, i_token);
291 | }
292 |
293 | int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * ctx, const char * model_path,
294 | const char * device,
295 | const char * cache_dir){
296 | return whisper_ctx_init_openvino_encoder(ctx->ptr, model_path, device, cache_dir);
297 | }
298 |
299 | class WhisperFullParamsWrapper : public whisper_full_params {
300 | std::string initial_prompt_str;
301 | std::string suppress_regex_str;
302 | public:
303 | py::function py_progress_callback;
304 | WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
305 | : whisper_full_params(params),
306 | initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
307 | suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
308 | initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
309 | suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
310 | // progress callback
311 | progress_callback_user_data = this;
312 | progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
313 | auto* self = static_cast(user_data);
314 | if(self && self->print_progress){
315 | if (self->py_progress_callback) {
316 | // call the python callback
317 | py::gil_scoped_acquire gil;
318 | self->py_progress_callback(progress); // Call Python callback
319 | }
320 | else {
321 | fprintf(stderr, "Progress: %3d%%\n", progress);
322 | } // Default message
323 | }
324 | } ;
325 | }
326 |
327 | WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other)
328 | : WhisperFullParamsWrapper(static_cast(other)) {}
329 |
330 | void set_initial_prompt(const std::string& prompt) {
331 | initial_prompt_str = prompt;
332 | initial_prompt = initial_prompt_str.c_str();
333 | }
334 |
335 | void set_suppress_regex(const std::string& regex) {
336 | suppress_regex_str = regex;
337 | suppress_regex = suppress_regex_str.c_str();
338 | }
339 | };
340 |
341 | WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) {
342 | return WhisperFullParamsWrapper(whisper_full_default_params(strategy));
343 | }
344 |
345 | // callbacks mechanism
346 |
347 | void _new_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data){
348 | struct whisper_context_wrapper ctx_w;
349 | ctx_w.ptr = ctx;
350 | // call the python callback
351 | py::gil_scoped_acquire gil; // Acquire the GIL while in this scope.
352 | py_new_segment_callback(ctx_w, n_new, user_data);
353 | };
354 |
355 | void assign_new_segment_callback(struct whisper_full_params *params, py::function f){
356 | params->new_segment_callback = _new_segment_callback;
357 | py_new_segment_callback = f;
358 | };
359 |
360 | bool _encoder_begin_callback(struct whisper_context * ctx, struct whisper_state * state, void * user_data){
361 | struct whisper_context_wrapper ctx_w;
362 | ctx_w.ptr = ctx;
363 | // call the python callback
364 | py::object result_py = py_encoder_begin_callback(ctx_w, user_data);
365 | bool res = result_py.cast();
366 | return res;
367 | }
368 |
369 | void assign_encoder_begin_callback(struct whisper_full_params *params, py::function f){
370 | params->encoder_begin_callback = _encoder_begin_callback;
371 | py_encoder_begin_callback = f;
372 | }
373 |
374 | void _logits_filter_callback(
375 | struct whisper_context * ctx,
376 | struct whisper_state * state,
377 | const whisper_token_data * tokens,
378 | int n_tokens,
379 | float * logits,
380 | void * user_data){
381 | struct whisper_context_wrapper ctx_w;
382 | ctx_w.ptr = ctx;
383 | // call the python callback
384 | py_logits_filter_callback(ctx_w, n_tokens, logits, user_data);
385 | }
386 |
387 | void assign_logits_filter_callback(struct whisper_full_params *params, py::function f){
388 | params->logits_filter_callback = _logits_filter_callback;
389 | py_logits_filter_callback = f;
390 | }
391 |
392 | py::dict get_greedy(whisper_full_params * params){
393 | py::dict d("best_of"_a=params->greedy.best_of);
394 | return d;
395 | }
396 |
397 | PYBIND11_MODULE(_pywhispercpp, m) {
398 | m.doc() = R"pbdoc(
399 | Pywhispercpp: Python binding to whisper.cpp
400 | -----------------------
401 |
402 | .. currentmodule:: _whispercpp
403 |
404 | .. autosummary::
405 | :toctree: _generate
406 |
407 | )pbdoc";
408 |
409 | m.attr("WHISPER_SAMPLE_RATE") = WHISPER_SAMPLE_RATE;
410 | m.attr("WHISPER_N_FFT") = WHISPER_N_FFT;
411 | m.attr("WHISPER_HOP_LENGTH") = WHISPER_HOP_LENGTH;
412 | m.attr("WHISPER_CHUNK_SIZE") = WHISPER_CHUNK_SIZE;
413 |
414 | py::class_(m, "whisper_context");
415 | py::class_(m, "whisper_token")
416 | .def(py::init<>());
417 | py::class_(m,"whisper_token_data")
418 | .def(py::init<>())
419 | .def_readwrite("id", &whisper_token_data::id)
420 | .def_readwrite("tid", &whisper_token_data::tid)
421 | .def_readwrite("p", &whisper_token_data::p)
422 | .def_readwrite("plog", &whisper_token_data::plog)
423 | .def_readwrite("pt", &whisper_token_data::pt)
424 | .def_readwrite("ptsum", &whisper_token_data::ptsum)
425 | .def_readwrite("t0", &whisper_token_data::t0)
426 | .def_readwrite("t1", &whisper_token_data::t1)
427 | .def_readwrite("vlen", &whisper_token_data::vlen);
428 |
429 | py::class_(m,"whisper_model_loader")
430 | .def(py::init<>());
431 |
432 | DEF_RELEASE_GIL("whisper_init_from_file", &whisper_init_from_file_wrapper, "Various functions for loading a ggml whisper model.\n"
433 | "Allocate (almost) all memory needed for the model.\n"
434 | "Return NULL on failure");
435 | DEF_RELEASE_GIL("whisper_init_from_buffer", &whisper_init_from_buffer_wrapper, "Various functions for loading a ggml whisper model.\n"
436 | "Allocate (almost) all memory needed for the model.\n"
437 | "Return NULL on failure");
438 | DEF_RELEASE_GIL("whisper_init", &whisper_init_wrapper, "Various functions for loading a ggml whisper model.\n"
439 | "Allocate (almost) all memory needed for the model.\n"
440 | "Return NULL on failure");
441 |
442 |
443 | m.def("whisper_free", &whisper_free_wrapper, "Frees all memory allocated by the model.");
444 |
445 | m.def("whisper_pcm_to_mel", &whisper_pcm_to_mel_wrapper, "Convert RAW PCM audio to log mel spectrogram.\n"
446 | "The resulting spectrogram is stored inside the provided whisper context.\n"
447 | "Returns 0 on success");
448 |
449 | m.def("whisper_set_mel", &whisper_set_mel_wrapper, " This can be used to set a custom log mel spectrogram inside the provided whisper context.\n"
450 | "Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.\n"
451 | "n_mel must be 80\n"
452 | "Returns 0 on success");
453 |
454 | m.def("whisper_encode", &whisper_encode_wrapper, "Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.\n"
455 | "Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.\n"
456 | "offset can be used to specify the offset of the first frame in the spectrogram.\n"
457 | "Returns 0 on success");
458 |
459 | m.def("whisper_decode", &whisper_decode_wrapper, "Run the Whisper decoder to obtain the logits and probabilities for the next token.\n"
460 | "Make sure to call whisper_encode() first.\n"
461 | "tokens + n_tokens is the provided context for the decoder.\n"
462 | "n_past is the number of tokens to use from previous decoder calls.\n"
463 | "Returns 0 on success\n"
464 | "TODO: add support for multiple decoders");
465 |
466 | m.def("whisper_tokenize", &whisper_tokenize_wrapper, "Convert the provided text into tokens.\n"
467 | "The tokens pointer must be large enough to hold the resulting tokens.\n"
468 | "Returns the number of tokens on success, no more than n_max_tokens\n"
469 | "Returns -1 on failure\n"
470 | "TODO: not sure if correct");
471 |
472 | m.def("whisper_lang_max_id", &whisper_lang_max_id, "Largest language id (i.e. number of available languages - 1)");
473 | m.def("whisper_lang_id", &whisper_lang_id, "Return the id of the specified language, returns -1 if not found\n"
474 | "Examples:\n"
475 | "\"de\" -> 2\n"
476 | "\"german\" -> 2");
477 | m.def("whisper_lang_str", &whisper_lang_str, "Return the short string of the specified language id (e.g. 2 -> \"de\"), returns nullptr if not found");
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 | m.def("whisper_lang_auto_detect", &whisper_lang_auto_detect_wrapper, "Use mel data at offset_ms to try and auto-detect the spoken language\n"
486 | "Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first\n"
487 | "Returns the top language id or negative on failure\n"
488 | "If not null, fills the lang_probs array with the probabilities of all languages\n"
489 | "The array must be whispe_lang_max_id() + 1 in size\n"
490 | "ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69\n");
491 | m.def("whisper_n_len", &whisper_n_len_wrapper, "whisper_n_len");
492 | m.def("whisper_n_vocab", &whisper_n_vocab_wrapper, "wrapper_whisper_n_vocab");
493 | m.def("whisper_n_text_ctx", &whisper_n_text_ctx_wrapper, "whisper_n_text_ctx");
494 | m.def("whisper_n_audio_ctx", &whisper_n_audio_ctx_wrapper, "whisper_n_audio_ctx");
495 | m.def("whisper_is_multilingual", &whisper_is_multilingual_wrapper, "whisper_is_multilingual");
496 | m.def("whisper_get_logits", &whisper_get_logits_wrapper, "Token logits obtained from the last call to whisper_decode()\n"
497 | "The logits for the last token are stored in the last row\n"
498 | "Rows: n_tokens\n"
499 | "Cols: n_vocab");
500 |
501 |
502 | m.def("whisper_token_to_str", &whisper_token_to_str_wrapper, "whisper_token_to_str");
503 | m.def("whisper_token_to_bytes", &whisper_token_to_bytes_wrapper, "whisper_token_to_bytes");
504 | m.def("whisper_token_eot", &whisper_token_eot_wrapper, "whisper_token_eot");
505 | m.def("whisper_token_sot", &whisper_token_sot_wrapper, "whisper_token_sot");
506 | m.def("whisper_token_prev", &whisper_token_prev_wrapper);
507 | m.def("whisper_token_solm", &whisper_token_solm_wrapper);
508 | m.def("whisper_token_not", &whisper_token_not_wrapper);
509 | m.def("whisper_token_beg", &whisper_token_beg_wrapper);
510 | m.def("whisper_token_lang", &whisper_token_lang_wrapper);
511 |
512 | m.def("whisper_token_translate", &whisper_token_translate_wrapper);
513 | m.def("whisper_token_transcribe", &whisper_token_transcribe_wrapper);
514 |
515 | m.def("whisper_print_timings", &whisper_print_timings_wrapper);
516 | m.def("whisper_reset_timings", &whisper_reset_timings_wrapper);
517 |
518 | m.def("whisper_print_system_info", &whisper_print_system_info);
519 |
520 |
521 |
522 | //////////////////////
523 |
524 | py::enum_(m, "whisper_sampling_strategy")
525 | .value("WHISPER_SAMPLING_GREEDY", whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY)
526 | .value("WHISPER_SAMPLING_BEAM_SEARCH", whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH)
527 | .export_values();
528 |
529 | py::class_(m, "__whisper_full_params__internal")
530 | .def(py::init<>())
531 | .def("__repr__", [](const whisper_full_params& self) {
532 | std::ostringstream oss;
533 | oss << "whisper_full_params("
534 | << "strategy=" << self.strategy << ", "
535 | << "n_threads=" << self.n_threads << ", "
536 | << "n_max_text_ctx=" << self.n_max_text_ctx << ", "
537 | << "offset_ms=" << self.offset_ms << ", "
538 | << "duration_ms=" << self.duration_ms << ", "
539 | << "translate=" << (self.translate ? "True" : "False") << ", "
540 | << "no_context=" << (self.no_context ? "True" : "False") << ", "
541 | << "no_timestamps=" << (self.no_timestamps ? "True" : "False") << ", "
542 | << "single_segment=" << (self.single_segment ? "True" : "False") << ", "
543 | << "print_special=" << (self.print_special ? "True" : "False") << ", "
544 | << "print_progress=" << (self.print_progress ? "True" : "False") << ", "
545 | << "print_realtime=" << (self.print_realtime ? "True" : "False") << ", "
546 | << "print_timestamps=" << (self.print_timestamps ? "True" : "False") << ", "
547 | << "token_timestamps=" << (self.token_timestamps ? "True" : "False") << ", "
548 | << "thold_pt=" << self.thold_pt << ", "
549 | << "thold_ptsum=" << self.thold_ptsum << ", "
550 | << "max_len=" << self.max_len << ", "
551 | << "split_on_word=" << (self.split_on_word ? "True" : "False") << ", "
552 | << "max_tokens=" << self.max_tokens << ", "
553 | << "debug_mode=" << (self.debug_mode ? "True" : "False") << ", "
554 | << "audio_ctx=" << self.audio_ctx << ", "
555 | << "tdrz_enable=" << (self.tdrz_enable ? "True" : "False") << ", "
556 | << "suppress_regex=" << (self.suppress_regex ? self.suppress_regex : "None") << ", "
557 | << "initial_prompt=" << (self.initial_prompt ? self.initial_prompt : "None") << ", "
558 | << "prompt_tokens=" << (self.prompt_tokens ? "(whisper_token *)" : "None") << ", "
559 | << "prompt_n_tokens=" << self.prompt_n_tokens << ", "
560 | << "language=" << (self.language ? self.language : "None") << ", "
561 | << "detect_language=" << (self.detect_language ? "True" : "False") << ", "
562 | << "suppress_blank=" << (self.suppress_blank ? "True" : "False") << ", "
563 | << "temperature=" << self.temperature << ", "
564 | << "max_initial_ts=" << self.max_initial_ts << ", "
565 | << "length_penalty=" << self.length_penalty << ", "
566 | << "temperature_inc=" << self.temperature_inc << ", "
567 | << "entropy_thold=" << self.entropy_thold << ", "
568 | << "logprob_thold=" << self.logprob_thold << ", "
569 | << "no_speech_thold=" << self.no_speech_thold << ", "
570 | << "greedy={best_of=" << self.greedy.best_of << "}, "
571 | << "beam_search={beam_size=" << self.beam_search.beam_size << ", patience=" << self.beam_search.patience << "}, "
572 | << "new_segment_callback=" << (self.new_segment_callback ? "(function pointer)" : "None") << ", "
573 | << "progress_callback=" << (self.progress_callback ? "(function pointer)" : "None") << ", "
574 | << "encoder_begin_callback=" << (self.encoder_begin_callback ? "(function pointer)" : "None") << ", "
575 | << "abort_callback=" << (self.abort_callback ? "(function pointer)" : "None") << ", "
576 | << "logits_filter_callback=" << (self.logits_filter_callback ? "(function pointer)" : "None") << ", "
577 | << "grammar_rules=" << (self.grammar_rules ? "(whisper_grammar_element **)" : "None") << ", "
578 | << "n_grammar_rules=" << self.n_grammar_rules << ", "
579 | << "i_start_rule=" << self.i_start_rule << ", "
580 | << "grammar_penalty=" << self.grammar_penalty
581 | << ")";
582 | return oss.str();
583 | });
584 |
585 | py::class_(m, "whisper_full_params")
586 | .def(py::init<>())
587 | .def_readwrite("strategy", &WhisperFullParamsWrapper::strategy)
588 | .def_readwrite("n_threads", &WhisperFullParamsWrapper::n_threads)
589 | .def_readwrite("n_max_text_ctx", &WhisperFullParamsWrapper::n_max_text_ctx)
590 | .def_readwrite("offset_ms", &WhisperFullParamsWrapper::offset_ms)
591 | .def_readwrite("duration_ms", &WhisperFullParamsWrapper::duration_ms)
592 | .def_readwrite("translate", &WhisperFullParamsWrapper::translate)
593 | .def_readwrite("no_context", &WhisperFullParamsWrapper::no_context)
594 | .def_readwrite("single_segment", &WhisperFullParamsWrapper::single_segment)
595 | .def_readwrite("print_special", &WhisperFullParamsWrapper::print_special)
596 | .def_readwrite("print_progress", &WhisperFullParamsWrapper::print_progress)
597 | .def_readwrite("progress_callback", &WhisperFullParamsWrapper::py_progress_callback)
598 | .def_readwrite("print_realtime", &WhisperFullParamsWrapper::print_realtime)
599 | .def_readwrite("print_timestamps", &WhisperFullParamsWrapper::print_timestamps)
600 | .def_readwrite("token_timestamps", &WhisperFullParamsWrapper::token_timestamps)
601 | .def_readwrite("thold_pt", &WhisperFullParamsWrapper::thold_pt)
602 | .def_readwrite("thold_ptsum", &WhisperFullParamsWrapper::thold_ptsum)
603 | .def_readwrite("max_len", &WhisperFullParamsWrapper::max_len)
604 | .def_readwrite("split_on_word", &WhisperFullParamsWrapper::split_on_word)
605 | .def_readwrite("max_tokens", &WhisperFullParamsWrapper::max_tokens)
606 | .def_readwrite("audio_ctx", &WhisperFullParamsWrapper::audio_ctx)
607 | .def_property("suppress_regex",
608 | [](WhisperFullParamsWrapper &self) {
609 | return py::str(self.suppress_regex ? self.suppress_regex : "");
610 | },
611 | [](WhisperFullParamsWrapper &self, const std::string &new_c) {
612 | self.set_suppress_regex(new_c);
613 | })
614 | .def_property("initial_prompt",
615 | [](WhisperFullParamsWrapper &self) {
616 | return py::str(self.initial_prompt ? self.initial_prompt : "");
617 | },
618 | [](WhisperFullParamsWrapper &self, const std::string &initial_prompt) {
619 | self.set_initial_prompt(initial_prompt);
620 | }
621 | )
622 | .def_readwrite("prompt_tokens", &WhisperFullParamsWrapper::prompt_tokens)
623 | .def_readwrite("prompt_n_tokens", &WhisperFullParamsWrapper::prompt_n_tokens)
624 | .def_property("language",
625 | [](WhisperFullParamsWrapper &self) {
626 | return py::str(self.language);
627 | },
628 | [](WhisperFullParamsWrapper &self, const char *new_c) {// using lang_id let us avoid issues with memory management
629 | const int lang_id = (new_c && strlen(new_c) > 0) ? whisper_lang_id(new_c) : -1;
630 | if (lang_id != -1) {
631 | self.language = whisper_lang_str(lang_id);
632 | } else {
633 | self.language = ""; //defaults to auto-detect
634 | }
635 | })
636 | .def_readwrite("suppress_blank", &WhisperFullParamsWrapper::suppress_blank)
637 | .def_readwrite("temperature", &WhisperFullParamsWrapper::temperature)
638 | .def_readwrite("max_initial_ts", &WhisperFullParamsWrapper::max_initial_ts)
639 | .def_readwrite("length_penalty", &WhisperFullParamsWrapper::length_penalty)
640 | .def_readwrite("temperature_inc", &WhisperFullParamsWrapper::temperature_inc)
641 | .def_readwrite("entropy_thold", &WhisperFullParamsWrapper::entropy_thold)
642 | .def_readwrite("logprob_thold", &WhisperFullParamsWrapper::logprob_thold)
643 | .def_readwrite("no_speech_thold", &WhisperFullParamsWrapper::no_speech_thold)
644 | // little hack for the internal stuct
645 | .def_property("greedy", [](WhisperFullParamsWrapper &self) {return py::dict("best_of"_a=self.greedy.best_of);},
646 | [](WhisperFullParamsWrapper &self, py::dict dict) {self.greedy.best_of = dict["best_of"].cast();})
647 | .def_property("beam_search", [](WhisperFullParamsWrapper &self) {return py::dict("beam_size"_a=self.beam_search.beam_size, "patience"_a=self.beam_search.patience);},
648 | [](WhisperFullParamsWrapper &self, py::dict dict) {self.beam_search.beam_size = dict["beam_size"].cast(); self.beam_search.patience = dict["patience"].cast();})
649 | .def_readwrite("new_segment_callback_user_data", &WhisperFullParamsWrapper::new_segment_callback_user_data)
650 | .def_readwrite("encoder_begin_callback_user_data", &WhisperFullParamsWrapper::encoder_begin_callback_user_data)
651 | .def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data);
652 |
653 |
654 | py::implicitly_convertible();
655 |
656 | m.def("whisper_full_default_params", &whisper_full_default_params_wrapper);
657 |
658 | m.def("whisper_full", &whisper_full_wrapper, "Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text\n"
659 | "Uses the specified decoding strategy to obtain the text.\n");
660 |
661 | m.def("whisper_full_parallel", &whisper_full_parallel_wrapper, "Split the input audio in chunks and process each chunk separately using whisper_full()\n"
662 | "It seems this approach can offer some speedup in some cases.\n"
663 | "However, the transcription accuracy can be worse at the beginning and end of each chunk.");
664 |
665 | m.def("whisper_full_n_segments", &whisper_full_n_segments_wrapper, "Number of generated text segments.\n"
666 | "A segment can be a few words, a sentence, or even a paragraph.\n");
667 |
668 | m.def("whisper_full_lang_id", &whisper_full_lang_id_wrapper, "Language id associated with the current context");
669 | m.def("whisper_full_get_segment_t0", &whisper_full_get_segment_t0_wrapper, "Get the start time of the specified segment");
670 | m.def("whisper_full_get_segment_t1", &whisper_full_get_segment_t1_wrapper, "Get the end time of the specified segment");
671 |
672 | m.def("whisper_full_get_segment_text", &whisper_full_get_segment_text_wrapper, "Get the text of the specified segment");
673 | m.def("whisper_full_n_tokens", &whisper_full_n_tokens_wrapper, "Get number of tokens in the specified segment.");
674 |
675 | m.def("whisper_full_get_token_text", &whisper_full_get_token_text_wrapper, "Get the token text of the specified token in the specified segment.");
676 | m.def("whisper_full_get_token_id", &whisper_full_get_token_id_wrapper, "Get the token text of the specified token in the specified segment.");
677 |
678 | m.def("whisper_full_get_token_data", &whisper_full_get_token_data_wrapper, "Get token data for the specified token in the specified segment.\n"
679 | "This contains probabilities, timestamps, etc.");
680 |
681 | m.def("whisper_full_get_token_p", &whisper_full_get_token_p_wrapper, "Get the probability of the specified token in the specified segment.");
682 |
683 | m.def("whisper_ctx_init_openvino_encoder", &whisper_ctx_init_openvino_encoder_wrapper, "Given a context, enable use of OpenVINO for encode inference.");
684 |
685 |
686 | ////////////////////////////////////////////////////////////////////////////
687 |
688 | m.def("whisper_bench_memcpy", &whisper_bench_memcpy, "Temporary helpers needed for exposing ggml interface");
689 | m.def("whisper_bench_ggml_mul_mat", &whisper_bench_ggml_mul_mat, "Temporary helpers needed for exposing ggml interface");
690 |
691 | ////////////////////////////////////////////////////////////////////////////
692 | // Helper mechanism to set callbacks from python
693 | // The only difference from the C-Style API
694 |
695 | m.def("assign_new_segment_callback", &assign_new_segment_callback, "Assigns a new_segment_callback, takes instance and a callable function with the same parameters which are defined in the interface",
696 | py::arg("params"), py::arg("callback"));
697 |
698 | m.def("assign_encoder_begin_callback", &assign_encoder_begin_callback, "Assigns an encoder_begin_callback, takes instance and a callable function with the same parameters which are defined in the interface",
699 | py::arg("params"), py::arg("callback"));
700 |
701 | m.def("assign_logits_filter_callback", &assign_logits_filter_callback, "Assigns a logits_filter_callback, takes instance and a callable function with the same parameters which are defined in the interface",
702 | py::arg("params"), py::arg("callback"));
703 |
704 |
705 | #ifdef VERSION_INFO
706 | m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
707 | #else
708 | m.attr("__version__") = "dev";
709 | #endif
710 | }
711 |
--------------------------------------------------------------------------------