├── static
├── js
│ └── script.js
└── css
│ ├── style.css
│ └── styles.css
├── templates
├── index.html
├── chat_messages.html
├── settings.html
├── base.html
└── chat.html
├── jarvis
├── tests
│ ├── __init__.py
│ ├── unit
│ │ ├── __init__.py
│ │ └── test_audio_processor.py
│ └── integration
│ │ ├── __init__.py
│ │ └── test_pipeline.py
├── src
│ ├── audio
│ │ ├── __init__.py
│ │ ├── recorder.py
│ │ └── processor.py
│ ├── gui
│ │ ├── __init__.py
│ │ ├── system_tray.py
│ │ └── main_window.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── transcriber.py
│ │ └── text_corrector.py
│ ├── __init__.py
│ ├── main.py
│ ├── cli.py
│ └── pipeline.py
├── run_gui.sh
├── requirements.txt
├── setup.py
├── Makefile
├── .gitignore
├── scripts
│ └── download_models.py
├── README.md
└── CLAUDE.md
├── requirements.txt
├── models
├── converters.py
├── indexer.py
├── retriever.py
├── model_loader.py
└── responder.py
├── logger.py
├── .gitignore
├── README.md
└── app.py
/static/js/script.js:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/jarvis/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Tests package
--------------------------------------------------------------------------------
/jarvis/tests/unit/__init__.py:
--------------------------------------------------------------------------------
1 | # Unit tests package
--------------------------------------------------------------------------------
/jarvis/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 | # Integration tests package
--------------------------------------------------------------------------------
/jarvis/src/audio/__init__.py:
--------------------------------------------------------------------------------
1 | """Audio processing modules"""
2 |
3 | from .recorder import AudioRecorder
4 | from .processor import AudioProcessor
5 |
6 | __all__ = ["AudioRecorder", "AudioProcessor"]
--------------------------------------------------------------------------------
/jarvis/src/gui/__init__.py:
--------------------------------------------------------------------------------
1 | """GUI components for the speech transcription application"""
2 |
3 | from .main_window import MainWindow
4 | from .system_tray import SystemTray
5 |
6 | __all__ = ["MainWindow", "SystemTray"]
--------------------------------------------------------------------------------
/jarvis/run_gui.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set Qt plugin path for PyQt6
4 | export QT_QPA_PLATFORM_PLUGIN_PATH="/Users/prompt/anaconda3/lib/python3.11/site-packages/PyQt6/Qt6/plugins/platforms"
5 |
6 | # Run the GUI application
7 | python -m src.main
--------------------------------------------------------------------------------
/jarvis/requirements.txt:
--------------------------------------------------------------------------------
1 | sounddevice>=0.4.6
2 | numpy>=1.24.0
3 | scipy>=1.10.0
4 | PyQt6>=6.5.0
5 | mlx>=0.5.0
6 | mlx-lm>=0.2.0
7 | mlx-whisper>=0.2.0
8 | requests>=2.31.0
9 | huggingface-hub>=0.16.0
10 | pynput>=1.7.0
11 | soundfile>=0.12.0
12 | pytest>=7.4.0
13 | black>=23.7.0
14 | ruff>=0.0.280
15 | mypy>=1.4.0
--------------------------------------------------------------------------------
/jarvis/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | """Machine learning models for transcription and text correction"""
2 |
3 | from .transcriber import SpeechTranscriber, AVAILABLE_MODELS
4 | from .text_corrector import TextCorrector, AVAILABLE_LLM_MODELS
5 |
6 | __all__ = ["SpeechTranscriber", "TextCorrector", "AVAILABLE_MODELS", "AVAILABLE_LLM_MODELS"]
--------------------------------------------------------------------------------
/jarvis/src/__init__.py:
--------------------------------------------------------------------------------
1 | """Speech Transcription Application"""
2 |
3 | __version__ = "1.0.0"
4 |
5 | from .pipeline import TranscriptionPipeline
6 | from .models.transcriber import SpeechTranscriber
7 | from .models.text_corrector import TextCorrector
8 |
9 | __all__ = [
10 | "TranscriptionPipeline",
11 | "SpeechTranscriber",
12 | "TextCorrector"
13 | ]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Flask
2 | byaldi
3 | cmake
4 | pkgconfig
5 | python-poppler
6 | torch
7 | torchvision
8 | google-generativeai
9 | openai
10 | docx2pdf
11 | qwen-vl-utils
12 | vllm>=0.6.1.post1; sys_platform != 'darwin'
13 | mistral_common>=1.4.1
14 | einops
15 | mistral_common[opencv]
16 | mistral_common
17 | mistral_inference
18 | groq
19 | markdown
20 | hf_transfer
21 | ollama
--------------------------------------------------------------------------------
/jarvis/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="speech-transcription",
5 | version="1.0.0",
6 | packages=find_packages(),
7 | install_requires=[
8 | "sounddevice>=0.4.6",
9 | "numpy>=1.24.0",
10 | "scipy>=1.10.0",
11 | "PyQt6>=6.5.0",
12 | "mlx>=0.5.0",
13 | "mlx-lm>=0.2.0",
14 | "mlx-whisper>=0.2.0",
15 | "soundfile>=0.12.0",
16 | ],
17 | entry_points={
18 | "console_scripts": [
19 | "speech-transcribe=src.cli:main",
20 | "speech-transcribe-gui=src.main:main",
21 | ],
22 | },
23 | python_requires=">=3.9",
24 | )
--------------------------------------------------------------------------------
/jarvis/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile
2 | .PHONY: install test run lint format clean
3 |
4 | install:
5 | pip install -r requirements.txt
6 | pip install -e .
7 |
8 | test:
9 | pytest tests/ -v
10 |
11 | test-unit:
12 | pytest tests/unit/ -v
13 |
14 | test-integration:
15 | pytest tests/integration/ -v
16 |
17 | run:
18 | python -m src.main
19 |
20 | run-cli:
21 | python -m src.cli --help
22 |
23 | lint:
24 | ruff check src tests
25 | mypy src
26 |
27 | format:
28 | black src tests
29 | ruff format src tests
30 |
31 | clean:
32 | find . -type f -name "*.pyc" -delete
33 | find . -type d -name "__pycache__" -delete
34 | rm -rf build dist *.egg-info
35 |
36 | download-models:
37 | python scripts/download_models.py
--------------------------------------------------------------------------------
/jarvis/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # Virtual Environment
24 | venv/
25 | ENV/
26 | env/
27 |
28 | # IDE
29 | .vscode/
30 | .idea/
31 | *.swp
32 | *.swo
33 |
34 | # macOS
35 | .DS_Store
36 | .AppleDouble
37 | .LSOverride
38 |
39 | # Temporary files
40 | *.tmp
41 | *.temp
42 | .cache/
43 |
44 | # Model cache (but not src/models/)
45 | /models/
46 | *.bin
47 | *.safetensors
48 |
49 | # Audio files
50 | *.wav
51 | *.mp3
52 | *.m4a
53 | *.flac
54 |
55 | # Logs
56 | *.log
57 | logs/
58 |
59 | # pytest
60 | .pytest_cache/
--------------------------------------------------------------------------------
/templates/chat_messages.html:
--------------------------------------------------------------------------------
1 | {% for message in messages %}
2 |
3 | {% if message.role == 'user' %}
4 | {{ message.content }}
5 | {% else %}
6 | {{ message.content|safe }}
7 | {% endif %}
8 | {% if message.images %}
9 |
10 | {% for image in message.images %}
11 |
 }})
12 | {% endfor %}
13 |
14 | {% endif %}
15 |
16 | {% endfor %}
--------------------------------------------------------------------------------
/models/converters.py:
--------------------------------------------------------------------------------
1 | # models/converters.py
2 |
3 | import os
4 | from docx2pdf import convert
5 | from logger import get_logger
6 |
7 | logger = get_logger(__name__)
8 |
9 | def convert_docs_to_pdfs(folder_path):
10 | """
11 | Converts .doc and .docx files in the folder to PDFs.
12 |
13 | Args:
14 | folder_path (str): The path to the folder containing documents.
15 | """
16 | try:
17 | for filename in os.listdir(folder_path):
18 | if filename.lower().endswith(('.doc', '.docx')):
19 | doc_path = os.path.join(folder_path, filename)
20 | pdf_path = os.path.splitext(doc_path)[0] + '.pdf'
21 | convert(doc_path, pdf_path)
22 | logger.info(f"Converted '{filename}' to PDF.")
23 | except Exception as e:
24 | logger.error(f"Error converting documents to PDFs: {e}")
25 | raise
--------------------------------------------------------------------------------
/jarvis/tests/integration/test_pipeline.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | from src.pipeline import TranscriptionPipeline
4 |
5 | class TestPipeline:
6 | @pytest.fixture
7 | def pipeline(self):
8 | return TranscriptionPipeline(model_name="whisper-tiny")
9 |
10 | def test_transcription_flow(self, pipeline, tmp_path):
11 | # Create test audio
12 | duration = 3 # seconds
13 | sample_rate = 16000
14 | t = np.linspace(0, duration, duration * sample_rate)
15 |
16 | # Simple tone
17 | audio = np.sin(2 * np.pi * 440 * t) * 0.5
18 |
19 | # Save to file
20 | import soundfile as sf
21 | test_file = tmp_path / "test.wav"
22 | sf.write(test_file, audio, sample_rate)
23 |
24 | # Transcribe
25 | result = pipeline.transcribe_file(str(test_file))
26 |
27 | # Should return empty for pure tone
28 | assert isinstance(result, str)
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # logger.py
2 |
3 | import logging
4 |
5 | def get_logger(name):
6 | """
7 | Creates a logger with the specified name.
8 |
9 | Args:
10 | name (str): The name of the logger.
11 |
12 | Returns:
13 | Logger: Configured logger instance.
14 | """
15 | logger = logging.getLogger(name)
16 | logger.setLevel(logging.DEBUG)
17 |
18 | if not logger.handlers:
19 | # Console handler
20 | c_handler = logging.StreamHandler()
21 | c_handler.setLevel(logging.INFO)
22 |
23 | # File handler
24 | f_handler = logging.FileHandler('app.log')
25 | f_handler.setLevel(logging.DEBUG)
26 |
27 | # Create formatters and add them to handlers
28 | c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
29 | f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
30 | c_handler.setFormatter(c_format)
31 | f_handler.setFormatter(f_format)
32 |
33 | # Add handlers to the logger
34 | logger.addHandler(c_handler)
35 | logger.addHandler(f_handler)
36 |
37 | return logger
38 |
--------------------------------------------------------------------------------
/jarvis/scripts/download_models.py:
--------------------------------------------------------------------------------
1 | """Download required ML models"""
2 |
3 | from huggingface_hub import snapshot_download
4 | import os
5 |
6 | # Models to download
7 | MODELS = [
8 | # Whisper models for transcription
9 | "mlx-community/whisper-tiny",
10 | "mlx-community/whisper-base-mlx",
11 |
12 | # LLM models for text correction (at least one)
13 | "mlx-community/Phi-3.5-mini-instruct-4bit",
14 | "mlx-community/Qwen2.5-0.5B-Instruct-4bit", # Tiny option
15 | # Add more as needed:
16 | # "mlx-community/gemma-2-2b-it-4bit",
17 | # "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
18 | ]
19 |
20 | def download_models():
21 | """Download all required models"""
22 | cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
23 |
24 | for model_id in MODELS:
25 | print(f"Downloading {model_id}...")
26 | try:
27 | snapshot_download(
28 | repo_id=model_id,
29 | cache_dir=cache_dir,
30 | resume_download=True
31 | )
32 | print(f"✓ {model_id} downloaded successfully")
33 | except Exception as e:
34 | print(f"✗ Failed to download {model_id}: {e}")
35 |
36 | if __name__ == "__main__":
37 | download_models()
--------------------------------------------------------------------------------
/jarvis/tests/unit/test_audio_processor.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | from src.audio.processor import AudioProcessor
4 |
5 | class TestAudioProcessor:
6 | def test_normalize_audio(self):
7 | processor = AudioProcessor()
8 | audio = np.array([0.5, -0.5, 0.25, -0.25])
9 | normalized = processor.normalize_audio(audio)
10 | assert normalized.max() == 1.0 or normalized.min() == -1.0
11 |
12 | def test_silence_detection(self):
13 | processor = AudioProcessor()
14 |
15 | # Test silence
16 | silence = np.zeros(16000)
17 | assert processor.is_silence(silence)
18 |
19 | # Test non-silence
20 | tone = np.sin(2 * np.pi * 440 * np.linspace(0, 1, 16000))
21 | assert not processor.is_silence(tone)
22 |
23 | def test_resample(self):
24 | processor = AudioProcessor()
25 |
26 | # Create 1 second of audio at 44100 Hz
27 | original_sr = 44100
28 | audio = np.sin(2 * np.pi * 440 * np.linspace(0, 1, original_sr))
29 |
30 | # Resample to 16000 Hz
31 | resampled = processor.resample(audio, original_sr, 16000)
32 |
33 | # Check length
34 | expected_length = 16000
35 | assert abs(len(resampled) - expected_length) < 10
--------------------------------------------------------------------------------
/static/css/style.css:
--------------------------------------------------------------------------------
1 | /* static/css/style.css */
2 |
3 | /* Customize chat window */
4 | .chat-window {
5 | background-color: #f8f9fa;
6 | border-radius: 5px;
7 | padding: 15px;
8 | }
9 |
10 | /* Customize messages */
11 | .message {
12 | margin-bottom: 10px;
13 | }
14 |
15 | .user-message {
16 | text-align: right;
17 | }
18 |
19 | .bot-message {
20 | text-align: left;
21 | }
22 |
23 | /* Customize images */
24 | .img-thumbnail {
25 | cursor: pointer;
26 | }
27 |
28 | /* Sidebar */
29 | .list-group-item.active {
30 | background-color: #007bff;
31 | border-color: #007bff;
32 | color: #fff;
33 | }
34 |
35 | .list-group-item a {
36 | color: inherit;
37 | text-decoration: none;
38 | }
39 |
40 | .list-group-item a:hover {
41 | text-decoration: none;
42 | }
43 |
44 | /* Delete Session Button */
45 | .delete-session-btn {
46 | font-size: 0.8rem;
47 | padding: 0.2rem 0.5rem;
48 | }
49 |
50 | /* Spinner Overlay */
51 | .spinner-overlay {
52 | position: fixed;
53 | top: 0;
54 | left: 0;
55 | right: 0;
56 | bottom: 0;
57 | z-index: 1050; /* Higher than modal */
58 | background-color: rgba(255, 255, 255, 0.7);
59 | display: flex;
60 | align-items: center;
61 | justify-content: center;
62 | }
63 |
64 | /* Spinner */
65 | .spinner-border {
66 | width: 3rem;
67 | height: 3rem;
68 | }
69 |
--------------------------------------------------------------------------------
/jarvis/src/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | # Disable tokenizers parallelism warning
5 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
6 |
7 | # Set Qt plugin path for macOS
8 | import platform
9 | if platform.system() == "Darwin": # macOS
10 | # Try to find PyQt6 installation
11 | try:
12 | import PyQt6
13 | pyqt6_path = os.path.dirname(PyQt6.__file__)
14 | plugin_path = os.path.join(pyqt6_path, "Qt6", "plugins", "platforms")
15 | if os.path.exists(plugin_path):
16 | os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = plugin_path
17 | except:
18 | pass
19 |
20 | from PyQt6.QtWidgets import QApplication
21 | from PyQt6.QtCore import Qt
22 | from .gui.main_window import MainWindow
23 | from .gui.system_tray import SystemTray
24 |
25 | def main():
26 | """Main application entry point"""
27 | # Enable high DPI support
28 | if hasattr(Qt.ApplicationAttribute, 'AA_UseHighDpiPixmaps'):
29 | QApplication.setAttribute(Qt.ApplicationAttribute.AA_UseHighDpiPixmaps)
30 |
31 | app = QApplication(sys.argv)
32 | app.setApplicationName("Speech Transcription")
33 | app.setOrganizationName("SpeechTranscription")
34 |
35 | # Create main window
36 | window = MainWindow()
37 |
38 | # Create system tray
39 | tray = SystemTray()
40 | tray.showMainWindow.connect(window.show)
41 | tray.quitApp.connect(app.quit)
42 |
43 | # Show window
44 | window.show()
45 |
46 | # Run app
47 | sys.exit(app.exec())
48 |
49 | if __name__ == "__main__":
50 | main()
--------------------------------------------------------------------------------
/models/indexer.py:
--------------------------------------------------------------------------------
1 | # models/indexer.py
2 |
3 | import os
4 | from byaldi import RAGMultiModalModel
5 | from models.converters import convert_docs_to_pdfs
6 | from logger import get_logger
7 |
8 | logger = get_logger(__name__)
9 |
10 | def index_documents(folder_path, index_name='document_index', index_path=None, indexer_model='vidore/colpali'):
11 | """
12 | Indexes documents in the specified folder using Byaldi.
13 |
14 | Args:
15 | folder_path (str): The path to the folder containing documents to index.
16 | index_name (str): The name of the index to create or update.
17 | index_path (str): The path where the index should be saved.
18 | indexer_model (str): The name of the indexer model to use.
19 |
20 | Returns:
21 | RAGMultiModalModel: The RAG model with the indexed documents.
22 | """
23 | try:
24 | logger.info(f"Starting document indexing in folder: {folder_path}")
25 | # Convert non-PDF documents to PDFs
26 | convert_docs_to_pdfs(folder_path)
27 | logger.info("Conversion of non-PDF documents to PDFs completed.")
28 |
29 | # Initialize RAG model
30 | RAG = RAGMultiModalModel.from_pretrained(indexer_model)
31 | if RAG is None:
32 | raise ValueError(f"Failed to initialize RAGMultiModalModel with model {indexer_model}")
33 | logger.info(f"RAG model initialized with {indexer_model}.")
34 |
35 | # Index the documents in the folder
36 | RAG.index(
37 | input_path=folder_path,
38 | index_name=index_name,
39 | store_collection_with_index=True,
40 | overwrite=True
41 | )
42 |
43 | logger.info(f"Indexing completed. Index saved at '{index_path}'.")
44 |
45 | return RAG
46 | except Exception as e:
47 | logger.error(f"Error during indexing: {str(e)}")
48 | raise
--------------------------------------------------------------------------------
/jarvis/src/audio/recorder.py:
--------------------------------------------------------------------------------
1 | import sounddevice as sd
2 | import numpy as np
3 | from typing import Callable, Optional
4 | import threading
5 | import queue
6 |
7 | class AudioRecorder:
8 | def __init__(
9 | self,
10 | sample_rate: int = 16000,
11 | chunk_duration: float = 0.5,
12 | callback: Optional[Callable[[np.ndarray], None]] = None
13 | ):
14 | self.sample_rate = sample_rate
15 | self.chunk_size = int(sample_rate * chunk_duration)
16 | self.callback = callback
17 | self.audio_queue = queue.Queue()
18 | self.is_recording = False
19 |
20 | def audio_callback(self, indata, frames, time, status):
21 | """Callback for sounddevice stream"""
22 | if status:
23 | print(f"Audio callback status: {status}")
24 |
25 | # Copy audio data to queue
26 | audio_chunk = indata[:, 0].copy() # Get first channel
27 | self.audio_queue.put(audio_chunk)
28 |
29 | # Call user callback if provided
30 | if self.callback:
31 | self.callback(audio_chunk)
32 |
33 | def start_recording(self, device_id: Optional[int] = None):
34 | """Start audio recording"""
35 | self.is_recording = True
36 |
37 | # Start audio stream
38 | self.stream = sd.InputStream(
39 | device=device_id,
40 | channels=1,
41 | samplerate=self.sample_rate,
42 | callback=self.audio_callback,
43 | blocksize=self.chunk_size
44 | )
45 | self.stream.start()
46 |
47 | def stop_recording(self) -> np.ndarray:
48 | """Stop recording and return accumulated audio"""
49 | self.is_recording = False
50 |
51 | if hasattr(self, 'stream'):
52 | self.stream.stop()
53 | self.stream.close()
54 |
55 | # Collect all audio chunks
56 | audio_chunks = []
57 | while not self.audio_queue.empty():
58 | chunk = self.audio_queue.get()
59 | audio_chunks.append(chunk)
60 |
61 | if audio_chunks:
62 | return np.concatenate(audio_chunks)
63 | return np.array([])
64 |
65 | def get_audio_level(self) -> float:
66 | """Get current audio level for visualization"""
67 | if not self.audio_queue.empty():
68 | chunk = self.audio_queue.queue[-1] # Peek at last item
69 | return float(np.abs(chunk).mean())
70 | return 0.0
--------------------------------------------------------------------------------
/models/retriever.py:
--------------------------------------------------------------------------------
1 | # models/retriever.py
2 |
3 | import base64
4 | import os
5 | from PIL import Image
6 | from io import BytesIO
7 | from logger import get_logger
8 | import time
9 | import hashlib
10 |
11 | logger = get_logger(__name__)
12 |
13 | def retrieve_documents(RAG, query, session_id, k=3):
14 | """
15 | Retrieves relevant documents based on the user query using Byaldi.
16 |
17 | Args:
18 | RAG (RAGMultiModalModel): The RAG model with the indexed documents.
19 | query (str): The user's query.
20 | session_id (str): The session ID to store images in per-session folder.
21 | k (int): The number of documents to retrieve.
22 |
23 | Returns:
24 | list: A list of image filenames corresponding to the retrieved documents.
25 | """
26 | try:
27 | logger.info(f"Retrieving documents for query: {query}")
28 | results = RAG.search(query, k=k)
29 | images = []
30 | session_images_folder = os.path.join('static', 'images', session_id)
31 | os.makedirs(session_images_folder, exist_ok=True)
32 |
33 | for i, result in enumerate(results):
34 | if result.base64:
35 | image_data = base64.b64decode(result.base64)
36 | image = Image.open(BytesIO(image_data))
37 |
38 | # Generate a unique filename based on the image content
39 | image_hash = hashlib.md5(image_data).hexdigest()
40 | image_filename = f"retrieved_{image_hash}.png"
41 | image_path = os.path.join(session_images_folder, image_filename)
42 |
43 | if not os.path.exists(image_path):
44 | image.save(image_path, format='PNG')
45 | logger.debug(f"Retrieved and saved image: {image_path}")
46 | else:
47 | logger.debug(f"Image already exists: {image_path}")
48 |
49 | # Store the relative path from the static folder
50 | relative_path = os.path.join('images', session_id, image_filename)
51 | images.append(relative_path)
52 | logger.info(f"Added image to list: {relative_path}")
53 | else:
54 | logger.warning(f"No base64 data for document {result.doc_id}, page {result.page_num}")
55 |
56 | logger.info(f"Total {len(images)} documents retrieved. Image paths: {images}")
57 | return images
58 | except Exception as e:
59 | logger.error(f"Error retrieving documents: {e}")
60 | return []
--------------------------------------------------------------------------------
/jarvis/src/gui/system_tray.py:
--------------------------------------------------------------------------------
1 | from PyQt6.QtWidgets import QSystemTrayIcon, QMenu, QApplication
2 | from PyQt6.QtGui import QIcon, QAction
3 | from PyQt6.QtCore import QObject, pyqtSignal
4 | import os
5 |
6 | class SystemTray(QObject):
7 | """System tray integration for background operation"""
8 |
9 | showMainWindow = pyqtSignal()
10 | startRecording = pyqtSignal()
11 | quitApp = pyqtSignal()
12 |
13 | def __init__(self, parent=None):
14 | super().__init__(parent)
15 | self.tray_icon = None
16 | self.init_tray()
17 |
18 | def init_tray(self):
19 | """Initialize system tray icon"""
20 | # Create tray icon
21 | self.tray_icon = QSystemTrayIcon(self)
22 | self.tray_icon.setIcon(self.get_icon())
23 |
24 | # Create menu
25 | tray_menu = QMenu()
26 |
27 | # Show action
28 | show_action = QAction("Show", self)
29 | show_action.triggered.connect(self.showMainWindow.emit)
30 | tray_menu.addAction(show_action)
31 |
32 | # Record action
33 | record_action = QAction("Quick Record", self)
34 | record_action.triggered.connect(self.startRecording.emit)
35 | tray_menu.addAction(record_action)
36 |
37 | tray_menu.addSeparator()
38 |
39 | # Quit action
40 | quit_action = QAction("Quit", self)
41 | quit_action.triggered.connect(self.quitApp.emit)
42 | tray_menu.addAction(quit_action)
43 |
44 | # Set menu and show
45 | self.tray_icon.setContextMenu(tray_menu)
46 | self.tray_icon.show()
47 |
48 | # Handle clicks
49 | self.tray_icon.activated.connect(self.on_tray_activated)
50 |
51 | def get_icon(self):
52 | """Get or create tray icon"""
53 | # Create a simple icon programmatically
54 | from PyQt6.QtGui import QPixmap, QPainter, QBrush
55 | from PyQt6.QtCore import Qt
56 |
57 | pixmap = QPixmap(32, 32)
58 | pixmap.fill(Qt.GlobalColor.transparent)
59 |
60 | painter = QPainter(pixmap)
61 | painter.setRenderHint(QPainter.RenderHint.Antialiasing)
62 |
63 | # Draw microphone icon
64 | painter.setBrush(QBrush(Qt.GlobalColor.white))
65 | painter.setPen(Qt.PenStyle.NoPen)
66 |
67 | # Mic body
68 | painter.drawEllipse(10, 5, 12, 18)
69 |
70 | # Mic stand
71 | painter.drawRect(14, 23, 4, 5)
72 |
73 | # Mic base
74 | painter.drawRect(10, 28, 12, 2)
75 |
76 | painter.end()
77 |
78 | return QIcon(pixmap)
79 |
80 | def on_tray_activated(self, reason):
81 | """Handle tray icon activation"""
82 | if reason == QSystemTrayIcon.ActivationReason.DoubleClick:
83 | self.showMainWindow.emit()
84 |
85 | def set_recording_state(self, is_recording: bool):
86 | """Update tray icon for recording state"""
87 | if is_recording:
88 | self.tray_icon.setToolTip("Speech Transcription - Recording...")
89 | # Could update icon to show recording state
90 | else:
91 | self.tray_icon.setToolTip("Speech Transcription")
--------------------------------------------------------------------------------
/templates/settings.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | {% extends 'base.html' %}
4 |
5 | {% block content %}
6 |
46 | {% endblock %}
47 |
48 | {% block styles %}
49 |
63 | {% endblock %}
64 |
--------------------------------------------------------------------------------
/jarvis/src/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from pathlib import Path
4 | from .pipeline import TranscriptionPipeline
5 | from .models.transcriber import AVAILABLE_MODELS
6 |
7 | def main():
8 | """CLI entry point"""
9 | parser = argparse.ArgumentParser(
10 | description="Speech Transcription CLI"
11 | )
12 |
13 | parser.add_argument(
14 | "--audio",
15 | type=str,
16 | help="Path to audio file to transcribe"
17 | )
18 |
19 | parser.add_argument(
20 | "--model",
21 | type=str,
22 | default="whisper-tiny",
23 | choices=list(AVAILABLE_MODELS.keys()),
24 | help="Model to use for transcription"
25 | )
26 |
27 | parser.add_argument(
28 | "--output",
29 | type=str,
30 | help="Output file path (default: stdout)"
31 | )
32 |
33 | parser.add_argument(
34 | "--no-correction",
35 | action="store_true",
36 | help="Skip AI text correction"
37 | )
38 |
39 | parser.add_argument(
40 | "--context",
41 | type=str,
42 | help="Context for better corrections"
43 | )
44 |
45 | parser.add_argument(
46 | "--list-models",
47 | action="store_true",
48 | help="List available models"
49 | )
50 |
51 | parser.add_argument(
52 | "--keep-fillers",
53 | action="store_true",
54 | help="Keep filler words in transcription"
55 | )
56 |
57 | args = parser.parse_args()
58 |
59 | # List models
60 | if args.list_models:
61 | print("\nAvailable models:")
62 | for name, config in AVAILABLE_MODELS.items():
63 | print(f" {name}: {config['description']}")
64 | return 0
65 |
66 | # Validate audio file
67 | if not args.audio:
68 | parser.error("--audio is required")
69 |
70 | audio_path = Path(args.audio)
71 | if not audio_path.exists():
72 | print(f"Error: Audio file not found: {audio_path}")
73 | return 1
74 |
75 | # Initialize pipeline
76 | try:
77 | pipeline = TranscriptionPipeline(model_name=args.model)
78 | except Exception as e:
79 | print(f"Error initializing pipeline: {e}")
80 | return 1
81 |
82 | # Transcribe
83 | print(f"Transcribing {audio_path} with {args.model}...")
84 | try:
85 | text = pipeline.transcribe_file(str(audio_path))
86 |
87 | if not text:
88 | print("No speech detected in audio file")
89 | return 1
90 |
91 | # Apply correction if requested
92 | if not args.no_correction:
93 | print("Applying text correction...")
94 | text = pipeline.correct_text(
95 | text,
96 | context=args.context
97 | )
98 |
99 | # Output result
100 | if args.output:
101 | with open(args.output, 'w') as f:
102 | f.write(text)
103 | print(f"Transcription saved to {args.output}")
104 | else:
105 | print("\nTranscription:")
106 | print("-" * 50)
107 | print(text)
108 | print("-" * 50)
109 |
110 | return 0
111 |
112 | except Exception as e:
113 | print(f"Error during transcription: {e}")
114 | return 1
115 |
116 | if __name__ == "__main__":
117 | sys.exit(main())
--------------------------------------------------------------------------------
/jarvis/src/pipeline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Optional, Callable
3 | from .audio.recorder import AudioRecorder
4 | from .audio.processor import AudioProcessor
5 | from .models.transcriber import SpeechTranscriber
6 | from .models.text_corrector import TextCorrector
7 |
8 | class TranscriptionPipeline:
9 | """Main pipeline orchestrating transcription flow"""
10 |
11 | def __init__(
12 | self,
13 | model_name: str = "whisper-tiny",
14 | llm_model_name: str = "Phi-3.5-mini",
15 | sample_rate: int = 16000
16 | ):
17 | self.sample_rate = sample_rate
18 | self.audio_recorder = None
19 | self.audio_processor = AudioProcessor(sample_rate)
20 | self.transcriber = SpeechTranscriber(model_name)
21 | self.text_corrector = TextCorrector(llm_model_name)
22 |
23 | def set_model(self, model_name: str):
24 | """Change the transcription model"""
25 | self.transcriber.load_model(model_name)
26 |
27 | def set_llm_model(self, model_name: str):
28 | """Change the LLM model for text correction"""
29 | self.text_corrector.set_model(model_name)
30 |
31 | def start_recording(self, callback: Optional[Callable] = None, device_id: Optional[int] = None):
32 | """Start recording audio"""
33 | self.audio_recorder = AudioRecorder(
34 | sample_rate=self.sample_rate,
35 | callback=callback
36 | )
37 | self.audio_recorder.start_recording(device_id=device_id)
38 |
39 | def stop_recording(self) -> str:
40 | """Stop recording and return transcription"""
41 | if not self.audio_recorder:
42 | return ""
43 |
44 | # Stop recording
45 | final_audio = self.audio_recorder.stop_recording()
46 |
47 | # Use only the audio from stop_recording (which already includes all chunks)
48 | audio = final_audio
49 |
50 | # Process audio
51 | processed_audio, has_speech = self.audio_processor.process(
52 | audio, self.sample_rate
53 | )
54 |
55 | if not has_speech:
56 | return ""
57 |
58 | # Transcribe
59 | text = self.transcriber.transcribe(processed_audio, self.sample_rate)
60 | return text
61 |
62 | def transcribe_file(self, file_path: str) -> str:
63 | """Transcribe an audio file"""
64 | import soundfile as sf
65 |
66 | # Load audio
67 | audio, sr = sf.read(file_path)
68 |
69 | # Process
70 | processed_audio, has_speech = self.audio_processor.process(audio, sr)
71 |
72 | if not has_speech:
73 | return ""
74 |
75 | # Transcribe
76 | return self.transcriber.transcribe(processed_audio, self.sample_rate)
77 |
78 | def correct_text(
79 | self,
80 | text: str,
81 | context: Optional[str] = None
82 | ) -> str:
83 | """Apply AI correction to transcribed text"""
84 | return self.text_corrector.correct(text, context)
85 |
86 | def get_audio_level(self) -> float:
87 | """Get current audio level"""
88 | if self.audio_recorder:
89 | return self.audio_recorder.get_audio_level()
90 | return 0.0
91 |
92 | def process_stream(
93 | self,
94 | audio_chunk: np.ndarray,
95 | callback: Callable[[str], None]
96 | ):
97 | """Process audio stream in real-time"""
98 | # This would be used for streaming mode
99 | # Process chunk and transcribe incrementally
100 | pass
--------------------------------------------------------------------------------
/jarvis/src/audio/processor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import signal
3 | from typing import Tuple, Optional
4 |
5 | class AudioProcessor:
6 | def __init__(self, target_sample_rate: int = 16000):
7 | self.target_sample_rate = target_sample_rate
8 | self.silence_threshold = 0.01
9 | self.min_audio_length = 0.5 # seconds
10 |
11 | def process(
12 | self,
13 | audio: np.ndarray,
14 | sample_rate: int
15 | ) -> Tuple[np.ndarray, bool]:
16 | """Process audio for transcription"""
17 | # Resample if needed
18 | if sample_rate != self.target_sample_rate:
19 | audio = self.resample(audio, sample_rate, self.target_sample_rate)
20 |
21 | # Normalize audio
22 | audio = self.normalize_audio(audio)
23 |
24 | # Check if audio is too short
25 | if len(audio) < self.target_sample_rate * self.min_audio_length:
26 | return audio, False
27 |
28 | # Apply noise reduction
29 | audio = self.reduce_noise(audio)
30 |
31 | # Check for silence
32 | is_silence = self.is_silence(audio)
33 |
34 | return audio, not is_silence
35 |
36 | def normalize_audio(self, audio: np.ndarray) -> np.ndarray:
37 | """Normalize audio to [-1, 1] range"""
38 | max_val = np.abs(audio).max()
39 | if max_val > 0:
40 | return audio / max_val
41 | return audio
42 |
43 | def resample(
44 | self,
45 | audio: np.ndarray,
46 | orig_sr: int,
47 | target_sr: int
48 | ) -> np.ndarray:
49 | """Resample audio to target sample rate"""
50 | if orig_sr == target_sr:
51 | return audio
52 |
53 | # Calculate resample ratio
54 | resample_ratio = target_sr / orig_sr
55 | new_length = int(len(audio) * resample_ratio)
56 |
57 | # Use scipy's resample
58 | return signal.resample(audio, new_length)
59 |
60 | def reduce_noise(self, audio: np.ndarray) -> np.ndarray:
61 | """Apply basic noise reduction using spectral subtraction"""
62 | # Apply high-pass filter to remove low-frequency noise
63 | if len(audio) > 13: # Minimum length for filter
64 | b, a = signal.butter(4, 100, 'hp', fs=self.target_sample_rate)
65 | audio = signal.filtfilt(b, a, audio)
66 | return audio
67 |
68 | def is_silence(self, audio: np.ndarray) -> bool:
69 | """Detect if audio is silence"""
70 | rms = np.sqrt(np.mean(audio**2))
71 | return rms < self.silence_threshold
72 |
73 | def apply_voice_activity_detection(
74 | self,
75 | audio: np.ndarray
76 | ) -> np.ndarray:
77 | """Simple VAD to trim silence from beginning and end"""
78 | # Calculate energy for each frame
79 | frame_length = int(0.025 * self.target_sample_rate) # 25ms frames
80 | hop_length = int(0.010 * self.target_sample_rate) # 10ms hop
81 |
82 | energy = []
83 | for i in range(0, len(audio) - frame_length, hop_length):
84 | frame = audio[i:i + frame_length]
85 | energy.append(np.sum(frame**2))
86 |
87 | energy = np.array(energy)
88 | threshold = np.mean(energy) * 0.1
89 |
90 | # Find voice activity regions
91 | voice_activity = energy > threshold
92 |
93 | if np.any(voice_activity):
94 | # Find first and last active frame
95 | first_active = np.argmax(voice_activity)
96 | last_active = len(voice_activity) - np.argmax(voice_activity[::-1])
97 |
98 | # Convert frame indices to sample indices
99 | start_sample = first_active * hop_length
100 | end_sample = min(last_active * hop_length + frame_length, len(audio))
101 |
102 | return audio[start_sample:end_sample]
103 |
104 | return audio
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore vscode
2 | /.vscode
3 | /DB
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | .idea/
165 |
166 | #MacOS
167 | .DS_Store
168 | SOURCE_DOCUMENTS/.DS_Store
169 |
170 |
171 | .byaldi/
172 | sessions/
173 | static/images/
174 | uploaded_documents/
175 |
--------------------------------------------------------------------------------
/jarvis/src/models/transcriber.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import numpy as np
4 | import soundfile as sf
5 | from typing import Optional, Dict, Any
6 | from abc import ABC, abstractmethod
7 |
8 | # Model configurations
9 | AVAILABLE_MODELS = {
10 | "whisper-tiny": {
11 | "repo": "mlx-community/whisper-tiny",
12 | "type": "whisper",
13 | "description": "Very fast, lower accuracy"
14 | },
15 | "whisper-base": {
16 | "repo": "mlx-community/whisper-base-mlx",
17 | "type": "whisper",
18 | "description": "Fast, decent accuracy"
19 | },
20 | "whisper-small": {
21 | "repo": "mlx-community/whisper-small-mlx",
22 | "type": "whisper",
23 | "description": "Good balance"
24 | },
25 | "whisper-medium": {
26 | "repo": "mlx-community/whisper-medium-mlx",
27 | "type": "whisper",
28 | "description": "High accuracy, slower"
29 | },
30 | "whisper-large-v3": {
31 | "repo": "mlx-community/whisper-large-v3-mlx",
32 | "type": "whisper",
33 | "description": "Best accuracy, slowest"
34 | },
35 | "distil-whisper-large-v3": {
36 | "repo": "mlx-community/distil-whisper-large-v3",
37 | "type": "whisper",
38 | "description": "Fast with high accuracy"
39 | }
40 | }
41 |
42 | class BaseTranscriber(ABC):
43 | """Base class for speech transcribers"""
44 |
45 | @abstractmethod
46 | def transcribe(self, audio_path: str) -> str:
47 | pass
48 |
49 | class WhisperTranscriber(BaseTranscriber):
50 | """Whisper model transcriber using MLX"""
51 |
52 | def __init__(self, model_id: str):
53 | try:
54 | import mlx_whisper
55 | self.mlx_whisper = mlx_whisper
56 | self.model_path = model_id
57 | except ImportError:
58 | raise ImportError("mlx-whisper is required. Install with: pip install mlx-whisper")
59 |
60 | def transcribe(self, audio_path: str) -> str:
61 | """Transcribe audio file using Whisper"""
62 | result = self.mlx_whisper.transcribe(
63 | audio_path,
64 | path_or_hf_repo=self.model_path
65 | )
66 | return result.get("text", "")
67 |
68 | class SpeechTranscriber:
69 | """Main transcriber class with model management"""
70 |
71 | def __init__(self, model_name: str = "whisper-tiny"):
72 | self.model_name = model_name
73 | self.model = None
74 | self.load_model(model_name)
75 |
76 | def load_model(self, model_name: str):
77 | """Load the specified model"""
78 | if model_name not in AVAILABLE_MODELS:
79 | raise ValueError(f"Unknown model: {model_name}")
80 |
81 | config = AVAILABLE_MODELS[model_name]
82 | model_type = config["type"]
83 | repo_id = config["repo"]
84 |
85 | print(f"Loading {model_name} model...")
86 |
87 | if model_type == "whisper":
88 | self.model = WhisperTranscriber(repo_id)
89 | else:
90 | raise ValueError(f"Unknown model type: {model_type}")
91 |
92 | self.model_name = model_name
93 | print(f"Model {model_name} loaded successfully")
94 |
95 | def transcribe(self, audio: np.ndarray, sample_rate: int = 16000) -> str:
96 | """Transcribe audio array"""
97 | if self.model is None:
98 | raise RuntimeError("No model loaded")
99 |
100 | # Write audio to temporary file (required by models)
101 | with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
102 | sf.write(tmp.name, audio, sample_rate)
103 | temp_path = tmp.name
104 |
105 | try:
106 | # Transcribe
107 | text = self.model.transcribe(temp_path)
108 | return text.strip()
109 | finally:
110 | # Clean up temp file
111 | if os.path.exists(temp_path):
112 | os.unlink(temp_path)
113 |
114 | def list_models(self) -> Dict[str, Any]:
115 | """List available models"""
116 | return AVAILABLE_MODELS
--------------------------------------------------------------------------------
/models/model_loader.py:
--------------------------------------------------------------------------------
1 | # models/model_loader.py
2 |
3 | import os
4 | import torch
5 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6 | from transformers import MllamaForConditionalGeneration
7 | from vllm.sampling_params import SamplingParams
8 | from transformers import AutoModelForCausalLM
9 | import google.generativeai as genai
10 | from vllm import LLM
11 | from groq import Groq
12 | from dotenv import load_dotenv
13 |
14 | # Load environment variables from .env file
15 | load_dotenv()
16 |
17 | from logger import get_logger
18 |
19 | logger = get_logger(__name__)
20 |
21 | # Cache for loaded models
22 | _model_cache = {}
23 |
24 | # Models that only support single image processing
25 | SINGLE_IMAGE_MODELS = {
26 | 'ollama-llama-vision': True,
27 | 'groq-llama-vision': True,
28 | 'llama-vision': True,
29 | 'pixtral': True,
30 | 'molmo': True
31 | }
32 |
33 | def is_single_image_model(model_choice):
34 | """Returns True if the model only supports processing a single image."""
35 | return model_choice in SINGLE_IMAGE_MODELS
36 |
37 | def detect_device():
38 | """
39 | Detects the best available device (CUDA, MPS, or CPU).
40 | """
41 | if torch.cuda.is_available():
42 | return 'cuda'
43 | elif torch.backends.mps.is_available():
44 | return 'mps'
45 | else:
46 | return 'cpu'
47 |
48 | def load_model(model_choice):
49 | """
50 | Loads and caches the specified model.
51 | """
52 | global _model_cache
53 |
54 | if model_choice in _model_cache:
55 | logger.info(f"Model '{model_choice}' loaded from cache.")
56 | return _model_cache[model_choice]
57 |
58 | if model_choice == 'qwen':
59 | device = detect_device()
60 | model = Qwen2VLForConditionalGeneration.from_pretrained(
61 | "Qwen/Qwen2-VL-7B-Instruct",
62 | torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
63 | device_map="auto"
64 | )
65 | processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
66 | model.to(device)
67 | _model_cache[model_choice] = (model, processor, device)
68 | logger.info("Qwen model loaded and cached.")
69 | return _model_cache[model_choice]
70 |
71 | elif model_choice == 'gemini':
72 | api_key = os.getenv("GOOGLE_API_KEY")
73 | if not api_key:
74 | raise ValueError("GOOGLE_API_KEY not found in .env file")
75 | genai.configure(api_key=api_key)
76 | model = genai.GenerativeModel('gemini-1.5-flash-002')
77 | return model, None
78 |
79 | elif model_choice == 'llama-vision':
80 | device = detect_device()
81 | model_id = "alpindale/Llama-3.2-11B-Vision-Instruct"
82 | model = MllamaForConditionalGeneration.from_pretrained(
83 | model_id,
84 | torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
85 | device_map="auto"
86 | )
87 | processor = AutoProcessor.from_pretrained(model_id)
88 | model.to(device)
89 | _model_cache[model_choice] = (model, processor, device)
90 | logger.info("Llama-Vision model loaded and cached.")
91 | return _model_cache[model_choice]
92 |
93 | elif model_choice == "pixtral":
94 | device = detect_device()
95 | mistral_models_path = os.path.join(os.getcwd(), 'mistral_models', 'Pixtral')
96 |
97 | if not os.path.exists(mistral_models_path):
98 | os.makedirs(mistral_models_path, exist_ok=True)
99 | from huggingface_hub import snapshot_download
100 | snapshot_download(repo_id="mistralai/Pixtral-12B-2409",
101 | allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
102 | local_dir=mistral_models_path)
103 |
104 | from mistral_inference.transformer import Transformer
105 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
106 | from mistral_common.generate import generate
107 |
108 | tokenizer = MistralTokenizer.from_file(os.path.join(mistral_models_path, "tekken.json"))
109 | model = Transformer.from_folder(mistral_models_path)
110 |
111 | _model_cache[model_choice] = (model, tokenizer, generate, device)
112 | logger.info("Pixtral model loaded and cached.")
113 | return _model_cache[model_choice]
114 |
115 | elif model_choice == "molmo":
116 | device = detect_device()
117 | processor = AutoProcessor.from_pretrained(
118 | 'allenai/MolmoE-1B-0924',
119 | trust_remote_code=True,
120 | torch_dtype='auto',
121 | device_map='auto'
122 | )
123 | model = AutoModelForCausalLM.from_pretrained(
124 | 'allenai/MolmoE-1B-0924',
125 | trust_remote_code=True,
126 | torch_dtype='auto',
127 | device_map='auto'
128 | )
129 | _model_cache[model_choice] = (model, processor, device)
130 | return _model_cache[model_choice]
131 |
132 | elif model_choice == 'groq-llama-vision':
133 | api_key = os.getenv("GROQ_API_KEY")
134 | if not api_key:
135 | raise ValueError("GROQ_API_KEY not found in .env file")
136 | client = Groq(api_key=api_key)
137 | _model_cache[model_choice] = client
138 | logger.info("Groq Llama Vision model loaded and cached.")
139 | return _model_cache[model_choice]
140 |
141 | elif model_choice == 'ollama-llama-vision':
142 | logger.info("Ollama Llama Vision model ready to use.")
143 | return None
144 |
145 | else:
146 | logger.error(f"Invalid model choice: {model_choice}")
147 | raise ValueError("Invalid model choice.")
148 |
--------------------------------------------------------------------------------
/jarvis/README.md:
--------------------------------------------------------------------------------
1 | # Speech Transcription Application for macOS
2 |
3 | A privacy-focused, real-time speech transcription desktop application for macOS that runs entirely on-device using Apple Silicon optimization. Features speech-to-text transcription with AI-powered text correction, multiple UI modes, and support for various speech recognition models.
4 |
5 | ## Features
6 |
7 | - 🎤 **Real-time Speech Transcription**: Capture and transcribe audio from your microphone
8 | - 🔒 **Privacy-First**: All processing happens on-device, no data sent to cloud
9 | - ⚡ **Apple Silicon Optimized**: Uses MLX framework for fast inference on M1/M2/M3 chips
10 | - 🤖 **AI-Powered Correction**: Automatic removal of filler words and transcription errors
11 | - 🎯 **Multiple Models**: Support for various Whisper models with different speed/accuracy tradeoffs
12 | - 🖥️ **Multiple Interfaces**: GUI, CLI, and system tray modes
13 | - ⌨️ **Keyboard Shortcuts**: Space to start/stop, Escape to cancel, Cmd+S to save
14 |
15 | ## Requirements
16 |
17 | - macOS with Apple Silicon (M1/M2/M3)
18 | - Python 3.9 or higher
19 | - Minimum 8GB RAM (16GB recommended)
20 | - ~5GB disk space for models
21 |
22 | ## Installation
23 |
24 | ### 1. Clone the repository
25 |
26 | ```bash
27 | git clone https://github.com/yourusername/speech-transcription-app.git
28 | cd speech-transcription-app
29 | ```
30 |
31 | ### 2. Create a virtual environment
32 |
33 | ```bash
34 | python -m venv venv
35 | source venv/bin/activate
36 | ```
37 |
38 | ### 3. Install dependencies
39 |
40 | ```bash
41 | pip install -r requirements.txt
42 | ```
43 |
44 | ### 4. Download ML models
45 |
46 | ```bash
47 | python scripts/download_models.py
48 | ```
49 |
50 | ## Usage
51 |
52 | ### GUI Mode
53 |
54 | Launch the graphical interface:
55 |
56 | ```bash
57 | python -m src.main
58 | ```
59 |
60 | Features:
61 | - Click "Start Recording" or press Space to begin
62 | - Real-time audio level visualization
63 | - Automatic text correction (can be disabled)
64 | - Save transcriptions with Cmd+S
65 | - Dark mode interface
66 |
67 | ### CLI Mode
68 |
69 | Transcribe audio files from the command line:
70 |
71 | ```bash
72 | # Basic transcription
73 | python -m src.cli --audio recording.wav
74 |
75 | # With specific model
76 | python -m src.cli --audio recording.wav --model whisper-medium
77 |
78 | # With text correction and context
79 | python -m src.cli --audio recording.wav --context "Technical meeting"
80 |
81 | # Save to file
82 | python -m src.cli --audio recording.wav --output transcript.txt
83 |
84 | # List available models
85 | python -m src.cli --list-models
86 | ```
87 |
88 | ### Available Models
89 |
90 | - `whisper-tiny`: Very fast, lower accuracy
91 | - `whisper-base`: Fast, decent accuracy
92 | - `whisper-small`: Good balance
93 | - `whisper-medium`: High accuracy, slower
94 | - `whisper-large-v3`: Best accuracy, slowest
95 | - `distil-whisper-large-v3`: Fast with high accuracy
96 |
97 | ## Development
98 |
99 | ### Running Tests
100 |
101 | ```bash
102 | # Run all tests
103 | make test
104 |
105 | # Run unit tests only
106 | make test-unit
107 |
108 | # Run integration tests
109 | make test-integration
110 | ```
111 |
112 | ### Code Formatting
113 |
114 | ```bash
115 | # Format code with black and ruff
116 | make format
117 |
118 | # Run linters
119 | make lint
120 | ```
121 |
122 | ### Building for Distribution
123 |
124 | ```bash
125 | pip install -e .
126 | ```
127 |
128 | This will install the package in development mode and create command-line scripts:
129 | - `speech-transcribe`: CLI interface
130 | - `speech-transcribe-gui`: GUI interface
131 |
132 | ## Project Structure
133 |
134 | ```
135 | speech-transcription-app/
136 | ├── src/
137 | │ ├── audio/ # Audio recording and processing
138 | │ ├── models/ # ML models for transcription and correction
139 | │ ├── gui/ # GUI components
140 | │ ├── pipeline.py # Main orchestration
141 | │ ├── cli.py # CLI interface
142 | │ └── main.py # GUI entry point
143 | ├── tests/ # Unit and integration tests
144 | ├── scripts/ # Utility scripts
145 | ├── requirements.txt # Python dependencies
146 | └── README.md # This file
147 | ```
148 |
149 | ## Performance
150 |
151 | - **Audio Latency**: < 100ms for recording start/stop
152 | - **Transcription Speed**: > 5x real-time (10s audio in < 2s)
153 | - **Memory Usage**: < 2GB with models loaded
154 | - **Model Loading**: < 10s for initial load
155 |
156 | ## Troubleshooting
157 |
158 | ### Common Issues
159 |
160 | 1. **"Could not find the Qt platform plugin 'cocoa'"**
161 | - This is fixed automatically in the code, but if it persists:
162 | - Run with: `./run_gui.sh` instead of `python -m src.main`
163 | - Or set manually: `export QT_QPA_PLATFORM_PLUGIN_PATH=$(python -c "import PyQt6, os; print(os.path.join(os.path.dirname(PyQt6.__file__), 'Qt6', 'plugins', 'platforms'))")`
164 |
165 | 2. **"No module named 'mlx'"**
166 | - Ensure you're on a Mac with Apple Silicon
167 | - Reinstall with: `pip install --upgrade mlx mlx-lm mlx-whisper`
168 |
169 | 3. **Audio permission denied**
170 | - Go to System Preferences → Security & Privacy → Microphone
171 | - Grant permission to Terminal/your IDE
172 |
173 | 4. **Model download fails**
174 | - Check internet connection
175 | - Manually download from Hugging Face if needed
176 |
177 | ### Debug Mode
178 |
179 | Run with verbose output:
180 | ```bash
181 | python -m src.cli --audio test.wav --debug
182 | ```
183 |
184 | ## Contributing
185 |
186 | 1. Fork the repository
187 | 2. Create a feature branch
188 | 3. Make your changes
189 | 4. Run tests and linters
190 | 5. Submit a pull request
191 |
192 | ## License
193 |
194 | This project is licensed under the MIT License - see LICENSE file for details.
195 |
196 | ## Acknowledgments
197 |
198 | - Apple's MLX team for the optimization framework
199 | - OpenAI for Whisper models
200 | - Hugging Face for model hosting
--------------------------------------------------------------------------------
/jarvis/src/models/text_corrector.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, List
3 | import mlx.core as mx
4 | from mlx_lm import load, generate
5 |
6 | # Available LLM models for text correction
7 | AVAILABLE_LLM_MODELS = {
8 | # Small & Fast Models (< 2GB RAM)
9 | "Qwen2.5-0.5B": {
10 | "repo": "mlx-community/Qwen2.5-0.5B-Instruct-4bit",
11 | "description": "Tiny & fast (0.5B)",
12 | "size": "0.5B"
13 | },
14 | "gemma-2b": {
15 | "repo": "mlx-community/gemma-2-2b-it-4bit",
16 | "description": "Google's efficient (2B)",
17 | "size": "2B"
18 | },
19 |
20 | # Medium Models (2-4GB RAM)
21 | "Phi-3.5-mini": {
22 | "repo": "mlx-community/Phi-3.5-mini-instruct-4bit",
23 | "description": "Best balance (3.8B)",
24 | "size": "3.8B"
25 | },
26 | "gemma-3-4b": {
27 | "repo": "mlx-community/gemma-3-4b-it-qat-4bit",
28 | "description": "Google's latest (4B)",
29 | "size": "4B"
30 | },
31 |
32 | # Larger Models (4-6GB RAM)
33 | "Mistral-7B": {
34 | "repo": "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
35 | "description": "Very capable (7B)",
36 | "size": "7B"
37 | },
38 | "Llama-3.1-8B": {
39 | "repo": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
40 | "description": "State of the art (8B)",
41 | "size": "8B"
42 | },
43 | }
44 |
45 | class TextCorrector:
46 | """LLM-based text correction for transcriptions"""
47 |
48 | def __init__(self, model_name: str = "Phi-3.5-mini"):
49 | self.model = None
50 | self.tokenizer = None
51 | self.model_name = None
52 | self.selected_model = model_name
53 | self._load_model(model_name)
54 |
55 | def _load_model(self, model_name: str):
56 | """Load the specified LLM model"""
57 | if model_name not in AVAILABLE_LLM_MODELS:
58 | print(f"Unknown model {model_name}, using default Phi-3.5-mini")
59 | model_name = "Phi-3.5-mini"
60 |
61 | model_info = AVAILABLE_LLM_MODELS[model_name]
62 | repo_id = model_info["repo"]
63 |
64 | try:
65 | print(f"Loading text correction model: {model_name} ({repo_id})")
66 | self.model, self.tokenizer = load(repo_id)
67 | self.model_name = model_name
68 | print(f"Successfully loaded {model_name}")
69 | except Exception as e:
70 | print(f"Failed to load {model_name}: {e}")
71 | # Try fallback to Phi-3.5-mini if different model was selected
72 | if model_name != "Phi-3.5-mini":
73 | print("Falling back to Phi-3.5-mini")
74 | self._load_model("Phi-3.5-mini")
75 | else:
76 | print("Warning: No text correction model could be loaded")
77 |
78 | def set_model(self, model_name: str):
79 | """Change the LLM model"""
80 | if model_name != self.model_name:
81 | self._load_model(model_name)
82 |
83 | def correct(
84 | self,
85 | text: str,
86 | context: Optional[str] = None,
87 | remove_fillers: bool = True
88 | ) -> str:
89 | """Correct transcribed text"""
90 | if not self.model or not text or len(text.strip()) < 10:
91 | return text
92 |
93 | # Build correction prompt
94 | prompt = self._build_prompt(text, context, remove_fillers)
95 |
96 | try:
97 | # Generate correction
98 | response = generate(
99 | self.model,
100 | self.tokenizer,
101 | prompt=prompt,
102 | max_tokens=len(text) * 2
103 | )
104 |
105 | # Extract corrected text
106 | corrected = self._extract_correction(response)
107 |
108 | # Validate correction
109 | if self._is_valid_correction(text, corrected):
110 | return corrected
111 | else:
112 | return text
113 |
114 | except Exception as e:
115 | print(f"Text correction failed: {e}")
116 | return text
117 |
118 | def _build_prompt(
119 | self,
120 | text: str,
121 | context: Optional[str],
122 | remove_fillers: bool
123 | ) -> str:
124 | """Build the correction prompt"""
125 | base_prompt = """Fix ONLY spelling errors and remove filler words from this transcription.
126 | DO NOT rephrase or summarize. Keep the original wording and structure.
127 | Only fix obvious errors like duplicated words, misspellings, and remove filler words (um, uh, etc).
128 | Output ONLY the cleaned text, nothing else."""
129 |
130 | if context:
131 | base_prompt = f"Context: {context}\n\n{base_prompt}"
132 |
133 | return f"""{base_prompt}
134 |
135 | Original: {text}
136 |
137 | Cleaned:"""
138 |
139 | def _extract_correction(self, response: str) -> str:
140 | """Extract the corrected text from model response"""
141 | # Remove any explanation or metadata
142 | lines = response.strip().split('\n')
143 |
144 | # Find the actual correction
145 | corrected_text = ""
146 | for line in lines:
147 | # Skip meta lines
148 | if any(marker in line.lower() for marker in ['original:', 'cleaned:', 'corrected:']):
149 | continue
150 | if line.strip():
151 | corrected_text = line.strip()
152 | break
153 |
154 | return corrected_text
155 |
156 | def _is_valid_correction(self, original: str, corrected: str) -> bool:
157 | """Validate that correction is reasonable"""
158 | if not corrected:
159 | return False
160 |
161 | # Check length difference (shouldn't be too different)
162 | len_ratio = len(corrected) / len(original)
163 | if len_ratio < 0.5 or len_ratio > 1.5:
164 | return False
165 |
166 | # Check word count difference
167 | orig_words = original.split()
168 | corr_words = corrected.split()
169 | word_ratio = len(corr_words) / len(orig_words)
170 | if word_ratio < 0.5 or word_ratio > 1.2:
171 | return False
172 |
173 | return True
174 |
175 | def get_filler_words(self) -> List[str]:
176 | """Get list of filler words to remove"""
177 | return [
178 | "um", "uh", "er", "ah", "like", "you know", "I mean",
179 | "actually", "basically", "literally", "right", "so"
180 | ]
--------------------------------------------------------------------------------
/static/css/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3 | background-color: #f8f9fa;
4 | font-size: 14px; /* Base font size */
5 | }
6 |
7 | /* Sidebar styles */
8 | #sidebar {
9 | min-height: 100vh;
10 | background-color: #343a40;
11 | color: #fff;
12 | transition: all 0.3s;
13 | display: flex;
14 | flex-direction: column;
15 | }
16 |
17 | #sidebar.active {
18 | margin-left: -250px;
19 | }
20 |
21 | #sidebar .sidebar-header {
22 | padding: 20px;
23 | background: #212529;
24 | }
25 |
26 | #sidebar .sidebar-header h3 {
27 | font-size: 1.5em; /* Larger font for header */
28 | }
29 |
30 | #sidebar .sidebar-content {
31 | flex-grow: 1;
32 | overflow-y: auto;
33 | }
34 |
35 | #sidebar ul.components {
36 | padding: 20px 0;
37 | border-bottom: 1px solid #47748b;
38 | }
39 |
40 | #sidebar ul p {
41 | color: #fff;
42 | padding: 10px;
43 | }
44 |
45 | #sidebar ul li a {
46 | padding: 8px 10px;
47 | font-size: 1em; /* Consistent font size */
48 | display: block;
49 | color: #fff;
50 | text-decoration: none;
51 | }
52 |
53 | #sidebar ul li a:hover {
54 | color: #7386D5;
55 | background: #fff;
56 | }
57 |
58 | #sidebar ul li.active > a, a[aria-expanded="true"] {
59 | color: #fff;
60 | background: #6d7fcc;
61 | }
62 |
63 | .session-list {
64 | max-height: calc(100vh - 250px);
65 | overflow-y: auto;
66 | }
67 |
68 | .session-list .nav-item {
69 | margin-bottom: 2px;
70 | }
71 |
72 | /* Main content styles */
73 | main {
74 | transition: all 0.3s;
75 | height: 100vh;
76 | overflow-y: auto;
77 | }
78 |
79 | main.active {
80 | margin-left: 0;
81 | }
82 |
83 | /* Top bar styles */
84 | .btn-toolbar .btn {
85 | margin-left: 5px;
86 | font-size: 1em; /* Consistent font size */
87 | }
88 |
89 | /* Chat container styles */
90 | .chat-container {
91 | height: calc(100vh - 100px);
92 | display: flex;
93 | flex-direction: column;
94 | background-color: #ffffff;
95 | border-radius: 10px;
96 | box-shadow: 0 1px 3px rgba(0, 0, 0, 0.12), 0 1px 2px rgba(0, 0, 0, 0.24);
97 | margin: 20px auto;
98 | max-width: 800px;
99 | }
100 |
101 | .chat-messages {
102 | flex-grow: 1;
103 | overflow-y: auto;
104 | padding: 20px;
105 | }
106 |
107 | .chat-input-container {
108 | padding: 20px;
109 | background-color: #f8f9fa;
110 | border-top: 1px solid #dee2e6;
111 | border-bottom-left-radius: 10px;
112 | border-bottom-right-radius: 10px;
113 | }
114 |
115 | .chat-input-container .input-group {
116 | max-width: 600px;
117 | margin: 0 auto;
118 | }
119 |
120 | /* Message styles */
121 | .message {
122 | margin-bottom: 15px;
123 | padding: 10px 15px;
124 | border-radius: 18px;
125 | max-width: 80%;
126 | word-wrap: break-word;
127 | font-size: 1em; /* Consistent font size */
128 | }
129 |
130 | .user-message {
131 | background-color: #007bff;
132 | color: #fff;
133 | align-self: flex-end;
134 | margin-left: auto;
135 | }
136 |
137 | .ai-message {
138 | background-color: #f1f3f5;
139 | color: #343a40;
140 | align-self: flex-start;
141 | margin-right: auto;
142 | }
143 |
144 | /* Image styles */
145 | .image-container {
146 | display: flex;
147 | flex-wrap: wrap;
148 | gap: 10px;
149 | margin-top: 10px;
150 | }
151 |
152 | .retrieved-image {
153 | max-width: 150px;
154 | max-height: 150px;
155 | object-fit: cover;
156 | cursor: zoom-in;
157 | transition: transform 0.3s ease;
158 | border-radius: 8px;
159 | }
160 |
161 | .retrieved-image:hover {
162 | transform: scale(1.05);
163 | }
164 |
165 | /* Loading indicator styles */
166 | #loading-indicator {
167 | position: fixed;
168 | top: 50%;
169 | left: 50%;
170 | transform: translate(-50%, -50%);
171 | z-index: 1000;
172 | background-color: rgba(255, 255, 255, 0.9);
173 | padding: 20px;
174 | border-radius: 10px;
175 | box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
176 | }
177 |
178 | #loading-indicator .spinner-border {
179 | width: 3rem;
180 | height: 3rem;
181 | }
182 |
183 | #loading-indicator p {
184 | margin-top: 10px;
185 | font-weight: bold;
186 | font-size: 1em; /* Consistent font size */
187 | }
188 |
189 | /* Indexed files list styles */
190 | #indexed-files-list {
191 | max-height: 300px;
192 | overflow-y: auto;
193 | }
194 |
195 | #indexed-files-list .list-group-item {
196 | padding: 8px 12px;
197 | font-size: 1em; /* Consistent font size */
198 | border-left: none;
199 | border-right: none;
200 | }
201 |
202 | .session-name {
203 | cursor: pointer;
204 | transition: color 0.3s ease;
205 | font-size: 1em; /* Consistent font size */
206 | }
207 |
208 | .session-name:hover {
209 | color: #007bff;
210 | }
211 |
212 | /* Responsive adjustments */
213 | @media (max-width: 768px) {
214 | #sidebar {
215 | margin-left: -250px;
216 | }
217 | #sidebar.active {
218 | margin-left: 0;
219 | }
220 | main {
221 | margin-left: 0;
222 | }
223 | main.active {
224 | margin-left: 250px;
225 | }
226 | }
227 |
228 | /* Medium Zoom styles */
229 | .medium-zoom-overlay {
230 | z-index: 1000;
231 | }
232 |
233 | .medium-zoom-image--opened {
234 | z-index: 1001;
235 | }
236 |
237 | .session-options {
238 | position: relative;
239 | }
240 |
241 | .fa-ellipsis-h {
242 | cursor: pointer;
243 | padding: 3px;
244 | color: #adb5bd;
245 | font-size: 1em; /* Consistent font size */
246 | }
247 |
248 | .options-popup {
249 | display: none;
250 | position: absolute;
251 | right: 0;
252 | top: 100%;
253 | background-color: #343a40;
254 | border: 1px solid #495057;
255 | border-radius: 4px;
256 | box-shadow: 0 2px 10px rgba(0,0,0,0.2);
257 | z-index: 1000;
258 | min-width: 120px;
259 | }
260 |
261 | .option {
262 | padding: 8px 12px;
263 | cursor: pointer;
264 | white-space: nowrap;
265 | color: #fff;
266 | transition: background-color 0.2s ease;
267 | font-size: 1em; /* Consistent font size */
268 | }
269 |
270 | .option:hover {
271 | background-color: #495057;
272 | }
273 |
274 | .session-list .nav-item.current-session {
275 | background-color: #495057;
276 | border-radius: 4px;
277 | }
278 |
279 | .session-list .nav-item.current-session .nav-link {
280 | color: #ffffff;
281 | font-weight: bold;
282 | }
283 |
284 | .session-list .nav-item.current-session .fa-ellipsis-h {
285 | color: #ffffff;
286 | }
287 |
288 | /* Markdown styles */
289 | .ai-message h1, .ai-message h2, .ai-message h3, .ai-message h4, .ai-message h5, .ai-message h6 {
290 | margin-top: 10px;
291 | margin-bottom: 5px;
292 | }
293 |
294 | .ai-message p {
295 | margin-bottom: 10px;
296 | }
297 |
298 | .ai-message ul, .ai-message ol {
299 | margin-left: 20px;
300 | margin-bottom: 10px;
301 | }
302 |
303 | .ai-message code {
304 | background-color: #f0f0f0;
305 | padding: 2px 4px;
306 | border-radius: 4px;
307 | }
308 |
309 | .ai-message pre {
310 | background-color: #f0f0f0;
311 | padding: 10px;
312 | border-radius: 4px;
313 | overflow-x: auto;
314 | }
--------------------------------------------------------------------------------
/templates/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | LocalGPT Vision
9 |
10 |
11 |
12 |
13 | {% block styles %}{% endblock %}
14 |
15 |
16 |
17 |
18 |
19 |
52 |
53 |
54 |
55 |
65 |
66 | {% block content %}{% endblock %}
67 |
68 |
69 |
70 |
71 |
72 |
87 |
88 |
89 |
90 |
91 |
185 | {% block scripts %}{% endblock %}
186 |
187 |
--------------------------------------------------------------------------------
/jarvis/CLAUDE.md:
--------------------------------------------------------------------------------
1 | # CLAUDE.md
2 |
3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4 |
5 | ## Project Overview
6 |
7 | This is a Real-Time Speech Transcription Application for macOS, designed to be privacy-focused with all processing done on-device using Apple Silicon optimization. The application features speech-to-text transcription with AI-powered text correction, multiple UI modes (GUI, CLI), and support for various speech recognition models.
8 |
9 | ### Original Requirements (PRD.md)
10 | The project was built following a comprehensive Product Requirements Document that specified:
11 | - Privacy-first approach with all processing on-device
12 | - Real-time transcription performance (>5x real-time)
13 | - Multiple STT model support (Whisper variants)
14 | - AI-powered text correction using LLMs
15 | - GUI and CLI interfaces
16 | - Apple Silicon optimization using MLX framework
17 |
18 | ### Recent Feature Additions
19 | 1. **Microphone Selection**: Dropdown to choose from all available audio input devices
20 | 2. **Copy Functionality**: Separate copy buttons for raw and corrected text
21 | 3. **LLM Model Selection**: Dynamic selection of text correction models from GUI
22 | 4. **Enhanced Error Handling**: Proper thread cleanup and crash prevention
23 | 5. **Auto Qt Path Detection**: Automatic resolution of Qt platform plugin issues
24 |
25 | ## Architecture
26 |
27 | ### Core Components
28 |
29 | 1. **Audio Pipeline** (`src/audio/`)
30 | - `AudioRecorder`: Manages sounddevice streams for microphone capture with device selection
31 | - `AudioProcessor`: Handles normalization, resampling, noise reduction, and VAD
32 |
33 | 2. **Model Layer** (`src/models/`)
34 | - `SpeechTranscriber`: Abstraction over multiple STT models (Whisper variants)
35 | - `TextCorrector`: LLM-based correction using MLX models with dynamic model selection
36 | - Model dictionaries: `AVAILABLE_MODELS` (STT) and `AVAILABLE_LLM_MODELS` (correction)
37 |
38 | 3. **User Interfaces** (`src/gui/`)
39 | - `MainWindow`: PyQt6-based GUI with real-time visualization
40 | - Microphone selection dropdown with refresh capability
41 | - Model selection dropdowns for both STT and LLM models
42 | - Copy functionality for both raw and corrected text
43 | - Dark theme with audio level visualization
44 | - CLI interface in `src/cli.py`
45 |
46 | 4. **Pipeline** (`src/pipeline.py`)
47 | - Orchestrates the full transcription flow
48 | - Manages audio buffering without duplication
49 | - Supports dynamic model switching for both STT and LLM
50 | - Coordinates all components
51 |
52 | ### Key Design Decisions
53 |
54 | - **Model Flexibility**: Support for multiple STT and LLM models with different accuracy/speed tradeoffs
55 | - **Privacy-First**: All processing on-device, no cloud APIs
56 | - **MLX Optimization**: Uses Apple's MLX framework for efficient inference on Apple Silicon
57 | - **Dynamic Model Selection**: Users can choose models from the GUI without restart
58 | - **Error Resilience**: Proper thread cleanup and error handling for stability
59 |
60 | ## Development Commands
61 |
62 | ### Project Setup
63 | ```bash
64 | # Create virtual environment
65 | python -m venv venv
66 | source venv/bin/activate
67 |
68 | # Install dependencies
69 | pip install -r requirements.txt
70 |
71 | # Download required ML models
72 | python scripts/download_models.py
73 | ```
74 |
75 | ### Running the Application
76 | ```bash
77 | # GUI mode
78 | python -m src.main
79 |
80 | # CLI mode
81 | python -m src.cli --audio recording.wav --model whisper-small
82 |
83 | # With text correction
84 | python -m src.cli --audio recording.wav --context "Technical discussion"
85 | ```
86 |
87 | ### Testing
88 | ```bash
89 | # Run all tests
90 | pytest tests/ -v
91 |
92 | # Run unit tests only
93 | pytest tests/unit/ -v
94 |
95 | # Run integration tests
96 | pytest tests/integration/ -v
97 |
98 | # Run specific test
99 | pytest tests/unit/test_audio_processor.py::TestAudioProcessor::test_normalize_audio -v
100 | ```
101 |
102 | ### Development Workflow
103 | ```bash
104 | # Format code
105 | black src tests
106 | ruff format src tests
107 |
108 | # Lint code
109 | ruff check src tests
110 | mypy src
111 |
112 | # Clean build artifacts
113 | find . -type f -name "*.pyc" -delete
114 | find . -type d -name "__pycache__" -delete
115 | ```
116 |
117 | ## Model Management
118 |
119 | ### Available STT Models
120 | - **whisper-tiny**: Very fast, lower accuracy (39M params)
121 | - **whisper-base**: Fast, decent accuracy (74M params)
122 | - **whisper-small**: Good balance (244M params)
123 | - **whisper-medium**: High accuracy, slower (769M params)
124 | - **whisper-large-v3**: Best accuracy, slowest (1550M params)
125 | - **distil-whisper-large-v3**: Fast with high accuracy (756M params)
126 |
127 | ### Available LLM Models for Text Correction
128 | - **Qwen2.5-0.5B**: Tiny & fast (0.5B params, <2GB RAM)
129 | - **Qwen2.5-1.5B**: Small & efficient (1.5B params, ~3GB RAM)
130 | - **gemma-2-2b-it**: Compact & capable (2B params, ~4GB RAM)
131 | - **Phi-3.5-mini**: Excellent quality (3.8B params, ~6GB RAM) - Default
132 | - **Qwen2.5-7B**: Large & powerful (7B params, ~12GB RAM)
133 | - **Mistral-7B-Instruct**: High quality (7B params, ~12GB RAM)
134 |
135 | ### Model Loading
136 | Models are loaded on-demand from Hugging Face hub. Both `SpeechTranscriber` and `TextCorrector` support dynamic model switching via the GUI without restart.
137 |
138 | ## Audio Processing Pipeline
139 |
140 | 1. **Recording**: 16kHz mono audio captured in 0.5s chunks
141 | - Device selection with automatic default detection
142 | - Real-time audio level monitoring
143 | - Proper buffering without duplication
144 | 2. **Processing**: Normalization → Resampling → Noise reduction → VAD
145 | 3. **Transcription**: Audio written to temp file → Model inference → Text output
146 | 4. **Correction**: Optional LLM-based correction with context awareness
147 | - Custom prompt structure with context integration
148 | - No temperature/top_p parameters (unsupported by mlx_lm)
149 |
150 | ## GUI Architecture
151 |
152 | - **Main Thread**: UI updates and user interaction
153 | - **TranscriptionThread**: Background audio recording and level monitoring
154 | - **Model Operations**: Synchronous but UI remains responsive via threading
155 | - **Real-time Feedback**: Audio level visualization at 50ms intervals
156 |
157 | ## Performance Considerations
158 |
159 | - Audio chunks processed at 0.5s intervals for low latency
160 | - Models use MLX for Apple Silicon optimization
161 | - Text correction uses default generation parameters (no temperature control)
162 | - Memory-mapped model loading for faster startup
163 | - Thread cleanup ensures no resource leaks on cancel/clear
164 |
165 | ## Common Issues and Fixes
166 |
167 | ### Qt Platform Plugin Error
168 | **Problem**: "Could not find the Qt platform plugin 'cocoa'"
169 | **Solution**: The code automatically detects and sets the Qt plugin path in `main.py`. If it persists:
170 | ```bash
171 | # Use the provided run script
172 | ./run_gui.sh
173 |
174 | # Or set manually
175 | export QT_QPA_PLATFORM_PLUGIN_PATH=$(python -c "import PyQt6, os; print(os.path.join(os.path.dirname(PyQt6.__file__), 'Qt6', 'plugins', 'platforms'))")
176 | ```
177 |
178 | ### Duplicate Transcription
179 | **Problem**: Transcribed text appears twice
180 | **Solution**: Fixed by removing duplicate audio buffering in `pipeline.py`. The `stop_recording()` method now uses only the final audio from `AudioRecorder.stop_recording()`.
181 |
182 | ### LLM Temperature Error
183 | **Problem**: "generate_step() got an unexpected keyword argument 'temperature'"
184 | **Solution**: The mlx_lm library doesn't support temperature/top_p parameters. These have been removed from the `generate()` call in `TextCorrector`.
185 |
186 | ### Clear Button Crash
187 | **Problem**: Clicking clear button terminates the application
188 | **Solution**: Added proper error handling and thread cleanup in `clear_text()` method. The method now safely cancels any running recording before clearing.
189 |
190 | ### Text Correction Prompt
191 | The LLM receives this structured prompt:
192 | ```
193 | You are a helpful assistant that corrects transcription errors.
194 | Context: {context if provided}
195 |
196 | Please correct any transcription errors in the following text, removing filler words, fixing grammar, and improving clarity while preserving the original meaning:
197 |
198 | {transcribed_text}
199 |
200 | Corrected text:
201 | ```
202 |
203 | ### Whisper Model Repository Names
204 | **Problem**: 404 errors when loading Whisper models
205 | **Solution**: The mlx-community organization uses inconsistent naming for Whisper models:
206 | - `whisper-tiny` - no suffix needed
207 | - `whisper-base-mlx` - requires "-mlx" suffix
208 | - `whisper-small-mlx` - requires "-mlx" suffix
209 | - `whisper-medium-mlx` - requires "-mlx" suffix
210 | - `whisper-large-v3-mlx` - requires "-mlx" suffix
211 | - `distil-whisper-large-v3` - no suffix needed
212 |
213 | ## Common Patterns
214 |
215 | ### Adding New STT Models
216 | 1. Add model config to `AVAILABLE_MODELS` in `src/models/transcriber.py`
217 | 2. Ensure model is Whisper-based (current implementation)
218 | 3. Update model download script if needed
219 |
220 | ### Adding New LLM Models
221 | 1. Add model config to `AVAILABLE_LLM_MODELS` in `src/models/text_corrector.py`
222 | 2. Include repo, description, and size fields
223 | 3. Test with mlx_lm.generate() compatibility
224 | 4. Update download script for default models
225 |
226 | ### Modifying Audio Processing
227 | - All audio processing goes through `AudioProcessor.process()`
228 | - Maintain 16kHz target sample rate for model compatibility
229 | - Ensure processed audio is normalized to [-1, 1] range
230 | - Avoid duplicate buffering in the pipeline
231 |
232 | ### GUI Customization
233 | - Dark theme defined in `MainWindow.init_ui()`
234 | - Keyboard shortcuts in `setup_shortcuts()`
235 | - Audio visualization via `update_level()` callback
236 | - Model dropdowns update pipeline without restart
237 | - Copy buttons use QApplication.clipboard()
238 |
239 | ## Best Practices & Lessons Learned
240 |
241 | ### Thread Safety
242 | - Always use proper thread cleanup in GUI operations
243 | - Set `is_recording` flag before cleanup to prevent race conditions
244 | - Use `wait()` with timeout before `terminate()` for graceful shutdown
245 |
246 | ### Audio Buffer Management
247 | - Avoid duplicate buffering between recorder and pipeline
248 | - Use the final audio from `stop_recording()` which includes all chunks
249 | - Maintain consistent sample rates throughout the pipeline
250 |
251 | ### Model Parameter Compatibility
252 | - mlx_lm's `generate()` doesn't support temperature/top_p parameters
253 | - Use default generation settings for consistent behavior
254 | - Test parameter compatibility when adding new models
255 |
256 | ### Error Handling
257 | - Wrap all GUI callbacks with try/except blocks
258 | - Provide user-friendly error messages via status bar
259 | - Use QMessageBox for critical errors only
260 |
261 | ### Development Workflow
262 | 1. Always check PRD.md for requirements alignment
263 | 2. Test model downloads before assuming availability
264 | 3. Verify audio permissions before recording
265 | 4. Run both GUI and CLI modes during testing
266 | 5. Check memory usage with different model sizes
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # localGPT-Vision
2 | [](https://github.com/PromtEngineer/localGPT-Vision/stargazers)
3 | [](https://github.com/PromtEngineer/localGPT-Vision/network/members)
4 | [](https://github.com/PromtEngineer/localGPT-Vision/issues)
5 | [](https://github.com/PromtEngineer/localGPT-Vision/pulls)
6 | [](https://twitter.com/engineerrprompt)
7 |
8 |
9 | [Watch the video on YouTube](https://youtu.be/YPs4eGDpIY4)
10 |
11 |
12 | localGPT-Vision is an end-to-end vision-based Retrieval-Augmented Generation (RAG) system. It allows users to upload and index documents (PDFs and images), ask questions about the content, and receive responses along with relevant document snippets. The retrieval is performed using the [Colqwen](https://huggingface.co/vidore/colqwen2-v0.1) or [ColPali](https://huggingface.co/blog/manu/colpali) models, and the retrieved pages are passed to a Vision Language Model (VLM) for generating responses. Currently, the code supports these VLMs:
13 |
14 | - [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)
15 | - [LLAMA-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
16 | - [Pixtral-12B-2409](https://huggingface.co/mistralai/Pixtral-12B-2409)
17 | - [Molmo-7B-O-0924](https://huggingface.co/allenai/Molmo-7B-O-0924)
18 | - [Google Gemini](https://aistudio.google.com/app/prompts/new_chat)
19 | - [OpenAI GPT-4o](https://platform.openai.com/docs/guides/vision)
20 | - [LLAMA-3.2 with Ollama](https://ollama.com/blog/llama3.2-vision)
21 |
22 | The project is built on top of the [Byaldi](https://github.com/AnswerDotAI/byaldi) library.
23 |
24 | ## Table of Contents
25 | - [Features](#features)
26 | - [Architecture](#architecture)
27 | - [Prerequisites](#prerequisites)
28 | - [Installation](#installation)
29 | - [Usage](#usage)
30 | - [Project Structure](#project-structure)
31 | - [System Workflow](#system-workflow)
32 | - [Contributing](#contributing)
33 |
34 | ## Features
35 | - End-to-End Vision-Based RAG: Combines visual document retrieval with language models for comprehensive answers.
36 | - Document Upload and Indexing: Upload PDFs and images, which are then indexed using ColPali for retrieval.
37 | - Chat Interface: Engage in a conversational interface to ask questions about the uploaded documents.
38 | - Session Management: Create, rename, switch between, and delete chat sessions.
39 | - Model Selection: Choose between different Vision Language Models (Qwen2-VL-7B-Instruct, Google Gemini, OpenAI GPT-4 etc).
40 | - Persistent Indexes: Indexes are saved on disk and loaded upon application restart.
41 |
42 | ## Architecture
43 | localGPT-Vision is built as an end-to-end vision-based RAG system. T he architecture comprises two main components:
44 |
45 | 1. Visual Document Retrieval with Colqwen and ColPali:
46 | - [Colqwen](https://huggingface.co/vidore/colqwen2-v0.1) and [ColPali](https://huggingface.co/blog/manu/colpali) are Vision Encoders designed for efficient document retrieval solely using the image representation of document pages.
47 | - It embeds page images directly, leveraging visual cues like layout, fonts, figures, and tables without relying on OCR or text extraction.
48 | - During indexing, document pages are converted into image embeddings and stored.
49 | - During querying, the user query is matched against these embeddings to retrieve the most relevant document pages.
50 |
51 | 
52 |
53 | 2. Response Generation with Vision Language Models:
54 | - The retrieved document images are passed to a Vision Language Model (VLM).
55 | - Supported models include Qwen2-VL-7B-Instruct, LLAMA3.2, Pixtral, Molmo, Google Gemini, and OpenAI GPT-4.
56 | - These models generate responses by understanding both the visual and textual content of the documents.
57 | - NOTE: The quality of the responses is highly dependent on the VLM used and the resolution of the document images.
58 |
59 | This architecture eliminates the need for complex text extraction pipelines and provides a more holistic understanding of documents by considering their visual elements. You don't need any chunking strategies or selection of embeddings model or retrieval strategy used in traditional RAG systems.
60 |
61 | ## Prerequisites
62 | - Anaconda or Miniconda installed on your system
63 | - Python 3.10 or higher
64 | - Git (optional, for cloning the repository)
65 |
66 | ## Installation
67 | Follow these steps to set up and run the application on your local machine.
68 |
69 | 1. Clone the Repository
70 | ```bash
71 | git clone https://github.com/PromtEngineer/localGPT-Vision.git
72 | cd localGPT-Vision
73 | ```
74 |
75 | 2. Create a Conda Environment
76 | ```bash
77 | conda create -n localgpt-vision python=3.10
78 | conda activate localgpt-vision
79 | ```
80 |
81 | 3a. Install Dependencies
82 | ```bash
83 | pip install -r requirements.txt
84 | ```
85 |
86 | 3b. Install Transformers from HuggingFace - Dev version
87 | ```bash
88 | pip uninstall transformers
89 | pip install git+https://github.com/huggingface/transformers
90 | ```
91 |
92 | 4. Set Environment Variables
93 | Set your API keys for Google Gemini and OpenAI GPT-4:
94 |
95 | ```bash
96 | export GENAI_API_KEY='your_genai_api_key'
97 | export OPENAI_API_KEY='your_openai_api_key'
98 | export GROQ_API_KEY='your_groq_api_key'
99 | ```
100 |
101 | On Windows Command Prompt:
102 | ```cmd
103 | set GENAI_API_KEY=your_genai_api_key
104 | set OPENAI_API_KEY=your_openai_api_key
105 | set GROQ_API_KEY='your_groq_api_key'
106 | ```
107 |
108 | 5. Run the Application
109 | ```bash
110 | python app.py
111 | ```
112 |
113 | 6. Access the Application
114 | Open your web browser and navigate to:
115 | ```
116 | http://localhost:5050/
117 | ```
118 | ## Debugging
119 |
120 | To assist with debugging localGPT-Vision, certain additional dependencies need to be installed on your system. Follow the instructions below to set up your environment for debugging purposes.
121 |
122 | ### Required Packages
123 |
124 | Install the necessary packages using the following commands:
125 |
126 | #### On Ubuntu/Debian-based Systems
127 |
128 | 1. **Update Package Lists**
129 |
130 | ```bash
131 | sudo apt update
132 | ```
133 |
134 | 2. **Install Poppler Libraries and Utilities**
135 |
136 | These libraries are essential for handling PDF files.
137 |
138 | ```bash
139 | sudo apt install libpoppler-cpp-dev poppler-utils
140 | ```
141 |
142 | 3. **Verify Installation of `pdftoppm`**
143 |
144 | Check the version to ensure it's installed correctly:
145 |
146 | ```bash
147 | pdftoppm -v
148 | ```
149 |
150 | 4. **Install Additional Dependencies**
151 |
152 | These packages are required for building and managing libraries.
153 |
154 | ```bash
155 | sudo apt install cmake pkgconfig python3-poppler-qt5
156 | ```
157 |
158 | ### Note
159 |
160 | Ensure that you have appropriate permissions to run `sudo` commands on your machine. This setup is specifically tailored for Ubuntu/Debian-based systems, and steps might vary slightly if using a different Linux distribution or macOS.
161 |
162 | ---
163 |
164 | Feel free to modify this section as needed based on any additional requirements or specific instructions pertinent to other operating systems.
165 |
166 |
167 | ## Usage
168 | ### Upload and Index Documents
169 | 1. Click on "New Chat" to start a new session.
170 | 2. Under "Upload and Index Documents", click "Choose Files" and select your PDF or image files.
171 | 3. Click "Upload and Index". The documents will be indexed using ColPali and ready for querying.
172 |
173 | ### Ask Questions
174 | 1. In the "Enter your question here" textbox, type your query related to the uploaded documents.
175 | 2. Click "Send". The system will retrieve relevant document pages and generate a response using the selected Vision Language Model.
176 |
177 | ### Manage Sessions
178 | - Rename Session: Click "Edit Name", enter a new name, and click "Save Name".
179 | - Switch Sessions: Click on a session name in the sidebar to switch to that session.
180 | - Delete Session: Click "Delete" next to a session to remove it permanently.
181 |
182 | ### Settings
183 | 1. Click on "Settings" in the navigation bar.
184 | 2. Select the desired language model and image dimensions.
185 | 3. Click "Save Settings".
186 |
187 | ## Project Structure
188 | ```
189 | localGPT-Vision/
190 | ├── app.py
191 | ├── logger.py
192 | ├── models/
193 | │ ├── indexer.py
194 | │ ├── retriever.py
195 | │ ├── responder.py
196 | │ ├── model_loader.py
197 | │ └── converters.py
198 | ├── sessions/
199 | ├── templates/
200 | │ ├── base.html
201 | │ ├── chat.html
202 | │ ├── settings.html
203 | │ └── index.html
204 | ├── static/
205 | │ ├── css/
206 | │ │ └── style.css
207 | │ ├── js/
208 | │ │ └── script.js
209 | │ └── images/
210 | ├── uploaded_documents/
211 | ├── byaldi_indices/
212 | ├── requirements.txt
213 | ├── .gitignore
214 | └── README.md
215 | ```
216 |
217 | - `app.py`: Main Flask application.
218 | - `logger.py`: Configures application logging.
219 | - `models/`: Contains modules for indexing, retrieving, and responding.
220 | - `templates/`: HTML templates for rendering views.
221 | - `static/`: Static files like CSS and JavaScript.
222 | - `sessions/`: Stores session data.
223 | - `uploaded_documents/`: Stores uploaded documents.
224 | - `.byaldi/`: Stores the indexes created by Byaldi.
225 | - `requirements.txt`: Python dependencies.
226 | - `.gitignore`: Files and directories to be ignored by Git.
227 | - `README.md`: Project documentation.
228 |
229 | ## System Workflow
230 | 1. User Interaction: The user interacts with the web interface to upload documents and ask questions.
231 | 2. Document Indexing with ColPali:
232 | - Uploaded documents are converted to PDFs if necessary.
233 | - Documents are indexed using ColPali, which creates embeddings based on the visual content of the document pages.
234 | - The indexes are stored in the byaldi_indices/ directory.
235 | 3. Session Management:
236 | - Each chat session has a unique ID and stores its own index and chat history.
237 | - Sessions are saved on disk and loaded upon application restart.
238 | 4. Query Processing:
239 | - User queries are sent to the backend.
240 | - The query is embedded and matched against the visual embeddings of document pages to retrieve relevant pages.
241 | 5. Response Generation with Vision Language Models:
242 | - The retrieved document images and the user query are passed to the selected Vision Language Model (Qwen, Gemini, or GPT-4).
243 | - The VLM generates a response by understanding both the visual and textual content of the documents.
244 | 6. Display Results:
245 | - The response and relevant document snippets are displayed in the chat interface.
246 |
247 | ```mermaid
248 | graph TD
249 | A[User] -->|Uploads Documents| B(Flask App)
250 | B -->|Saves Files| C[uploaded_documents/]
251 | B -->|Converts and Indexes with ColPali| D[Indexing Module]
252 | D -->|Creates Visual Embeddings| E[byaldi_indices/]
253 | A -->|Asks Question| B
254 | B -->|Embeds Query and Retrieves Pages| F[Retrieval Module]
255 | F -->|Retrieves Relevant Pages| E
256 | F -->|Passes Pages to| G[Vision Language Model]
257 | G -->|Generates Response| B
258 | B -->|Displays Response| A
259 | B -->|Saves Session Data| H[sessions/]
260 | subgraph Backend
261 | B
262 | D
263 | F
264 | G
265 | end
266 | subgraph Storage
267 | C
268 | E
269 | H
270 | end
271 | ```
272 |
273 | ## Contributing
274 | Contributions are welcome! Please follow these steps:
275 |
276 | 1. Fork the repository.
277 | 2. Create a new branch for your feature: `git checkout -b feature-name`.
278 | 3. Commit your changes: `git commit -am 'Add new feature'`.
279 | 4. Push to the branch: `git push origin feature-name`.
280 | 5. Submit a pull request.
281 |
282 | ## Star History
283 |
284 | [](https://star-history.com/#PromtEngineer/localGPT-Vision&Date)
285 |
286 |
--------------------------------------------------------------------------------
/models/responder.py:
--------------------------------------------------------------------------------
1 | # models/responder.py
2 |
3 | from models.model_loader import load_model, is_single_image_model
4 | from transformers import GenerationConfig
5 | import google.generativeai as genai
6 | from dotenv import load_dotenv
7 | from logger import get_logger
8 | from openai import OpenAI
9 | from PIL import Image
10 | import torch
11 | import base64
12 | import os
13 | import io
14 | # from langchain_core.messages import HumanMessage
15 | from io import BytesIO
16 | import ollama
17 |
18 |
19 | logger = get_logger(__name__)
20 |
21 | # Function to encode the image
22 | def encode_image(image_path):
23 | with open(image_path, "rb") as image_file:
24 | return base64.b64encode(image_file.read()).decode('utf-8')
25 |
26 | def generate_response(images, query, session_id, resized_height=280, resized_width=280, model_choice='qwen'):
27 | """
28 | Generates a response using the selected model based on the query and images.
29 | Returns: (response_text, used_images)
30 | """
31 | try:
32 | logger.info(f"Generating response using model '{model_choice}'.")
33 |
34 | # Convert resized_height and resized_width to integers
35 | resized_height = int(resized_height)
36 | resized_width = int(resized_width)
37 |
38 | # Ensure images are full paths
39 | full_image_paths = [os.path.join('static', img) if not img.startswith('static') else img for img in images]
40 |
41 | # Check if any valid images exist
42 | valid_images = [img for img in full_image_paths if os.path.exists(img)]
43 |
44 | if not valid_images:
45 | logger.warning("No valid images found for analysis.")
46 | return "No images could be loaded for analysis.", []
47 |
48 | # If model only supports single image, use only the first image
49 | if is_single_image_model(model_choice):
50 | valid_images = [valid_images[0]]
51 | logger.info(f"Model {model_choice} only supports single image, using first image only.")
52 |
53 | if model_choice == 'qwen':
54 | from qwen_vl_utils import process_vision_info
55 | # Load cached model
56 | model, processor, device = load_model('qwen')
57 | # Ensure dimensions are multiples of 28
58 | resized_height = (resized_height // 28) * 28
59 | resized_width = (resized_width // 28) * 28
60 |
61 | image_contents = []
62 | for image in valid_images:
63 | image_contents.append({
64 | "type": "image",
65 | "image": image, # Use the full path
66 | "resized_height": resized_height,
67 | "resized_width": resized_width
68 | })
69 | messages = [
70 | {
71 | "role": "user",
72 | "content": image_contents + [{"type": "text", "text": query}],
73 | }
74 | ]
75 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
76 | image_inputs, video_inputs = process_vision_info(messages)
77 | inputs = processor(
78 | text=[text],
79 | images=image_inputs,
80 | videos=video_inputs,
81 | padding=True,
82 | return_tensors="pt",
83 | )
84 | inputs = inputs.to(device)
85 | generated_ids = model.generate(**inputs, max_new_tokens=128)
86 | generated_ids_trimmed = [
87 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
88 | ]
89 | output_text = processor.batch_decode(
90 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
91 | )
92 | logger.info("Response generated using Qwen model.")
93 | return output_text[0], valid_images
94 |
95 | elif model_choice == 'gemini':
96 | model, _ = load_model('gemini')
97 |
98 | try:
99 | content = [query] # Add the text query first
100 |
101 | for img_path in valid_images:
102 | if os.path.exists(img_path):
103 | try:
104 | img = Image.open(img_path)
105 | content.append(img)
106 | except Exception as e:
107 | logger.error(f"Error opening image {img_path}: {e}")
108 | else:
109 | logger.warning(f"Image file not found: {img_path}")
110 |
111 | if len(content) == 1: # Only text, no images
112 | return "No images could be loaded for analysis.", []
113 |
114 | response = model.generate_content(content)
115 |
116 | if response.text:
117 | generated_text = response.text
118 | logger.info("Response generated using Gemini model.")
119 | return generated_text, valid_images
120 | else:
121 | return "The Gemini model did not generate any text response.", []
122 |
123 | except Exception as e:
124 | logger.error(f"Error in Gemini processing: {str(e)}", exc_info=True)
125 | return f"An error occurred while processing the images: {str(e)}", []
126 |
127 | elif model_choice == 'gpt4':
128 | api_key = os.getenv("OPENAI_API_KEY")
129 | client = OpenAI(api_key=api_key)
130 |
131 | try:
132 | content = [{"type": "text", "text": query}]
133 |
134 | for img_path in valid_images:
135 | logger.info(f"Processing image: {img_path}")
136 | if os.path.exists(img_path):
137 | base64_image = encode_image(img_path)
138 | content.append({
139 | "type": "image_url",
140 | "image_url": {
141 | "url": f"data:image/jpeg;base64,{base64_image}"
142 | }
143 | })
144 | else:
145 | logger.warning(f"Image file not found: {img_path}")
146 |
147 | if len(content) == 1: # Only text, no images
148 | return "No images could be loaded for analysis.", []
149 |
150 | response = client.chat.completions.create(
151 | model="gpt-4o",
152 | messages=[
153 | {
154 | "role": "user",
155 | "content": content
156 | }
157 | ],
158 | max_tokens=1024
159 | )
160 |
161 | generated_text = response.choices[0].message.content
162 | logger.info("Response generated using GPT-4 model.")
163 | return generated_text, valid_images
164 |
165 | except Exception as e:
166 | logger.error(f"Error in GPT-4 processing: {str(e)}", exc_info=True)
167 | return f"An error occurred while processing the images: {str(e)}", []
168 |
169 | elif model_choice == 'llama-vision':
170 | # Load model, processor, and device
171 | model, processor, device = load_model('llama-vision')
172 |
173 | # Process images
174 | # For simplicity, use the first image
175 | image_path = valid_images[0] if valid_images else None
176 | if image_path and os.path.exists(image_path):
177 | image = Image.open(image_path).convert('RGB')
178 | else:
179 | return "No valid image found for analysis.", []
180 |
181 | # Prepare messages
182 | messages = [
183 | {"role": "user", "content": [
184 | {"type": "image"},
185 | {"type": "text", "text": query}
186 | ]}
187 | ]
188 | input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
189 | inputs = processor(image, input_text, return_tensors="pt").to(device)
190 |
191 | # Generate response
192 | output = model.generate(**inputs, max_new_tokens=512)
193 | response = processor.decode(output[0], skip_special_tokens=True)
194 | return response, valid_images
195 |
196 | elif model_choice == "pixtral":
197 | model, tokenizer, generate_func, device = load_model('pixtral')
198 |
199 | def image_to_data_url(image_path):
200 | with open(image_path, "rb") as image_file:
201 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
202 | ext = os.path.splitext(image_path)[1][1:] # Get the file extension
203 | return f"data:image/{ext};base64,{encoded_string}"
204 |
205 | from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
206 | from mistral_common.protocol.instruct.request import ChatCompletionRequest
207 |
208 | # Prepare the content with text and images
209 | content = [TextChunk(text=query)]
210 | for img_path in valid_images[:1]: # Use only the first image
211 | content.append(ImageURLChunk(image_url=image_to_data_url(img_path)))
212 |
213 | completion_request = ChatCompletionRequest(messages=[UserMessage(content=content)])
214 |
215 | encoded = tokenizer.encode_chat_completion(completion_request)
216 |
217 | images = encoded.images
218 | tokens = encoded.tokens
219 |
220 | out_tokens, _ = generate_func([tokens], model, images=[images], max_tokens=256, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
221 | result = tokenizer.decode(out_tokens[0])
222 |
223 | logger.info("Response generated using Pixtral model.")
224 | return result, valid_images
225 |
226 | elif model_choice == "molmo":
227 | model, processor, device = load_model('molmo')
228 | model = model.half() # Convert model to half precision
229 | pil_images = []
230 | for img_path in valid_images[:1]: # Process only the first image for now
231 | if os.path.exists(img_path):
232 | try:
233 | img = Image.open(img_path).convert('RGB')
234 | pil_images.append(img)
235 | except Exception as e:
236 | logger.error(f"Error opening image {img_path}: {e}")
237 | else:
238 | logger.warning(f"Image file not found: {img_path}")
239 |
240 | if not pil_images:
241 | return "No images could be loaded for analysis.", []
242 |
243 | try:
244 | # Process the images and text
245 | inputs = processor.process(
246 | images=pil_images,
247 | text=query
248 | )
249 |
250 | # Move inputs to the correct device and make a batch of size 1
251 | # Convert float tensors to half precision, but keep integer tensors as they are
252 | inputs = {k: (v.to(device).unsqueeze(0).half() if v.dtype in [torch.float32, torch.float64] else
253 | v.to(device).unsqueeze(0))
254 | if isinstance(v, torch.Tensor) else v
255 | for k, v in inputs.items()}
256 |
257 | # Generate output
258 | with torch.no_grad(): # Disable gradient calculation
259 | output = model.generate_from_batch(
260 | inputs,
261 | GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
262 | tokenizer=processor.tokenizer
263 | )
264 |
265 | # Only get generated tokens; decode them to text
266 | generated_tokens = output[0, inputs['input_ids'].size(1):]
267 | generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
268 |
269 | return generated_text, valid_images
270 |
271 | except Exception as e:
272 | logger.error(f"Error in Molmo processing: {str(e)}", exc_info=True)
273 | return f"An error occurred while processing the images: {str(e)}", []
274 | finally:
275 | # Close the opened images to free up resources
276 | for img in pil_images:
277 | img.close()
278 | elif model_choice == 'groq-llama-vision':
279 | client = load_model('groq-llama-vision')
280 |
281 | content = [{"type": "text", "text": query}]
282 |
283 | # Use only the first image
284 | if valid_images:
285 | img_path = valid_images[0]
286 | if os.path.exists(img_path):
287 | base64_image = encode_image(img_path)
288 | content.append({
289 | "type": "image_url",
290 | "image_url": {
291 | "url": f"data:image/jpeg;base64,{base64_image}"
292 | }
293 | })
294 | else:
295 | logger.warning(f"Image file not found: {img_path}")
296 |
297 | if len(content) == 1: # Only text, no images
298 | return "No images could be loaded for analysis.", []
299 |
300 | try:
301 | chat_completion = client.chat.completions.create(
302 | messages=[
303 | {
304 | "role": "user",
305 | "content": content
306 | }
307 | ],
308 | model="llava-v1.5-7b-4096-preview",
309 | )
310 | generated_text = chat_completion.choices[0].message.content
311 | logger.info("Response generated using Groq Llama Vision model.")
312 | return generated_text, valid_images
313 | except Exception as e:
314 | logger.error(f"Error in Groq Llama Vision processing: {str(e)}", exc_info=True)
315 | return f"An error occurred while processing the image: {str(e)}", []
316 | elif model_choice == 'ollama-llama-vision':
317 | try:
318 | message = {
319 | 'role': 'user',
320 | 'content': query,
321 | 'images': [valid_images[0]]
322 | }
323 |
324 | response = ollama.chat(
325 | model='llama3.2-vision',
326 | messages=[message]
327 | )
328 |
329 | logger.info("Response generated using Ollama Llama Vision model.")
330 | return response['message']['content'], valid_images
331 |
332 | except Exception as e:
333 | logger.error(f"Error in Ollama Llama Vision processing: {str(e)}", exc_info=True)
334 | return f"An error occurred while processing the image: {str(e)}", []
335 | else:
336 | logger.error(f"Invalid model choice: {model_choice}")
337 | return "Invalid model selected.", []
338 | except Exception as e:
339 | logger.error(f"Error generating response: {e}")
340 | return f"An error occurred while generating the response: {str(e)}", []
341 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | import json
4 | import time # Add this import at the top of the file
5 | from flask import Flask, render_template, request, redirect, url_for, session, flash, jsonify
6 | from markupsafe import Markup
7 | from models.indexer import index_documents
8 | from models.retriever import retrieve_documents
9 | from models.responder import generate_response
10 | from werkzeug.utils import secure_filename
11 | from logger import get_logger
12 | from byaldi import RAGMultiModalModel
13 | import markdown
14 |
15 | # Set the TOKENIZERS_PARALLELISM environment variable to suppress warnings
16 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
17 |
18 | # Initialize the Flask application
19 | app = Flask(__name__)
20 | app.secret_key = 'your_secret_key' # Replace with a secure secret key
21 |
22 | logger = get_logger(__name__)
23 |
24 | # Configure upload folders
25 | app.config['UPLOAD_FOLDER'] = 'uploaded_documents'
26 | app.config['STATIC_FOLDER'] = 'static'
27 | app.config['SESSION_FOLDER'] = 'sessions'
28 | app.config['INDEX_FOLDER'] = os.path.join(os.getcwd(), '.byaldi') # Set to .byaldi folder in current directory
29 |
30 | # Create necessary directories if they don't exist
31 | os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
32 | os.makedirs(app.config['STATIC_FOLDER'], exist_ok=True)
33 | os.makedirs(app.config['SESSION_FOLDER'], exist_ok=True)
34 |
35 | # Initialize global variables
36 | RAG_models = {} # Dictionary to store RAG models per session
37 | app.config['INITIALIZATION_DONE'] = False # Flag to track initialization
38 | logger.info("Application started.")
39 |
40 | def load_rag_model_for_session(session_id):
41 | """
42 | Loads the RAG model for the given session_id from the index on disk.
43 | """
44 | index_path = os.path.join(app.config['INDEX_FOLDER'], session_id)
45 |
46 | if os.path.exists(index_path):
47 | try:
48 | RAG = RAGMultiModalModel.from_index(index_path)
49 | RAG_models[session_id] = RAG
50 | logger.info(f"RAG model for session {session_id} loaded from index.")
51 | except Exception as e:
52 | logger.error(f"Error loading RAG model for session {session_id}: {e}")
53 | else:
54 | logger.warning(f"No index found for session {session_id}.")
55 |
56 | def load_existing_indexes():
57 | """
58 | Loads all existing indexes from the .byaldi folder when the application starts.
59 | """
60 | global RAG_models
61 | if os.path.exists(app.config['INDEX_FOLDER']):
62 | for session_id in os.listdir(app.config['INDEX_FOLDER']):
63 | if os.path.isdir(os.path.join(app.config['INDEX_FOLDER'], session_id)):
64 | load_rag_model_for_session(session_id)
65 | else:
66 | logger.warning("No .byaldi folder found. No existing indexes to load.")
67 |
68 | @app.before_request
69 | def initialize_app():
70 | """
71 | Initializes the application by loading existing indexes.
72 | This will run before the first request, but only once.
73 | """
74 | if not app.config['INITIALIZATION_DONE']:
75 | load_existing_indexes()
76 | app.config['INITIALIZATION_DONE'] = True
77 | logger.info("Application initialized and indexes loaded.")
78 |
79 | @app.before_request
80 | def make_session_permanent():
81 | session.permanent = True
82 | if 'session_id' not in session:
83 | session['session_id'] = str(uuid.uuid4())
84 |
85 |
86 | @app.route('/', methods=['GET'])
87 | def home():
88 | return redirect(url_for('chat'))
89 |
90 | @app.route('/chat', methods=['GET', 'POST'])
91 | def chat():
92 | if 'session_id' not in session:
93 | session['session_id'] = str(uuid.uuid4())
94 |
95 | session_id = session['session_id']
96 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json")
97 |
98 | # Load session data from file
99 | if os.path.exists(session_file):
100 | with open(session_file, 'r') as f:
101 | session_data = json.load(f)
102 | chat_history = session_data.get('chat_history', [])
103 | session_name = session_data.get('session_name', 'Untitled Session')
104 | indexed_files = session_data.get('indexed_files', [])
105 | else:
106 | chat_history = []
107 | session_name = 'Untitled Session'
108 | indexed_files = []
109 |
110 | if request.method == 'POST':
111 | if 'upload' in request.form:
112 | # Handle file upload and indexing
113 | files = request.files.getlist('file')
114 | session_folder = os.path.join(app.config['UPLOAD_FOLDER'], session_id)
115 | os.makedirs(session_folder, exist_ok=True)
116 | uploaded_files = []
117 | for file in files:
118 | if file and file.filename:
119 | filename = secure_filename(file.filename)
120 | file_path = os.path.join(session_folder, filename)
121 | file.save(file_path)
122 | uploaded_files.append(filename)
123 | logger.info(f"File saved: {file_path}")
124 |
125 | if uploaded_files:
126 | try:
127 | index_name = session_id
128 | index_path = os.path.join(app.config['INDEX_FOLDER'], index_name)
129 | indexer_model = session.get('indexer_model', 'vidore/colpali')
130 | RAG = index_documents(session_folder, index_name=index_name, index_path=index_path, indexer_model=indexer_model)
131 | if RAG is None:
132 | raise ValueError("Indexing failed: RAG model is None")
133 | RAG_models[session_id] = RAG
134 | session['index_name'] = index_name
135 | session['session_folder'] = session_folder
136 | indexed_files.extend(uploaded_files)
137 | session_data = {
138 | 'session_name': session_name,
139 | 'chat_history': chat_history,
140 | 'indexed_files': indexed_files
141 | }
142 | with open(session_file, 'w') as f:
143 | json.dump(session_data, f)
144 | logger.info("Documents indexed successfully.")
145 | return jsonify({
146 | "success": True,
147 | "message": "Files indexed successfully.",
148 | "indexed_files": indexed_files
149 | })
150 | except Exception as e:
151 | logger.error(f"Error indexing documents: {str(e)}")
152 | return jsonify({"success": False, "message": f"Error indexing files: {str(e)}"})
153 | else:
154 | return jsonify({"success": False, "message": "No files were uploaded."})
155 |
156 | elif 'send_query' in request.form:
157 | query = request.form['query']
158 |
159 | try:
160 | generation_model = session.get('generation_model', 'qwen')
161 | resized_height = session.get('resized_height', 280)
162 | resized_width = session.get('resized_width', 280)
163 |
164 | # Retrieve relevant documents
165 | rag_model = RAG_models.get(session_id)
166 | if rag_model is None:
167 | logger.error(f"RAG model not found for session {session_id}")
168 | return jsonify({"success": False, "message": "RAG model not found for this session."})
169 |
170 | retrieved_images = retrieve_documents(rag_model, query, session_id)
171 | logger.info(f"Retrieved images: {retrieved_images}")
172 |
173 | # Generate response with full image paths
174 | full_image_paths = [os.path.join(app.static_folder, img) for img in retrieved_images]
175 | response_text, used_images = generate_response(
176 | full_image_paths,
177 | query,
178 | session_id,
179 | resized_height,
180 | resized_width,
181 | generation_model
182 | )
183 |
184 | # Parse markdown in the response
185 | parsed_response = Markup(markdown.markdown(response_text))
186 |
187 | # Get relative paths for used images
188 | relative_images = [os.path.relpath(img, app.static_folder) for img in used_images]
189 |
190 | # Update chat history
191 | chat_history.append({"role": "user", "content": query})
192 | chat_history.append({
193 | "role": "assistant",
194 | "content": parsed_response,
195 | "images": relative_images # Use relative paths for frontend
196 | })
197 |
198 | # Update session name if it's the first message
199 | if len(chat_history) == 2: # First user message and AI response
200 | session_name = query[:50] # Truncate to 50 characters
201 |
202 | session_data = {
203 | 'session_name': session_name,
204 | 'chat_history': chat_history,
205 | 'indexed_files': indexed_files
206 | }
207 | with open(session_file, 'w') as f:
208 | json.dump(session_data, f)
209 |
210 | # Render the new messages
211 | new_messages_html = render_template('chat_messages.html', messages=[
212 | {"role": "user", "content": query},
213 | {"role": "assistant", "content": parsed_response, "images": relative_images}
214 | ])
215 |
216 | return jsonify({
217 | "success": True,
218 | "html": new_messages_html
219 | })
220 | except Exception as e:
221 | logger.error(f"Error generating response: {e}", exc_info=True)
222 | return jsonify({
223 | "success": False,
224 | "message": f"An error occurred while generating the response: {str(e)}"
225 | })
226 |
227 | # For GET requests, render the chat page
228 | session_files = os.listdir(app.config['SESSION_FOLDER'])
229 | chat_sessions = []
230 | for file in session_files:
231 | if file.endswith('.json'):
232 | s_id = file[:-5]
233 | with open(os.path.join(app.config['SESSION_FOLDER'], file), 'r') as f:
234 | data = json.load(f)
235 | name = data.get('session_name', 'Untitled Session')
236 | chat_sessions.append({'id': s_id, 'name': name})
237 |
238 | model_choice = session.get('model', 'qwen')
239 | resized_height = session.get('resized_height', 280)
240 | resized_width = session.get('resized_width', 280)
241 |
242 | return render_template('chat.html', chat_history=chat_history, chat_sessions=chat_sessions,
243 | current_session=session_id, model_choice=model_choice,
244 | resized_height=resized_height, resized_width=resized_width,
245 | session_name=session_name, indexed_files=indexed_files)
246 |
247 | @app.route('/switch_session/')
248 | def switch_session(session_id):
249 | session['session_id'] = session_id
250 | if session_id not in RAG_models:
251 | load_rag_model_for_session(session_id)
252 | flash(f"Switched to session.", "info")
253 | return redirect(url_for('chat'))
254 |
255 | @app.route('/rename_session', methods=['POST'])
256 | def rename_session():
257 | session_id = request.form.get('session_id')
258 | new_session_name = request.form.get('new_session_name', 'Untitled Session')
259 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json")
260 |
261 | if os.path.exists(session_file):
262 | with open(session_file, 'r') as f:
263 | session_data = json.load(f)
264 |
265 | session_data['session_name'] = new_session_name
266 |
267 | with open(session_file, 'w') as f:
268 | json.dump(session_data, f)
269 |
270 | return jsonify({"success": True, "message": "Session name updated."})
271 | else:
272 | return jsonify({"success": False, "message": "Session not found."})
273 |
274 | @app.route('/delete_session/', methods=['POST'])
275 | def delete_session(session_id):
276 | try:
277 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json")
278 | if os.path.exists(session_file):
279 | os.remove(session_file)
280 |
281 | session_folder = os.path.join(app.config['UPLOAD_FOLDER'], session_id)
282 | if os.path.exists(session_folder):
283 | import shutil
284 | shutil.rmtree(session_folder)
285 |
286 | session_images_folder = os.path.join('static', 'images', session_id)
287 | if os.path.exists(session_images_folder):
288 | import shutil
289 | shutil.rmtree(session_images_folder)
290 |
291 | RAG_models.pop(session_id, None)
292 |
293 | if session.get('session_id') == session_id:
294 | session['session_id'] = str(uuid.uuid4())
295 |
296 | logger.info(f"Session {session_id} deleted.")
297 | return jsonify({"success": True, "message": "Session deleted successfully."})
298 | except Exception as e:
299 | logger.error(f"Error deleting session {session_id}: {e}")
300 | return jsonify({"success": False, "message": f"An error occurred while deleting the session: {str(e)}"})
301 |
302 | @app.route('/settings', methods=['GET', 'POST'])
303 | def settings():
304 | if request.method == 'POST':
305 | indexer_model = request.form.get('indexer_model', 'vidore/colpali')
306 | generation_model = request.form.get('generation_model', 'qwen')
307 | resized_height = request.form.get('resized_height', 280)
308 | resized_width = request.form.get('resized_width', 280)
309 | session['indexer_model'] = indexer_model
310 | session['generation_model'] = generation_model
311 | session['resized_height'] = resized_height
312 | session['resized_width'] = resized_width
313 | session.modified = True
314 | logger.info(f"Settings updated: indexer_model={indexer_model}, generation_model={generation_model}, resized_height={resized_height}, resized_width={resized_width}")
315 | flash("Settings updated.", "success")
316 | return redirect(url_for('chat'))
317 | else:
318 | indexer_model = session.get('indexer_model', 'vidore/colpali')
319 | generation_model = session.get('generation_model', 'qwen')
320 | resized_height = session.get('resized_height', 280)
321 | resized_width = session.get('resized_width', 280)
322 | return render_template('settings.html',
323 | indexer_model=indexer_model,
324 | generation_model=generation_model,
325 | resized_height=resized_height,
326 | resized_width=resized_width)
327 |
328 | @app.route('/new_session')
329 | def new_session():
330 | session_id = str(uuid.uuid4())
331 | session['session_id'] = session_id
332 | session_files = os.listdir(app.config['SESSION_FOLDER'])
333 | session_number = len([f for f in session_files if f.endswith('.json')]) + 1
334 | session_name = f"Session {session_number}"
335 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json")
336 | session_data = {
337 | 'session_name': session_name,
338 | 'chat_history': [],
339 | 'indexed_files': []
340 | }
341 | with open(session_file, 'w') as f:
342 | json.dump(session_data, f)
343 | flash("New chat session started.", "success")
344 | return redirect(url_for('chat'))
345 |
346 | @app.route('/get_indexed_files/')
347 | def get_indexed_files(session_id):
348 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json")
349 | if os.path.exists(session_file):
350 | with open(session_file, 'r') as f:
351 | session_data = json.load(f)
352 | indexed_files = session_data.get('indexed_files', [])
353 | return jsonify({"success": True, "indexed_files": indexed_files})
354 | else:
355 | return jsonify({"success": False, "message": "Session not found."})
356 |
357 | if __name__ == '__main__':
358 | app.run(port=5050, debug=True)
--------------------------------------------------------------------------------
/templates/chat.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | {% extends 'base.html' %}
4 |
5 | {% block content %}
6 |
7 |
8 | {% for message in chat_history %}
9 |
10 | {% if message.role == 'user' %}
11 | {{ message.content }}
12 | {% else %}
13 | {{ message.content|safe }}
14 | {% endif %}
15 | {% if message.images %}
16 |
17 | {% for image in message.images %}
18 |
 }})
19 | {% endfor %}
20 |
21 | {% endif %}
22 |
23 | {% endfor %}
24 |
25 |
39 |
40 |
41 |
42 |
43 |
44 | Loading...
45 |
46 |
Generating response...
47 |
48 |
49 |
50 |
51 |
52 |
53 |
57 |
58 |
Do you want to index the selected files?
59 |
60 |
63 |
Indexing in progress. This may take a while...
64 |
65 |
66 |
70 |
71 |
72 |
73 |
74 |
75 |
94 |
95 |
96 |
116 |
117 |
118 |
133 |
134 | {% endblock %}
135 |
136 | {% block scripts %}
137 |
378 | {% endblock %}
--------------------------------------------------------------------------------
/jarvis/src/gui/main_window.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from PyQt6.QtWidgets import (
3 | QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
4 | QPushButton, QTextEdit, QLabel, QComboBox,
5 | QCheckBox, QLineEdit, QProgressBar, QFileDialog,
6 | QMessageBox, QSplitter, QApplication
7 | )
8 | from PyQt6.QtCore import Qt, QTimer, pyqtSignal, QThread
9 | from PyQt6.QtGui import QAction, QKeySequence, QFont
10 | from typing import Optional
11 |
12 | from ..pipeline import TranscriptionPipeline
13 | from ..models.transcriber import AVAILABLE_MODELS
14 | from ..models.text_corrector import AVAILABLE_LLM_MODELS
15 | import sounddevice as sd
16 |
17 | class TranscriptionThread(QThread):
18 | """Background thread for transcription"""
19 | textReady = pyqtSignal(str)
20 | errorOccurred = pyqtSignal(str)
21 | levelUpdate = pyqtSignal(float)
22 |
23 | def __init__(self, pipeline: TranscriptionPipeline, device_id: Optional[int] = None):
24 | super().__init__()
25 | self.pipeline = pipeline
26 | self.device_id = device_id
27 | self.is_recording = False
28 |
29 | def run(self):
30 | """Run transcription in background"""
31 | try:
32 | self.is_recording = True
33 | self.pipeline.start_recording(device_id=self.device_id)
34 |
35 | # Update audio levels
36 | while self.is_recording:
37 | level = self.pipeline.get_audio_level()
38 | self.levelUpdate.emit(level)
39 | self.msleep(50) # 50ms updates
40 |
41 | except Exception as e:
42 | self.errorOccurred.emit(str(e))
43 |
44 | def stop_recording(self):
45 | """Stop recording and get transcription"""
46 | self.is_recording = False
47 | text = self.pipeline.stop_recording()
48 | self.textReady.emit(text)
49 |
50 | class MainWindow(QMainWindow):
51 | """Main application window"""
52 |
53 | def __init__(self):
54 | super().__init__()
55 | self.pipeline = TranscriptionPipeline()
56 | self.transcription_thread = None
57 | self.init_ui()
58 | self.setup_shortcuts()
59 |
60 | def init_ui(self):
61 | """Initialize the user interface"""
62 | self.setWindowTitle("Speech Transcription")
63 | self.setGeometry(100, 100, 900, 700)
64 |
65 | # Central widget
66 | central_widget = QWidget()
67 | self.setCentralWidget(central_widget)
68 |
69 | # Main layout
70 | layout = QVBoxLayout(central_widget)
71 |
72 | # Header section
73 | header_layout = QHBoxLayout()
74 |
75 | # Microphone selection
76 | mic_label = QLabel("Microphone:")
77 | self.mic_combo = QComboBox()
78 | self.refresh_microphones()
79 | self.mic_combo.currentIndexChanged.connect(self.on_mic_changed)
80 |
81 | # Add refresh button for microphones
82 | self.refresh_mic_button = QPushButton("🔄")
83 | self.refresh_mic_button.setMaximumWidth(30)
84 | self.refresh_mic_button.clicked.connect(self.refresh_microphones)
85 | self.refresh_mic_button.setToolTip("Refresh microphone list")
86 |
87 | header_layout.addWidget(mic_label)
88 | header_layout.addWidget(self.mic_combo)
89 | header_layout.addWidget(self.refresh_mic_button)
90 |
91 | # Model selection
92 | model_label = QLabel("Model:")
93 | self.model_combo = QComboBox()
94 | for model_name, config in AVAILABLE_MODELS.items():
95 | self.model_combo.addItem(
96 | f"{model_name} - {config['description']}",
97 | model_name
98 | )
99 | self.model_combo.currentIndexChanged.connect(self.on_model_changed)
100 |
101 | header_layout.addWidget(model_label)
102 | header_layout.addWidget(self.model_combo)
103 | header_layout.addStretch()
104 |
105 | # Options
106 | self.auto_correct_check = QCheckBox("Auto-correct after recording")
107 | self.auto_correct_check.setChecked(True)
108 | header_layout.addWidget(self.auto_correct_check)
109 |
110 | layout.addLayout(header_layout)
111 |
112 | # Second header row for LLM model and context
113 | second_header_layout = QHBoxLayout()
114 |
115 | # LLM Model selection
116 | llm_label = QLabel("LLM Model:")
117 | self.llm_combo = QComboBox()
118 | for model_name, config in AVAILABLE_LLM_MODELS.items():
119 | self.llm_combo.addItem(
120 | f"{model_name} - {config['description']}",
121 | model_name
122 | )
123 | # Set default to Phi-3.5-mini
124 | default_index = list(AVAILABLE_LLM_MODELS.keys()).index("Phi-3.5-mini")
125 | self.llm_combo.setCurrentIndex(default_index)
126 | self.llm_combo.currentIndexChanged.connect(self.on_llm_model_changed)
127 |
128 | second_header_layout.addWidget(llm_label)
129 | second_header_layout.addWidget(self.llm_combo)
130 |
131 | # Context input
132 | context_label = QLabel("Context:")
133 | self.context_input = QLineEdit()
134 | self.context_input.setPlaceholderText("e.g., Medical discussion, Technical meeting")
135 | second_header_layout.addWidget(context_label)
136 | second_header_layout.addWidget(self.context_input)
137 |
138 | layout.addLayout(second_header_layout)
139 |
140 | # Audio level indicator
141 | self.level_bar = QProgressBar()
142 | self.level_bar.setMaximum(100)
143 | self.level_bar.setTextVisible(False)
144 | self.level_bar.setFixedHeight(10)
145 | layout.addWidget(self.level_bar)
146 |
147 | # Control buttons
148 | button_layout = QHBoxLayout()
149 |
150 | self.record_button = QPushButton("Start Recording")
151 | self.record_button.clicked.connect(self.toggle_recording)
152 | self.record_button.setStyleSheet("""
153 | QPushButton {
154 | background-color: #4CAF50;
155 | color: white;
156 | font-size: 16px;
157 | font-weight: bold;
158 | padding: 10px;
159 | border-radius: 5px;
160 | }
161 | QPushButton:hover {
162 | background-color: #45a049;
163 | }
164 | QPushButton:pressed {
165 | background-color: #3d8b40;
166 | }
167 | """)
168 |
169 | self.clear_button = QPushButton("Clear")
170 | self.clear_button.clicked.connect(self.clear_text)
171 |
172 | self.correct_button = QPushButton("Correct Text")
173 | self.correct_button.clicked.connect(self.correct_text)
174 |
175 | button_layout.addWidget(self.record_button)
176 | button_layout.addWidget(self.clear_button)
177 | button_layout.addWidget(self.correct_button)
178 | button_layout.addStretch()
179 |
180 | layout.addLayout(button_layout)
181 |
182 | # Text display area with splitter
183 | splitter = QSplitter(Qt.Orientation.Horizontal)
184 |
185 | # Raw transcription
186 | raw_container = QWidget()
187 | raw_layout = QVBoxLayout(raw_container)
188 |
189 | # Raw header with label and copy button
190 | raw_header = QHBoxLayout()
191 | raw_label = QLabel("Raw Transcription:")
192 | self.copy_raw_button = QPushButton("Copy")
193 | self.copy_raw_button.clicked.connect(self.copy_raw_text)
194 | raw_header.addWidget(raw_label)
195 | raw_header.addStretch()
196 | raw_header.addWidget(self.copy_raw_button)
197 | raw_layout.addLayout(raw_header)
198 |
199 | self.raw_text_edit = QTextEdit()
200 | self.raw_text_edit.setReadOnly(True)
201 | raw_layout.addWidget(self.raw_text_edit)
202 |
203 | # Corrected text
204 | corrected_container = QWidget()
205 | corrected_layout = QVBoxLayout(corrected_container)
206 |
207 | # Corrected header with label and copy button
208 | corrected_header = QHBoxLayout()
209 | corrected_label = QLabel("Corrected Text:")
210 | self.copy_corrected_button = QPushButton("Copy")
211 | self.copy_corrected_button.clicked.connect(self.copy_corrected_text)
212 | corrected_header.addWidget(corrected_label)
213 | corrected_header.addStretch()
214 | corrected_header.addWidget(self.copy_corrected_button)
215 | corrected_layout.addLayout(corrected_header)
216 |
217 | self.corrected_text_edit = QTextEdit()
218 | corrected_layout.addWidget(self.corrected_text_edit)
219 |
220 | splitter.addWidget(raw_container)
221 | splitter.addWidget(corrected_container)
222 | splitter.setSizes([450, 450])
223 |
224 | layout.addWidget(splitter)
225 |
226 | # Status bar
227 | self.status_label = QLabel("Ready")
228 | layout.addWidget(self.status_label)
229 |
230 | # Apply dark theme
231 | self.setStyleSheet("""
232 | QMainWindow {
233 | background-color: #2b2b2b;
234 | }
235 | QLabel {
236 | color: #ffffff;
237 | font-size: 14px;
238 | }
239 | QTextEdit {
240 | background-color: #3c3c3c;
241 | color: #ffffff;
242 | border: 1px solid #555;
243 | font-size: 14px;
244 | font-family: Monaco, Menlo, monospace;
245 | }
246 | QComboBox, QLineEdit {
247 | background-color: #3c3c3c;
248 | color: #ffffff;
249 | border: 1px solid #555;
250 | padding: 5px;
251 | }
252 | QComboBox::drop-down {
253 | border: none;
254 | }
255 | QComboBox::down-arrow {
256 | image: none;
257 | border-left: 5px solid transparent;
258 | border-right: 5px solid transparent;
259 | border-top: 5px solid #ffffff;
260 | margin-right: 5px;
261 | }
262 | QPushButton {
263 | background-color: #0d7377;
264 | color: white;
265 | border: none;
266 | padding: 8px 16px;
267 | font-size: 14px;
268 | border-radius: 4px;
269 | }
270 | QPushButton:hover {
271 | background-color: #14a085;
272 | }
273 | QCheckBox {
274 | color: #ffffff;
275 | font-size: 14px;
276 | }
277 | QCheckBox::indicator {
278 | width: 18px;
279 | height: 18px;
280 | }
281 | QProgressBar {
282 | background-color: #3c3c3c;
283 | border: 1px solid #555;
284 | }
285 | QProgressBar::chunk {
286 | background-color: #4CAF50;
287 | }
288 | """)
289 |
290 | def setup_shortcuts(self):
291 | """Setup keyboard shortcuts"""
292 | # Space to start/stop recording
293 | record_action = QAction("Record", self)
294 | record_action.setShortcut(QKeySequence(Qt.Key.Key_Space))
295 | record_action.triggered.connect(self.toggle_recording)
296 | self.addAction(record_action)
297 |
298 | # Escape to cancel recording
299 | cancel_action = QAction("Cancel", self)
300 | cancel_action.setShortcut(QKeySequence(Qt.Key.Key_Escape))
301 | cancel_action.triggered.connect(self.cancel_recording)
302 | self.addAction(cancel_action)
303 |
304 | # Cmd+S to save
305 | save_action = QAction("Save", self)
306 | save_action.setShortcut(QKeySequence.StandardKey.Save)
307 | save_action.triggered.connect(self.save_transcription)
308 | self.addAction(save_action)
309 |
310 | def toggle_recording(self):
311 | """Start or stop recording"""
312 | if self.transcription_thread and self.transcription_thread.isRunning():
313 | self.stop_recording()
314 | else:
315 | self.start_recording()
316 |
317 | def start_recording(self):
318 | """Start recording audio"""
319 | # Update UI
320 | self.record_button.setText("Stop Recording")
321 | self.record_button.setStyleSheet("""
322 | QPushButton {
323 | background-color: #f44336;
324 | color: white;
325 | font-size: 16px;
326 | font-weight: bold;
327 | padding: 10px;
328 | border-radius: 5px;
329 | }
330 | """)
331 | self.status_label.setText("Recording...")
332 |
333 | # Get selected model
334 | model_name = self.model_combo.currentData()
335 | if model_name != self.pipeline.transcriber.model_name:
336 | self.pipeline.set_model(model_name)
337 |
338 | # Clear previous text
339 | self.raw_text_edit.clear()
340 |
341 | # Get selected microphone
342 | device_id = self.mic_combo.currentData() if self.mic_combo.currentIndex() >= 0 else None
343 |
344 | # Start recording in background thread
345 | self.transcription_thread = TranscriptionThread(self.pipeline, device_id)
346 | self.transcription_thread.textReady.connect(self.on_transcription_ready)
347 | self.transcription_thread.errorOccurred.connect(self.on_error)
348 | self.transcription_thread.levelUpdate.connect(self.update_level)
349 | self.transcription_thread.start()
350 |
351 | def stop_recording(self):
352 | """Stop recording and transcribe"""
353 | if self.transcription_thread:
354 | self.transcription_thread.stop_recording()
355 | self.status_label.setText("Transcribing...")
356 |
357 | # Reset button
358 | self.record_button.setText("Start Recording")
359 | self.record_button.setStyleSheet("""
360 | QPushButton {
361 | background-color: #4CAF50;
362 | color: white;
363 | font-size: 16px;
364 | font-weight: bold;
365 | padding: 10px;
366 | border-radius: 5px;
367 | }
368 | """)
369 |
370 | def cancel_recording(self):
371 | """Cancel current recording"""
372 | try:
373 | if self.transcription_thread and self.transcription_thread.isRunning():
374 | # Stop recording gracefully first
375 | self.transcription_thread.is_recording = False
376 | # Wait a bit for thread to finish
377 | if not self.transcription_thread.wait(500): # 500ms timeout
378 | # If it doesn't finish, terminate it
379 | self.transcription_thread.terminate()
380 | self.transcription_thread.wait() # Wait for termination
381 |
382 | self.transcription_thread = None
383 | self.record_button.setText("Start Recording")
384 | self.record_button.setStyleSheet("""
385 | QPushButton {
386 | background-color: #4CAF50;
387 | color: white;
388 | font-size: 16px;
389 | font-weight: bold;
390 | padding: 10px;
391 | border-radius: 5px;
392 | }
393 | """)
394 | self.status_label.setText("Recording cancelled")
395 | self.level_bar.setValue(0)
396 | except Exception as e:
397 | print(f"Error in cancel_recording: {e}")
398 |
399 | def on_transcription_ready(self, text: str):
400 | """Handle transcription result"""
401 | self.raw_text_edit.setText(text)
402 | self.status_label.setText("Transcription complete")
403 | self.level_bar.setValue(0)
404 |
405 | # Auto-correct if enabled
406 | if self.auto_correct_check.isChecked() and text.strip():
407 | self.correct_text()
408 |
409 | def correct_text(self):
410 | """Correct the transcribed text"""
411 | raw_text = self.raw_text_edit.toPlainText()
412 | if not raw_text.strip():
413 | return
414 |
415 | self.status_label.setText("Correcting text...")
416 | context = self.context_input.text()
417 |
418 | # Run correction in background
419 | QTimer.singleShot(0, lambda: self._do_correction(raw_text, context))
420 |
421 | def _do_correction(self, text: str, context: str):
422 | """Perform text correction"""
423 | try:
424 | corrected = self.pipeline.correct_text(text, context)
425 | self.corrected_text_edit.setText(corrected)
426 | self.status_label.setText("Text correction complete")
427 | except Exception as e:
428 | self.on_error(f"Correction failed: {str(e)}")
429 |
430 | def update_level(self, level: float):
431 | """Update audio level indicator"""
432 | # Convert to percentage (0-100)
433 | percentage = min(int(level * 1000), 100)
434 | self.level_bar.setValue(percentage)
435 |
436 | def on_model_changed(self):
437 | """Handle model selection change"""
438 | model_name = self.model_combo.currentData()
439 | self.status_label.setText(f"Model changed to {model_name}")
440 |
441 | def clear_text(self):
442 | """Clear all text fields"""
443 | try:
444 | # Stop any running recording first
445 | if self.transcription_thread and self.transcription_thread.isRunning():
446 | self.cancel_recording()
447 |
448 | self.raw_text_edit.clear()
449 | self.corrected_text_edit.clear()
450 | self.context_input.clear()
451 | self.status_label.setText("Ready")
452 | except Exception as e:
453 | print(f"Error in clear_text: {e}")
454 | self.on_error(f"Clear failed: {str(e)}")
455 |
456 | def save_transcription(self):
457 | """Save transcription to file"""
458 | text = self.corrected_text_edit.toPlainText()
459 | if not text:
460 | text = self.raw_text_edit.toPlainText()
461 |
462 | if not text:
463 | QMessageBox.warning(self, "Warning", "No text to save")
464 | return
465 |
466 | filename, _ = QFileDialog.getSaveFileName(
467 | self, "Save Transcription", "", "Text Files (*.txt)"
468 | )
469 |
470 | if filename:
471 | try:
472 | with open(filename, 'w') as f:
473 | f.write(text)
474 | self.status_label.setText(f"Saved to {filename}")
475 | except Exception as e:
476 | self.on_error(f"Save failed: {str(e)}")
477 |
478 | def on_error(self, error_msg: str):
479 | """Handle errors"""
480 | self.status_label.setText(f"Error: {error_msg}")
481 | QMessageBox.critical(self, "Error", error_msg)
482 |
483 | def refresh_microphones(self):
484 | """Refresh the list of available microphones"""
485 | self.mic_combo.clear()
486 |
487 | # Get list of audio input devices
488 | devices = sd.query_devices()
489 | current_default = sd.default.device[0] # Default input device
490 |
491 | default_index = 0
492 | for i, device in enumerate(devices):
493 | if device['max_input_channels'] > 0: # Only input devices
494 | device_name = f"{device['name']} ({device['hostapi']})"
495 | self.mic_combo.addItem(device_name, i)
496 |
497 | # Track default device
498 | if i == current_default:
499 | default_index = self.mic_combo.count() - 1
500 |
501 | # Set to default device
502 | self.mic_combo.setCurrentIndex(default_index)
503 |
504 | def on_mic_changed(self):
505 | """Handle microphone selection change"""
506 | if self.mic_combo.currentIndex() >= 0:
507 | device_id = self.mic_combo.currentData()
508 | device_name = self.mic_combo.currentText()
509 | self.status_label.setText(f"Microphone changed to {device_name}")
510 |
511 | def copy_raw_text(self):
512 | """Copy raw transcription to clipboard"""
513 | text = self.raw_text_edit.toPlainText()
514 | if text:
515 | clipboard = QApplication.clipboard()
516 | clipboard.setText(text)
517 | self.status_label.setText("Raw text copied to clipboard")
518 |
519 | def copy_corrected_text(self):
520 | """Copy corrected text to clipboard"""
521 | text = self.corrected_text_edit.toPlainText()
522 | if text:
523 | clipboard = QApplication.clipboard()
524 | clipboard.setText(text)
525 | self.status_label.setText("Corrected text copied to clipboard")
526 |
527 | def on_llm_model_changed(self):
528 | """Handle LLM model selection change"""
529 | model_name = self.llm_combo.currentData()
530 | self.status_label.setText(f"LLM model changed to {model_name}")
531 | # Update the pipeline with new LLM model
532 | self.pipeline.set_llm_model(model_name)
--------------------------------------------------------------------------------