├── static ├── js │ └── script.js └── css │ ├── style.css │ └── styles.css ├── templates ├── index.html ├── chat_messages.html ├── settings.html ├── base.html └── chat.html ├── jarvis ├── tests │ ├── __init__.py │ ├── unit │ │ ├── __init__.py │ │ └── test_audio_processor.py │ └── integration │ │ ├── __init__.py │ │ └── test_pipeline.py ├── src │ ├── audio │ │ ├── __init__.py │ │ ├── recorder.py │ │ └── processor.py │ ├── gui │ │ ├── __init__.py │ │ ├── system_tray.py │ │ └── main_window.py │ ├── models │ │ ├── __init__.py │ │ ├── transcriber.py │ │ └── text_corrector.py │ ├── __init__.py │ ├── main.py │ ├── cli.py │ └── pipeline.py ├── run_gui.sh ├── requirements.txt ├── setup.py ├── Makefile ├── .gitignore ├── scripts │ └── download_models.py ├── README.md └── CLAUDE.md ├── requirements.txt ├── models ├── converters.py ├── indexer.py ├── retriever.py ├── model_loader.py └── responder.py ├── logger.py ├── .gitignore ├── README.md └── app.py /static/js/script.js: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jarvis/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Tests package -------------------------------------------------------------------------------- /jarvis/tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Unit tests package -------------------------------------------------------------------------------- /jarvis/tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | # Integration tests package -------------------------------------------------------------------------------- /jarvis/src/audio/__init__.py: -------------------------------------------------------------------------------- 1 | """Audio processing modules""" 2 | 3 | from .recorder import AudioRecorder 4 | from .processor import AudioProcessor 5 | 6 | __all__ = ["AudioRecorder", "AudioProcessor"] -------------------------------------------------------------------------------- /jarvis/src/gui/__init__.py: -------------------------------------------------------------------------------- 1 | """GUI components for the speech transcription application""" 2 | 3 | from .main_window import MainWindow 4 | from .system_tray import SystemTray 5 | 6 | __all__ = ["MainWindow", "SystemTray"] -------------------------------------------------------------------------------- /jarvis/run_gui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set Qt plugin path for PyQt6 4 | export QT_QPA_PLATFORM_PLUGIN_PATH="/Users/prompt/anaconda3/lib/python3.11/site-packages/PyQt6/Qt6/plugins/platforms" 5 | 6 | # Run the GUI application 7 | python -m src.main -------------------------------------------------------------------------------- /jarvis/requirements.txt: -------------------------------------------------------------------------------- 1 | sounddevice>=0.4.6 2 | numpy>=1.24.0 3 | scipy>=1.10.0 4 | PyQt6>=6.5.0 5 | mlx>=0.5.0 6 | mlx-lm>=0.2.0 7 | mlx-whisper>=0.2.0 8 | requests>=2.31.0 9 | huggingface-hub>=0.16.0 10 | pynput>=1.7.0 11 | soundfile>=0.12.0 12 | pytest>=7.4.0 13 | black>=23.7.0 14 | ruff>=0.0.280 15 | mypy>=1.4.0 -------------------------------------------------------------------------------- /jarvis/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Machine learning models for transcription and text correction""" 2 | 3 | from .transcriber import SpeechTranscriber, AVAILABLE_MODELS 4 | from .text_corrector import TextCorrector, AVAILABLE_LLM_MODELS 5 | 6 | __all__ = ["SpeechTranscriber", "TextCorrector", "AVAILABLE_MODELS", "AVAILABLE_LLM_MODELS"] -------------------------------------------------------------------------------- /jarvis/src/__init__.py: -------------------------------------------------------------------------------- 1 | """Speech Transcription Application""" 2 | 3 | __version__ = "1.0.0" 4 | 5 | from .pipeline import TranscriptionPipeline 6 | from .models.transcriber import SpeechTranscriber 7 | from .models.text_corrector import TextCorrector 8 | 9 | __all__ = [ 10 | "TranscriptionPipeline", 11 | "SpeechTranscriber", 12 | "TextCorrector" 13 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask 2 | byaldi 3 | cmake 4 | pkgconfig 5 | python-poppler 6 | torch 7 | torchvision 8 | google-generativeai 9 | openai 10 | docx2pdf 11 | qwen-vl-utils 12 | vllm>=0.6.1.post1; sys_platform != 'darwin' 13 | mistral_common>=1.4.1 14 | einops 15 | mistral_common[opencv] 16 | mistral_common 17 | mistral_inference 18 | groq 19 | markdown 20 | hf_transfer 21 | ollama -------------------------------------------------------------------------------- /jarvis/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="speech-transcription", 5 | version="1.0.0", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "sounddevice>=0.4.6", 9 | "numpy>=1.24.0", 10 | "scipy>=1.10.0", 11 | "PyQt6>=6.5.0", 12 | "mlx>=0.5.0", 13 | "mlx-lm>=0.2.0", 14 | "mlx-whisper>=0.2.0", 15 | "soundfile>=0.12.0", 16 | ], 17 | entry_points={ 18 | "console_scripts": [ 19 | "speech-transcribe=src.cli:main", 20 | "speech-transcribe-gui=src.main:main", 21 | ], 22 | }, 23 | python_requires=">=3.9", 24 | ) -------------------------------------------------------------------------------- /jarvis/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile 2 | .PHONY: install test run lint format clean 3 | 4 | install: 5 | pip install -r requirements.txt 6 | pip install -e . 7 | 8 | test: 9 | pytest tests/ -v 10 | 11 | test-unit: 12 | pytest tests/unit/ -v 13 | 14 | test-integration: 15 | pytest tests/integration/ -v 16 | 17 | run: 18 | python -m src.main 19 | 20 | run-cli: 21 | python -m src.cli --help 22 | 23 | lint: 24 | ruff check src tests 25 | mypy src 26 | 27 | format: 28 | black src tests 29 | ruff format src tests 30 | 31 | clean: 32 | find . -type f -name "*.pyc" -delete 33 | find . -type d -name "__pycache__" -delete 34 | rm -rf build dist *.egg-info 35 | 36 | download-models: 37 | python scripts/download_models.py -------------------------------------------------------------------------------- /jarvis/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | venv/ 25 | ENV/ 26 | env/ 27 | 28 | # IDE 29 | .vscode/ 30 | .idea/ 31 | *.swp 32 | *.swo 33 | 34 | # macOS 35 | .DS_Store 36 | .AppleDouble 37 | .LSOverride 38 | 39 | # Temporary files 40 | *.tmp 41 | *.temp 42 | .cache/ 43 | 44 | # Model cache (but not src/models/) 45 | /models/ 46 | *.bin 47 | *.safetensors 48 | 49 | # Audio files 50 | *.wav 51 | *.mp3 52 | *.m4a 53 | *.flac 54 | 55 | # Logs 56 | *.log 57 | logs/ 58 | 59 | # pytest 60 | .pytest_cache/ -------------------------------------------------------------------------------- /templates/chat_messages.html: -------------------------------------------------------------------------------- 1 | {% for message in messages %} 2 |
3 | {% if message.role == 'user' %} 4 | {{ message.content }} 5 | {% else %} 6 | {{ message.content|safe }} 7 | {% endif %} 8 | {% if message.images %} 9 |
10 | {% for image in message.images %} 11 | Retrieved Image 12 | {% endfor %} 13 |
14 | {% endif %} 15 |
16 | {% endfor %} -------------------------------------------------------------------------------- /models/converters.py: -------------------------------------------------------------------------------- 1 | # models/converters.py 2 | 3 | import os 4 | from docx2pdf import convert 5 | from logger import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | def convert_docs_to_pdfs(folder_path): 10 | """ 11 | Converts .doc and .docx files in the folder to PDFs. 12 | 13 | Args: 14 | folder_path (str): The path to the folder containing documents. 15 | """ 16 | try: 17 | for filename in os.listdir(folder_path): 18 | if filename.lower().endswith(('.doc', '.docx')): 19 | doc_path = os.path.join(folder_path, filename) 20 | pdf_path = os.path.splitext(doc_path)[0] + '.pdf' 21 | convert(doc_path, pdf_path) 22 | logger.info(f"Converted '{filename}' to PDF.") 23 | except Exception as e: 24 | logger.error(f"Error converting documents to PDFs: {e}") 25 | raise -------------------------------------------------------------------------------- /jarvis/tests/integration/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from src.pipeline import TranscriptionPipeline 4 | 5 | class TestPipeline: 6 | @pytest.fixture 7 | def pipeline(self): 8 | return TranscriptionPipeline(model_name="whisper-tiny") 9 | 10 | def test_transcription_flow(self, pipeline, tmp_path): 11 | # Create test audio 12 | duration = 3 # seconds 13 | sample_rate = 16000 14 | t = np.linspace(0, duration, duration * sample_rate) 15 | 16 | # Simple tone 17 | audio = np.sin(2 * np.pi * 440 * t) * 0.5 18 | 19 | # Save to file 20 | import soundfile as sf 21 | test_file = tmp_path / "test.wav" 22 | sf.write(test_file, audio, sample_rate) 23 | 24 | # Transcribe 25 | result = pipeline.transcribe_file(str(test_file)) 26 | 27 | # Should return empty for pure tone 28 | assert isinstance(result, str) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # logger.py 2 | 3 | import logging 4 | 5 | def get_logger(name): 6 | """ 7 | Creates a logger with the specified name. 8 | 9 | Args: 10 | name (str): The name of the logger. 11 | 12 | Returns: 13 | Logger: Configured logger instance. 14 | """ 15 | logger = logging.getLogger(name) 16 | logger.setLevel(logging.DEBUG) 17 | 18 | if not logger.handlers: 19 | # Console handler 20 | c_handler = logging.StreamHandler() 21 | c_handler.setLevel(logging.INFO) 22 | 23 | # File handler 24 | f_handler = logging.FileHandler('app.log') 25 | f_handler.setLevel(logging.DEBUG) 26 | 27 | # Create formatters and add them to handlers 28 | c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s') 29 | f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s') 30 | c_handler.setFormatter(c_format) 31 | f_handler.setFormatter(f_format) 32 | 33 | # Add handlers to the logger 34 | logger.addHandler(c_handler) 35 | logger.addHandler(f_handler) 36 | 37 | return logger 38 | -------------------------------------------------------------------------------- /jarvis/scripts/download_models.py: -------------------------------------------------------------------------------- 1 | """Download required ML models""" 2 | 3 | from huggingface_hub import snapshot_download 4 | import os 5 | 6 | # Models to download 7 | MODELS = [ 8 | # Whisper models for transcription 9 | "mlx-community/whisper-tiny", 10 | "mlx-community/whisper-base-mlx", 11 | 12 | # LLM models for text correction (at least one) 13 | "mlx-community/Phi-3.5-mini-instruct-4bit", 14 | "mlx-community/Qwen2.5-0.5B-Instruct-4bit", # Tiny option 15 | # Add more as needed: 16 | # "mlx-community/gemma-2-2b-it-4bit", 17 | # "mlx-community/Mistral-7B-Instruct-v0.3-4bit", 18 | ] 19 | 20 | def download_models(): 21 | """Download all required models""" 22 | cache_dir = os.path.expanduser("~/.cache/huggingface/hub") 23 | 24 | for model_id in MODELS: 25 | print(f"Downloading {model_id}...") 26 | try: 27 | snapshot_download( 28 | repo_id=model_id, 29 | cache_dir=cache_dir, 30 | resume_download=True 31 | ) 32 | print(f"✓ {model_id} downloaded successfully") 33 | except Exception as e: 34 | print(f"✗ Failed to download {model_id}: {e}") 35 | 36 | if __name__ == "__main__": 37 | download_models() -------------------------------------------------------------------------------- /jarvis/tests/unit/test_audio_processor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from src.audio.processor import AudioProcessor 4 | 5 | class TestAudioProcessor: 6 | def test_normalize_audio(self): 7 | processor = AudioProcessor() 8 | audio = np.array([0.5, -0.5, 0.25, -0.25]) 9 | normalized = processor.normalize_audio(audio) 10 | assert normalized.max() == 1.0 or normalized.min() == -1.0 11 | 12 | def test_silence_detection(self): 13 | processor = AudioProcessor() 14 | 15 | # Test silence 16 | silence = np.zeros(16000) 17 | assert processor.is_silence(silence) 18 | 19 | # Test non-silence 20 | tone = np.sin(2 * np.pi * 440 * np.linspace(0, 1, 16000)) 21 | assert not processor.is_silence(tone) 22 | 23 | def test_resample(self): 24 | processor = AudioProcessor() 25 | 26 | # Create 1 second of audio at 44100 Hz 27 | original_sr = 44100 28 | audio = np.sin(2 * np.pi * 440 * np.linspace(0, 1, original_sr)) 29 | 30 | # Resample to 16000 Hz 31 | resampled = processor.resample(audio, original_sr, 16000) 32 | 33 | # Check length 34 | expected_length = 16000 35 | assert abs(len(resampled) - expected_length) < 10 -------------------------------------------------------------------------------- /static/css/style.css: -------------------------------------------------------------------------------- 1 | /* static/css/style.css */ 2 | 3 | /* Customize chat window */ 4 | .chat-window { 5 | background-color: #f8f9fa; 6 | border-radius: 5px; 7 | padding: 15px; 8 | } 9 | 10 | /* Customize messages */ 11 | .message { 12 | margin-bottom: 10px; 13 | } 14 | 15 | .user-message { 16 | text-align: right; 17 | } 18 | 19 | .bot-message { 20 | text-align: left; 21 | } 22 | 23 | /* Customize images */ 24 | .img-thumbnail { 25 | cursor: pointer; 26 | } 27 | 28 | /* Sidebar */ 29 | .list-group-item.active { 30 | background-color: #007bff; 31 | border-color: #007bff; 32 | color: #fff; 33 | } 34 | 35 | .list-group-item a { 36 | color: inherit; 37 | text-decoration: none; 38 | } 39 | 40 | .list-group-item a:hover { 41 | text-decoration: none; 42 | } 43 | 44 | /* Delete Session Button */ 45 | .delete-session-btn { 46 | font-size: 0.8rem; 47 | padding: 0.2rem 0.5rem; 48 | } 49 | 50 | /* Spinner Overlay */ 51 | .spinner-overlay { 52 | position: fixed; 53 | top: 0; 54 | left: 0; 55 | right: 0; 56 | bottom: 0; 57 | z-index: 1050; /* Higher than modal */ 58 | background-color: rgba(255, 255, 255, 0.7); 59 | display: flex; 60 | align-items: center; 61 | justify-content: center; 62 | } 63 | 64 | /* Spinner */ 65 | .spinner-border { 66 | width: 3rem; 67 | height: 3rem; 68 | } 69 | -------------------------------------------------------------------------------- /jarvis/src/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Disable tokenizers parallelism warning 5 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 6 | 7 | # Set Qt plugin path for macOS 8 | import platform 9 | if platform.system() == "Darwin": # macOS 10 | # Try to find PyQt6 installation 11 | try: 12 | import PyQt6 13 | pyqt6_path = os.path.dirname(PyQt6.__file__) 14 | plugin_path = os.path.join(pyqt6_path, "Qt6", "plugins", "platforms") 15 | if os.path.exists(plugin_path): 16 | os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = plugin_path 17 | except: 18 | pass 19 | 20 | from PyQt6.QtWidgets import QApplication 21 | from PyQt6.QtCore import Qt 22 | from .gui.main_window import MainWindow 23 | from .gui.system_tray import SystemTray 24 | 25 | def main(): 26 | """Main application entry point""" 27 | # Enable high DPI support 28 | if hasattr(Qt.ApplicationAttribute, 'AA_UseHighDpiPixmaps'): 29 | QApplication.setAttribute(Qt.ApplicationAttribute.AA_UseHighDpiPixmaps) 30 | 31 | app = QApplication(sys.argv) 32 | app.setApplicationName("Speech Transcription") 33 | app.setOrganizationName("SpeechTranscription") 34 | 35 | # Create main window 36 | window = MainWindow() 37 | 38 | # Create system tray 39 | tray = SystemTray() 40 | tray.showMainWindow.connect(window.show) 41 | tray.quitApp.connect(app.quit) 42 | 43 | # Show window 44 | window.show() 45 | 46 | # Run app 47 | sys.exit(app.exec()) 48 | 49 | if __name__ == "__main__": 50 | main() -------------------------------------------------------------------------------- /models/indexer.py: -------------------------------------------------------------------------------- 1 | # models/indexer.py 2 | 3 | import os 4 | from byaldi import RAGMultiModalModel 5 | from models.converters import convert_docs_to_pdfs 6 | from logger import get_logger 7 | 8 | logger = get_logger(__name__) 9 | 10 | def index_documents(folder_path, index_name='document_index', index_path=None, indexer_model='vidore/colpali'): 11 | """ 12 | Indexes documents in the specified folder using Byaldi. 13 | 14 | Args: 15 | folder_path (str): The path to the folder containing documents to index. 16 | index_name (str): The name of the index to create or update. 17 | index_path (str): The path where the index should be saved. 18 | indexer_model (str): The name of the indexer model to use. 19 | 20 | Returns: 21 | RAGMultiModalModel: The RAG model with the indexed documents. 22 | """ 23 | try: 24 | logger.info(f"Starting document indexing in folder: {folder_path}") 25 | # Convert non-PDF documents to PDFs 26 | convert_docs_to_pdfs(folder_path) 27 | logger.info("Conversion of non-PDF documents to PDFs completed.") 28 | 29 | # Initialize RAG model 30 | RAG = RAGMultiModalModel.from_pretrained(indexer_model) 31 | if RAG is None: 32 | raise ValueError(f"Failed to initialize RAGMultiModalModel with model {indexer_model}") 33 | logger.info(f"RAG model initialized with {indexer_model}.") 34 | 35 | # Index the documents in the folder 36 | RAG.index( 37 | input_path=folder_path, 38 | index_name=index_name, 39 | store_collection_with_index=True, 40 | overwrite=True 41 | ) 42 | 43 | logger.info(f"Indexing completed. Index saved at '{index_path}'.") 44 | 45 | return RAG 46 | except Exception as e: 47 | logger.error(f"Error during indexing: {str(e)}") 48 | raise -------------------------------------------------------------------------------- /jarvis/src/audio/recorder.py: -------------------------------------------------------------------------------- 1 | import sounddevice as sd 2 | import numpy as np 3 | from typing import Callable, Optional 4 | import threading 5 | import queue 6 | 7 | class AudioRecorder: 8 | def __init__( 9 | self, 10 | sample_rate: int = 16000, 11 | chunk_duration: float = 0.5, 12 | callback: Optional[Callable[[np.ndarray], None]] = None 13 | ): 14 | self.sample_rate = sample_rate 15 | self.chunk_size = int(sample_rate * chunk_duration) 16 | self.callback = callback 17 | self.audio_queue = queue.Queue() 18 | self.is_recording = False 19 | 20 | def audio_callback(self, indata, frames, time, status): 21 | """Callback for sounddevice stream""" 22 | if status: 23 | print(f"Audio callback status: {status}") 24 | 25 | # Copy audio data to queue 26 | audio_chunk = indata[:, 0].copy() # Get first channel 27 | self.audio_queue.put(audio_chunk) 28 | 29 | # Call user callback if provided 30 | if self.callback: 31 | self.callback(audio_chunk) 32 | 33 | def start_recording(self, device_id: Optional[int] = None): 34 | """Start audio recording""" 35 | self.is_recording = True 36 | 37 | # Start audio stream 38 | self.stream = sd.InputStream( 39 | device=device_id, 40 | channels=1, 41 | samplerate=self.sample_rate, 42 | callback=self.audio_callback, 43 | blocksize=self.chunk_size 44 | ) 45 | self.stream.start() 46 | 47 | def stop_recording(self) -> np.ndarray: 48 | """Stop recording and return accumulated audio""" 49 | self.is_recording = False 50 | 51 | if hasattr(self, 'stream'): 52 | self.stream.stop() 53 | self.stream.close() 54 | 55 | # Collect all audio chunks 56 | audio_chunks = [] 57 | while not self.audio_queue.empty(): 58 | chunk = self.audio_queue.get() 59 | audio_chunks.append(chunk) 60 | 61 | if audio_chunks: 62 | return np.concatenate(audio_chunks) 63 | return np.array([]) 64 | 65 | def get_audio_level(self) -> float: 66 | """Get current audio level for visualization""" 67 | if not self.audio_queue.empty(): 68 | chunk = self.audio_queue.queue[-1] # Peek at last item 69 | return float(np.abs(chunk).mean()) 70 | return 0.0 -------------------------------------------------------------------------------- /models/retriever.py: -------------------------------------------------------------------------------- 1 | # models/retriever.py 2 | 3 | import base64 4 | import os 5 | from PIL import Image 6 | from io import BytesIO 7 | from logger import get_logger 8 | import time 9 | import hashlib 10 | 11 | logger = get_logger(__name__) 12 | 13 | def retrieve_documents(RAG, query, session_id, k=3): 14 | """ 15 | Retrieves relevant documents based on the user query using Byaldi. 16 | 17 | Args: 18 | RAG (RAGMultiModalModel): The RAG model with the indexed documents. 19 | query (str): The user's query. 20 | session_id (str): The session ID to store images in per-session folder. 21 | k (int): The number of documents to retrieve. 22 | 23 | Returns: 24 | list: A list of image filenames corresponding to the retrieved documents. 25 | """ 26 | try: 27 | logger.info(f"Retrieving documents for query: {query}") 28 | results = RAG.search(query, k=k) 29 | images = [] 30 | session_images_folder = os.path.join('static', 'images', session_id) 31 | os.makedirs(session_images_folder, exist_ok=True) 32 | 33 | for i, result in enumerate(results): 34 | if result.base64: 35 | image_data = base64.b64decode(result.base64) 36 | image = Image.open(BytesIO(image_data)) 37 | 38 | # Generate a unique filename based on the image content 39 | image_hash = hashlib.md5(image_data).hexdigest() 40 | image_filename = f"retrieved_{image_hash}.png" 41 | image_path = os.path.join(session_images_folder, image_filename) 42 | 43 | if not os.path.exists(image_path): 44 | image.save(image_path, format='PNG') 45 | logger.debug(f"Retrieved and saved image: {image_path}") 46 | else: 47 | logger.debug(f"Image already exists: {image_path}") 48 | 49 | # Store the relative path from the static folder 50 | relative_path = os.path.join('images', session_id, image_filename) 51 | images.append(relative_path) 52 | logger.info(f"Added image to list: {relative_path}") 53 | else: 54 | logger.warning(f"No base64 data for document {result.doc_id}, page {result.page_num}") 55 | 56 | logger.info(f"Total {len(images)} documents retrieved. Image paths: {images}") 57 | return images 58 | except Exception as e: 59 | logger.error(f"Error retrieving documents: {e}") 60 | return [] -------------------------------------------------------------------------------- /jarvis/src/gui/system_tray.py: -------------------------------------------------------------------------------- 1 | from PyQt6.QtWidgets import QSystemTrayIcon, QMenu, QApplication 2 | from PyQt6.QtGui import QIcon, QAction 3 | from PyQt6.QtCore import QObject, pyqtSignal 4 | import os 5 | 6 | class SystemTray(QObject): 7 | """System tray integration for background operation""" 8 | 9 | showMainWindow = pyqtSignal() 10 | startRecording = pyqtSignal() 11 | quitApp = pyqtSignal() 12 | 13 | def __init__(self, parent=None): 14 | super().__init__(parent) 15 | self.tray_icon = None 16 | self.init_tray() 17 | 18 | def init_tray(self): 19 | """Initialize system tray icon""" 20 | # Create tray icon 21 | self.tray_icon = QSystemTrayIcon(self) 22 | self.tray_icon.setIcon(self.get_icon()) 23 | 24 | # Create menu 25 | tray_menu = QMenu() 26 | 27 | # Show action 28 | show_action = QAction("Show", self) 29 | show_action.triggered.connect(self.showMainWindow.emit) 30 | tray_menu.addAction(show_action) 31 | 32 | # Record action 33 | record_action = QAction("Quick Record", self) 34 | record_action.triggered.connect(self.startRecording.emit) 35 | tray_menu.addAction(record_action) 36 | 37 | tray_menu.addSeparator() 38 | 39 | # Quit action 40 | quit_action = QAction("Quit", self) 41 | quit_action.triggered.connect(self.quitApp.emit) 42 | tray_menu.addAction(quit_action) 43 | 44 | # Set menu and show 45 | self.tray_icon.setContextMenu(tray_menu) 46 | self.tray_icon.show() 47 | 48 | # Handle clicks 49 | self.tray_icon.activated.connect(self.on_tray_activated) 50 | 51 | def get_icon(self): 52 | """Get or create tray icon""" 53 | # Create a simple icon programmatically 54 | from PyQt6.QtGui import QPixmap, QPainter, QBrush 55 | from PyQt6.QtCore import Qt 56 | 57 | pixmap = QPixmap(32, 32) 58 | pixmap.fill(Qt.GlobalColor.transparent) 59 | 60 | painter = QPainter(pixmap) 61 | painter.setRenderHint(QPainter.RenderHint.Antialiasing) 62 | 63 | # Draw microphone icon 64 | painter.setBrush(QBrush(Qt.GlobalColor.white)) 65 | painter.setPen(Qt.PenStyle.NoPen) 66 | 67 | # Mic body 68 | painter.drawEllipse(10, 5, 12, 18) 69 | 70 | # Mic stand 71 | painter.drawRect(14, 23, 4, 5) 72 | 73 | # Mic base 74 | painter.drawRect(10, 28, 12, 2) 75 | 76 | painter.end() 77 | 78 | return QIcon(pixmap) 79 | 80 | def on_tray_activated(self, reason): 81 | """Handle tray icon activation""" 82 | if reason == QSystemTrayIcon.ActivationReason.DoubleClick: 83 | self.showMainWindow.emit() 84 | 85 | def set_recording_state(self, is_recording: bool): 86 | """Update tray icon for recording state""" 87 | if is_recording: 88 | self.tray_icon.setToolTip("Speech Transcription - Recording...") 89 | # Could update icon to show recording state 90 | else: 91 | self.tray_icon.setToolTip("Speech Transcription") -------------------------------------------------------------------------------- /templates/settings.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | {% extends 'base.html' %} 4 | 5 | {% block content %} 6 |
7 |

Settings

8 |
9 |

Retrieval Model

10 |
11 | 12 | 17 |
18 | 19 |

Generation Model

20 |
21 | 22 | 32 |
33 | 34 |

Image Settings

35 |
36 | 37 | 38 |
39 |
40 | 41 | 42 |
43 | 44 |
45 |
46 | {% endblock %} 47 | 48 | {% block styles %} 49 | 63 | {% endblock %} 64 | -------------------------------------------------------------------------------- /jarvis/src/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from pathlib import Path 4 | from .pipeline import TranscriptionPipeline 5 | from .models.transcriber import AVAILABLE_MODELS 6 | 7 | def main(): 8 | """CLI entry point""" 9 | parser = argparse.ArgumentParser( 10 | description="Speech Transcription CLI" 11 | ) 12 | 13 | parser.add_argument( 14 | "--audio", 15 | type=str, 16 | help="Path to audio file to transcribe" 17 | ) 18 | 19 | parser.add_argument( 20 | "--model", 21 | type=str, 22 | default="whisper-tiny", 23 | choices=list(AVAILABLE_MODELS.keys()), 24 | help="Model to use for transcription" 25 | ) 26 | 27 | parser.add_argument( 28 | "--output", 29 | type=str, 30 | help="Output file path (default: stdout)" 31 | ) 32 | 33 | parser.add_argument( 34 | "--no-correction", 35 | action="store_true", 36 | help="Skip AI text correction" 37 | ) 38 | 39 | parser.add_argument( 40 | "--context", 41 | type=str, 42 | help="Context for better corrections" 43 | ) 44 | 45 | parser.add_argument( 46 | "--list-models", 47 | action="store_true", 48 | help="List available models" 49 | ) 50 | 51 | parser.add_argument( 52 | "--keep-fillers", 53 | action="store_true", 54 | help="Keep filler words in transcription" 55 | ) 56 | 57 | args = parser.parse_args() 58 | 59 | # List models 60 | if args.list_models: 61 | print("\nAvailable models:") 62 | for name, config in AVAILABLE_MODELS.items(): 63 | print(f" {name}: {config['description']}") 64 | return 0 65 | 66 | # Validate audio file 67 | if not args.audio: 68 | parser.error("--audio is required") 69 | 70 | audio_path = Path(args.audio) 71 | if not audio_path.exists(): 72 | print(f"Error: Audio file not found: {audio_path}") 73 | return 1 74 | 75 | # Initialize pipeline 76 | try: 77 | pipeline = TranscriptionPipeline(model_name=args.model) 78 | except Exception as e: 79 | print(f"Error initializing pipeline: {e}") 80 | return 1 81 | 82 | # Transcribe 83 | print(f"Transcribing {audio_path} with {args.model}...") 84 | try: 85 | text = pipeline.transcribe_file(str(audio_path)) 86 | 87 | if not text: 88 | print("No speech detected in audio file") 89 | return 1 90 | 91 | # Apply correction if requested 92 | if not args.no_correction: 93 | print("Applying text correction...") 94 | text = pipeline.correct_text( 95 | text, 96 | context=args.context 97 | ) 98 | 99 | # Output result 100 | if args.output: 101 | with open(args.output, 'w') as f: 102 | f.write(text) 103 | print(f"Transcription saved to {args.output}") 104 | else: 105 | print("\nTranscription:") 106 | print("-" * 50) 107 | print(text) 108 | print("-" * 50) 109 | 110 | return 0 111 | 112 | except Exception as e: 113 | print(f"Error during transcription: {e}") 114 | return 1 115 | 116 | if __name__ == "__main__": 117 | sys.exit(main()) -------------------------------------------------------------------------------- /jarvis/src/pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Optional, Callable 3 | from .audio.recorder import AudioRecorder 4 | from .audio.processor import AudioProcessor 5 | from .models.transcriber import SpeechTranscriber 6 | from .models.text_corrector import TextCorrector 7 | 8 | class TranscriptionPipeline: 9 | """Main pipeline orchestrating transcription flow""" 10 | 11 | def __init__( 12 | self, 13 | model_name: str = "whisper-tiny", 14 | llm_model_name: str = "Phi-3.5-mini", 15 | sample_rate: int = 16000 16 | ): 17 | self.sample_rate = sample_rate 18 | self.audio_recorder = None 19 | self.audio_processor = AudioProcessor(sample_rate) 20 | self.transcriber = SpeechTranscriber(model_name) 21 | self.text_corrector = TextCorrector(llm_model_name) 22 | 23 | def set_model(self, model_name: str): 24 | """Change the transcription model""" 25 | self.transcriber.load_model(model_name) 26 | 27 | def set_llm_model(self, model_name: str): 28 | """Change the LLM model for text correction""" 29 | self.text_corrector.set_model(model_name) 30 | 31 | def start_recording(self, callback: Optional[Callable] = None, device_id: Optional[int] = None): 32 | """Start recording audio""" 33 | self.audio_recorder = AudioRecorder( 34 | sample_rate=self.sample_rate, 35 | callback=callback 36 | ) 37 | self.audio_recorder.start_recording(device_id=device_id) 38 | 39 | def stop_recording(self) -> str: 40 | """Stop recording and return transcription""" 41 | if not self.audio_recorder: 42 | return "" 43 | 44 | # Stop recording 45 | final_audio = self.audio_recorder.stop_recording() 46 | 47 | # Use only the audio from stop_recording (which already includes all chunks) 48 | audio = final_audio 49 | 50 | # Process audio 51 | processed_audio, has_speech = self.audio_processor.process( 52 | audio, self.sample_rate 53 | ) 54 | 55 | if not has_speech: 56 | return "" 57 | 58 | # Transcribe 59 | text = self.transcriber.transcribe(processed_audio, self.sample_rate) 60 | return text 61 | 62 | def transcribe_file(self, file_path: str) -> str: 63 | """Transcribe an audio file""" 64 | import soundfile as sf 65 | 66 | # Load audio 67 | audio, sr = sf.read(file_path) 68 | 69 | # Process 70 | processed_audio, has_speech = self.audio_processor.process(audio, sr) 71 | 72 | if not has_speech: 73 | return "" 74 | 75 | # Transcribe 76 | return self.transcriber.transcribe(processed_audio, self.sample_rate) 77 | 78 | def correct_text( 79 | self, 80 | text: str, 81 | context: Optional[str] = None 82 | ) -> str: 83 | """Apply AI correction to transcribed text""" 84 | return self.text_corrector.correct(text, context) 85 | 86 | def get_audio_level(self) -> float: 87 | """Get current audio level""" 88 | if self.audio_recorder: 89 | return self.audio_recorder.get_audio_level() 90 | return 0.0 91 | 92 | def process_stream( 93 | self, 94 | audio_chunk: np.ndarray, 95 | callback: Callable[[str], None] 96 | ): 97 | """Process audio stream in real-time""" 98 | # This would be used for streaming mode 99 | # Process chunk and transcribe incrementally 100 | pass -------------------------------------------------------------------------------- /jarvis/src/audio/processor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import signal 3 | from typing import Tuple, Optional 4 | 5 | class AudioProcessor: 6 | def __init__(self, target_sample_rate: int = 16000): 7 | self.target_sample_rate = target_sample_rate 8 | self.silence_threshold = 0.01 9 | self.min_audio_length = 0.5 # seconds 10 | 11 | def process( 12 | self, 13 | audio: np.ndarray, 14 | sample_rate: int 15 | ) -> Tuple[np.ndarray, bool]: 16 | """Process audio for transcription""" 17 | # Resample if needed 18 | if sample_rate != self.target_sample_rate: 19 | audio = self.resample(audio, sample_rate, self.target_sample_rate) 20 | 21 | # Normalize audio 22 | audio = self.normalize_audio(audio) 23 | 24 | # Check if audio is too short 25 | if len(audio) < self.target_sample_rate * self.min_audio_length: 26 | return audio, False 27 | 28 | # Apply noise reduction 29 | audio = self.reduce_noise(audio) 30 | 31 | # Check for silence 32 | is_silence = self.is_silence(audio) 33 | 34 | return audio, not is_silence 35 | 36 | def normalize_audio(self, audio: np.ndarray) -> np.ndarray: 37 | """Normalize audio to [-1, 1] range""" 38 | max_val = np.abs(audio).max() 39 | if max_val > 0: 40 | return audio / max_val 41 | return audio 42 | 43 | def resample( 44 | self, 45 | audio: np.ndarray, 46 | orig_sr: int, 47 | target_sr: int 48 | ) -> np.ndarray: 49 | """Resample audio to target sample rate""" 50 | if orig_sr == target_sr: 51 | return audio 52 | 53 | # Calculate resample ratio 54 | resample_ratio = target_sr / orig_sr 55 | new_length = int(len(audio) * resample_ratio) 56 | 57 | # Use scipy's resample 58 | return signal.resample(audio, new_length) 59 | 60 | def reduce_noise(self, audio: np.ndarray) -> np.ndarray: 61 | """Apply basic noise reduction using spectral subtraction""" 62 | # Apply high-pass filter to remove low-frequency noise 63 | if len(audio) > 13: # Minimum length for filter 64 | b, a = signal.butter(4, 100, 'hp', fs=self.target_sample_rate) 65 | audio = signal.filtfilt(b, a, audio) 66 | return audio 67 | 68 | def is_silence(self, audio: np.ndarray) -> bool: 69 | """Detect if audio is silence""" 70 | rms = np.sqrt(np.mean(audio**2)) 71 | return rms < self.silence_threshold 72 | 73 | def apply_voice_activity_detection( 74 | self, 75 | audio: np.ndarray 76 | ) -> np.ndarray: 77 | """Simple VAD to trim silence from beginning and end""" 78 | # Calculate energy for each frame 79 | frame_length = int(0.025 * self.target_sample_rate) # 25ms frames 80 | hop_length = int(0.010 * self.target_sample_rate) # 10ms hop 81 | 82 | energy = [] 83 | for i in range(0, len(audio) - frame_length, hop_length): 84 | frame = audio[i:i + frame_length] 85 | energy.append(np.sum(frame**2)) 86 | 87 | energy = np.array(energy) 88 | threshold = np.mean(energy) * 0.1 89 | 90 | # Find voice activity regions 91 | voice_activity = energy > threshold 92 | 93 | if np.any(voice_activity): 94 | # Find first and last active frame 95 | first_active = np.argmax(voice_activity) 96 | last_active = len(voice_activity) - np.argmax(voice_activity[::-1]) 97 | 98 | # Convert frame indices to sample indices 99 | start_sample = first_active * hop_length 100 | end_sample = min(last_active * hop_length + frame_length, len(audio)) 101 | 102 | return audio[start_sample:end_sample] 103 | 104 | return audio -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore vscode 2 | /.vscode 3 | /DB 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | .idea/ 165 | 166 | #MacOS 167 | .DS_Store 168 | SOURCE_DOCUMENTS/.DS_Store 169 | 170 | 171 | .byaldi/ 172 | sessions/ 173 | static/images/ 174 | uploaded_documents/ 175 | -------------------------------------------------------------------------------- /jarvis/src/models/transcriber.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import numpy as np 4 | import soundfile as sf 5 | from typing import Optional, Dict, Any 6 | from abc import ABC, abstractmethod 7 | 8 | # Model configurations 9 | AVAILABLE_MODELS = { 10 | "whisper-tiny": { 11 | "repo": "mlx-community/whisper-tiny", 12 | "type": "whisper", 13 | "description": "Very fast, lower accuracy" 14 | }, 15 | "whisper-base": { 16 | "repo": "mlx-community/whisper-base-mlx", 17 | "type": "whisper", 18 | "description": "Fast, decent accuracy" 19 | }, 20 | "whisper-small": { 21 | "repo": "mlx-community/whisper-small-mlx", 22 | "type": "whisper", 23 | "description": "Good balance" 24 | }, 25 | "whisper-medium": { 26 | "repo": "mlx-community/whisper-medium-mlx", 27 | "type": "whisper", 28 | "description": "High accuracy, slower" 29 | }, 30 | "whisper-large-v3": { 31 | "repo": "mlx-community/whisper-large-v3-mlx", 32 | "type": "whisper", 33 | "description": "Best accuracy, slowest" 34 | }, 35 | "distil-whisper-large-v3": { 36 | "repo": "mlx-community/distil-whisper-large-v3", 37 | "type": "whisper", 38 | "description": "Fast with high accuracy" 39 | } 40 | } 41 | 42 | class BaseTranscriber(ABC): 43 | """Base class for speech transcribers""" 44 | 45 | @abstractmethod 46 | def transcribe(self, audio_path: str) -> str: 47 | pass 48 | 49 | class WhisperTranscriber(BaseTranscriber): 50 | """Whisper model transcriber using MLX""" 51 | 52 | def __init__(self, model_id: str): 53 | try: 54 | import mlx_whisper 55 | self.mlx_whisper = mlx_whisper 56 | self.model_path = model_id 57 | except ImportError: 58 | raise ImportError("mlx-whisper is required. Install with: pip install mlx-whisper") 59 | 60 | def transcribe(self, audio_path: str) -> str: 61 | """Transcribe audio file using Whisper""" 62 | result = self.mlx_whisper.transcribe( 63 | audio_path, 64 | path_or_hf_repo=self.model_path 65 | ) 66 | return result.get("text", "") 67 | 68 | class SpeechTranscriber: 69 | """Main transcriber class with model management""" 70 | 71 | def __init__(self, model_name: str = "whisper-tiny"): 72 | self.model_name = model_name 73 | self.model = None 74 | self.load_model(model_name) 75 | 76 | def load_model(self, model_name: str): 77 | """Load the specified model""" 78 | if model_name not in AVAILABLE_MODELS: 79 | raise ValueError(f"Unknown model: {model_name}") 80 | 81 | config = AVAILABLE_MODELS[model_name] 82 | model_type = config["type"] 83 | repo_id = config["repo"] 84 | 85 | print(f"Loading {model_name} model...") 86 | 87 | if model_type == "whisper": 88 | self.model = WhisperTranscriber(repo_id) 89 | else: 90 | raise ValueError(f"Unknown model type: {model_type}") 91 | 92 | self.model_name = model_name 93 | print(f"Model {model_name} loaded successfully") 94 | 95 | def transcribe(self, audio: np.ndarray, sample_rate: int = 16000) -> str: 96 | """Transcribe audio array""" 97 | if self.model is None: 98 | raise RuntimeError("No model loaded") 99 | 100 | # Write audio to temporary file (required by models) 101 | with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: 102 | sf.write(tmp.name, audio, sample_rate) 103 | temp_path = tmp.name 104 | 105 | try: 106 | # Transcribe 107 | text = self.model.transcribe(temp_path) 108 | return text.strip() 109 | finally: 110 | # Clean up temp file 111 | if os.path.exists(temp_path): 112 | os.unlink(temp_path) 113 | 114 | def list_models(self) -> Dict[str, Any]: 115 | """List available models""" 116 | return AVAILABLE_MODELS -------------------------------------------------------------------------------- /models/model_loader.py: -------------------------------------------------------------------------------- 1 | # models/model_loader.py 2 | 3 | import os 4 | import torch 5 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 6 | from transformers import MllamaForConditionalGeneration 7 | from vllm.sampling_params import SamplingParams 8 | from transformers import AutoModelForCausalLM 9 | import google.generativeai as genai 10 | from vllm import LLM 11 | from groq import Groq 12 | from dotenv import load_dotenv 13 | 14 | # Load environment variables from .env file 15 | load_dotenv() 16 | 17 | from logger import get_logger 18 | 19 | logger = get_logger(__name__) 20 | 21 | # Cache for loaded models 22 | _model_cache = {} 23 | 24 | # Models that only support single image processing 25 | SINGLE_IMAGE_MODELS = { 26 | 'ollama-llama-vision': True, 27 | 'groq-llama-vision': True, 28 | 'llama-vision': True, 29 | 'pixtral': True, 30 | 'molmo': True 31 | } 32 | 33 | def is_single_image_model(model_choice): 34 | """Returns True if the model only supports processing a single image.""" 35 | return model_choice in SINGLE_IMAGE_MODELS 36 | 37 | def detect_device(): 38 | """ 39 | Detects the best available device (CUDA, MPS, or CPU). 40 | """ 41 | if torch.cuda.is_available(): 42 | return 'cuda' 43 | elif torch.backends.mps.is_available(): 44 | return 'mps' 45 | else: 46 | return 'cpu' 47 | 48 | def load_model(model_choice): 49 | """ 50 | Loads and caches the specified model. 51 | """ 52 | global _model_cache 53 | 54 | if model_choice in _model_cache: 55 | logger.info(f"Model '{model_choice}' loaded from cache.") 56 | return _model_cache[model_choice] 57 | 58 | if model_choice == 'qwen': 59 | device = detect_device() 60 | model = Qwen2VLForConditionalGeneration.from_pretrained( 61 | "Qwen/Qwen2-VL-7B-Instruct", 62 | torch_dtype=torch.float16 if device != 'cpu' else torch.float32, 63 | device_map="auto" 64 | ) 65 | processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") 66 | model.to(device) 67 | _model_cache[model_choice] = (model, processor, device) 68 | logger.info("Qwen model loaded and cached.") 69 | return _model_cache[model_choice] 70 | 71 | elif model_choice == 'gemini': 72 | api_key = os.getenv("GOOGLE_API_KEY") 73 | if not api_key: 74 | raise ValueError("GOOGLE_API_KEY not found in .env file") 75 | genai.configure(api_key=api_key) 76 | model = genai.GenerativeModel('gemini-1.5-flash-002') 77 | return model, None 78 | 79 | elif model_choice == 'llama-vision': 80 | device = detect_device() 81 | model_id = "alpindale/Llama-3.2-11B-Vision-Instruct" 82 | model = MllamaForConditionalGeneration.from_pretrained( 83 | model_id, 84 | torch_dtype=torch.float16 if device != 'cpu' else torch.float32, 85 | device_map="auto" 86 | ) 87 | processor = AutoProcessor.from_pretrained(model_id) 88 | model.to(device) 89 | _model_cache[model_choice] = (model, processor, device) 90 | logger.info("Llama-Vision model loaded and cached.") 91 | return _model_cache[model_choice] 92 | 93 | elif model_choice == "pixtral": 94 | device = detect_device() 95 | mistral_models_path = os.path.join(os.getcwd(), 'mistral_models', 'Pixtral') 96 | 97 | if not os.path.exists(mistral_models_path): 98 | os.makedirs(mistral_models_path, exist_ok=True) 99 | from huggingface_hub import snapshot_download 100 | snapshot_download(repo_id="mistralai/Pixtral-12B-2409", 101 | allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], 102 | local_dir=mistral_models_path) 103 | 104 | from mistral_inference.transformer import Transformer 105 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 106 | from mistral_common.generate import generate 107 | 108 | tokenizer = MistralTokenizer.from_file(os.path.join(mistral_models_path, "tekken.json")) 109 | model = Transformer.from_folder(mistral_models_path) 110 | 111 | _model_cache[model_choice] = (model, tokenizer, generate, device) 112 | logger.info("Pixtral model loaded and cached.") 113 | return _model_cache[model_choice] 114 | 115 | elif model_choice == "molmo": 116 | device = detect_device() 117 | processor = AutoProcessor.from_pretrained( 118 | 'allenai/MolmoE-1B-0924', 119 | trust_remote_code=True, 120 | torch_dtype='auto', 121 | device_map='auto' 122 | ) 123 | model = AutoModelForCausalLM.from_pretrained( 124 | 'allenai/MolmoE-1B-0924', 125 | trust_remote_code=True, 126 | torch_dtype='auto', 127 | device_map='auto' 128 | ) 129 | _model_cache[model_choice] = (model, processor, device) 130 | return _model_cache[model_choice] 131 | 132 | elif model_choice == 'groq-llama-vision': 133 | api_key = os.getenv("GROQ_API_KEY") 134 | if not api_key: 135 | raise ValueError("GROQ_API_KEY not found in .env file") 136 | client = Groq(api_key=api_key) 137 | _model_cache[model_choice] = client 138 | logger.info("Groq Llama Vision model loaded and cached.") 139 | return _model_cache[model_choice] 140 | 141 | elif model_choice == 'ollama-llama-vision': 142 | logger.info("Ollama Llama Vision model ready to use.") 143 | return None 144 | 145 | else: 146 | logger.error(f"Invalid model choice: {model_choice}") 147 | raise ValueError("Invalid model choice.") 148 | -------------------------------------------------------------------------------- /jarvis/README.md: -------------------------------------------------------------------------------- 1 | # Speech Transcription Application for macOS 2 | 3 | A privacy-focused, real-time speech transcription desktop application for macOS that runs entirely on-device using Apple Silicon optimization. Features speech-to-text transcription with AI-powered text correction, multiple UI modes, and support for various speech recognition models. 4 | 5 | ## Features 6 | 7 | - 🎤 **Real-time Speech Transcription**: Capture and transcribe audio from your microphone 8 | - 🔒 **Privacy-First**: All processing happens on-device, no data sent to cloud 9 | - ⚡ **Apple Silicon Optimized**: Uses MLX framework for fast inference on M1/M2/M3 chips 10 | - 🤖 **AI-Powered Correction**: Automatic removal of filler words and transcription errors 11 | - 🎯 **Multiple Models**: Support for various Whisper models with different speed/accuracy tradeoffs 12 | - 🖥️ **Multiple Interfaces**: GUI, CLI, and system tray modes 13 | - ⌨️ **Keyboard Shortcuts**: Space to start/stop, Escape to cancel, Cmd+S to save 14 | 15 | ## Requirements 16 | 17 | - macOS with Apple Silicon (M1/M2/M3) 18 | - Python 3.9 or higher 19 | - Minimum 8GB RAM (16GB recommended) 20 | - ~5GB disk space for models 21 | 22 | ## Installation 23 | 24 | ### 1. Clone the repository 25 | 26 | ```bash 27 | git clone https://github.com/yourusername/speech-transcription-app.git 28 | cd speech-transcription-app 29 | ``` 30 | 31 | ### 2. Create a virtual environment 32 | 33 | ```bash 34 | python -m venv venv 35 | source venv/bin/activate 36 | ``` 37 | 38 | ### 3. Install dependencies 39 | 40 | ```bash 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ### 4. Download ML models 45 | 46 | ```bash 47 | python scripts/download_models.py 48 | ``` 49 | 50 | ## Usage 51 | 52 | ### GUI Mode 53 | 54 | Launch the graphical interface: 55 | 56 | ```bash 57 | python -m src.main 58 | ``` 59 | 60 | Features: 61 | - Click "Start Recording" or press Space to begin 62 | - Real-time audio level visualization 63 | - Automatic text correction (can be disabled) 64 | - Save transcriptions with Cmd+S 65 | - Dark mode interface 66 | 67 | ### CLI Mode 68 | 69 | Transcribe audio files from the command line: 70 | 71 | ```bash 72 | # Basic transcription 73 | python -m src.cli --audio recording.wav 74 | 75 | # With specific model 76 | python -m src.cli --audio recording.wav --model whisper-medium 77 | 78 | # With text correction and context 79 | python -m src.cli --audio recording.wav --context "Technical meeting" 80 | 81 | # Save to file 82 | python -m src.cli --audio recording.wav --output transcript.txt 83 | 84 | # List available models 85 | python -m src.cli --list-models 86 | ``` 87 | 88 | ### Available Models 89 | 90 | - `whisper-tiny`: Very fast, lower accuracy 91 | - `whisper-base`: Fast, decent accuracy 92 | - `whisper-small`: Good balance 93 | - `whisper-medium`: High accuracy, slower 94 | - `whisper-large-v3`: Best accuracy, slowest 95 | - `distil-whisper-large-v3`: Fast with high accuracy 96 | 97 | ## Development 98 | 99 | ### Running Tests 100 | 101 | ```bash 102 | # Run all tests 103 | make test 104 | 105 | # Run unit tests only 106 | make test-unit 107 | 108 | # Run integration tests 109 | make test-integration 110 | ``` 111 | 112 | ### Code Formatting 113 | 114 | ```bash 115 | # Format code with black and ruff 116 | make format 117 | 118 | # Run linters 119 | make lint 120 | ``` 121 | 122 | ### Building for Distribution 123 | 124 | ```bash 125 | pip install -e . 126 | ``` 127 | 128 | This will install the package in development mode and create command-line scripts: 129 | - `speech-transcribe`: CLI interface 130 | - `speech-transcribe-gui`: GUI interface 131 | 132 | ## Project Structure 133 | 134 | ``` 135 | speech-transcription-app/ 136 | ├── src/ 137 | │ ├── audio/ # Audio recording and processing 138 | │ ├── models/ # ML models for transcription and correction 139 | │ ├── gui/ # GUI components 140 | │ ├── pipeline.py # Main orchestration 141 | │ ├── cli.py # CLI interface 142 | │ └── main.py # GUI entry point 143 | ├── tests/ # Unit and integration tests 144 | ├── scripts/ # Utility scripts 145 | ├── requirements.txt # Python dependencies 146 | └── README.md # This file 147 | ``` 148 | 149 | ## Performance 150 | 151 | - **Audio Latency**: < 100ms for recording start/stop 152 | - **Transcription Speed**: > 5x real-time (10s audio in < 2s) 153 | - **Memory Usage**: < 2GB with models loaded 154 | - **Model Loading**: < 10s for initial load 155 | 156 | ## Troubleshooting 157 | 158 | ### Common Issues 159 | 160 | 1. **"Could not find the Qt platform plugin 'cocoa'"** 161 | - This is fixed automatically in the code, but if it persists: 162 | - Run with: `./run_gui.sh` instead of `python -m src.main` 163 | - Or set manually: `export QT_QPA_PLATFORM_PLUGIN_PATH=$(python -c "import PyQt6, os; print(os.path.join(os.path.dirname(PyQt6.__file__), 'Qt6', 'plugins', 'platforms'))")` 164 | 165 | 2. **"No module named 'mlx'"** 166 | - Ensure you're on a Mac with Apple Silicon 167 | - Reinstall with: `pip install --upgrade mlx mlx-lm mlx-whisper` 168 | 169 | 3. **Audio permission denied** 170 | - Go to System Preferences → Security & Privacy → Microphone 171 | - Grant permission to Terminal/your IDE 172 | 173 | 4. **Model download fails** 174 | - Check internet connection 175 | - Manually download from Hugging Face if needed 176 | 177 | ### Debug Mode 178 | 179 | Run with verbose output: 180 | ```bash 181 | python -m src.cli --audio test.wav --debug 182 | ``` 183 | 184 | ## Contributing 185 | 186 | 1. Fork the repository 187 | 2. Create a feature branch 188 | 3. Make your changes 189 | 4. Run tests and linters 190 | 5. Submit a pull request 191 | 192 | ## License 193 | 194 | This project is licensed under the MIT License - see LICENSE file for details. 195 | 196 | ## Acknowledgments 197 | 198 | - Apple's MLX team for the optimization framework 199 | - OpenAI for Whisper models 200 | - Hugging Face for model hosting -------------------------------------------------------------------------------- /jarvis/src/models/text_corrector.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, List 3 | import mlx.core as mx 4 | from mlx_lm import load, generate 5 | 6 | # Available LLM models for text correction 7 | AVAILABLE_LLM_MODELS = { 8 | # Small & Fast Models (< 2GB RAM) 9 | "Qwen2.5-0.5B": { 10 | "repo": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", 11 | "description": "Tiny & fast (0.5B)", 12 | "size": "0.5B" 13 | }, 14 | "gemma-2b": { 15 | "repo": "mlx-community/gemma-2-2b-it-4bit", 16 | "description": "Google's efficient (2B)", 17 | "size": "2B" 18 | }, 19 | 20 | # Medium Models (2-4GB RAM) 21 | "Phi-3.5-mini": { 22 | "repo": "mlx-community/Phi-3.5-mini-instruct-4bit", 23 | "description": "Best balance (3.8B)", 24 | "size": "3.8B" 25 | }, 26 | "gemma-3-4b": { 27 | "repo": "mlx-community/gemma-3-4b-it-qat-4bit", 28 | "description": "Google's latest (4B)", 29 | "size": "4B" 30 | }, 31 | 32 | # Larger Models (4-6GB RAM) 33 | "Mistral-7B": { 34 | "repo": "mlx-community/Mistral-7B-Instruct-v0.3-4bit", 35 | "description": "Very capable (7B)", 36 | "size": "7B" 37 | }, 38 | "Llama-3.1-8B": { 39 | "repo": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", 40 | "description": "State of the art (8B)", 41 | "size": "8B" 42 | }, 43 | } 44 | 45 | class TextCorrector: 46 | """LLM-based text correction for transcriptions""" 47 | 48 | def __init__(self, model_name: str = "Phi-3.5-mini"): 49 | self.model = None 50 | self.tokenizer = None 51 | self.model_name = None 52 | self.selected_model = model_name 53 | self._load_model(model_name) 54 | 55 | def _load_model(self, model_name: str): 56 | """Load the specified LLM model""" 57 | if model_name not in AVAILABLE_LLM_MODELS: 58 | print(f"Unknown model {model_name}, using default Phi-3.5-mini") 59 | model_name = "Phi-3.5-mini" 60 | 61 | model_info = AVAILABLE_LLM_MODELS[model_name] 62 | repo_id = model_info["repo"] 63 | 64 | try: 65 | print(f"Loading text correction model: {model_name} ({repo_id})") 66 | self.model, self.tokenizer = load(repo_id) 67 | self.model_name = model_name 68 | print(f"Successfully loaded {model_name}") 69 | except Exception as e: 70 | print(f"Failed to load {model_name}: {e}") 71 | # Try fallback to Phi-3.5-mini if different model was selected 72 | if model_name != "Phi-3.5-mini": 73 | print("Falling back to Phi-3.5-mini") 74 | self._load_model("Phi-3.5-mini") 75 | else: 76 | print("Warning: No text correction model could be loaded") 77 | 78 | def set_model(self, model_name: str): 79 | """Change the LLM model""" 80 | if model_name != self.model_name: 81 | self._load_model(model_name) 82 | 83 | def correct( 84 | self, 85 | text: str, 86 | context: Optional[str] = None, 87 | remove_fillers: bool = True 88 | ) -> str: 89 | """Correct transcribed text""" 90 | if not self.model or not text or len(text.strip()) < 10: 91 | return text 92 | 93 | # Build correction prompt 94 | prompt = self._build_prompt(text, context, remove_fillers) 95 | 96 | try: 97 | # Generate correction 98 | response = generate( 99 | self.model, 100 | self.tokenizer, 101 | prompt=prompt, 102 | max_tokens=len(text) * 2 103 | ) 104 | 105 | # Extract corrected text 106 | corrected = self._extract_correction(response) 107 | 108 | # Validate correction 109 | if self._is_valid_correction(text, corrected): 110 | return corrected 111 | else: 112 | return text 113 | 114 | except Exception as e: 115 | print(f"Text correction failed: {e}") 116 | return text 117 | 118 | def _build_prompt( 119 | self, 120 | text: str, 121 | context: Optional[str], 122 | remove_fillers: bool 123 | ) -> str: 124 | """Build the correction prompt""" 125 | base_prompt = """Fix ONLY spelling errors and remove filler words from this transcription. 126 | DO NOT rephrase or summarize. Keep the original wording and structure. 127 | Only fix obvious errors like duplicated words, misspellings, and remove filler words (um, uh, etc). 128 | Output ONLY the cleaned text, nothing else.""" 129 | 130 | if context: 131 | base_prompt = f"Context: {context}\n\n{base_prompt}" 132 | 133 | return f"""{base_prompt} 134 | 135 | Original: {text} 136 | 137 | Cleaned:""" 138 | 139 | def _extract_correction(self, response: str) -> str: 140 | """Extract the corrected text from model response""" 141 | # Remove any explanation or metadata 142 | lines = response.strip().split('\n') 143 | 144 | # Find the actual correction 145 | corrected_text = "" 146 | for line in lines: 147 | # Skip meta lines 148 | if any(marker in line.lower() for marker in ['original:', 'cleaned:', 'corrected:']): 149 | continue 150 | if line.strip(): 151 | corrected_text = line.strip() 152 | break 153 | 154 | return corrected_text 155 | 156 | def _is_valid_correction(self, original: str, corrected: str) -> bool: 157 | """Validate that correction is reasonable""" 158 | if not corrected: 159 | return False 160 | 161 | # Check length difference (shouldn't be too different) 162 | len_ratio = len(corrected) / len(original) 163 | if len_ratio < 0.5 or len_ratio > 1.5: 164 | return False 165 | 166 | # Check word count difference 167 | orig_words = original.split() 168 | corr_words = corrected.split() 169 | word_ratio = len(corr_words) / len(orig_words) 170 | if word_ratio < 0.5 or word_ratio > 1.2: 171 | return False 172 | 173 | return True 174 | 175 | def get_filler_words(self) -> List[str]: 176 | """Get list of filler words to remove""" 177 | return [ 178 | "um", "uh", "er", "ah", "like", "you know", "I mean", 179 | "actually", "basically", "literally", "right", "so" 180 | ] -------------------------------------------------------------------------------- /static/css/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | font-size: 14px; /* Base font size */ 5 | } 6 | 7 | /* Sidebar styles */ 8 | #sidebar { 9 | min-height: 100vh; 10 | background-color: #343a40; 11 | color: #fff; 12 | transition: all 0.3s; 13 | display: flex; 14 | flex-direction: column; 15 | } 16 | 17 | #sidebar.active { 18 | margin-left: -250px; 19 | } 20 | 21 | #sidebar .sidebar-header { 22 | padding: 20px; 23 | background: #212529; 24 | } 25 | 26 | #sidebar .sidebar-header h3 { 27 | font-size: 1.5em; /* Larger font for header */ 28 | } 29 | 30 | #sidebar .sidebar-content { 31 | flex-grow: 1; 32 | overflow-y: auto; 33 | } 34 | 35 | #sidebar ul.components { 36 | padding: 20px 0; 37 | border-bottom: 1px solid #47748b; 38 | } 39 | 40 | #sidebar ul p { 41 | color: #fff; 42 | padding: 10px; 43 | } 44 | 45 | #sidebar ul li a { 46 | padding: 8px 10px; 47 | font-size: 1em; /* Consistent font size */ 48 | display: block; 49 | color: #fff; 50 | text-decoration: none; 51 | } 52 | 53 | #sidebar ul li a:hover { 54 | color: #7386D5; 55 | background: #fff; 56 | } 57 | 58 | #sidebar ul li.active > a, a[aria-expanded="true"] { 59 | color: #fff; 60 | background: #6d7fcc; 61 | } 62 | 63 | .session-list { 64 | max-height: calc(100vh - 250px); 65 | overflow-y: auto; 66 | } 67 | 68 | .session-list .nav-item { 69 | margin-bottom: 2px; 70 | } 71 | 72 | /* Main content styles */ 73 | main { 74 | transition: all 0.3s; 75 | height: 100vh; 76 | overflow-y: auto; 77 | } 78 | 79 | main.active { 80 | margin-left: 0; 81 | } 82 | 83 | /* Top bar styles */ 84 | .btn-toolbar .btn { 85 | margin-left: 5px; 86 | font-size: 1em; /* Consistent font size */ 87 | } 88 | 89 | /* Chat container styles */ 90 | .chat-container { 91 | height: calc(100vh - 100px); 92 | display: flex; 93 | flex-direction: column; 94 | background-color: #ffffff; 95 | border-radius: 10px; 96 | box-shadow: 0 1px 3px rgba(0, 0, 0, 0.12), 0 1px 2px rgba(0, 0, 0, 0.24); 97 | margin: 20px auto; 98 | max-width: 800px; 99 | } 100 | 101 | .chat-messages { 102 | flex-grow: 1; 103 | overflow-y: auto; 104 | padding: 20px; 105 | } 106 | 107 | .chat-input-container { 108 | padding: 20px; 109 | background-color: #f8f9fa; 110 | border-top: 1px solid #dee2e6; 111 | border-bottom-left-radius: 10px; 112 | border-bottom-right-radius: 10px; 113 | } 114 | 115 | .chat-input-container .input-group { 116 | max-width: 600px; 117 | margin: 0 auto; 118 | } 119 | 120 | /* Message styles */ 121 | .message { 122 | margin-bottom: 15px; 123 | padding: 10px 15px; 124 | border-radius: 18px; 125 | max-width: 80%; 126 | word-wrap: break-word; 127 | font-size: 1em; /* Consistent font size */ 128 | } 129 | 130 | .user-message { 131 | background-color: #007bff; 132 | color: #fff; 133 | align-self: flex-end; 134 | margin-left: auto; 135 | } 136 | 137 | .ai-message { 138 | background-color: #f1f3f5; 139 | color: #343a40; 140 | align-self: flex-start; 141 | margin-right: auto; 142 | } 143 | 144 | /* Image styles */ 145 | .image-container { 146 | display: flex; 147 | flex-wrap: wrap; 148 | gap: 10px; 149 | margin-top: 10px; 150 | } 151 | 152 | .retrieved-image { 153 | max-width: 150px; 154 | max-height: 150px; 155 | object-fit: cover; 156 | cursor: zoom-in; 157 | transition: transform 0.3s ease; 158 | border-radius: 8px; 159 | } 160 | 161 | .retrieved-image:hover { 162 | transform: scale(1.05); 163 | } 164 | 165 | /* Loading indicator styles */ 166 | #loading-indicator { 167 | position: fixed; 168 | top: 50%; 169 | left: 50%; 170 | transform: translate(-50%, -50%); 171 | z-index: 1000; 172 | background-color: rgba(255, 255, 255, 0.9); 173 | padding: 20px; 174 | border-radius: 10px; 175 | box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); 176 | } 177 | 178 | #loading-indicator .spinner-border { 179 | width: 3rem; 180 | height: 3rem; 181 | } 182 | 183 | #loading-indicator p { 184 | margin-top: 10px; 185 | font-weight: bold; 186 | font-size: 1em; /* Consistent font size */ 187 | } 188 | 189 | /* Indexed files list styles */ 190 | #indexed-files-list { 191 | max-height: 300px; 192 | overflow-y: auto; 193 | } 194 | 195 | #indexed-files-list .list-group-item { 196 | padding: 8px 12px; 197 | font-size: 1em; /* Consistent font size */ 198 | border-left: none; 199 | border-right: none; 200 | } 201 | 202 | .session-name { 203 | cursor: pointer; 204 | transition: color 0.3s ease; 205 | font-size: 1em; /* Consistent font size */ 206 | } 207 | 208 | .session-name:hover { 209 | color: #007bff; 210 | } 211 | 212 | /* Responsive adjustments */ 213 | @media (max-width: 768px) { 214 | #sidebar { 215 | margin-left: -250px; 216 | } 217 | #sidebar.active { 218 | margin-left: 0; 219 | } 220 | main { 221 | margin-left: 0; 222 | } 223 | main.active { 224 | margin-left: 250px; 225 | } 226 | } 227 | 228 | /* Medium Zoom styles */ 229 | .medium-zoom-overlay { 230 | z-index: 1000; 231 | } 232 | 233 | .medium-zoom-image--opened { 234 | z-index: 1001; 235 | } 236 | 237 | .session-options { 238 | position: relative; 239 | } 240 | 241 | .fa-ellipsis-h { 242 | cursor: pointer; 243 | padding: 3px; 244 | color: #adb5bd; 245 | font-size: 1em; /* Consistent font size */ 246 | } 247 | 248 | .options-popup { 249 | display: none; 250 | position: absolute; 251 | right: 0; 252 | top: 100%; 253 | background-color: #343a40; 254 | border: 1px solid #495057; 255 | border-radius: 4px; 256 | box-shadow: 0 2px 10px rgba(0,0,0,0.2); 257 | z-index: 1000; 258 | min-width: 120px; 259 | } 260 | 261 | .option { 262 | padding: 8px 12px; 263 | cursor: pointer; 264 | white-space: nowrap; 265 | color: #fff; 266 | transition: background-color 0.2s ease; 267 | font-size: 1em; /* Consistent font size */ 268 | } 269 | 270 | .option:hover { 271 | background-color: #495057; 272 | } 273 | 274 | .session-list .nav-item.current-session { 275 | background-color: #495057; 276 | border-radius: 4px; 277 | } 278 | 279 | .session-list .nav-item.current-session .nav-link { 280 | color: #ffffff; 281 | font-weight: bold; 282 | } 283 | 284 | .session-list .nav-item.current-session .fa-ellipsis-h { 285 | color: #ffffff; 286 | } 287 | 288 | /* Markdown styles */ 289 | .ai-message h1, .ai-message h2, .ai-message h3, .ai-message h4, .ai-message h5, .ai-message h6 { 290 | margin-top: 10px; 291 | margin-bottom: 5px; 292 | } 293 | 294 | .ai-message p { 295 | margin-bottom: 10px; 296 | } 297 | 298 | .ai-message ul, .ai-message ol { 299 | margin-left: 20px; 300 | margin-bottom: 10px; 301 | } 302 | 303 | .ai-message code { 304 | background-color: #f0f0f0; 305 | padding: 2px 4px; 306 | border-radius: 4px; 307 | } 308 | 309 | .ai-message pre { 310 | background-color: #f0f0f0; 311 | padding: 10px; 312 | border-radius: 4px; 313 | overflow-x: auto; 314 | } -------------------------------------------------------------------------------- /templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | LocalGPT Vision 9 | 10 | 11 | 12 | 13 | {% block styles %}{% endblock %} 14 | 15 | 16 |
17 |
18 | 19 | 52 | 53 | 54 |
55 |
56 |
57 | 58 | Models 59 | 60 | 63 |
64 |
65 | 66 | {% block content %}{% endblock %} 67 |
68 |
69 |
70 | 71 | 72 | 87 | 88 | 89 | 90 | 91 | 185 | {% block scripts %}{% endblock %} 186 | 187 | -------------------------------------------------------------------------------- /jarvis/CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | This is a Real-Time Speech Transcription Application for macOS, designed to be privacy-focused with all processing done on-device using Apple Silicon optimization. The application features speech-to-text transcription with AI-powered text correction, multiple UI modes (GUI, CLI), and support for various speech recognition models. 8 | 9 | ### Original Requirements (PRD.md) 10 | The project was built following a comprehensive Product Requirements Document that specified: 11 | - Privacy-first approach with all processing on-device 12 | - Real-time transcription performance (>5x real-time) 13 | - Multiple STT model support (Whisper variants) 14 | - AI-powered text correction using LLMs 15 | - GUI and CLI interfaces 16 | - Apple Silicon optimization using MLX framework 17 | 18 | ### Recent Feature Additions 19 | 1. **Microphone Selection**: Dropdown to choose from all available audio input devices 20 | 2. **Copy Functionality**: Separate copy buttons for raw and corrected text 21 | 3. **LLM Model Selection**: Dynamic selection of text correction models from GUI 22 | 4. **Enhanced Error Handling**: Proper thread cleanup and crash prevention 23 | 5. **Auto Qt Path Detection**: Automatic resolution of Qt platform plugin issues 24 | 25 | ## Architecture 26 | 27 | ### Core Components 28 | 29 | 1. **Audio Pipeline** (`src/audio/`) 30 | - `AudioRecorder`: Manages sounddevice streams for microphone capture with device selection 31 | - `AudioProcessor`: Handles normalization, resampling, noise reduction, and VAD 32 | 33 | 2. **Model Layer** (`src/models/`) 34 | - `SpeechTranscriber`: Abstraction over multiple STT models (Whisper variants) 35 | - `TextCorrector`: LLM-based correction using MLX models with dynamic model selection 36 | - Model dictionaries: `AVAILABLE_MODELS` (STT) and `AVAILABLE_LLM_MODELS` (correction) 37 | 38 | 3. **User Interfaces** (`src/gui/`) 39 | - `MainWindow`: PyQt6-based GUI with real-time visualization 40 | - Microphone selection dropdown with refresh capability 41 | - Model selection dropdowns for both STT and LLM models 42 | - Copy functionality for both raw and corrected text 43 | - Dark theme with audio level visualization 44 | - CLI interface in `src/cli.py` 45 | 46 | 4. **Pipeline** (`src/pipeline.py`) 47 | - Orchestrates the full transcription flow 48 | - Manages audio buffering without duplication 49 | - Supports dynamic model switching for both STT and LLM 50 | - Coordinates all components 51 | 52 | ### Key Design Decisions 53 | 54 | - **Model Flexibility**: Support for multiple STT and LLM models with different accuracy/speed tradeoffs 55 | - **Privacy-First**: All processing on-device, no cloud APIs 56 | - **MLX Optimization**: Uses Apple's MLX framework for efficient inference on Apple Silicon 57 | - **Dynamic Model Selection**: Users can choose models from the GUI without restart 58 | - **Error Resilience**: Proper thread cleanup and error handling for stability 59 | 60 | ## Development Commands 61 | 62 | ### Project Setup 63 | ```bash 64 | # Create virtual environment 65 | python -m venv venv 66 | source venv/bin/activate 67 | 68 | # Install dependencies 69 | pip install -r requirements.txt 70 | 71 | # Download required ML models 72 | python scripts/download_models.py 73 | ``` 74 | 75 | ### Running the Application 76 | ```bash 77 | # GUI mode 78 | python -m src.main 79 | 80 | # CLI mode 81 | python -m src.cli --audio recording.wav --model whisper-small 82 | 83 | # With text correction 84 | python -m src.cli --audio recording.wav --context "Technical discussion" 85 | ``` 86 | 87 | ### Testing 88 | ```bash 89 | # Run all tests 90 | pytest tests/ -v 91 | 92 | # Run unit tests only 93 | pytest tests/unit/ -v 94 | 95 | # Run integration tests 96 | pytest tests/integration/ -v 97 | 98 | # Run specific test 99 | pytest tests/unit/test_audio_processor.py::TestAudioProcessor::test_normalize_audio -v 100 | ``` 101 | 102 | ### Development Workflow 103 | ```bash 104 | # Format code 105 | black src tests 106 | ruff format src tests 107 | 108 | # Lint code 109 | ruff check src tests 110 | mypy src 111 | 112 | # Clean build artifacts 113 | find . -type f -name "*.pyc" -delete 114 | find . -type d -name "__pycache__" -delete 115 | ``` 116 | 117 | ## Model Management 118 | 119 | ### Available STT Models 120 | - **whisper-tiny**: Very fast, lower accuracy (39M params) 121 | - **whisper-base**: Fast, decent accuracy (74M params) 122 | - **whisper-small**: Good balance (244M params) 123 | - **whisper-medium**: High accuracy, slower (769M params) 124 | - **whisper-large-v3**: Best accuracy, slowest (1550M params) 125 | - **distil-whisper-large-v3**: Fast with high accuracy (756M params) 126 | 127 | ### Available LLM Models for Text Correction 128 | - **Qwen2.5-0.5B**: Tiny & fast (0.5B params, <2GB RAM) 129 | - **Qwen2.5-1.5B**: Small & efficient (1.5B params, ~3GB RAM) 130 | - **gemma-2-2b-it**: Compact & capable (2B params, ~4GB RAM) 131 | - **Phi-3.5-mini**: Excellent quality (3.8B params, ~6GB RAM) - Default 132 | - **Qwen2.5-7B**: Large & powerful (7B params, ~12GB RAM) 133 | - **Mistral-7B-Instruct**: High quality (7B params, ~12GB RAM) 134 | 135 | ### Model Loading 136 | Models are loaded on-demand from Hugging Face hub. Both `SpeechTranscriber` and `TextCorrector` support dynamic model switching via the GUI without restart. 137 | 138 | ## Audio Processing Pipeline 139 | 140 | 1. **Recording**: 16kHz mono audio captured in 0.5s chunks 141 | - Device selection with automatic default detection 142 | - Real-time audio level monitoring 143 | - Proper buffering without duplication 144 | 2. **Processing**: Normalization → Resampling → Noise reduction → VAD 145 | 3. **Transcription**: Audio written to temp file → Model inference → Text output 146 | 4. **Correction**: Optional LLM-based correction with context awareness 147 | - Custom prompt structure with context integration 148 | - No temperature/top_p parameters (unsupported by mlx_lm) 149 | 150 | ## GUI Architecture 151 | 152 | - **Main Thread**: UI updates and user interaction 153 | - **TranscriptionThread**: Background audio recording and level monitoring 154 | - **Model Operations**: Synchronous but UI remains responsive via threading 155 | - **Real-time Feedback**: Audio level visualization at 50ms intervals 156 | 157 | ## Performance Considerations 158 | 159 | - Audio chunks processed at 0.5s intervals for low latency 160 | - Models use MLX for Apple Silicon optimization 161 | - Text correction uses default generation parameters (no temperature control) 162 | - Memory-mapped model loading for faster startup 163 | - Thread cleanup ensures no resource leaks on cancel/clear 164 | 165 | ## Common Issues and Fixes 166 | 167 | ### Qt Platform Plugin Error 168 | **Problem**: "Could not find the Qt platform plugin 'cocoa'" 169 | **Solution**: The code automatically detects and sets the Qt plugin path in `main.py`. If it persists: 170 | ```bash 171 | # Use the provided run script 172 | ./run_gui.sh 173 | 174 | # Or set manually 175 | export QT_QPA_PLATFORM_PLUGIN_PATH=$(python -c "import PyQt6, os; print(os.path.join(os.path.dirname(PyQt6.__file__), 'Qt6', 'plugins', 'platforms'))") 176 | ``` 177 | 178 | ### Duplicate Transcription 179 | **Problem**: Transcribed text appears twice 180 | **Solution**: Fixed by removing duplicate audio buffering in `pipeline.py`. The `stop_recording()` method now uses only the final audio from `AudioRecorder.stop_recording()`. 181 | 182 | ### LLM Temperature Error 183 | **Problem**: "generate_step() got an unexpected keyword argument 'temperature'" 184 | **Solution**: The mlx_lm library doesn't support temperature/top_p parameters. These have been removed from the `generate()` call in `TextCorrector`. 185 | 186 | ### Clear Button Crash 187 | **Problem**: Clicking clear button terminates the application 188 | **Solution**: Added proper error handling and thread cleanup in `clear_text()` method. The method now safely cancels any running recording before clearing. 189 | 190 | ### Text Correction Prompt 191 | The LLM receives this structured prompt: 192 | ``` 193 | You are a helpful assistant that corrects transcription errors. 194 | Context: {context if provided} 195 | 196 | Please correct any transcription errors in the following text, removing filler words, fixing grammar, and improving clarity while preserving the original meaning: 197 | 198 | {transcribed_text} 199 | 200 | Corrected text: 201 | ``` 202 | 203 | ### Whisper Model Repository Names 204 | **Problem**: 404 errors when loading Whisper models 205 | **Solution**: The mlx-community organization uses inconsistent naming for Whisper models: 206 | - `whisper-tiny` - no suffix needed 207 | - `whisper-base-mlx` - requires "-mlx" suffix 208 | - `whisper-small-mlx` - requires "-mlx" suffix 209 | - `whisper-medium-mlx` - requires "-mlx" suffix 210 | - `whisper-large-v3-mlx` - requires "-mlx" suffix 211 | - `distil-whisper-large-v3` - no suffix needed 212 | 213 | ## Common Patterns 214 | 215 | ### Adding New STT Models 216 | 1. Add model config to `AVAILABLE_MODELS` in `src/models/transcriber.py` 217 | 2. Ensure model is Whisper-based (current implementation) 218 | 3. Update model download script if needed 219 | 220 | ### Adding New LLM Models 221 | 1. Add model config to `AVAILABLE_LLM_MODELS` in `src/models/text_corrector.py` 222 | 2. Include repo, description, and size fields 223 | 3. Test with mlx_lm.generate() compatibility 224 | 4. Update download script for default models 225 | 226 | ### Modifying Audio Processing 227 | - All audio processing goes through `AudioProcessor.process()` 228 | - Maintain 16kHz target sample rate for model compatibility 229 | - Ensure processed audio is normalized to [-1, 1] range 230 | - Avoid duplicate buffering in the pipeline 231 | 232 | ### GUI Customization 233 | - Dark theme defined in `MainWindow.init_ui()` 234 | - Keyboard shortcuts in `setup_shortcuts()` 235 | - Audio visualization via `update_level()` callback 236 | - Model dropdowns update pipeline without restart 237 | - Copy buttons use QApplication.clipboard() 238 | 239 | ## Best Practices & Lessons Learned 240 | 241 | ### Thread Safety 242 | - Always use proper thread cleanup in GUI operations 243 | - Set `is_recording` flag before cleanup to prevent race conditions 244 | - Use `wait()` with timeout before `terminate()` for graceful shutdown 245 | 246 | ### Audio Buffer Management 247 | - Avoid duplicate buffering between recorder and pipeline 248 | - Use the final audio from `stop_recording()` which includes all chunks 249 | - Maintain consistent sample rates throughout the pipeline 250 | 251 | ### Model Parameter Compatibility 252 | - mlx_lm's `generate()` doesn't support temperature/top_p parameters 253 | - Use default generation settings for consistent behavior 254 | - Test parameter compatibility when adding new models 255 | 256 | ### Error Handling 257 | - Wrap all GUI callbacks with try/except blocks 258 | - Provide user-friendly error messages via status bar 259 | - Use QMessageBox for critical errors only 260 | 261 | ### Development Workflow 262 | 1. Always check PRD.md for requirements alignment 263 | 2. Test model downloads before assuming availability 264 | 3. Verify audio permissions before recording 265 | 4. Run both GUI and CLI modes during testing 266 | 5. Check memory usage with different model sizes -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # localGPT-Vision 2 | [![GitHub Stars](https://img.shields.io/github/stars/PromtEngineer/localGPT-Vision?style=social)](https://github.com/PromtEngineer/localGPT-Vision/stargazers) 3 | [![GitHub Forks](https://img.shields.io/github/forks/PromtEngineer/localGPT-Vision?style=social)](https://github.com/PromtEngineer/localGPT-Vision/network/members) 4 | [![GitHub Issues](https://img.shields.io/github/issues/PromtEngineer/localGPT-Vision)](https://github.com/PromtEngineer/localGPT-Vision/issues) 5 | [![GitHub Pull Requests](https://img.shields.io/github/issues-pr/PromtEngineer/localGPT-Vision)](https://github.com/PromtEngineer/localGPT-Vision/pulls) 6 | [![Twitter Follow](https://img.shields.io/twitter/follow/engineerrprompt?style=social)](https://twitter.com/engineerrprompt) 7 | 8 | 9 | [Watch the video on YouTube](https://youtu.be/YPs4eGDpIY4) 10 | 11 | 12 | localGPT-Vision is an end-to-end vision-based Retrieval-Augmented Generation (RAG) system. It allows users to upload and index documents (PDFs and images), ask questions about the content, and receive responses along with relevant document snippets. The retrieval is performed using the [Colqwen](https://huggingface.co/vidore/colqwen2-v0.1) or [ColPali](https://huggingface.co/blog/manu/colpali) models, and the retrieved pages are passed to a Vision Language Model (VLM) for generating responses. Currently, the code supports these VLMs: 13 | 14 | - [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) 15 | - [LLAMA-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision) 16 | - [Pixtral-12B-2409](https://huggingface.co/mistralai/Pixtral-12B-2409) 17 | - [Molmo-7B-O-0924](https://huggingface.co/allenai/Molmo-7B-O-0924) 18 | - [Google Gemini](https://aistudio.google.com/app/prompts/new_chat) 19 | - [OpenAI GPT-4o](https://platform.openai.com/docs/guides/vision) 20 | - [LLAMA-3.2 with Ollama](https://ollama.com/blog/llama3.2-vision) 21 | 22 | The project is built on top of the [Byaldi](https://github.com/AnswerDotAI/byaldi) library. 23 | 24 | ## Table of Contents 25 | - [Features](#features) 26 | - [Architecture](#architecture) 27 | - [Prerequisites](#prerequisites) 28 | - [Installation](#installation) 29 | - [Usage](#usage) 30 | - [Project Structure](#project-structure) 31 | - [System Workflow](#system-workflow) 32 | - [Contributing](#contributing) 33 | 34 | ## Features 35 | - End-to-End Vision-Based RAG: Combines visual document retrieval with language models for comprehensive answers. 36 | - Document Upload and Indexing: Upload PDFs and images, which are then indexed using ColPali for retrieval. 37 | - Chat Interface: Engage in a conversational interface to ask questions about the uploaded documents. 38 | - Session Management: Create, rename, switch between, and delete chat sessions. 39 | - Model Selection: Choose between different Vision Language Models (Qwen2-VL-7B-Instruct, Google Gemini, OpenAI GPT-4 etc). 40 | - Persistent Indexes: Indexes are saved on disk and loaded upon application restart. 41 | 42 | ## Architecture 43 | localGPT-Vision is built as an end-to-end vision-based RAG system. T he architecture comprises two main components: 44 | 45 | 1. Visual Document Retrieval with Colqwen and ColPali: 46 | - [Colqwen](https://huggingface.co/vidore/colqwen2-v0.1) and [ColPali](https://huggingface.co/blog/manu/colpali) are Vision Encoders designed for efficient document retrieval solely using the image representation of document pages. 47 | - It embeds page images directly, leveraging visual cues like layout, fonts, figures, and tables without relying on OCR or text extraction. 48 | - During indexing, document pages are converted into image embeddings and stored. 49 | - During querying, the user query is matched against these embeddings to retrieve the most relevant document pages. 50 | 51 | ![ColPali](https://cdn-uploads.huggingface.co/production/uploads/60f2e021adf471cbdf8bb660/La8vRJ_dtobqs6WQGKTzB.png) 52 | 53 | 2. Response Generation with Vision Language Models: 54 | - The retrieved document images are passed to a Vision Language Model (VLM). 55 | - Supported models include Qwen2-VL-7B-Instruct, LLAMA3.2, Pixtral, Molmo, Google Gemini, and OpenAI GPT-4. 56 | - These models generate responses by understanding both the visual and textual content of the documents. 57 | - NOTE: The quality of the responses is highly dependent on the VLM used and the resolution of the document images. 58 | 59 | This architecture eliminates the need for complex text extraction pipelines and provides a more holistic understanding of documents by considering their visual elements. You don't need any chunking strategies or selection of embeddings model or retrieval strategy used in traditional RAG systems. 60 | 61 | ## Prerequisites 62 | - Anaconda or Miniconda installed on your system 63 | - Python 3.10 or higher 64 | - Git (optional, for cloning the repository) 65 | 66 | ## Installation 67 | Follow these steps to set up and run the application on your local machine. 68 | 69 | 1. Clone the Repository 70 | ```bash 71 | git clone https://github.com/PromtEngineer/localGPT-Vision.git 72 | cd localGPT-Vision 73 | ``` 74 | 75 | 2. Create a Conda Environment 76 | ```bash 77 | conda create -n localgpt-vision python=3.10 78 | conda activate localgpt-vision 79 | ``` 80 | 81 | 3a. Install Dependencies 82 | ```bash 83 | pip install -r requirements.txt 84 | ``` 85 | 86 | 3b. Install Transformers from HuggingFace - Dev version 87 | ```bash 88 | pip uninstall transformers 89 | pip install git+https://github.com/huggingface/transformers 90 | ``` 91 | 92 | 4. Set Environment Variables 93 | Set your API keys for Google Gemini and OpenAI GPT-4: 94 | 95 | ```bash 96 | export GENAI_API_KEY='your_genai_api_key' 97 | export OPENAI_API_KEY='your_openai_api_key' 98 | export GROQ_API_KEY='your_groq_api_key' 99 | ``` 100 | 101 | On Windows Command Prompt: 102 | ```cmd 103 | set GENAI_API_KEY=your_genai_api_key 104 | set OPENAI_API_KEY=your_openai_api_key 105 | set GROQ_API_KEY='your_groq_api_key' 106 | ``` 107 | 108 | 5. Run the Application 109 | ```bash 110 | python app.py 111 | ``` 112 | 113 | 6. Access the Application 114 | Open your web browser and navigate to: 115 | ``` 116 | http://localhost:5050/ 117 | ``` 118 | ## Debugging 119 | 120 | To assist with debugging localGPT-Vision, certain additional dependencies need to be installed on your system. Follow the instructions below to set up your environment for debugging purposes. 121 | 122 | ### Required Packages 123 | 124 | Install the necessary packages using the following commands: 125 | 126 | #### On Ubuntu/Debian-based Systems 127 | 128 | 1. **Update Package Lists** 129 | 130 | ```bash 131 | sudo apt update 132 | ``` 133 | 134 | 2. **Install Poppler Libraries and Utilities** 135 | 136 | These libraries are essential for handling PDF files. 137 | 138 | ```bash 139 | sudo apt install libpoppler-cpp-dev poppler-utils 140 | ``` 141 | 142 | 3. **Verify Installation of `pdftoppm`** 143 | 144 | Check the version to ensure it's installed correctly: 145 | 146 | ```bash 147 | pdftoppm -v 148 | ``` 149 | 150 | 4. **Install Additional Dependencies** 151 | 152 | These packages are required for building and managing libraries. 153 | 154 | ```bash 155 | sudo apt install cmake pkgconfig python3-poppler-qt5 156 | ``` 157 | 158 | ### Note 159 | 160 | Ensure that you have appropriate permissions to run `sudo` commands on your machine. This setup is specifically tailored for Ubuntu/Debian-based systems, and steps might vary slightly if using a different Linux distribution or macOS. 161 | 162 | --- 163 | 164 | Feel free to modify this section as needed based on any additional requirements or specific instructions pertinent to other operating systems. 165 | 166 | 167 | ## Usage 168 | ### Upload and Index Documents 169 | 1. Click on "New Chat" to start a new session. 170 | 2. Under "Upload and Index Documents", click "Choose Files" and select your PDF or image files. 171 | 3. Click "Upload and Index". The documents will be indexed using ColPali and ready for querying. 172 | 173 | ### Ask Questions 174 | 1. In the "Enter your question here" textbox, type your query related to the uploaded documents. 175 | 2. Click "Send". The system will retrieve relevant document pages and generate a response using the selected Vision Language Model. 176 | 177 | ### Manage Sessions 178 | - Rename Session: Click "Edit Name", enter a new name, and click "Save Name". 179 | - Switch Sessions: Click on a session name in the sidebar to switch to that session. 180 | - Delete Session: Click "Delete" next to a session to remove it permanently. 181 | 182 | ### Settings 183 | 1. Click on "Settings" in the navigation bar. 184 | 2. Select the desired language model and image dimensions. 185 | 3. Click "Save Settings". 186 | 187 | ## Project Structure 188 | ``` 189 | localGPT-Vision/ 190 | ├── app.py 191 | ├── logger.py 192 | ├── models/ 193 | │ ├── indexer.py 194 | │ ├── retriever.py 195 | │ ├── responder.py 196 | │ ├── model_loader.py 197 | │ └── converters.py 198 | ├── sessions/ 199 | ├── templates/ 200 | │ ├── base.html 201 | │ ├── chat.html 202 | │ ├── settings.html 203 | │ └── index.html 204 | ├── static/ 205 | │ ├── css/ 206 | │ │ └── style.css 207 | │ ├── js/ 208 | │ │ └── script.js 209 | │ └── images/ 210 | ├── uploaded_documents/ 211 | ├── byaldi_indices/ 212 | ├── requirements.txt 213 | ├── .gitignore 214 | └── README.md 215 | ``` 216 | 217 | - `app.py`: Main Flask application. 218 | - `logger.py`: Configures application logging. 219 | - `models/`: Contains modules for indexing, retrieving, and responding. 220 | - `templates/`: HTML templates for rendering views. 221 | - `static/`: Static files like CSS and JavaScript. 222 | - `sessions/`: Stores session data. 223 | - `uploaded_documents/`: Stores uploaded documents. 224 | - `.byaldi/`: Stores the indexes created by Byaldi. 225 | - `requirements.txt`: Python dependencies. 226 | - `.gitignore`: Files and directories to be ignored by Git. 227 | - `README.md`: Project documentation. 228 | 229 | ## System Workflow 230 | 1. User Interaction: The user interacts with the web interface to upload documents and ask questions. 231 | 2. Document Indexing with ColPali: 232 | - Uploaded documents are converted to PDFs if necessary. 233 | - Documents are indexed using ColPali, which creates embeddings based on the visual content of the document pages. 234 | - The indexes are stored in the byaldi_indices/ directory. 235 | 3. Session Management: 236 | - Each chat session has a unique ID and stores its own index and chat history. 237 | - Sessions are saved on disk and loaded upon application restart. 238 | 4. Query Processing: 239 | - User queries are sent to the backend. 240 | - The query is embedded and matched against the visual embeddings of document pages to retrieve relevant pages. 241 | 5. Response Generation with Vision Language Models: 242 | - The retrieved document images and the user query are passed to the selected Vision Language Model (Qwen, Gemini, or GPT-4). 243 | - The VLM generates a response by understanding both the visual and textual content of the documents. 244 | 6. Display Results: 245 | - The response and relevant document snippets are displayed in the chat interface. 246 | 247 | ```mermaid 248 | graph TD 249 | A[User] -->|Uploads Documents| B(Flask App) 250 | B -->|Saves Files| C[uploaded_documents/] 251 | B -->|Converts and Indexes with ColPali| D[Indexing Module] 252 | D -->|Creates Visual Embeddings| E[byaldi_indices/] 253 | A -->|Asks Question| B 254 | B -->|Embeds Query and Retrieves Pages| F[Retrieval Module] 255 | F -->|Retrieves Relevant Pages| E 256 | F -->|Passes Pages to| G[Vision Language Model] 257 | G -->|Generates Response| B 258 | B -->|Displays Response| A 259 | B -->|Saves Session Data| H[sessions/] 260 | subgraph Backend 261 | B 262 | D 263 | F 264 | G 265 | end 266 | subgraph Storage 267 | C 268 | E 269 | H 270 | end 271 | ``` 272 | 273 | ## Contributing 274 | Contributions are welcome! Please follow these steps: 275 | 276 | 1. Fork the repository. 277 | 2. Create a new branch for your feature: `git checkout -b feature-name`. 278 | 3. Commit your changes: `git commit -am 'Add new feature'`. 279 | 4. Push to the branch: `git push origin feature-name`. 280 | 5. Submit a pull request. 281 | 282 | ## Star History 283 | 284 | [![Star History Chart](https://api.star-history.com/svg?repos=PromtEngineer/localGPT-Vision&type=Date)](https://star-history.com/#PromtEngineer/localGPT-Vision&Date) 285 | 286 | -------------------------------------------------------------------------------- /models/responder.py: -------------------------------------------------------------------------------- 1 | # models/responder.py 2 | 3 | from models.model_loader import load_model, is_single_image_model 4 | from transformers import GenerationConfig 5 | import google.generativeai as genai 6 | from dotenv import load_dotenv 7 | from logger import get_logger 8 | from openai import OpenAI 9 | from PIL import Image 10 | import torch 11 | import base64 12 | import os 13 | import io 14 | # from langchain_core.messages import HumanMessage 15 | from io import BytesIO 16 | import ollama 17 | 18 | 19 | logger = get_logger(__name__) 20 | 21 | # Function to encode the image 22 | def encode_image(image_path): 23 | with open(image_path, "rb") as image_file: 24 | return base64.b64encode(image_file.read()).decode('utf-8') 25 | 26 | def generate_response(images, query, session_id, resized_height=280, resized_width=280, model_choice='qwen'): 27 | """ 28 | Generates a response using the selected model based on the query and images. 29 | Returns: (response_text, used_images) 30 | """ 31 | try: 32 | logger.info(f"Generating response using model '{model_choice}'.") 33 | 34 | # Convert resized_height and resized_width to integers 35 | resized_height = int(resized_height) 36 | resized_width = int(resized_width) 37 | 38 | # Ensure images are full paths 39 | full_image_paths = [os.path.join('static', img) if not img.startswith('static') else img for img in images] 40 | 41 | # Check if any valid images exist 42 | valid_images = [img for img in full_image_paths if os.path.exists(img)] 43 | 44 | if not valid_images: 45 | logger.warning("No valid images found for analysis.") 46 | return "No images could be loaded for analysis.", [] 47 | 48 | # If model only supports single image, use only the first image 49 | if is_single_image_model(model_choice): 50 | valid_images = [valid_images[0]] 51 | logger.info(f"Model {model_choice} only supports single image, using first image only.") 52 | 53 | if model_choice == 'qwen': 54 | from qwen_vl_utils import process_vision_info 55 | # Load cached model 56 | model, processor, device = load_model('qwen') 57 | # Ensure dimensions are multiples of 28 58 | resized_height = (resized_height // 28) * 28 59 | resized_width = (resized_width // 28) * 28 60 | 61 | image_contents = [] 62 | for image in valid_images: 63 | image_contents.append({ 64 | "type": "image", 65 | "image": image, # Use the full path 66 | "resized_height": resized_height, 67 | "resized_width": resized_width 68 | }) 69 | messages = [ 70 | { 71 | "role": "user", 72 | "content": image_contents + [{"type": "text", "text": query}], 73 | } 74 | ] 75 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 76 | image_inputs, video_inputs = process_vision_info(messages) 77 | inputs = processor( 78 | text=[text], 79 | images=image_inputs, 80 | videos=video_inputs, 81 | padding=True, 82 | return_tensors="pt", 83 | ) 84 | inputs = inputs.to(device) 85 | generated_ids = model.generate(**inputs, max_new_tokens=128) 86 | generated_ids_trimmed = [ 87 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 88 | ] 89 | output_text = processor.batch_decode( 90 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 91 | ) 92 | logger.info("Response generated using Qwen model.") 93 | return output_text[0], valid_images 94 | 95 | elif model_choice == 'gemini': 96 | model, _ = load_model('gemini') 97 | 98 | try: 99 | content = [query] # Add the text query first 100 | 101 | for img_path in valid_images: 102 | if os.path.exists(img_path): 103 | try: 104 | img = Image.open(img_path) 105 | content.append(img) 106 | except Exception as e: 107 | logger.error(f"Error opening image {img_path}: {e}") 108 | else: 109 | logger.warning(f"Image file not found: {img_path}") 110 | 111 | if len(content) == 1: # Only text, no images 112 | return "No images could be loaded for analysis.", [] 113 | 114 | response = model.generate_content(content) 115 | 116 | if response.text: 117 | generated_text = response.text 118 | logger.info("Response generated using Gemini model.") 119 | return generated_text, valid_images 120 | else: 121 | return "The Gemini model did not generate any text response.", [] 122 | 123 | except Exception as e: 124 | logger.error(f"Error in Gemini processing: {str(e)}", exc_info=True) 125 | return f"An error occurred while processing the images: {str(e)}", [] 126 | 127 | elif model_choice == 'gpt4': 128 | api_key = os.getenv("OPENAI_API_KEY") 129 | client = OpenAI(api_key=api_key) 130 | 131 | try: 132 | content = [{"type": "text", "text": query}] 133 | 134 | for img_path in valid_images: 135 | logger.info(f"Processing image: {img_path}") 136 | if os.path.exists(img_path): 137 | base64_image = encode_image(img_path) 138 | content.append({ 139 | "type": "image_url", 140 | "image_url": { 141 | "url": f"data:image/jpeg;base64,{base64_image}" 142 | } 143 | }) 144 | else: 145 | logger.warning(f"Image file not found: {img_path}") 146 | 147 | if len(content) == 1: # Only text, no images 148 | return "No images could be loaded for analysis.", [] 149 | 150 | response = client.chat.completions.create( 151 | model="gpt-4o", 152 | messages=[ 153 | { 154 | "role": "user", 155 | "content": content 156 | } 157 | ], 158 | max_tokens=1024 159 | ) 160 | 161 | generated_text = response.choices[0].message.content 162 | logger.info("Response generated using GPT-4 model.") 163 | return generated_text, valid_images 164 | 165 | except Exception as e: 166 | logger.error(f"Error in GPT-4 processing: {str(e)}", exc_info=True) 167 | return f"An error occurred while processing the images: {str(e)}", [] 168 | 169 | elif model_choice == 'llama-vision': 170 | # Load model, processor, and device 171 | model, processor, device = load_model('llama-vision') 172 | 173 | # Process images 174 | # For simplicity, use the first image 175 | image_path = valid_images[0] if valid_images else None 176 | if image_path and os.path.exists(image_path): 177 | image = Image.open(image_path).convert('RGB') 178 | else: 179 | return "No valid image found for analysis.", [] 180 | 181 | # Prepare messages 182 | messages = [ 183 | {"role": "user", "content": [ 184 | {"type": "image"}, 185 | {"type": "text", "text": query} 186 | ]} 187 | ] 188 | input_text = processor.apply_chat_template(messages, add_generation_prompt=True) 189 | inputs = processor(image, input_text, return_tensors="pt").to(device) 190 | 191 | # Generate response 192 | output = model.generate(**inputs, max_new_tokens=512) 193 | response = processor.decode(output[0], skip_special_tokens=True) 194 | return response, valid_images 195 | 196 | elif model_choice == "pixtral": 197 | model, tokenizer, generate_func, device = load_model('pixtral') 198 | 199 | def image_to_data_url(image_path): 200 | with open(image_path, "rb") as image_file: 201 | encoded_string = base64.b64encode(image_file.read()).decode('utf-8') 202 | ext = os.path.splitext(image_path)[1][1:] # Get the file extension 203 | return f"data:image/{ext};base64,{encoded_string}" 204 | 205 | from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk 206 | from mistral_common.protocol.instruct.request import ChatCompletionRequest 207 | 208 | # Prepare the content with text and images 209 | content = [TextChunk(text=query)] 210 | for img_path in valid_images[:1]: # Use only the first image 211 | content.append(ImageURLChunk(image_url=image_to_data_url(img_path))) 212 | 213 | completion_request = ChatCompletionRequest(messages=[UserMessage(content=content)]) 214 | 215 | encoded = tokenizer.encode_chat_completion(completion_request) 216 | 217 | images = encoded.images 218 | tokens = encoded.tokens 219 | 220 | out_tokens, _ = generate_func([tokens], model, images=[images], max_tokens=256, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) 221 | result = tokenizer.decode(out_tokens[0]) 222 | 223 | logger.info("Response generated using Pixtral model.") 224 | return result, valid_images 225 | 226 | elif model_choice == "molmo": 227 | model, processor, device = load_model('molmo') 228 | model = model.half() # Convert model to half precision 229 | pil_images = [] 230 | for img_path in valid_images[:1]: # Process only the first image for now 231 | if os.path.exists(img_path): 232 | try: 233 | img = Image.open(img_path).convert('RGB') 234 | pil_images.append(img) 235 | except Exception as e: 236 | logger.error(f"Error opening image {img_path}: {e}") 237 | else: 238 | logger.warning(f"Image file not found: {img_path}") 239 | 240 | if not pil_images: 241 | return "No images could be loaded for analysis.", [] 242 | 243 | try: 244 | # Process the images and text 245 | inputs = processor.process( 246 | images=pil_images, 247 | text=query 248 | ) 249 | 250 | # Move inputs to the correct device and make a batch of size 1 251 | # Convert float tensors to half precision, but keep integer tensors as they are 252 | inputs = {k: (v.to(device).unsqueeze(0).half() if v.dtype in [torch.float32, torch.float64] else 253 | v.to(device).unsqueeze(0)) 254 | if isinstance(v, torch.Tensor) else v 255 | for k, v in inputs.items()} 256 | 257 | # Generate output 258 | with torch.no_grad(): # Disable gradient calculation 259 | output = model.generate_from_batch( 260 | inputs, 261 | GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"), 262 | tokenizer=processor.tokenizer 263 | ) 264 | 265 | # Only get generated tokens; decode them to text 266 | generated_tokens = output[0, inputs['input_ids'].size(1):] 267 | generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) 268 | 269 | return generated_text, valid_images 270 | 271 | except Exception as e: 272 | logger.error(f"Error in Molmo processing: {str(e)}", exc_info=True) 273 | return f"An error occurred while processing the images: {str(e)}", [] 274 | finally: 275 | # Close the opened images to free up resources 276 | for img in pil_images: 277 | img.close() 278 | elif model_choice == 'groq-llama-vision': 279 | client = load_model('groq-llama-vision') 280 | 281 | content = [{"type": "text", "text": query}] 282 | 283 | # Use only the first image 284 | if valid_images: 285 | img_path = valid_images[0] 286 | if os.path.exists(img_path): 287 | base64_image = encode_image(img_path) 288 | content.append({ 289 | "type": "image_url", 290 | "image_url": { 291 | "url": f"data:image/jpeg;base64,{base64_image}" 292 | } 293 | }) 294 | else: 295 | logger.warning(f"Image file not found: {img_path}") 296 | 297 | if len(content) == 1: # Only text, no images 298 | return "No images could be loaded for analysis.", [] 299 | 300 | try: 301 | chat_completion = client.chat.completions.create( 302 | messages=[ 303 | { 304 | "role": "user", 305 | "content": content 306 | } 307 | ], 308 | model="llava-v1.5-7b-4096-preview", 309 | ) 310 | generated_text = chat_completion.choices[0].message.content 311 | logger.info("Response generated using Groq Llama Vision model.") 312 | return generated_text, valid_images 313 | except Exception as e: 314 | logger.error(f"Error in Groq Llama Vision processing: {str(e)}", exc_info=True) 315 | return f"An error occurred while processing the image: {str(e)}", [] 316 | elif model_choice == 'ollama-llama-vision': 317 | try: 318 | message = { 319 | 'role': 'user', 320 | 'content': query, 321 | 'images': [valid_images[0]] 322 | } 323 | 324 | response = ollama.chat( 325 | model='llama3.2-vision', 326 | messages=[message] 327 | ) 328 | 329 | logger.info("Response generated using Ollama Llama Vision model.") 330 | return response['message']['content'], valid_images 331 | 332 | except Exception as e: 333 | logger.error(f"Error in Ollama Llama Vision processing: {str(e)}", exc_info=True) 334 | return f"An error occurred while processing the image: {str(e)}", [] 335 | else: 336 | logger.error(f"Invalid model choice: {model_choice}") 337 | return "Invalid model selected.", [] 338 | except Exception as e: 339 | logger.error(f"Error generating response: {e}") 340 | return f"An error occurred while generating the response: {str(e)}", [] 341 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import json 4 | import time # Add this import at the top of the file 5 | from flask import Flask, render_template, request, redirect, url_for, session, flash, jsonify 6 | from markupsafe import Markup 7 | from models.indexer import index_documents 8 | from models.retriever import retrieve_documents 9 | from models.responder import generate_response 10 | from werkzeug.utils import secure_filename 11 | from logger import get_logger 12 | from byaldi import RAGMultiModalModel 13 | import markdown 14 | 15 | # Set the TOKENIZERS_PARALLELISM environment variable to suppress warnings 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | # Initialize the Flask application 19 | app = Flask(__name__) 20 | app.secret_key = 'your_secret_key' # Replace with a secure secret key 21 | 22 | logger = get_logger(__name__) 23 | 24 | # Configure upload folders 25 | app.config['UPLOAD_FOLDER'] = 'uploaded_documents' 26 | app.config['STATIC_FOLDER'] = 'static' 27 | app.config['SESSION_FOLDER'] = 'sessions' 28 | app.config['INDEX_FOLDER'] = os.path.join(os.getcwd(), '.byaldi') # Set to .byaldi folder in current directory 29 | 30 | # Create necessary directories if they don't exist 31 | os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) 32 | os.makedirs(app.config['STATIC_FOLDER'], exist_ok=True) 33 | os.makedirs(app.config['SESSION_FOLDER'], exist_ok=True) 34 | 35 | # Initialize global variables 36 | RAG_models = {} # Dictionary to store RAG models per session 37 | app.config['INITIALIZATION_DONE'] = False # Flag to track initialization 38 | logger.info("Application started.") 39 | 40 | def load_rag_model_for_session(session_id): 41 | """ 42 | Loads the RAG model for the given session_id from the index on disk. 43 | """ 44 | index_path = os.path.join(app.config['INDEX_FOLDER'], session_id) 45 | 46 | if os.path.exists(index_path): 47 | try: 48 | RAG = RAGMultiModalModel.from_index(index_path) 49 | RAG_models[session_id] = RAG 50 | logger.info(f"RAG model for session {session_id} loaded from index.") 51 | except Exception as e: 52 | logger.error(f"Error loading RAG model for session {session_id}: {e}") 53 | else: 54 | logger.warning(f"No index found for session {session_id}.") 55 | 56 | def load_existing_indexes(): 57 | """ 58 | Loads all existing indexes from the .byaldi folder when the application starts. 59 | """ 60 | global RAG_models 61 | if os.path.exists(app.config['INDEX_FOLDER']): 62 | for session_id in os.listdir(app.config['INDEX_FOLDER']): 63 | if os.path.isdir(os.path.join(app.config['INDEX_FOLDER'], session_id)): 64 | load_rag_model_for_session(session_id) 65 | else: 66 | logger.warning("No .byaldi folder found. No existing indexes to load.") 67 | 68 | @app.before_request 69 | def initialize_app(): 70 | """ 71 | Initializes the application by loading existing indexes. 72 | This will run before the first request, but only once. 73 | """ 74 | if not app.config['INITIALIZATION_DONE']: 75 | load_existing_indexes() 76 | app.config['INITIALIZATION_DONE'] = True 77 | logger.info("Application initialized and indexes loaded.") 78 | 79 | @app.before_request 80 | def make_session_permanent(): 81 | session.permanent = True 82 | if 'session_id' not in session: 83 | session['session_id'] = str(uuid.uuid4()) 84 | 85 | 86 | @app.route('/', methods=['GET']) 87 | def home(): 88 | return redirect(url_for('chat')) 89 | 90 | @app.route('/chat', methods=['GET', 'POST']) 91 | def chat(): 92 | if 'session_id' not in session: 93 | session['session_id'] = str(uuid.uuid4()) 94 | 95 | session_id = session['session_id'] 96 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json") 97 | 98 | # Load session data from file 99 | if os.path.exists(session_file): 100 | with open(session_file, 'r') as f: 101 | session_data = json.load(f) 102 | chat_history = session_data.get('chat_history', []) 103 | session_name = session_data.get('session_name', 'Untitled Session') 104 | indexed_files = session_data.get('indexed_files', []) 105 | else: 106 | chat_history = [] 107 | session_name = 'Untitled Session' 108 | indexed_files = [] 109 | 110 | if request.method == 'POST': 111 | if 'upload' in request.form: 112 | # Handle file upload and indexing 113 | files = request.files.getlist('file') 114 | session_folder = os.path.join(app.config['UPLOAD_FOLDER'], session_id) 115 | os.makedirs(session_folder, exist_ok=True) 116 | uploaded_files = [] 117 | for file in files: 118 | if file and file.filename: 119 | filename = secure_filename(file.filename) 120 | file_path = os.path.join(session_folder, filename) 121 | file.save(file_path) 122 | uploaded_files.append(filename) 123 | logger.info(f"File saved: {file_path}") 124 | 125 | if uploaded_files: 126 | try: 127 | index_name = session_id 128 | index_path = os.path.join(app.config['INDEX_FOLDER'], index_name) 129 | indexer_model = session.get('indexer_model', 'vidore/colpali') 130 | RAG = index_documents(session_folder, index_name=index_name, index_path=index_path, indexer_model=indexer_model) 131 | if RAG is None: 132 | raise ValueError("Indexing failed: RAG model is None") 133 | RAG_models[session_id] = RAG 134 | session['index_name'] = index_name 135 | session['session_folder'] = session_folder 136 | indexed_files.extend(uploaded_files) 137 | session_data = { 138 | 'session_name': session_name, 139 | 'chat_history': chat_history, 140 | 'indexed_files': indexed_files 141 | } 142 | with open(session_file, 'w') as f: 143 | json.dump(session_data, f) 144 | logger.info("Documents indexed successfully.") 145 | return jsonify({ 146 | "success": True, 147 | "message": "Files indexed successfully.", 148 | "indexed_files": indexed_files 149 | }) 150 | except Exception as e: 151 | logger.error(f"Error indexing documents: {str(e)}") 152 | return jsonify({"success": False, "message": f"Error indexing files: {str(e)}"}) 153 | else: 154 | return jsonify({"success": False, "message": "No files were uploaded."}) 155 | 156 | elif 'send_query' in request.form: 157 | query = request.form['query'] 158 | 159 | try: 160 | generation_model = session.get('generation_model', 'qwen') 161 | resized_height = session.get('resized_height', 280) 162 | resized_width = session.get('resized_width', 280) 163 | 164 | # Retrieve relevant documents 165 | rag_model = RAG_models.get(session_id) 166 | if rag_model is None: 167 | logger.error(f"RAG model not found for session {session_id}") 168 | return jsonify({"success": False, "message": "RAG model not found for this session."}) 169 | 170 | retrieved_images = retrieve_documents(rag_model, query, session_id) 171 | logger.info(f"Retrieved images: {retrieved_images}") 172 | 173 | # Generate response with full image paths 174 | full_image_paths = [os.path.join(app.static_folder, img) for img in retrieved_images] 175 | response_text, used_images = generate_response( 176 | full_image_paths, 177 | query, 178 | session_id, 179 | resized_height, 180 | resized_width, 181 | generation_model 182 | ) 183 | 184 | # Parse markdown in the response 185 | parsed_response = Markup(markdown.markdown(response_text)) 186 | 187 | # Get relative paths for used images 188 | relative_images = [os.path.relpath(img, app.static_folder) for img in used_images] 189 | 190 | # Update chat history 191 | chat_history.append({"role": "user", "content": query}) 192 | chat_history.append({ 193 | "role": "assistant", 194 | "content": parsed_response, 195 | "images": relative_images # Use relative paths for frontend 196 | }) 197 | 198 | # Update session name if it's the first message 199 | if len(chat_history) == 2: # First user message and AI response 200 | session_name = query[:50] # Truncate to 50 characters 201 | 202 | session_data = { 203 | 'session_name': session_name, 204 | 'chat_history': chat_history, 205 | 'indexed_files': indexed_files 206 | } 207 | with open(session_file, 'w') as f: 208 | json.dump(session_data, f) 209 | 210 | # Render the new messages 211 | new_messages_html = render_template('chat_messages.html', messages=[ 212 | {"role": "user", "content": query}, 213 | {"role": "assistant", "content": parsed_response, "images": relative_images} 214 | ]) 215 | 216 | return jsonify({ 217 | "success": True, 218 | "html": new_messages_html 219 | }) 220 | except Exception as e: 221 | logger.error(f"Error generating response: {e}", exc_info=True) 222 | return jsonify({ 223 | "success": False, 224 | "message": f"An error occurred while generating the response: {str(e)}" 225 | }) 226 | 227 | # For GET requests, render the chat page 228 | session_files = os.listdir(app.config['SESSION_FOLDER']) 229 | chat_sessions = [] 230 | for file in session_files: 231 | if file.endswith('.json'): 232 | s_id = file[:-5] 233 | with open(os.path.join(app.config['SESSION_FOLDER'], file), 'r') as f: 234 | data = json.load(f) 235 | name = data.get('session_name', 'Untitled Session') 236 | chat_sessions.append({'id': s_id, 'name': name}) 237 | 238 | model_choice = session.get('model', 'qwen') 239 | resized_height = session.get('resized_height', 280) 240 | resized_width = session.get('resized_width', 280) 241 | 242 | return render_template('chat.html', chat_history=chat_history, chat_sessions=chat_sessions, 243 | current_session=session_id, model_choice=model_choice, 244 | resized_height=resized_height, resized_width=resized_width, 245 | session_name=session_name, indexed_files=indexed_files) 246 | 247 | @app.route('/switch_session/') 248 | def switch_session(session_id): 249 | session['session_id'] = session_id 250 | if session_id not in RAG_models: 251 | load_rag_model_for_session(session_id) 252 | flash(f"Switched to session.", "info") 253 | return redirect(url_for('chat')) 254 | 255 | @app.route('/rename_session', methods=['POST']) 256 | def rename_session(): 257 | session_id = request.form.get('session_id') 258 | new_session_name = request.form.get('new_session_name', 'Untitled Session') 259 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json") 260 | 261 | if os.path.exists(session_file): 262 | with open(session_file, 'r') as f: 263 | session_data = json.load(f) 264 | 265 | session_data['session_name'] = new_session_name 266 | 267 | with open(session_file, 'w') as f: 268 | json.dump(session_data, f) 269 | 270 | return jsonify({"success": True, "message": "Session name updated."}) 271 | else: 272 | return jsonify({"success": False, "message": "Session not found."}) 273 | 274 | @app.route('/delete_session/', methods=['POST']) 275 | def delete_session(session_id): 276 | try: 277 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json") 278 | if os.path.exists(session_file): 279 | os.remove(session_file) 280 | 281 | session_folder = os.path.join(app.config['UPLOAD_FOLDER'], session_id) 282 | if os.path.exists(session_folder): 283 | import shutil 284 | shutil.rmtree(session_folder) 285 | 286 | session_images_folder = os.path.join('static', 'images', session_id) 287 | if os.path.exists(session_images_folder): 288 | import shutil 289 | shutil.rmtree(session_images_folder) 290 | 291 | RAG_models.pop(session_id, None) 292 | 293 | if session.get('session_id') == session_id: 294 | session['session_id'] = str(uuid.uuid4()) 295 | 296 | logger.info(f"Session {session_id} deleted.") 297 | return jsonify({"success": True, "message": "Session deleted successfully."}) 298 | except Exception as e: 299 | logger.error(f"Error deleting session {session_id}: {e}") 300 | return jsonify({"success": False, "message": f"An error occurred while deleting the session: {str(e)}"}) 301 | 302 | @app.route('/settings', methods=['GET', 'POST']) 303 | def settings(): 304 | if request.method == 'POST': 305 | indexer_model = request.form.get('indexer_model', 'vidore/colpali') 306 | generation_model = request.form.get('generation_model', 'qwen') 307 | resized_height = request.form.get('resized_height', 280) 308 | resized_width = request.form.get('resized_width', 280) 309 | session['indexer_model'] = indexer_model 310 | session['generation_model'] = generation_model 311 | session['resized_height'] = resized_height 312 | session['resized_width'] = resized_width 313 | session.modified = True 314 | logger.info(f"Settings updated: indexer_model={indexer_model}, generation_model={generation_model}, resized_height={resized_height}, resized_width={resized_width}") 315 | flash("Settings updated.", "success") 316 | return redirect(url_for('chat')) 317 | else: 318 | indexer_model = session.get('indexer_model', 'vidore/colpali') 319 | generation_model = session.get('generation_model', 'qwen') 320 | resized_height = session.get('resized_height', 280) 321 | resized_width = session.get('resized_width', 280) 322 | return render_template('settings.html', 323 | indexer_model=indexer_model, 324 | generation_model=generation_model, 325 | resized_height=resized_height, 326 | resized_width=resized_width) 327 | 328 | @app.route('/new_session') 329 | def new_session(): 330 | session_id = str(uuid.uuid4()) 331 | session['session_id'] = session_id 332 | session_files = os.listdir(app.config['SESSION_FOLDER']) 333 | session_number = len([f for f in session_files if f.endswith('.json')]) + 1 334 | session_name = f"Session {session_number}" 335 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json") 336 | session_data = { 337 | 'session_name': session_name, 338 | 'chat_history': [], 339 | 'indexed_files': [] 340 | } 341 | with open(session_file, 'w') as f: 342 | json.dump(session_data, f) 343 | flash("New chat session started.", "success") 344 | return redirect(url_for('chat')) 345 | 346 | @app.route('/get_indexed_files/') 347 | def get_indexed_files(session_id): 348 | session_file = os.path.join(app.config['SESSION_FOLDER'], f"{session_id}.json") 349 | if os.path.exists(session_file): 350 | with open(session_file, 'r') as f: 351 | session_data = json.load(f) 352 | indexed_files = session_data.get('indexed_files', []) 353 | return jsonify({"success": True, "indexed_files": indexed_files}) 354 | else: 355 | return jsonify({"success": False, "message": "Session not found."}) 356 | 357 | if __name__ == '__main__': 358 | app.run(port=5050, debug=True) -------------------------------------------------------------------------------- /templates/chat.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | {% extends 'base.html' %} 4 | 5 | {% block content %} 6 |
7 |
8 | {% for message in chat_history %} 9 |
10 | {% if message.role == 'user' %} 11 | {{ message.content }} 12 | {% else %} 13 | {{ message.content|safe }} 14 | {% endif %} 15 | {% if message.images %} 16 |
17 | {% for image in message.images %} 18 | Retrieved Image 19 | {% endfor %} 20 |
21 | {% endif %} 22 |
23 | {% endfor %} 24 |
25 |
26 |
27 |
28 | 29 | 32 | 33 | 36 |
37 |
38 |
39 |
40 | 41 | 42 | 48 | 49 | 50 | 73 | 74 | 75 | 94 | 95 | 96 | 116 | 117 | 118 | 133 | 134 | {% endblock %} 135 | 136 | {% block scripts %} 137 | 378 | {% endblock %} -------------------------------------------------------------------------------- /jarvis/src/gui/main_window.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PyQt6.QtWidgets import ( 3 | QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, 4 | QPushButton, QTextEdit, QLabel, QComboBox, 5 | QCheckBox, QLineEdit, QProgressBar, QFileDialog, 6 | QMessageBox, QSplitter, QApplication 7 | ) 8 | from PyQt6.QtCore import Qt, QTimer, pyqtSignal, QThread 9 | from PyQt6.QtGui import QAction, QKeySequence, QFont 10 | from typing import Optional 11 | 12 | from ..pipeline import TranscriptionPipeline 13 | from ..models.transcriber import AVAILABLE_MODELS 14 | from ..models.text_corrector import AVAILABLE_LLM_MODELS 15 | import sounddevice as sd 16 | 17 | class TranscriptionThread(QThread): 18 | """Background thread for transcription""" 19 | textReady = pyqtSignal(str) 20 | errorOccurred = pyqtSignal(str) 21 | levelUpdate = pyqtSignal(float) 22 | 23 | def __init__(self, pipeline: TranscriptionPipeline, device_id: Optional[int] = None): 24 | super().__init__() 25 | self.pipeline = pipeline 26 | self.device_id = device_id 27 | self.is_recording = False 28 | 29 | def run(self): 30 | """Run transcription in background""" 31 | try: 32 | self.is_recording = True 33 | self.pipeline.start_recording(device_id=self.device_id) 34 | 35 | # Update audio levels 36 | while self.is_recording: 37 | level = self.pipeline.get_audio_level() 38 | self.levelUpdate.emit(level) 39 | self.msleep(50) # 50ms updates 40 | 41 | except Exception as e: 42 | self.errorOccurred.emit(str(e)) 43 | 44 | def stop_recording(self): 45 | """Stop recording and get transcription""" 46 | self.is_recording = False 47 | text = self.pipeline.stop_recording() 48 | self.textReady.emit(text) 49 | 50 | class MainWindow(QMainWindow): 51 | """Main application window""" 52 | 53 | def __init__(self): 54 | super().__init__() 55 | self.pipeline = TranscriptionPipeline() 56 | self.transcription_thread = None 57 | self.init_ui() 58 | self.setup_shortcuts() 59 | 60 | def init_ui(self): 61 | """Initialize the user interface""" 62 | self.setWindowTitle("Speech Transcription") 63 | self.setGeometry(100, 100, 900, 700) 64 | 65 | # Central widget 66 | central_widget = QWidget() 67 | self.setCentralWidget(central_widget) 68 | 69 | # Main layout 70 | layout = QVBoxLayout(central_widget) 71 | 72 | # Header section 73 | header_layout = QHBoxLayout() 74 | 75 | # Microphone selection 76 | mic_label = QLabel("Microphone:") 77 | self.mic_combo = QComboBox() 78 | self.refresh_microphones() 79 | self.mic_combo.currentIndexChanged.connect(self.on_mic_changed) 80 | 81 | # Add refresh button for microphones 82 | self.refresh_mic_button = QPushButton("🔄") 83 | self.refresh_mic_button.setMaximumWidth(30) 84 | self.refresh_mic_button.clicked.connect(self.refresh_microphones) 85 | self.refresh_mic_button.setToolTip("Refresh microphone list") 86 | 87 | header_layout.addWidget(mic_label) 88 | header_layout.addWidget(self.mic_combo) 89 | header_layout.addWidget(self.refresh_mic_button) 90 | 91 | # Model selection 92 | model_label = QLabel("Model:") 93 | self.model_combo = QComboBox() 94 | for model_name, config in AVAILABLE_MODELS.items(): 95 | self.model_combo.addItem( 96 | f"{model_name} - {config['description']}", 97 | model_name 98 | ) 99 | self.model_combo.currentIndexChanged.connect(self.on_model_changed) 100 | 101 | header_layout.addWidget(model_label) 102 | header_layout.addWidget(self.model_combo) 103 | header_layout.addStretch() 104 | 105 | # Options 106 | self.auto_correct_check = QCheckBox("Auto-correct after recording") 107 | self.auto_correct_check.setChecked(True) 108 | header_layout.addWidget(self.auto_correct_check) 109 | 110 | layout.addLayout(header_layout) 111 | 112 | # Second header row for LLM model and context 113 | second_header_layout = QHBoxLayout() 114 | 115 | # LLM Model selection 116 | llm_label = QLabel("LLM Model:") 117 | self.llm_combo = QComboBox() 118 | for model_name, config in AVAILABLE_LLM_MODELS.items(): 119 | self.llm_combo.addItem( 120 | f"{model_name} - {config['description']}", 121 | model_name 122 | ) 123 | # Set default to Phi-3.5-mini 124 | default_index = list(AVAILABLE_LLM_MODELS.keys()).index("Phi-3.5-mini") 125 | self.llm_combo.setCurrentIndex(default_index) 126 | self.llm_combo.currentIndexChanged.connect(self.on_llm_model_changed) 127 | 128 | second_header_layout.addWidget(llm_label) 129 | second_header_layout.addWidget(self.llm_combo) 130 | 131 | # Context input 132 | context_label = QLabel("Context:") 133 | self.context_input = QLineEdit() 134 | self.context_input.setPlaceholderText("e.g., Medical discussion, Technical meeting") 135 | second_header_layout.addWidget(context_label) 136 | second_header_layout.addWidget(self.context_input) 137 | 138 | layout.addLayout(second_header_layout) 139 | 140 | # Audio level indicator 141 | self.level_bar = QProgressBar() 142 | self.level_bar.setMaximum(100) 143 | self.level_bar.setTextVisible(False) 144 | self.level_bar.setFixedHeight(10) 145 | layout.addWidget(self.level_bar) 146 | 147 | # Control buttons 148 | button_layout = QHBoxLayout() 149 | 150 | self.record_button = QPushButton("Start Recording") 151 | self.record_button.clicked.connect(self.toggle_recording) 152 | self.record_button.setStyleSheet(""" 153 | QPushButton { 154 | background-color: #4CAF50; 155 | color: white; 156 | font-size: 16px; 157 | font-weight: bold; 158 | padding: 10px; 159 | border-radius: 5px; 160 | } 161 | QPushButton:hover { 162 | background-color: #45a049; 163 | } 164 | QPushButton:pressed { 165 | background-color: #3d8b40; 166 | } 167 | """) 168 | 169 | self.clear_button = QPushButton("Clear") 170 | self.clear_button.clicked.connect(self.clear_text) 171 | 172 | self.correct_button = QPushButton("Correct Text") 173 | self.correct_button.clicked.connect(self.correct_text) 174 | 175 | button_layout.addWidget(self.record_button) 176 | button_layout.addWidget(self.clear_button) 177 | button_layout.addWidget(self.correct_button) 178 | button_layout.addStretch() 179 | 180 | layout.addLayout(button_layout) 181 | 182 | # Text display area with splitter 183 | splitter = QSplitter(Qt.Orientation.Horizontal) 184 | 185 | # Raw transcription 186 | raw_container = QWidget() 187 | raw_layout = QVBoxLayout(raw_container) 188 | 189 | # Raw header with label and copy button 190 | raw_header = QHBoxLayout() 191 | raw_label = QLabel("Raw Transcription:") 192 | self.copy_raw_button = QPushButton("Copy") 193 | self.copy_raw_button.clicked.connect(self.copy_raw_text) 194 | raw_header.addWidget(raw_label) 195 | raw_header.addStretch() 196 | raw_header.addWidget(self.copy_raw_button) 197 | raw_layout.addLayout(raw_header) 198 | 199 | self.raw_text_edit = QTextEdit() 200 | self.raw_text_edit.setReadOnly(True) 201 | raw_layout.addWidget(self.raw_text_edit) 202 | 203 | # Corrected text 204 | corrected_container = QWidget() 205 | corrected_layout = QVBoxLayout(corrected_container) 206 | 207 | # Corrected header with label and copy button 208 | corrected_header = QHBoxLayout() 209 | corrected_label = QLabel("Corrected Text:") 210 | self.copy_corrected_button = QPushButton("Copy") 211 | self.copy_corrected_button.clicked.connect(self.copy_corrected_text) 212 | corrected_header.addWidget(corrected_label) 213 | corrected_header.addStretch() 214 | corrected_header.addWidget(self.copy_corrected_button) 215 | corrected_layout.addLayout(corrected_header) 216 | 217 | self.corrected_text_edit = QTextEdit() 218 | corrected_layout.addWidget(self.corrected_text_edit) 219 | 220 | splitter.addWidget(raw_container) 221 | splitter.addWidget(corrected_container) 222 | splitter.setSizes([450, 450]) 223 | 224 | layout.addWidget(splitter) 225 | 226 | # Status bar 227 | self.status_label = QLabel("Ready") 228 | layout.addWidget(self.status_label) 229 | 230 | # Apply dark theme 231 | self.setStyleSheet(""" 232 | QMainWindow { 233 | background-color: #2b2b2b; 234 | } 235 | QLabel { 236 | color: #ffffff; 237 | font-size: 14px; 238 | } 239 | QTextEdit { 240 | background-color: #3c3c3c; 241 | color: #ffffff; 242 | border: 1px solid #555; 243 | font-size: 14px; 244 | font-family: Monaco, Menlo, monospace; 245 | } 246 | QComboBox, QLineEdit { 247 | background-color: #3c3c3c; 248 | color: #ffffff; 249 | border: 1px solid #555; 250 | padding: 5px; 251 | } 252 | QComboBox::drop-down { 253 | border: none; 254 | } 255 | QComboBox::down-arrow { 256 | image: none; 257 | border-left: 5px solid transparent; 258 | border-right: 5px solid transparent; 259 | border-top: 5px solid #ffffff; 260 | margin-right: 5px; 261 | } 262 | QPushButton { 263 | background-color: #0d7377; 264 | color: white; 265 | border: none; 266 | padding: 8px 16px; 267 | font-size: 14px; 268 | border-radius: 4px; 269 | } 270 | QPushButton:hover { 271 | background-color: #14a085; 272 | } 273 | QCheckBox { 274 | color: #ffffff; 275 | font-size: 14px; 276 | } 277 | QCheckBox::indicator { 278 | width: 18px; 279 | height: 18px; 280 | } 281 | QProgressBar { 282 | background-color: #3c3c3c; 283 | border: 1px solid #555; 284 | } 285 | QProgressBar::chunk { 286 | background-color: #4CAF50; 287 | } 288 | """) 289 | 290 | def setup_shortcuts(self): 291 | """Setup keyboard shortcuts""" 292 | # Space to start/stop recording 293 | record_action = QAction("Record", self) 294 | record_action.setShortcut(QKeySequence(Qt.Key.Key_Space)) 295 | record_action.triggered.connect(self.toggle_recording) 296 | self.addAction(record_action) 297 | 298 | # Escape to cancel recording 299 | cancel_action = QAction("Cancel", self) 300 | cancel_action.setShortcut(QKeySequence(Qt.Key.Key_Escape)) 301 | cancel_action.triggered.connect(self.cancel_recording) 302 | self.addAction(cancel_action) 303 | 304 | # Cmd+S to save 305 | save_action = QAction("Save", self) 306 | save_action.setShortcut(QKeySequence.StandardKey.Save) 307 | save_action.triggered.connect(self.save_transcription) 308 | self.addAction(save_action) 309 | 310 | def toggle_recording(self): 311 | """Start or stop recording""" 312 | if self.transcription_thread and self.transcription_thread.isRunning(): 313 | self.stop_recording() 314 | else: 315 | self.start_recording() 316 | 317 | def start_recording(self): 318 | """Start recording audio""" 319 | # Update UI 320 | self.record_button.setText("Stop Recording") 321 | self.record_button.setStyleSheet(""" 322 | QPushButton { 323 | background-color: #f44336; 324 | color: white; 325 | font-size: 16px; 326 | font-weight: bold; 327 | padding: 10px; 328 | border-radius: 5px; 329 | } 330 | """) 331 | self.status_label.setText("Recording...") 332 | 333 | # Get selected model 334 | model_name = self.model_combo.currentData() 335 | if model_name != self.pipeline.transcriber.model_name: 336 | self.pipeline.set_model(model_name) 337 | 338 | # Clear previous text 339 | self.raw_text_edit.clear() 340 | 341 | # Get selected microphone 342 | device_id = self.mic_combo.currentData() if self.mic_combo.currentIndex() >= 0 else None 343 | 344 | # Start recording in background thread 345 | self.transcription_thread = TranscriptionThread(self.pipeline, device_id) 346 | self.transcription_thread.textReady.connect(self.on_transcription_ready) 347 | self.transcription_thread.errorOccurred.connect(self.on_error) 348 | self.transcription_thread.levelUpdate.connect(self.update_level) 349 | self.transcription_thread.start() 350 | 351 | def stop_recording(self): 352 | """Stop recording and transcribe""" 353 | if self.transcription_thread: 354 | self.transcription_thread.stop_recording() 355 | self.status_label.setText("Transcribing...") 356 | 357 | # Reset button 358 | self.record_button.setText("Start Recording") 359 | self.record_button.setStyleSheet(""" 360 | QPushButton { 361 | background-color: #4CAF50; 362 | color: white; 363 | font-size: 16px; 364 | font-weight: bold; 365 | padding: 10px; 366 | border-radius: 5px; 367 | } 368 | """) 369 | 370 | def cancel_recording(self): 371 | """Cancel current recording""" 372 | try: 373 | if self.transcription_thread and self.transcription_thread.isRunning(): 374 | # Stop recording gracefully first 375 | self.transcription_thread.is_recording = False 376 | # Wait a bit for thread to finish 377 | if not self.transcription_thread.wait(500): # 500ms timeout 378 | # If it doesn't finish, terminate it 379 | self.transcription_thread.terminate() 380 | self.transcription_thread.wait() # Wait for termination 381 | 382 | self.transcription_thread = None 383 | self.record_button.setText("Start Recording") 384 | self.record_button.setStyleSheet(""" 385 | QPushButton { 386 | background-color: #4CAF50; 387 | color: white; 388 | font-size: 16px; 389 | font-weight: bold; 390 | padding: 10px; 391 | border-radius: 5px; 392 | } 393 | """) 394 | self.status_label.setText("Recording cancelled") 395 | self.level_bar.setValue(0) 396 | except Exception as e: 397 | print(f"Error in cancel_recording: {e}") 398 | 399 | def on_transcription_ready(self, text: str): 400 | """Handle transcription result""" 401 | self.raw_text_edit.setText(text) 402 | self.status_label.setText("Transcription complete") 403 | self.level_bar.setValue(0) 404 | 405 | # Auto-correct if enabled 406 | if self.auto_correct_check.isChecked() and text.strip(): 407 | self.correct_text() 408 | 409 | def correct_text(self): 410 | """Correct the transcribed text""" 411 | raw_text = self.raw_text_edit.toPlainText() 412 | if not raw_text.strip(): 413 | return 414 | 415 | self.status_label.setText("Correcting text...") 416 | context = self.context_input.text() 417 | 418 | # Run correction in background 419 | QTimer.singleShot(0, lambda: self._do_correction(raw_text, context)) 420 | 421 | def _do_correction(self, text: str, context: str): 422 | """Perform text correction""" 423 | try: 424 | corrected = self.pipeline.correct_text(text, context) 425 | self.corrected_text_edit.setText(corrected) 426 | self.status_label.setText("Text correction complete") 427 | except Exception as e: 428 | self.on_error(f"Correction failed: {str(e)}") 429 | 430 | def update_level(self, level: float): 431 | """Update audio level indicator""" 432 | # Convert to percentage (0-100) 433 | percentage = min(int(level * 1000), 100) 434 | self.level_bar.setValue(percentage) 435 | 436 | def on_model_changed(self): 437 | """Handle model selection change""" 438 | model_name = self.model_combo.currentData() 439 | self.status_label.setText(f"Model changed to {model_name}") 440 | 441 | def clear_text(self): 442 | """Clear all text fields""" 443 | try: 444 | # Stop any running recording first 445 | if self.transcription_thread and self.transcription_thread.isRunning(): 446 | self.cancel_recording() 447 | 448 | self.raw_text_edit.clear() 449 | self.corrected_text_edit.clear() 450 | self.context_input.clear() 451 | self.status_label.setText("Ready") 452 | except Exception as e: 453 | print(f"Error in clear_text: {e}") 454 | self.on_error(f"Clear failed: {str(e)}") 455 | 456 | def save_transcription(self): 457 | """Save transcription to file""" 458 | text = self.corrected_text_edit.toPlainText() 459 | if not text: 460 | text = self.raw_text_edit.toPlainText() 461 | 462 | if not text: 463 | QMessageBox.warning(self, "Warning", "No text to save") 464 | return 465 | 466 | filename, _ = QFileDialog.getSaveFileName( 467 | self, "Save Transcription", "", "Text Files (*.txt)" 468 | ) 469 | 470 | if filename: 471 | try: 472 | with open(filename, 'w') as f: 473 | f.write(text) 474 | self.status_label.setText(f"Saved to {filename}") 475 | except Exception as e: 476 | self.on_error(f"Save failed: {str(e)}") 477 | 478 | def on_error(self, error_msg: str): 479 | """Handle errors""" 480 | self.status_label.setText(f"Error: {error_msg}") 481 | QMessageBox.critical(self, "Error", error_msg) 482 | 483 | def refresh_microphones(self): 484 | """Refresh the list of available microphones""" 485 | self.mic_combo.clear() 486 | 487 | # Get list of audio input devices 488 | devices = sd.query_devices() 489 | current_default = sd.default.device[0] # Default input device 490 | 491 | default_index = 0 492 | for i, device in enumerate(devices): 493 | if device['max_input_channels'] > 0: # Only input devices 494 | device_name = f"{device['name']} ({device['hostapi']})" 495 | self.mic_combo.addItem(device_name, i) 496 | 497 | # Track default device 498 | if i == current_default: 499 | default_index = self.mic_combo.count() - 1 500 | 501 | # Set to default device 502 | self.mic_combo.setCurrentIndex(default_index) 503 | 504 | def on_mic_changed(self): 505 | """Handle microphone selection change""" 506 | if self.mic_combo.currentIndex() >= 0: 507 | device_id = self.mic_combo.currentData() 508 | device_name = self.mic_combo.currentText() 509 | self.status_label.setText(f"Microphone changed to {device_name}") 510 | 511 | def copy_raw_text(self): 512 | """Copy raw transcription to clipboard""" 513 | text = self.raw_text_edit.toPlainText() 514 | if text: 515 | clipboard = QApplication.clipboard() 516 | clipboard.setText(text) 517 | self.status_label.setText("Raw text copied to clipboard") 518 | 519 | def copy_corrected_text(self): 520 | """Copy corrected text to clipboard""" 521 | text = self.corrected_text_edit.toPlainText() 522 | if text: 523 | clipboard = QApplication.clipboard() 524 | clipboard.setText(text) 525 | self.status_label.setText("Corrected text copied to clipboard") 526 | 527 | def on_llm_model_changed(self): 528 | """Handle LLM model selection change""" 529 | model_name = self.llm_combo.currentData() 530 | self.status_label.setText(f"LLM model changed to {model_name}") 531 | # Update the pipeline with new LLM model 532 | self.pipeline.set_llm_model(model_name) --------------------------------------------------------------------------------