├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── PROJECT_PLANS.md
├── README.md
├── assets
└── logo.jpg
├── config.json
├── pyproject.toml
├── src
└── whisperchain
│ ├── cli
│ ├── run.py
│ ├── run_client.py
│ └── run_server.py
│ ├── client
│ ├── key_listener.py
│ └── stream_client.py
│ ├── core
│ ├── __init__.py
│ ├── audio.py
│ ├── chain.py
│ └── config.py
│ ├── prompts
│ └── transcription_cleanup.txt
│ ├── server
│ ├── __init__.py
│ └── server.py
│ ├── ui
│ └── streamlit_app.py
│ └── utils
│ ├── decorators.py
│ ├── logger.py
│ ├── secrets.py
│ └── segment.py
└── tests
├── test_audio_capture.py
├── test_chain.py
├── test_key_listener.py
├── test_pywhispercpp.py
├── test_stream_client.py
└── test_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python cache
2 | __pycache__/
3 | *.pyc
4 | *.egg-info/
5 |
6 | # Build
7 | build/
8 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v4.4.0
7 | hooks:
8 | - id: check-ast
9 | - id: check-merge-conflict
10 | - id: check-yaml
11 | - id: end-of-file-fixer
12 | - id: trailing-whitespace
13 | args: [--markdown-linebreak-ext=md]
14 |
15 | - repo: https://github.com/psf/black
16 | rev: 23.3.0
17 | hooks:
18 | - id: black
19 | language_version: python3
20 | args: ["--line-length", "99"]
21 |
22 | - repo: https://github.com/pycqa/isort
23 | rev: 5.12.0
24 | hooks:
25 | - id: isort
26 | exclude: README.md
27 | args: ["--profile", "black"]
28 |
29 | # jupyter notebook cell output clearing
30 | - repo: https://github.com/kynan/nbstripout
31 | rev: 0.6.1
32 | hooks:
33 | - id: nbstripout
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2025, Chris Choy
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | * Neither the name of the PyAutoGUI nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/PROJECT_PLANS.md:
--------------------------------------------------------------------------------
1 | # Project Plan
2 |
3 | ## Final Product
4 |
5 | - Press to talk
6 | - Transcribe speech to text using pywhispercpp
7 | - Use [pywhispercpp](https://github.com/absadiki/pywhispercpp) for Whisper.cpp integration
8 | - Support Apple silicon chips (M1, M2, ...)
9 | - Support CUDA for GPU acceleration
10 | - Support real-time transcription via WebSocket
11 | - Use LangChain to parse the text and clean up the text
12 | - E.g. "Ehh what is the emm wheather like in SF? no, Salt Lake City" -> "What is the weather like in Salt Lake City?"
13 | - Support multiple LLM providers
14 |
15 | ## Milestones
16 |
17 | - [x] Speech to text setup
18 | - [x] Install pywhispercpp with CoreML support for Apple Silicon
19 | - [x] Basic transcription test
20 | - [x] Basic git commit hook to check if the code is formatted
21 | - [x] Format the code
22 | - [x] Voice Processing Server
23 | - [x] FastAPI server setup
24 | - [x] Audio upload endpoint
25 | - [x] Streaming audio support
26 | - [x] LangChain Integration
27 | - [x] Test OpenAI API Key loading
28 | - [x] Chain configuration
29 | - [x] Text processing pipeline
30 | - [x] Response formatting
31 | - [ ] Support other LLMs (DeepSeek, Gemini, ...)
32 | - [ ] Local LLM support
33 | - [ ] Press to talk
34 | - [x] Key listener
35 | - [x] Capture a hot key regardless of the current application
36 | - [x] Put the final result in the system clipboard
37 | - [ ] Show an icon when voice control is active
38 | - [x] Command line interface
39 | - [x] Add a command line interface using `click`
40 | - [x] Web UI
41 | - [x] Streamlit UI
42 | - [x] Visualize (input audio), transcription, and output text
43 | - [x] Visualize transcription history
44 | - [ ] Prompt config
45 | - [ ] LangChain config
46 | - [ ] Context Management
47 | - [ ] System prompt configuration
48 | - [ ] Chat history persistence
49 | - [ ] Documentation
50 | - [ ] API Documentation
51 | - [ ] Usage examples and guides
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Whisper Chain
2 |
3 |
4 |
5 |
6 |
7 | ## Overview
8 |
9 | Typing is boring, let's use voice to speed up your workflow. This project combines:
10 | - Real-time speech recognition using Whisper.cpp
11 | - Transcription cleanup using LangChain
12 | - Global hotkey support for voice control
13 | - Automatic clipboard integration for the cleaned transcription
14 |
15 | ## Requirements
16 |
17 | - Python 3.8+
18 | - OpenAI API Key
19 | - For MacOS:
20 | - ffmpeg (for audio processing)
21 | - portaudio (for audio capture)
22 |
23 | ## Installation
24 |
25 | 1. Install system dependencies (MacOS):
26 | ```bash
27 | # Install ffmpeg and portaudio using Homebrew
28 | brew install ffmpeg portaudio
29 | ```
30 |
31 | 2. Install the project:
32 | ```bash
33 | pip install whisperchain
34 | ```
35 |
36 | ## Configuration
37 |
38 | WhisperChain will look for configuration in the following locations:
39 | 1. Environment variables
40 | 2. .env file in the current directory
41 | 3. ~/.whisperchain/.env file
42 |
43 | On first run, if no configuration is found, you will be prompted to enter your OpenAI API key. The key will be saved in `~/.whisperchain/.env` for future use.
44 |
45 | You can also manually set your OpenAI API key in any of these ways:
46 | ```bash
47 | # Option 1: Environment variable
48 | export OPENAI_API_KEY=your-api-key-here
49 |
50 | # Option 2: Create .env file in current directory
51 | echo "OPENAI_API_KEY=your-api-key-here" > .env
52 |
53 | # Option 3: Create global config
54 | mkdir -p ~/.whisperchain
55 | echo "OPENAI_API_KEY=your-api-key-here" > ~/.whisperchain/.env
56 | ```
57 |
58 | ## Usage
59 |
60 | 1. Start the application:
61 | ```bash
62 | # Run with default settings
63 | whisperchain
64 |
65 | # Run with custom configuration
66 | whisperchain --config config.json
67 |
68 | # Override specific settings
69 | whisperchain --port 8080 --hotkey "++t" --model "large" --debug
70 | ```
71 |
72 | 3. Use the global hotkey (`++r` by default. `++r` on MacOS):
73 | - Press and hold to start recording
74 | - Speak your text
75 | - Release to stop recording
76 | - The cleaned transcription will be copied to your clipboard automatically
77 | - Paste (Ctrl+V) to paste the transcription
78 |
79 | ## Development
80 |
81 | ### Streamlit UI
82 |
83 | ```bash
84 | streamlit run src/whisperchain/ui/streamlit_app.py
85 | ```
86 |
87 | If there is an error in the Streamlit UI, you can run the following command to kill all running Streamlit processes:
88 |
89 | ```bash
90 | lsof -ti :8501 | xargs kill -9
91 | ```
92 |
93 | ### Running Tests
94 |
95 | Install test dependencies:
96 | ```bash
97 | pip install -e ".[test]"
98 | ```
99 |
100 | Run tests:
101 | ```bash
102 | pytest tests/
103 | ```
104 |
105 | Run tests with microphone input:
106 | ```bash
107 | # Run specific microphone test
108 | TEST_WITH_MIC=1 pytest tests/test_stream_client.py -v -k test_stream_client_with_real_mic
109 |
110 | # Run all tests including microphone test
111 | TEST_WITH_MIC=1 pytest tests/
112 | ```
113 |
114 | ### Building the project
115 |
116 | ```bash
117 | python -m build
118 | pip install .
119 | ```
120 |
121 | ### Publishing to PyPI
122 |
123 | ```bash
124 | python -m build
125 | twine upload --repository pypi dist/*
126 | ```
127 |
128 | ## License
129 |
130 | [LICENSE](LICENSE)
131 |
132 | ## Acknowledgments
133 |
134 | - [Whisper.cpp](https://github.com/ggerganov/whisper.cpp)
135 | - [pywhispercpp](https://github.com/absadiki/pywhispercpp.git)
136 | - [LangChain](https://github.com/langchain-ai/langchain)
137 |
138 |
139 | ## Architecture
140 |
141 | ```mermaid
142 | graph TB
143 | subgraph "Client Options"
144 | K[Key Listener]
145 | A[Audio Stream]
146 | C[Clipboard]
147 | end
148 |
149 | subgraph "Streamlit Web UI :8501"
150 | WebP[Prompt]
151 | WebH[History]
152 | end
153 |
154 | subgraph "FastAPI Server :8000"
155 | WS[WebSocket /stream]
156 | W[Whisper Model]
157 | LC[LangChain Processor]
158 | H[History]
159 | end
160 |
161 | K -->|"Hot Key"| A
162 | A -->|"Audio Stream"| WS
163 | WS --> W
164 | W --> LC
165 | WebP --> LC
166 | LC --> C
167 | LC --> H
168 | H --> WebH
169 | ```
170 |
--------------------------------------------------------------------------------
/assets/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chrischoy/WhisperChain/9020bf96c3e1ed82543bb595976d9aa90e66cf2c/assets/logo.jpg
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "client": {
3 | "server_url": "ws://localhost:8000/stream",
4 | "hotkey": "++r",
5 | "audio": {
6 | "sample_rate": 16000,
7 | "channels": 1,
8 | "chunk_size": 4096,
9 | "format": "int16"
10 | },
11 | "stream": {
12 | "min_buffer_size": 32000,
13 | "timeout": 0.1,
14 | "end_marker": "END\n"
15 | }
16 | },
17 | "server": {
18 | "host": "0.0.0.0",
19 | "port": 8000,
20 | "model_name": "base.en",
21 | "debug": false
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "wheel", "build"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "whisperchain"
7 | version = "0.1.3"
8 | description = "Voice control using Whisper.cpp with LangChain cleanup"
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | keywords = ["whisper", "langchain", "voice-control", "speech-to-text"]
12 | classifiers = [
13 | "Development Status :: 3 - Alpha",
14 | "Intended Audience :: Developers",
15 | "License :: OSI Approved :: MIT License",
16 | "Programming Language :: Python :: 3",
17 | "Programming Language :: Python :: 3.8",
18 | "Programming Language :: Python :: 3.9",
19 | "Programming Language :: Python :: 3.10",
20 | "Topic :: Software Development :: Libraries :: Python Modules",
21 | "Topic :: Multimedia :: Sound/Audio :: Speech",
22 | ]
23 | license = { text = "MIT" }
24 | authors = [
25 | { name="Chris Choy", email="chrischoy@ai.stanford.edu" }
26 | ]
27 | dependencies = [
28 | "click>=8.0.0",
29 | "pydantic>=2.0.0",
30 | "pynput>=1.7.7",
31 | "pyperclip",
32 | "openai>=1.0.0",
33 | "pywhispercpp>=1.3.0",
34 | "fastapi>=0.100.0",
35 | "uvicorn>=0.22.0",
36 | "pyaudio>=0.2.11",
37 | "langchain>=0.1.0",
38 | "langchain-openai>=0.1.0",
39 | "websockets>=11.0.0",
40 | "streamlit>=1.20.0",
41 | ]
42 |
43 | [project.optional-dependencies]
44 | test = [
45 | "pytest>=7.0.0",
46 | "pytest-asyncio>=0.21.0", # Change from trio to asyncio
47 | "httpx>=0.24.0", # For testing FastAPI
48 | ]
49 | dev = [
50 | "pre-commit>=3.0.0",
51 | "black>=22.0.0",
52 | "isort>=5.0.0",
53 | "build>=0.10.0",
54 | "twine>=4.0.0",
55 | ]
56 |
57 | [project.scripts]
58 | whisperchain = "whisperchain.cli.run:main"
59 | whisperchain-client = "whisperchain.cli.run_client:main"
60 | whisperchain-server = "whisperchain.cli.run_server:main"
61 |
62 | [project.urls]
63 | Homepage = "https://github.com/chrischoy/whisperchain"
64 | "Bug Tracker" = "https://github.com/chrischoy/whisperchain/issues"
65 |
66 | [tool.black]
67 | line-length = 99
68 | target-version = ['py38']
69 |
70 | [tool.isort]
71 | profile = "black"
72 | multi_line_output = 3
73 |
74 | [tool.setuptools]
75 | package-dir = {"" = "src"} # Tells setuptools that packages are under src
76 |
77 | [tool.setuptools.packages.find]
78 | where = ["src"]
79 | include = ["whisperchain*"]
80 | exclude = ["tests*"]
81 |
82 | [tool.pytest.ini_options]
83 | asyncio_mode = "auto"
84 | asyncio_default_fixture_loop_scope = "function"
85 |
--------------------------------------------------------------------------------
/src/whisperchain/cli/run.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 | import sys
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | import click
7 | import streamlit.web.cli as stcli
8 | import uvicorn
9 |
10 | from whisperchain.client.key_listener import HotKeyRecordingListener
11 | from whisperchain.core.config import ClientConfig, ServerConfig, config
12 | from whisperchain.server.server import WhisperServer
13 |
14 |
15 | def run_server(server_config: ServerConfig):
16 | """Run the WhisperServer with given config"""
17 | server = WhisperServer(config=server_config)
18 | uvicorn.run(server.app, host=server_config.host, port=server_config.port)
19 |
20 |
21 | def run_ui():
22 | """Run the Streamlit UI"""
23 | # Ensure config is up to date
24 | config.generate_streamlit_config()
25 |
26 | # Get the UI script path
27 | ui_path = Path(__file__).parent.parent / "ui" / "streamlit_app.py"
28 |
29 | # Just run the script, config.toml will be used automatically
30 | sys.argv = ["streamlit", "run", str(ui_path)]
31 | stcli.main()
32 |
33 |
34 | def run_client(client_config: ClientConfig):
35 | """Run the recording client"""
36 | listener = HotKeyRecordingListener(config=client_config)
37 | listener.start()
38 |
39 |
40 | @click.command()
41 | @click.option("--server-only", is_flag=True, help="Run only the server")
42 | @click.option("--ui-only", is_flag=True, help="Run only the UI")
43 | @click.option("--client-only", is_flag=True, help="Run only the client")
44 | @click.option("--debug", is_flag=True, help="Enable debug mode")
45 | def main(server_only: bool, ui_only: bool, client_only: bool, debug: bool):
46 | """Run WhisperChain components"""
47 | server_process = None
48 | ui_process = None
49 | client_process = None
50 |
51 | try:
52 | # Load configs
53 | server_config = ServerConfig(debug=debug)
54 | client_config = ClientConfig()
55 |
56 | # Start components based on flags
57 | if server_only:
58 | run_server(server_config)
59 | elif ui_only:
60 | run_ui()
61 | elif client_only:
62 | run_client(client_config)
63 | else:
64 | # Start server in separate process
65 | server_process = mp.Process(target=run_server, args=(server_config,))
66 | server_process.start()
67 |
68 | # Start UI in separate process
69 | ui_process = mp.Process(target=run_ui)
70 | ui_process.start()
71 |
72 | # Run client in main process
73 | run_client(client_config)
74 |
75 | except KeyboardInterrupt:
76 | print("\nShutting down...")
77 | finally:
78 | # Cleanup processes
79 | if server_process and server_process.is_alive():
80 | server_process.terminate()
81 | server_process.join()
82 | if ui_process and ui_process.is_alive():
83 | ui_process.terminate()
84 | ui_process.join()
85 |
86 |
87 | if __name__ == "__main__":
88 | main()
89 |
--------------------------------------------------------------------------------
/src/whisperchain/cli/run_client.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Optional
3 |
4 | import click
5 |
6 | from whisperchain.client.key_listener import HotKeyRecordingListener
7 | from whisperchain.core.config import ClientConfig
8 |
9 |
10 | @click.command()
11 | @click.option("--hotkey", default="++r", help="Hotkey to start/stop recording")
12 | @click.option("--config", type=click.Path(exists=True), help="Path to config JSON file")
13 | @click.option("--sample-rate", type=int, help="Audio sample rate in Hz")
14 | @click.option("--channels", type=int, help="Number of audio channels")
15 | @click.option("--chunk-size", type=int, help="Audio chunk size")
16 | @click.option("--server-url", help="WebSocket server URL")
17 | def main(
18 | hotkey: str,
19 | config: Optional[str],
20 | sample_rate: Optional[int],
21 | channels: Optional[int],
22 | chunk_size: Optional[int],
23 | server_url: Optional[str],
24 | ):
25 | """Start the voice control client."""
26 | # Load base configuration
27 | if config:
28 | with open(config) as f:
29 | config_dict = json.load(f)
30 | client_config = ClientConfig.parse_obj(config_dict)
31 | else:
32 | client_config = ClientConfig()
33 |
34 | # Override with command line arguments if provided
35 | if hotkey:
36 | client_config.hotkey = hotkey
37 | if sample_rate:
38 | client_config.audio.sample_rate = sample_rate
39 | if channels:
40 | client_config.audio.channels = channels
41 | if chunk_size:
42 | client_config.audio.chunk_size = chunk_size
43 | if server_url:
44 | client_config.server_url = server_url
45 |
46 | listener = HotKeyRecordingListener(hotkey=client_config.hotkey, config=client_config)
47 | listener.start()
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/src/whisperchain/cli/run_server.py:
--------------------------------------------------------------------------------
1 | import click
2 | import uvicorn
3 |
4 | from whisperchain.core.config import ServerConfig
5 | from whisperchain.server.server import WhisperServer
6 | from whisperchain.utils.secrets import load_secrets
7 |
8 |
9 | @click.command()
10 | @click.option("--host", default="0.0.0.0", help="Server host")
11 | @click.option("--port", default=8000, help="Server port")
12 | @click.option("--model", default="base.en", help="Whisper model name")
13 | @click.option("--debug", is_flag=True, help="Enable debug mode")
14 | def main(host: str, port: int, model: str, debug: bool):
15 | """Run the FastAPI server."""
16 | # Initialize secrets
17 | load_secrets()
18 |
19 | config = ServerConfig(host=host, port=port, model_name=model, debug=debug)
20 | server = WhisperServer(config)
21 | uvicorn.run(server.app, host=config.host, port=config.port)
22 |
--------------------------------------------------------------------------------
/src/whisperchain/client/key_listener.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import multiprocessing as mp
3 | from threading import Thread
4 |
5 | import pyperclip
6 | from pynput import keyboard
7 |
8 | from whisperchain.client.stream_client import StreamClient
9 | from whisperchain.core.config import ClientConfig
10 | from whisperchain.utils.decorators import handle_exceptions
11 | from whisperchain.utils.logger import get_logger
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | class HotKeyListener:
17 | def __init__(self, combination_str="++r"):
18 | self.pressed = False
19 | # Parse the string into a combination set.
20 | self.combination = keyboard.HotKey.parse(combination_str)
21 | # Initialize a HotKey object with on_activate callback.
22 | self.hotkey = keyboard.HotKey(self.combination, self.on_activate)
23 |
24 | def on_activate(self):
25 | if not self.pressed:
26 | logger.info(f"Global hotkey: {self.combination} activated!")
27 | self.pressed = True
28 |
29 | def on_deactivate(self, key):
30 | # Release the key in the HotKey object.
31 | self.hotkey.release(key)
32 | if self.pressed:
33 | logger.info(f"Global hotkey: {self.combination} deactivated!")
34 | self.pressed = False
35 |
36 | def start(self):
37 | # Helper function to transform key press events.
38 | def for_canonical(f):
39 | return lambda k: f(listener.canonical(k))
40 |
41 | # Create a Listener using the canonical transformation.
42 | with keyboard.Listener(
43 | on_press=for_canonical(self.hotkey.press),
44 | on_release=for_canonical(self.on_deactivate),
45 | ) as listener:
46 | listener.join()
47 |
48 |
49 | class HotKeyRecordingListener(HotKeyListener):
50 | def __init__(self, hotkey: str = "++r", config: ClientConfig = None):
51 | super().__init__(hotkey)
52 | self.config = config or ClientConfig(hotkey=hotkey)
53 | self.recording = False
54 | self.stop_event = mp.Event()
55 | self.streaming_thread = None
56 |
57 | @handle_exceptions
58 | async def _streaming_loop(self):
59 | messages = []
60 | total_bytes_sent = 0
61 |
62 | async with StreamClient(config=self.config) as client:
63 | async for message in client.stream_microphone():
64 | if self.stop_event.is_set():
65 | logger.info("Stopping audio capture")
66 | client.stop()
67 |
68 | messages.append(message)
69 | # Extract byte count from message text if available.
70 | if not message.get("is_final"):
71 | try:
72 | byte_count = int(message["processed_bytes"])
73 | total_bytes_sent += byte_count
74 | except (IndexError, ValueError):
75 | pass
76 | if message.get("is_final"):
77 | final_cleaned = message["cleaned_transcription"]
78 | pyperclip.copy(final_cleaned)
79 | logger.info(f"Copied to clipboard: {final_cleaned}")
80 | break
81 | # Optionally, you can log or store the messages/byte counts.
82 | logger.info(f"Async streaming loop finished. Total bytes sent: {total_bytes_sent}")
83 |
84 | def on_activate(self):
85 | super().on_activate()
86 | if not self.recording:
87 | self.stop_event.clear()
88 | logger.info("Starting async streaming loop")
89 | # Run the async _streaming_loop() in a background thread.
90 | self.streaming_thread = Thread(
91 | target=lambda: asyncio.run(self._streaming_loop()), daemon=True
92 | )
93 | self.streaming_thread.start()
94 | self.recording = True
95 |
96 | def on_deactivate(self, key):
97 | super().on_deactivate(key)
98 | if self.recording:
99 | self.stop_event.set()
100 | self.recording = False
101 | logger.info("Joining streaming thread")
102 | self.streaming_thread.join()
103 | logger.info("Streaming thread joined")
104 |
105 |
106 | if __name__ == "__main__":
107 | listener = HotKeyRecordingListener("++r")
108 | listener.start()
109 |
--------------------------------------------------------------------------------
/src/whisperchain/client/stream_client.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import multiprocessing as mp
4 | import queue
5 | import threading
6 | import time
7 |
8 | import websockets
9 |
10 | from whisperchain.core.audio import AudioCapture
11 | from whisperchain.core.config import ClientConfig
12 | from whisperchain.utils.decorators import handle_exceptions
13 | from whisperchain.utils.logger import get_logger
14 |
15 | logger = get_logger(__name__)
16 |
17 |
18 | # StreamClient manages the connection to the WebSocket server and sends audio captured by AudioCapture.
19 | class StreamClient:
20 | def __init__(self, config: ClientConfig = None):
21 | self.config = config or ClientConfig()
22 | self.server_url = self.config.server_url
23 | self.min_buffer_size = self.config.stream.min_buffer_size
24 | self.audio_queue = queue.Queue()
25 | self.is_audio_capturing = threading.Event()
26 | self.stop_event = threading.Event()
27 | self.audio_thread = None
28 |
29 | def _start_audio_capture(self):
30 | self.stop_event.clear()
31 | self.is_audio_capturing.set()
32 | capture = AudioCapture(self.audio_queue, self.is_audio_capturing, config=self.config.audio)
33 | self.audio_thread = threading.Thread(target=capture.start)
34 | self.audio_thread.start()
35 | logger.info("StreamClient: Started recording thread")
36 |
37 | def _stop_audio_capture(self):
38 | if self.is_audio_capturing.is_set():
39 | logger.info("StreamClient: Stopping audio capture")
40 | self.is_audio_capturing.clear()
41 | if self.audio_thread:
42 | self.audio_thread.join(timeout=2.0)
43 | if self.audio_thread.is_alive():
44 | logger.warning("StreamClient: Audio thread still running")
45 | self.audio_thread = None
46 | logger.info("StreamClient: Audio capture stopped")
47 |
48 | def stop(self):
49 | self.stop_event.set()
50 |
51 | @handle_exceptions
52 | async def stream_microphone(self):
53 | audio_buffer = bytearray()
54 | end_sent = False
55 | logger.info("StreamClient: Connecting to server")
56 | async with websockets.connect(self.server_url) as websocket:
57 | logger.info("StreamClient: Connected to server")
58 | self._start_audio_capture()
59 | while True:
60 | # Check if the is_audio_capturing event has been cleared (e.g., hotkey released)
61 | if self.stop_event.is_set() and not end_sent:
62 | self._stop_audio_capture()
63 | if audio_buffer:
64 | await websocket.send(bytes(audio_buffer))
65 | logger.info("StreamClient: Sent remaining audio, cleared buffer")
66 | audio_buffer.clear()
67 | logger.info("StreamClient: Sending END marker")
68 | await websocket.send(self.config.stream.end_marker.encode())
69 | end_sent = True
70 |
71 | if not end_sent:
72 | try:
73 | data = self.audio_queue.get_nowait()
74 | logger.info(f"StreamClient: Got {len(data)} bytes from queue")
75 | audio_buffer.extend(data)
76 | if len(audio_buffer) >= self.min_buffer_size:
77 | await websocket.send(bytes(audio_buffer))
78 | logger.info("StreamClient: Sent audio chunk")
79 | audio_buffer.clear()
80 | except Exception:
81 | await asyncio.sleep(0.01)
82 |
83 | try:
84 | message = await asyncio.wait_for(
85 | websocket.recv(), timeout=self.config.stream.timeout
86 | )
87 | msg = json.loads(message)
88 | logger.info(f"StreamClient: Received message: {msg}")
89 | yield msg
90 | if msg.get("is_final"):
91 | break
92 | except asyncio.TimeoutError:
93 | continue
94 | await asyncio.sleep(0.1)
95 | logger.info("StreamClient: Stream ended")
96 |
97 | async def __aenter__(self):
98 | return self
99 |
100 | async def __aexit__(self, exc_type, exc, tb):
101 | self._stop_audio_capture()
102 |
--------------------------------------------------------------------------------
/src/whisperchain/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chrischoy/WhisperChain/9020bf96c3e1ed82543bb595976d9aa90e66cf2c/src/whisperchain/core/__init__.py
--------------------------------------------------------------------------------
/src/whisperchain/core/audio.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 |
3 | import pyaudio
4 |
5 | from whisperchain.core.config import AudioConfig
6 | from whisperchain.utils.logger import get_logger
7 |
8 | logger = get_logger(__name__)
9 |
10 |
11 | class AudioCapture:
12 | def __init__(self, queue: mp.Queue, is_recording: mp.Event, config: AudioConfig = None):
13 | self.queue = queue
14 | self.is_recording = is_recording
15 | self.config = config or AudioConfig()
16 | self.audio = None
17 | self.stream = None
18 |
19 | def start(self):
20 | self.audio = pyaudio.PyAudio()
21 | self.stream = self.audio.open(
22 | format=getattr(pyaudio, f"pa{self.config.format.capitalize()}"),
23 | channels=self.config.channels,
24 | rate=self.config.sample_rate,
25 | input=True,
26 | frames_per_buffer=self.config.chunk_size,
27 | )
28 | logger.info("AudioCapture: Started capturing audio")
29 | while self.is_recording.is_set():
30 | try:
31 | data = self.stream.read(self.config.chunk_size, exception_on_overflow=False)
32 | logger.info(f"AudioCapture: Captured {len(data)} bytes")
33 | self.queue.put(data)
34 | except Exception as e:
35 | logger.error(f"AudioCapture error: {e}")
36 | break
37 | self.cleanup()
38 |
39 | def cleanup(self):
40 | if self.stream:
41 | self.stream.stop_stream()
42 | self.stream.close()
43 | if self.audio:
44 | self.audio.terminate()
45 | logger.info("AudioCapture: Stopped capturing audio")
46 |
--------------------------------------------------------------------------------
/src/whisperchain/core/chain.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import os
3 | from pathlib import Path
4 |
5 | from langchain.prompts.chat import ChatPromptTemplate
6 | from langchain.schema import AIMessage
7 | from langchain_openai import ChatOpenAI
8 |
9 | from whisperchain.utils.logger import get_logger
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | def load_prompt(prompt_path: str | Path) -> str:
15 | """
16 | Load a prompt template from the PROJECT_ROOT/prompts folder.
17 |
18 | Args:
19 | prompt_name (str): The filename of the prompt (e.g., "transcription_cleanup.txt").
20 |
21 | Returns:
22 | str: The content of the prompt template.
23 | """
24 | if isinstance(prompt_path, str):
25 | prompt_path = Path(prompt_path)
26 |
27 | # if the prompt path is relative, convert it to an absolute path
28 | if not prompt_path.is_absolute():
29 | prompt_path = Path(__file__).parent.parent / prompt_path
30 |
31 | assert prompt_path.exists(), f"prompt path: {prompt_path} does not exist"
32 | with open(str(prompt_path), "r", encoding="utf-8") as file:
33 | return file.read()
34 |
35 |
36 | class TranscriptionCleaner:
37 | """
38 | Uses a composed (chained) runnable to clean up raw transcription text.
39 |
40 | This class builds a chain by composing a runnable prompt with an LLM. The prompt instructs
41 | the LLM to remove filler words, fix grammatical errors, and produce a coherent cleaned transcription.
42 | This composition via the pipe operator leverages the new RunnableSequence interface.
43 | """
44 |
45 | def __init__(
46 | self,
47 | model_name: str = "gpt-3.5-turbo",
48 | prompt_path: str = "prompts/transcription_cleanup.txt", # relative to the whisperchain package
49 | verbose: bool = False,
50 | ):
51 | # Load and convert the prompt text into a runnable ChatPromptTemplate.
52 | prompt_text = load_prompt(prompt_path)
53 | self.prompt_template = ChatPromptTemplate.from_template(prompt_text)
54 | self.llm = ChatOpenAI(model_name=model_name, temperature=0, verbose=verbose)
55 | self.runnable_chain = self.prompt_template | self.llm
56 |
57 | def clean(self, transcription: str) -> str:
58 | """
59 | Synchronously clean the provided transcription text by invoking the composed chain.
60 |
61 | Args:
62 | transcription (str): The raw transcription text.
63 |
64 | Returns:
65 | str: The cleaned transcription text.
66 | """
67 | result: AIMessage = self.runnable_chain.invoke({"transcription": transcription})
68 | return result.content.strip()
69 |
70 | async def aclean(self, transcription: str) -> str:
71 | """
72 | Asynchronously clean the provided transcription text by invoking the composed chain.
73 |
74 | Args:
75 | transcription (str): The raw transcription text.
76 |
77 | Returns:
78 | str: The cleaned transcription text.
79 | """
80 | result: AIMessage = await self.runnable_chain.ainvoke({"transcription": transcription})
81 | return result.content.strip()
82 |
--------------------------------------------------------------------------------
/src/whisperchain/core/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import toml
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class AudioConfig(BaseModel):
9 | """Audio capture configuration."""
10 |
11 | sample_rate: int = Field(default=16000, description="Sample rate in Hz")
12 | channels: int = Field(default=1, description="Number of audio channels")
13 | chunk_size: int = Field(
14 | default=4096, description="Chunk size for audio capture (~256ms at 16kHz)"
15 | )
16 | format: str = Field(default="int16", description="Audio format (int16, float32, etc.)")
17 |
18 |
19 | class StreamConfig(BaseModel):
20 | """Stream client configuration."""
21 |
22 | min_buffer_size: int = Field(
23 | default=32000, description="Minimum buffer size in bytes before sending"
24 | )
25 | timeout: float = Field(default=0.1, description="Timeout for websocket operations in seconds")
26 | end_marker: str = Field(default="END\n", description="Marker to indicate end of stream")
27 |
28 |
29 | class ClientConfig(BaseModel):
30 | """Client configuration including audio and stream settings."""
31 |
32 | server_url: str = Field(
33 | default="ws://localhost:8000/stream", description="WebSocket server URL"
34 | )
35 | hotkey: str = Field(default="++r", description="Global hotkey combination")
36 | audio: AudioConfig = Field(default_factory=AudioConfig, description="Audio capture settings")
37 | stream: StreamConfig = Field(
38 | default_factory=StreamConfig, description="Stream client settings"
39 | )
40 |
41 |
42 | class ServerConfig(BaseModel):
43 | host: str = "0.0.0.0"
44 | port: int = 8000
45 | model_name: str = "base.en"
46 | debug: bool = False
47 |
48 | def validate_model_name(cls, v):
49 | from pywhispercpp.constants import AVAILABLE_MODELS
50 |
51 | if v not in AVAILABLE_MODELS:
52 | raise ValueError(f"Model {v} not found in {AVAILABLE_MODELS}")
53 | return v
54 |
55 |
56 | class UIConfig(BaseModel):
57 | """Streamlit UI configuration"""
58 |
59 | title: str = "WhisperChain Dashboard"
60 | page_icon: str = "🎙️"
61 | layout: str = "wide"
62 |
63 | # Server connection
64 | server_url: str = "http://localhost:8000"
65 | refresh_interval: float = 5.0 # seconds
66 | quick_refresh: float = 0.1 # seconds
67 |
68 | # Display settings
69 | history_limit: int = 100
70 | default_expanded: bool = False
71 |
72 | # Theme
73 | theme_base: str = "light"
74 | theme_primary_color: str = "#F63366"
75 |
76 |
77 | class ConfigManager:
78 | """Central configuration manager"""
79 |
80 | _instance = None
81 |
82 | def __init__(self):
83 | self.config_dir = Path.home() / ".whisperchain"
84 | self.config_dir.mkdir(exist_ok=True)
85 |
86 | # Load configs
87 | self.ui_config = self._load_ui_config()
88 |
89 | @classmethod
90 | def get_instance(cls):
91 | """Get singleton instance"""
92 | if cls._instance is None:
93 | cls._instance = ConfigManager()
94 | return cls._instance
95 |
96 | def _load_ui_config(self) -> UIConfig:
97 | """Load UI config from file or create default"""
98 | config_file = self.config_dir / "ui_config.json"
99 |
100 | if config_file.exists():
101 | return UIConfig.parse_file(config_file)
102 |
103 | # Create default config
104 | config = UIConfig()
105 | self.save_ui_config(config)
106 | return config
107 |
108 | def save_ui_config(self, config: UIConfig):
109 | """Save UI config to file"""
110 | config_file = self.config_dir / "ui_config.json"
111 | with open(config_file, "w") as f:
112 | json.dump(config.dict(), f, indent=2)
113 |
114 | def generate_streamlit_config(self):
115 | """Generate Streamlit config.toml content"""
116 | streamlit_config = {
117 | "browser": {
118 | "gatherUsageStats": False,
119 | },
120 | "server": {
121 | "headless": True,
122 | "runOnSave": True,
123 | "address": "localhost",
124 | "port": 8501,
125 | "enableCORS": True,
126 | },
127 | "theme": {
128 | "base": self.ui_config.theme_base,
129 | "primaryColor": self.ui_config.theme_primary_color,
130 | },
131 | }
132 |
133 | # Write to .streamlit/config.toml
134 | config_path = Path.home() / ".streamlit" / "config.toml"
135 | config_path.parent.mkdir(exist_ok=True)
136 |
137 | with open(config_path, "w") as f:
138 | toml.dump(streamlit_config, f)
139 |
140 |
141 | # Global config instance
142 | config = ConfigManager.get_instance()
143 |
--------------------------------------------------------------------------------
/src/whisperchain/prompts/transcription_cleanup.txt:
--------------------------------------------------------------------------------
1 | You are a helpful assistant that cleans up raw transcription text by removing filler words, correcting misinterpretations, and producing a coherent sentence.
2 |
3 | Raw transcription: {transcription}
4 | Cleaned transcription:
5 |
--------------------------------------------------------------------------------
/src/whisperchain/server/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chrischoy/WhisperChain/9020bf96c3e1ed82543bb595976d9aa90e66cf2c/src/whisperchain/server/__init__.py
--------------------------------------------------------------------------------
/src/whisperchain/server/server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from datetime import datetime
4 | from typing import List
5 |
6 | import numpy as np
7 | import pyaudio
8 | import uvicorn
9 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10 | from fastapi.responses import FileResponse
11 | from fastapi.staticfiles import StaticFiles
12 | from pywhispercpp.constants import AVAILABLE_MODELS
13 | from pywhispercpp.model import Model, Segment
14 |
15 | from whisperchain.core.chain import TranscriptionCleaner
16 | from whisperchain.core.config import ServerConfig
17 | from whisperchain.utils.logger import get_logger
18 | from whisperchain.utils.segment import (
19 | list_of_segments_to_text,
20 | list_of_segments_to_text_with_timestamps,
21 | )
22 |
23 | logger = get_logger(__name__)
24 |
25 |
26 | class WhisperServer:
27 | def __init__(self, config: ServerConfig = None):
28 | self.config = config or ServerConfig()
29 | self.whisper_model = None
30 | self.transcription_cleaner = None
31 | self.app = FastAPI()
32 | self.transcription_history = []
33 | self.setup_routes()
34 |
35 | def setup_routes(self):
36 | self.app.add_event_handler("startup", self.startup_event)
37 | self.app.add_websocket_route("/stream", self.websocket_endpoint)
38 |
39 | @self.app.get("/")
40 | async def get_root():
41 | """Health check endpoint"""
42 | return {"status": "ok"}
43 |
44 | @self.app.get("/history")
45 | async def get_history():
46 | """Get transcription history"""
47 | return self.transcription_history
48 |
49 | @self.app.delete("/history")
50 | async def clear_history():
51 | """Clear transcription history"""
52 | self.transcription_history.clear()
53 | return {"status": "cleared"}
54 |
55 | async def startup_event(self):
56 | logger.info(f"Initializing Whisper model {self.config.model_name}...")
57 | self.whisper_model = Model(model=self.config.model_name)
58 | logger.info("Initializing transcription cleaner...")
59 | self.transcription_cleaner = TranscriptionCleaner()
60 | if self.config.debug:
61 | logger.info("Running in DEBUG mode - audio playback enabled. Printing all chain logs.")
62 |
63 | async def play_audio(self, audio_data: bytes):
64 | """Play the received audio data using PyAudio."""
65 | p = pyaudio.PyAudio()
66 | stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, output=True)
67 | stream.write(audio_data)
68 | stream.stop_stream()
69 | stream.close()
70 | p.terminate()
71 |
72 | async def transcribe_audio(self, audio_data: bytes) -> List[Segment]:
73 | """Transcribe audio data using the whisper model."""
74 | # Convert bytes to a numpy array
75 | audio_array = np.frombuffer(audio_data, dtype=np.int16)
76 | # Convert to float32
77 | audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max
78 | # Transcribe the audio
79 | result: List[Segment] = self.whisper_model.transcribe(audio_array)
80 | return result
81 |
82 | async def websocket_endpoint(self, websocket: WebSocket):
83 | await websocket.accept()
84 | received_data = b""
85 | while True:
86 | try:
87 | data = await websocket.receive_bytes()
88 | except WebSocketDisconnect:
89 | logger.info("Server: WebSocket disconnected")
90 | break
91 |
92 | if data.endswith(b"END\n"):
93 | # Remove the END marker and accumulate any remaining data.
94 | data_without_end = data[:-4]
95 | received_data += data_without_end
96 | # Transcribe the received audio
97 | segments = await self.transcribe_audio(received_data)
98 | # Clean the transcription
99 | cleaned_transcription = self.transcription_cleaner.clean(
100 | list_of_segments_to_text(segments)
101 | )
102 | # Build a final message
103 | final_message = {
104 | "type": "transcription",
105 | "processed_bytes": len(received_data),
106 | "is_final": True,
107 | "transcription": list_of_segments_to_text_with_timestamps(segments),
108 | "cleaned_transcription": cleaned_transcription,
109 | "timestamp": datetime.now().isoformat(),
110 | }
111 | logger.info("Server: Sending final message: %s", final_message)
112 | self.transcription_history.append(final_message)
113 | await websocket.send_json(final_message)
114 | # Play back the received audio only in debug mode
115 | if self.config.debug:
116 | logger.info("Server: Playing back received audio (DEBUG mode)...")
117 | await self.play_audio(received_data)
118 | await asyncio.sleep(0.1)
119 | break
120 | else:
121 | # Accumulate the incoming bytes and send an intermediate echo message.
122 | received_data += data
123 | echo_message = {
124 | "type": "transcription",
125 | "processed_bytes": len(data),
126 | "is_final": False,
127 | }
128 | logger.info("Server: Echoing message: %s", echo_message)
129 | await websocket.send_json(echo_message)
130 | try:
131 | await websocket.close()
132 | except RuntimeError as e:
133 | # Ignore errors if the connection is already closed/completed
134 | logger.warning("Server: Warning while closing websocket: %s", e)
135 |
136 |
137 | # Create default instance
138 | default_server = WhisperServer()
139 | app = default_server.app
140 |
141 | if __name__ == "__main__":
142 | uvicorn.run(app, host="0.0.0.0", port=8000)
143 |
--------------------------------------------------------------------------------
/src/whisperchain/ui/streamlit_app.py:
--------------------------------------------------------------------------------
1 | # Run this file to test the Streamlit UI
2 | #
3 | # $ streamlit run tests/test_streamlit_demo.py
4 | #
5 | # If you get an error about the config file, run the following command to kill all running Streamlit processes:
6 | #
7 | # $ lsof -ti :8501 | xargs kill -9
8 |
9 | import time
10 | from datetime import datetime
11 |
12 | import requests
13 | import streamlit as st
14 |
15 | from whisperchain.core.config import config
16 |
17 |
18 | def main():
19 | # Set page config
20 | st.set_page_config(
21 | page_title=config.ui_config.title,
22 | page_icon=config.ui_config.page_icon,
23 | layout=config.ui_config.layout,
24 | )
25 |
26 | st.title("WhisperChain Dashboard")
27 |
28 | # Initialize session state
29 | if "last_history" not in st.session_state:
30 | st.session_state.last_history = []
31 | st.session_state.last_check_time = time.time()
32 | st.session_state.server_online = False # Initialize server status
33 |
34 | # Server status check - only check every 5 seconds
35 | current_time = time.time()
36 | if current_time - st.session_state.last_check_time >= 5:
37 | try:
38 | response = requests.get(config.ui_config.server_url)
39 | st.session_state.server_online = response.status_code == 200
40 | st.session_state.last_check_time = current_time
41 | except:
42 | st.session_state.server_online = False
43 |
44 | # Display server status
45 | st.sidebar.header("Server Status")
46 | if st.session_state.server_online:
47 | st.sidebar.success("🟢 Server Online")
48 | else:
49 | st.sidebar.error("🔴 Server Offline")
50 | st.header(
51 | "Once the server is online, the UI will automatically refresh and display the transcription history."
52 | )
53 | time.sleep(5) # Longer delay when server is offline
54 | st.rerun()
55 |
56 | # Transcription History
57 | st.header("Transcription History")
58 |
59 | # Clear history button
60 | if st.button("Clear History"):
61 | requests.delete(config.ui_config.server_url + "/history")
62 | st.session_state.last_history = []
63 | st.rerun()
64 |
65 | try:
66 | # Fetch history from server
67 | response = requests.get(config.ui_config.server_url + "/history")
68 | history = response.json()
69 |
70 | # Check if history has changed
71 | history_changed = len(history) != len(st.session_state.last_history)
72 | st.session_state.last_history = history
73 |
74 | # Display transcriptions
75 | for idx, entry in enumerate(reversed(history)):
76 | with st.expander(f"Transcription {len(history) - idx}", expanded=False):
77 | col1, col2 = st.columns(2)
78 | with col1:
79 | st.subheader("Raw Transcription")
80 | st.text(entry.get("transcription", ""))
81 | st.caption(f"Processed bytes: {entry.get('processed_bytes', 0)}")
82 | with col2:
83 | st.subheader("Cleaned Transcription")
84 | st.text(entry.get("cleaned_transcription", ""))
85 | st.caption(f"Timestamp: {entry.get('timestamp', '')}")
86 |
87 | # Only rerun if history has changed
88 | if history_changed:
89 | print("History changed, rerunning...")
90 | time.sleep(config.ui_config.quick_refresh)
91 | st.rerun()
92 | else:
93 | time.sleep(config.ui_config.refresh_interval)
94 | st.rerun()
95 |
96 | except Exception as e:
97 | st.error(f"Error fetching history: {str(e)}")
98 | st.rerun()
99 |
100 |
101 | if __name__ == "__main__":
102 | # Ensure Streamlit config is up to date
103 | config.generate_streamlit_config()
104 | main()
105 |
--------------------------------------------------------------------------------
/src/whisperchain/utils/decorators.py:
--------------------------------------------------------------------------------
1 | from whisperchain.utils.logger import get_logger
2 |
3 | logger = get_logger(__name__)
4 |
5 |
6 | def handle_exceptions(func):
7 | """Decorator to handle exceptions in a function."""
8 |
9 | def wrapper(*args, **kwargs):
10 | try:
11 | return func(*args, **kwargs)
12 | except Exception as e:
13 | logger.error(f"Error in {func.__name__}: {e}")
14 | raise e
15 |
16 | return wrapper
17 |
--------------------------------------------------------------------------------
/src/whisperchain/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | from pathlib import Path
4 |
5 |
6 | class ColorFormatter(logging.Formatter):
7 | """Custom formatter with colors"""
8 |
9 | COLORS = {
10 | "DEBUG": "\033[37m", # White
11 | "INFO": "\033[32m", # Green
12 | "WARNING": "\033[33m", # Yellow
13 | "ERROR": "\033[31m", # Red
14 | "CRITICAL": "\033[41m", # Red background
15 | }
16 | RESET = "\033[0m"
17 |
18 | def format(self, record):
19 | # Add color to the level name
20 | color = self.COLORS.get(record.levelname, "")
21 | record.levelname = f"{color}{record.levelname}{self.RESET}"
22 |
23 | return super().format(record)
24 |
25 |
26 | def get_logger(name: str = None) -> logging.Logger:
27 | """
28 | Create a logger with consistent formatting including filename and line number
29 |
30 | Args:
31 | name: Logger name, defaults to file name if None
32 |
33 | Returns:
34 | Configured logger instance
35 | """
36 | if name is None:
37 | # Get the caller's filename if no name provided
38 | frame = sys._getframe(1)
39 | name = Path(frame.f_code.co_filename).stem
40 |
41 | # Create logger
42 | logger = logging.getLogger(name)
43 |
44 | # Only add handler if logger doesn't have one
45 | if not logger.handlers:
46 | # Create stderr handler
47 | handler = logging.StreamHandler(sys.stderr)
48 |
49 | # Format: [LEVEL] filename:line - message
50 | formatter = ColorFormatter(
51 | fmt="[%(levelname)s] %(filename)s:%(lineno)d - %(message)s",
52 | datefmt="%Y-%m-%d %H:%M:%S",
53 | )
54 |
55 | handler.setFormatter(formatter)
56 | logger.addHandler(handler)
57 |
58 | # Set default level to DEBUG to see all messages
59 | logger.setLevel(logging.DEBUG)
60 |
61 | # Prevent propagation to avoid duplicate logs
62 | logger.propagate = False
63 |
64 | return logger
65 |
--------------------------------------------------------------------------------
/src/whisperchain/utils/secrets.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import os
3 | from pathlib import Path
4 |
5 | from dotenv import load_dotenv
6 |
7 |
8 | def get_config_dir() -> Path:
9 | """Get or create configuration directory."""
10 | config_dir = Path.home() / ".whisperchain"
11 | config_dir.mkdir(exist_ok=True)
12 | return config_dir
13 |
14 |
15 | def setup_openai_api_key() -> str:
16 | """Interactive setup for OpenAI API key with automatic .env creation."""
17 | print("\nOpenAI API key")
18 | print("You can find your OpenAI API keys at: https://platform.openai.com/api-keys")
19 |
20 | api_key = getpass.getpass("\nEnter your OpenAI API key: ").strip()
21 |
22 | if not api_key:
23 | raise ValueError("OpenAI API key is required")
24 | return api_key
25 |
26 |
27 | def setup_secrets() -> dict:
28 | """Interactive setup for API keys with automatic .env creation."""
29 | print("\nWhisperChain Setup")
30 | print("------------------")
31 | print("API keys are required for this application.")
32 |
33 | openai_api_key = setup_openai_api_key()
34 |
35 | # Save to .env file in user's config directory
36 | env_path = get_config_dir() / ".env"
37 |
38 | # Create or update .env file
39 | with open(env_path, "a+") as f:
40 | f.seek(0)
41 | content = f.read()
42 | if "OPENAI_API_KEY" not in content:
43 | f.write(f"\nOPENAI_API_KEY={openai_api_key}")
44 |
45 | os.environ["OPENAI_API_KEY"] = openai_api_key
46 | print(f"\nAPI keys have been saved to {env_path}")
47 | return {"openai_api_key": openai_api_key}
48 |
49 |
50 | def load_secrets() -> dict:
51 | """Get API keys from environment or prompt for setup."""
52 | # Load from all possible .env locations
53 | load_dotenv() # Load from current directory
54 | load_dotenv(get_config_dir() / ".env") # Load from ~/.whisperchain/.env
55 |
56 | # Try environment variable first
57 | api_key = os.environ.get("OPENAI_API_KEY")
58 | if api_key:
59 | return api_key
60 |
61 | # If not found, run interactive setup
62 | return setup_secrets()
63 |
--------------------------------------------------------------------------------
/src/whisperchain/utils/segment.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from pywhispercpp.model import Segment
4 |
5 |
6 | def list_of_segments_to_text(segments: List[Segment]) -> str:
7 | return " ".join([segment.text for segment in segments])
8 |
9 |
10 | def list_of_segments_to_text_with_timestamps(segments: List[Segment]) -> str:
11 | return " ".join([f"[{segment.t0}-{segment.t1}] {segment.text}" for segment in segments])
12 |
--------------------------------------------------------------------------------
/tests/test_audio_capture.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 | import os
3 | import time
4 |
5 | import pyaudio
6 | import pytest
7 |
8 | from whisperchain.core.audio import AudioCapture
9 | from whisperchain.core.config import AudioConfig
10 |
11 |
12 | @pytest.mark.skipif(not os.getenv("TEST_WITH_MIC"), reason="Requires microphone input")
13 | def test_audio_capture():
14 | q = mp.Queue()
15 | is_recording = mp.Event()
16 | is_recording.set()
17 | # Create an AudioCapture instance with default config
18 | capture_instance = AudioCapture(q, is_recording, config=AudioConfig())
19 | process = mp.Process(target=capture_instance.start)
20 | process.start()
21 |
22 | record_duration = 5 # seconds
23 | print("Test: Recording for 5 seconds...", flush=True)
24 | time.sleep(record_duration)
25 | # Stop recording gracefully and wait for the process to finish.
26 | is_recording.clear()
27 |
28 | # Get the total number of bytes captured.
29 | total_bytes = 0
30 | while not q.empty():
31 | try:
32 | data = q.get_nowait()
33 | total_bytes += len(data)
34 | except Exception:
35 | break
36 | assert total_bytes > 0, "No audio was captured"
37 |
38 | process.join(timeout=2.0)
39 | if process.is_alive():
40 | process.terminate()
41 | process.join()
42 |
43 |
44 | @pytest.mark.skipif(not os.getenv("TEST_WITH_MIC"), reason="Requires microphone input")
45 | def test_audio_playback():
46 | # Capture audio for 5 seconds and play it back.
47 | q = mp.Queue()
48 | is_recording = mp.Event()
49 | is_recording.set()
50 | config = AudioConfig() # Use default config for playback
51 | # Create an AudioCapture instance with default config
52 | capture_instance = AudioCapture(q, is_recording, config=config)
53 | process = mp.Process(target=capture_instance.start)
54 | process.start()
55 |
56 | record_duration = 5 # seconds
57 | print("Test: Recording for 5 seconds...", flush=True)
58 | time.sleep(record_duration)
59 | # Stop recording gracefully and wait for the process to finish.
60 | is_recording.clear()
61 |
62 | # Get the total number of bytes captured.
63 | audio_data = bytearray()
64 | total_bytes = 0
65 | while not q.empty():
66 | try:
67 | data = q.get_nowait()
68 | audio_data.extend(data)
69 | total_bytes += len(data)
70 | except Exception:
71 | break
72 | assert total_bytes > 0, "No audio was captured"
73 |
74 | # Play the audio back using the same format as configured
75 | pyaudio.PyAudio().open(
76 | format=getattr(pyaudio, f"pa{config.format.capitalize()}"),
77 | channels=config.channels,
78 | rate=config.sample_rate,
79 | output=True,
80 | ).write(bytes(audio_data))
81 |
--------------------------------------------------------------------------------
/tests/test_chain.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pywhispercpp.model import Segment
3 |
4 | from whisperchain.core.chain import TranscriptionCleaner
5 | from whisperchain.utils.segment import list_of_segments_to_text
6 |
7 |
8 | @pytest.fixture
9 | def cleaner():
10 | return TranscriptionCleaner()
11 |
12 |
13 | def test_transcription_cleaner(cleaner):
14 | assert cleaner.clean("Hello, world!") == "Hello, world!"
15 |
16 |
17 | def test_list_of_segments(cleaner):
18 | segments = [
19 | Segment(0, 100, "ummm"),
20 | Segment(100, 200, "Hello, world!"),
21 | ]
22 | final_text = list_of_segments_to_text(segments)
23 | assert final_text == "ummm Hello, world!"
24 |
25 | cleaned_text = cleaner.clean(final_text)
26 | assert cleaned_text == "Hello, world!"
27 |
--------------------------------------------------------------------------------
/tests/test_key_listener.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from whisperchain.client.key_listener import HotKeyRecordingListener
4 |
5 |
6 | @pytest.fixture
7 | def listener():
8 | # Create an instance of HotKeyRecordingListener with the default hotkey.
9 | listener = HotKeyRecordingListener("++r")
10 | return listener
11 |
12 |
13 | def test_on_activate(listener):
14 | # Initially, no recording and not pressed.
15 | assert listener.pressed is False
16 | print(f"Please press the hotkey: {listener.combination}")
17 | listener.start()
18 |
--------------------------------------------------------------------------------
/tests/test_pywhispercpp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import tempfile
4 | import urllib.request
5 |
6 | import pytest
7 | from pywhispercpp.model import Model
8 |
9 |
10 | @pytest.fixture
11 | def test_audio_path():
12 | """Fixture to download and provide a test audio file"""
13 | # Using a small public domain audio file from Wikimedia
14 | audio_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Example.ogg"
15 |
16 | # Create a temporary file
17 | with tempfile.NamedTemporaryFile(delete=False, suffix=".ogg") as tmp_file:
18 | # Download the file
19 | urllib.request.urlretrieve(audio_url, tmp_file.name)
20 | audio_path = tmp_file.name
21 |
22 | yield audio_path
23 |
24 | # Cleanup after test
25 | if os.path.exists(audio_path):
26 | os.unlink(audio_path)
27 |
28 |
29 | def test_whisper_model_loading():
30 | """Test that we can load the basic whisper model"""
31 | try:
32 | # Initialize model with base.en model
33 | model = Model("base.en", print_realtime=False, print_progress=False)
34 | assert model is not None, "Model should be loaded successfully"
35 | except Exception as e:
36 | pytest.fail(f"Failed to load whisper model: {str(e)}")
37 |
38 |
39 | def test_basic_transcription(test_audio_path):
40 | """Test basic transcription functionality"""
41 | model = Model("base.en", print_realtime=False, print_progress=False)
42 |
43 | # Test transcription
44 | result = model.transcribe(test_audio_path)
45 | print(result)
46 | assert isinstance(result, list), "Transcription result should be a list of segments"
47 | assert len(result) > 0, "Transcription result should not be empty"
48 | assert hasattr(result[0], "text"), "Segments should have text attribute"
49 |
50 |
51 | @pytest.mark.skipif(
52 | platform.system() != "Darwin" or not os.environ.get("WHISPER_COREML"),
53 | reason="CoreML tests only run on MacOS with WHISPER_COREML=1",
54 | )
55 | def test_coreml_support():
56 | """Test CoreML support on MacOS"""
57 | model = Model("base.en", print_realtime=False, print_progress=False)
58 |
59 | # Get system info to verify CoreML
60 | system_info = Model.system_info()
61 | # CoreML support is shown in system info
62 | assert "CoreML" in str(system_info), "CoreML support should be enabled on MacOS"
63 |
--------------------------------------------------------------------------------
/tests/test_stream_client.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from time import sleep
4 |
5 | import pytest
6 |
7 | from whisperchain.client.stream_client import StreamClient
8 | from whisperchain.utils.logger import get_logger
9 |
10 | logger = get_logger(__name__)
11 |
12 |
13 | async def stop_after(client, seconds):
14 | await asyncio.sleep(seconds)
15 | # Clear the recording flag to trigger the stop logic in stream_microphone()
16 | client.stop()
17 | logger.info(f"Test: Cleared audio capturing flag after {seconds} seconds.")
18 |
19 |
20 | @pytest.mark.skipif(
21 | not os.getenv("TEST_WITH_MIC"),
22 | reason="Requires microphone input. Run with TEST_WITH_MIC=1 to enable.",
23 | )
24 | @pytest.mark.asyncio
25 | async def test_stream_client_with_real_mic():
26 | """
27 | Test StreamClient with actual microphone input.
28 | This test will record for 5 seconds and then force-stop the streamer.
29 | """
30 | print("\n=== Real Microphone Test ===")
31 | print("Please speak into your microphone when recording starts")
32 | print("Recording will last for 5 seconds")
33 | print("3...")
34 | sleep(1)
35 | print("2...")
36 | sleep(1)
37 | print("1...")
38 | sleep(1)
39 | print("Recording NOW!")
40 |
41 | messages = []
42 | total_bytes_sent = 0
43 |
44 | async with StreamClient() as client:
45 | # Schedule clearing the recording flag after 5 seconds.
46 | asyncio.create_task(stop_after(client, 5))
47 | async for message in client.stream_microphone():
48 | messages.append(message)
49 | # Extract byte count from message text if available
50 | if not message.get("is_final"):
51 | try:
52 | byte_count = int(message["processed_bytes"])
53 | total_bytes_sent += byte_count
54 | except (IndexError, ValueError):
55 | pass
56 | if message.get("is_final"):
57 | final_byte_count = int(message["processed_bytes"])
58 | break
59 |
60 | print("\nTest Results:")
61 | print(f"Total bytes sent in chunks: {total_bytes_sent}")
62 | print(f"Final bytes received by server: {final_byte_count}")
63 | print("\nServer should now play back the received audio.")
64 | print("Please verify that the playback matches what you spoke.")
65 |
66 | # Basic assertions
67 | assert len(messages) > 0, "Should receive at least one message"
68 | assert any(msg.get("is_final") for msg in messages), "Should receive final message"
69 | assert final_byte_count > 0, "Server should receive nonzero bytes"
70 | assert (
71 | abs(total_bytes_sent - final_byte_count) < 8192
72 | ), "Bytes sent should approximately match bytes received"
73 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from whisperchain.utils.logger import get_logger
4 |
5 |
6 | def test_logger(capsys):
7 | """Test that logger prints to terminal"""
8 | logger = get_logger("test")
9 |
10 | # Test all log levels
11 | logger.debug("Debug message")
12 | logger.info("Info message")
13 | logger.warning("Warning message")
14 | logger.error("Error message")
15 |
16 | # Get captured output
17 | captured = capsys.readouterr()
18 |
19 | # Check that messages appear in stderr
20 | assert "Debug message" in captured.err
21 | assert "Info message" in captured.err
22 | assert "Warning message" in captured.err
23 | assert "Error message" in captured.err
24 |
25 | # Check formatting with color codes
26 | assert "\033[37mDEBUG\033[0m" in captured.err # White DEBUG
27 | assert "test_utils.py:" in captured.err
28 |
--------------------------------------------------------------------------------