├── .env ├── README.md ├── backend ├── asr │ ├── Dockerfile │ ├── app.py │ └── requirements.txt ├── llm │ ├── Dockerfile │ ├── app.py │ └── requirements.txt ├── orchestrator │ ├── Dockerfile │ ├── orchestrator.py │ └── requirements.txt └── tts │ ├── Dockerfile │ ├── app.py │ ├── csm_utils │ ├── __init__.py │ └── loader.py │ ├── generator.py │ ├── models.py │ └── requirements.txt ├── docker-compose.yml ├── frontend ├── .gitignore ├── .vscode │ └── extensions.json ├── README.md ├── index.html ├── package-lock.json ├── package.json ├── public │ ├── tauri.svg │ └── vite.svg ├── src-tauri │ ├── .gitignore │ ├── Cargo.lock │ ├── Cargo.toml │ ├── build.rs │ ├── capabilities │ │ └── default.json │ ├── icons │ │ ├── 128x128.png │ │ ├── 128x128@2x.png │ │ ├── 32x32.png │ │ ├── Square107x107Logo.png │ │ ├── Square142x142Logo.png │ │ ├── Square150x150Logo.png │ │ ├── Square284x284Logo.png │ │ ├── Square30x30Logo.png │ │ ├── Square310x310Logo.png │ │ ├── Square44x44Logo.png │ │ ├── Square71x71Logo.png │ │ ├── Square89x89Logo.png │ │ ├── StoreLogo.png │ │ ├── icon.icns │ │ ├── icon.ico │ │ └── icon.png │ ├── src │ │ ├── lib.rs │ │ └── main.rs │ └── tauri.conf.json ├── src │ ├── App.css │ ├── App.tsx │ ├── assets │ │ └── react.svg │ ├── components │ │ ├── MicrophoneButton.css │ │ ├── MicrophoneButton.tsx │ │ ├── SpeakingAnimation.css │ │ ├── SpeakingAnimation.tsx │ │ ├── StatusDisplay.css │ │ ├── StatusDisplay.tsx │ │ ├── TranscriptDisplay.css │ │ ├── TranscriptDisplay.tsx │ │ └── icons │ │ │ ├── MicIcon.tsx │ │ │ ├── ResetIcon.tsx │ │ │ └── StopIcon.tsx │ ├── index.css │ ├── lib │ │ ├── api.ts │ │ ├── audioUtils.ts │ │ └── store.ts │ ├── main.tsx │ └── vite-env.d.ts ├── tsconfig.json ├── tsconfig.node.json └── vite.config.ts └── shared ├── config.yaml ├── logs ├── .keep ├── service_asr.log ├── service_llm.log ├── service_orchestrator.log └── service_tts_streaming.log └── models └── .keep /.env: -------------------------------------------------------------------------------- 1 | FRONTEND_DEV_PORT=1420 2 | ORCHESTRATOR_PORT=5000 3 | ASR_PORT=5001 4 | LLM_PORT=5002 5 | TTS_PORT=5003 6 | 7 | CACHE_DIR=/cache 8 | 9 | USE_GPU=true 10 | NVIDIA_VISIBLE_DEVICES=all 11 | 12 | HUGGING_FACE_TOKEN=hf_changeME ( change to download models ) 13 | 14 | LOG_LEVEL=info 15 | PYTHONUNBUFFERED=1 16 | 17 | TTS_SPEAKER_ID=4 18 | CSM_MAX_SEQ_LEN=65536 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sesame CSM Voice Assistant 2 | 3 | ## Overview 4 | A high-performance, local voice assistant with real-time transcription, LLM reasoning, and text-to-speech. Runs fully offline after setup and features Sesame CSM for expressive speech synthesis. Real-time factor: 0.6x with NVIDIA 4070 Ti Super. 5 | 6 | ## Features 7 | - Real-time Speech-to-Text using `distil-whisper` 8 | - On-device LLM using Llama 3.2 1B 9 | - Natural TTS via Sesame CSM (`senstella/csm-expressiva-1b`) 10 | - Desktop GUI with Tauri/React 11 | - Conversation history and speaking animations 12 | - GPU acceleration with CUDA 13 | - Modular Docker-based backend 14 | 15 | ## Tech Stack 16 | - **Frontend**: Tauri 2.5.1, React 18+, TypeScript 17 | - **Backend**: Python 3.10, FastAPI, Uvicorn 18 | - **Models**: `distil-whisper` (large-v3.5), Llama 3.2 1B (GGUF), Sesame CSM 19 | 20 | ## Requirements 21 | - NVIDIA GPU: 8GB+ VRAM 22 | - 32GB RAM 23 | - Docker Desktop 24 | - NVIDIA GPU Drivers (CUDA 12.1+) 25 | - NVIDIA Container Toolkit 26 | - Node.js & npm (v18+) 27 | - Rust & Cargo 28 | - Hugging Face access to Llama 3.2 1B 29 | 30 | ## Setup 31 | 1. **Prerequisites**: 32 | - Install Docker Desktop and ensure it's running 33 | - Install Rust, Tauri, and NVIDIA Container Toolkit 34 | - Request access to Llama 3.2 1B on Hugging Face 35 | 36 | 2. **Configuration**: 37 | - Edit `.env` file and set `HUGGING_FACE_TOKEN=hf_yourTokenHere` 38 | 39 | 3. **Backend**: 40 | - Build: `docker compose build` 41 | - Run: `docker compose up -d` 42 | 43 | 4. **Frontend**: 44 | - Install dependencies: `cd frontend && npm install && npm install uuid` 45 | - Start: `npm run tauri dev` 46 | 47 | ## Usage 48 | - Add your huggingface token and request access to the models (need to add links) 49 | - Build backend: `docker compose build` 50 | - Start backend: `docker compose up -d` 51 | - Build frontend: `npm install && npm install uuid` 52 | - Start frontend: `cd frontend && npm run tauri dev` 53 | - View logs: `docker compose logs -f` 54 | - Stop: `docker compose down` 55 | -------------------------------------------------------------------------------- /backend/asr/Dockerfile: -------------------------------------------------------------------------------- 1 | # backend/asr/Dockerfile 2 | # --- Use CUDA base image that includes cuDNN 8 --- 3 | FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 4 | 5 | ARG PYTHON_VERSION=3.10 6 | ARG USER=appuser 7 | ARG GROUP=appgroup 8 | ARG UID=1000 9 | ARG GID=1000 10 | 11 | WORKDIR /app 12 | 13 | ENV PYTHONUNBUFFERED=1 14 | ENV DEBIAN_FRONTEND=noninteractive 15 | ENV ASR_PORT=5001 16 | ENV CACHE_DIR=/cache 17 | ENV CONFIG_PATH=/app/config.yaml 18 | ENV LOG_FILE_BASE=/app/logs/service 19 | ENV LOG_LEVEL=info 20 | ENV USE_GPU=auto 21 | ENV HF_HOME=${CACHE_DIR}/huggingface 22 | ENV TRANSFORMERS_CACHE=${CACHE_DIR}/huggingface 23 | ENV PATH="/app/.venv/bin:$PATH" 24 | 25 | # Create non-root user/group 26 | RUN groupadd -g ${GID} ${GROUP} && \ 27 | useradd -u ${UID} -g ${GID} -ms /bin/bash ${USER} 28 | 29 | # Install system dependencies 30 | RUN apt-get update && apt-get install -y --no-install-recommends \ 31 | python${PYTHON_VERSION} \ 32 | python${PYTHON_VERSION}-venv \ 33 | python${PYTHON_VERSION}-dev \ 34 | build-essential \ 35 | libsndfile1 \ 36 | ffmpeg \ 37 | curl \ 38 | && apt-get clean && rm -rf /var/lib/apt/lists/* 39 | 40 | # Set up Python virtual environment 41 | RUN python${PYTHON_VERSION} -m venv /app/.venv 42 | 43 | # Create cache directory and set permissions 44 | RUN mkdir -p ${CACHE_DIR}/huggingface && chown ${USER}:${GROUP} ${CACHE_DIR} ${CACHE_DIR}/huggingface 45 | 46 | # Copy requirements and install Python packages into venv 47 | COPY requirements.txt . 48 | RUN . /app/.venv/bin/activate && \ 49 | pip install --no-cache-dir --upgrade pip && \ 50 | pip install --no-cache-dir -r requirements.txt 51 | 52 | # Copy application code 53 | COPY app.py . 54 | 55 | # Create log directory and set permissions for app directory 56 | RUN mkdir -p /app/logs && chown ${USER}:${GROUP} /app/logs 57 | RUN chown -R ${USER}:${GROUP} /app 58 | 59 | # Switch to non-root user 60 | USER ${USER} 61 | 62 | # Expose the application port 63 | EXPOSE ${ASR_PORT} 64 | 65 | # Define health check 66 | HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \ 67 | CMD curl --fail http://localhost:${ASR_PORT}/health || exit 1 68 | 69 | # Default command to run the application 70 | CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${ASR_PORT} --log-level ${LOG_LEVEL}"] -------------------------------------------------------------------------------- /backend/asr/app.py: -------------------------------------------------------------------------------- 1 | # backend/asr/app.py 2 | import os 3 | import logging 4 | import io 5 | import time 6 | import numpy as np 7 | import soundfile as sf 8 | import torch 9 | import yaml 10 | from contextlib import asynccontextmanager 11 | from pathlib import Path 12 | 13 | # --- ADD Pydub IMPORT --- 14 | try: 15 | from pydub import AudioSegment 16 | pydub_available = True 17 | except ImportError: 18 | logging.warning("pydub library not found. Audio conversion will not be available.") 19 | pydub_available = False 20 | # --- END ADD Pydub IMPORT --- 21 | 22 | from fastapi import FastAPI, HTTPException, UploadFile, File, status 23 | from fastapi.responses import JSONResponse 24 | from faster_whisper import WhisperModel 25 | from huggingface_hub import login, logout 26 | from dotenv import load_dotenv 27 | from typing import Optional, Dict, Any 28 | 29 | # --- Constants & Environment Loading --- 30 | load_dotenv() 31 | CONFIG_PATH = os.getenv('CONFIG_PATH', '/app/config.yaml') 32 | CACHE_DIR = Path(os.getenv('CACHE_DIR', '/cache')) 33 | HF_CACHE_DIR = Path(os.getenv('HF_HOME', CACHE_DIR / "huggingface")) 34 | LOG_FILE_BASE = os.getenv('LOG_FILE_BASE', '/app/logs/service') 35 | LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper() 36 | USE_GPU_ENV = os.getenv('USE_GPU', 'auto').lower() 37 | HF_TOKEN = os.getenv('HUGGING_FACE_TOKEN') 38 | SERVICE_NAME = "asr" 39 | LOG_PATH = f"{LOG_FILE_BASE}_{SERVICE_NAME}.log" 40 | 41 | # --- Logging Setup --- 42 | os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) 43 | HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) 44 | logging.basicConfig( level=LOG_LEVEL, format="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s", handlers=[ logging.StreamHandler(), logging.FileHandler(LOG_PATH) ]) 45 | logger = logging.getLogger(SERVICE_NAME) 46 | 47 | # --- Global Variables --- 48 | asr_model: Optional[WhisperModel] = None 49 | asr_config: Dict[str, Any] = {} 50 | effective_device: str = "cpu" 51 | model_load_info: Dict[str, Any] = {"status": "pending"} 52 | 53 | # --- Configuration Loading --- 54 | def load_configuration(): 55 | global asr_config, effective_device 56 | try: 57 | logger.info(f"Loading configuration from: {CONFIG_PATH}") 58 | if not os.path.exists(CONFIG_PATH): raise FileNotFoundError(f"Config file not found at {CONFIG_PATH}") 59 | with open(CONFIG_PATH, 'r') as f: config = yaml.safe_load(f) 60 | if not config or 'asr' not in config: raise ValueError("Config file is empty or missing 'asr' section.") 61 | asr_config = config['asr'] 62 | if not asr_config.get('model_name'): raise ValueError("Missing 'model_name' in asr configuration.") 63 | config_device = asr_config.get('device', 'auto') 64 | cuda_available = torch.cuda.is_available() 65 | logger.info(f"CUDA available: {cuda_available}, Torch version: {torch.__version__}") 66 | logger.info(f"USE_GPU environment variable: '{USE_GPU_ENV}'") 67 | logger.info(f"Configured ASR device: '{config_device}'") 68 | if USE_GPU_ENV == 'false': effective_device = "cpu"; logger.info("GPU usage explicitly disabled via environment variable.") 69 | elif config_device == "cpu": effective_device = "cpu"; logger.info("ASR device configured to CPU.") 70 | elif cuda_available and (config_device == "auto" or config_device.startswith("cuda")): effective_device = config_device if config_device.startswith("cuda") else "cuda"; logger.info(f"Attempting to use CUDA device '{effective_device}' for ASR.") 71 | else: effective_device = "cpu"; logger.warning(f"CUDA device '{config_device}' requested but not available or USE_GPU=false. Falling back to CPU.") if (config_device == "auto" or config_device.startswith("cuda")) and USE_GPU_ENV != 'false' else logger.info("Using CPU for ASR.") 72 | asr_config['effective_device'] = effective_device 73 | logger.info(f"ASR effective device set to: {effective_device}") 74 | except (FileNotFoundError, ValueError) as e: logger.critical(f"Configuration error: {e}. ASR service cannot start correctly.", exc_info=True); asr_config = {}; model_load_info.update({"status": "error", "error": f"Configuration error: {e}"}) 75 | except Exception as e: logger.critical(f"Unexpected error loading configuration: {e}. ASR service cannot start correctly.", exc_info=True); asr_config = {}; model_load_info.update({"status": "error", "error": f"Unexpected config error: {e}"}) 76 | 77 | # --- Model Loading / Downloading --- 78 | def load_asr_model(): 79 | global asr_model, model_load_info 80 | if not asr_config: logger.error("Skipping model load due to configuration errors."); return 81 | model_name = asr_config.get('model_name'); compute_type = asr_config.get('compute_type', 'int8'); device_to_load = asr_config.get('effective_device', 'cpu'); cache_path = HF_CACHE_DIR 82 | logger.info(f"Attempting to load/download ASR model: {model_name}"); logger.info(f"Target device: {device_to_load}, Compute type: {compute_type}"); logger.info(f"Using cache directory (HF_HOME): {cache_path}") 83 | model_load_info = {"status": "loading", "model_name": model_name, "device": device_to_load, "compute_type": compute_type}; start_time = time.monotonic() 84 | try: 85 | if HF_TOKEN: logger.info("Logging into Hugging Face Hub using provided token."); login(token=HF_TOKEN) 86 | asr_model = WhisperModel( model_name, device=device_to_load, compute_type=compute_type, download_root=str(cache_path)) 87 | load_time = time.monotonic() - start_time; model_load_info.update({"status": "loaded", "load_time_s": round(load_time, 2)}); logger.info(f"ASR Model '{model_name}' loaded successfully in {load_time:.2f} seconds.") 88 | except Exception as e: logger.critical(f"FATAL: Failed to load or download ASR model '{model_name}': {e}", exc_info=True); asr_model = None; load_time = time.monotonic() - start_time; model_load_info.update({"status": "error", "error": str(e), "load_time_s": round(load_time, 2)}); raise RuntimeError(f"ASR model loading failed: {e}") from e 89 | finally: 90 | if HF_TOKEN: 91 | try: logout(); logger.info("Logged out from Hugging Face Hub.") 92 | except Exception as logout_err: logger.warning(f"Could not log out from Hugging Face Hub: {logout_err}") 93 | 94 | # --- FastAPI Lifespan Event Handler --- 95 | @asynccontextmanager 96 | async def lifespan(app: FastAPI): 97 | logger.info(f"{SERVICE_NAME.upper()} Service starting up..."); model_load_info = {"status": "initializing"} 98 | load_configuration() 99 | if asr_config: 100 | try: load_asr_model() 101 | except RuntimeError as e: logger.critical(f"Lifespan startup failed due to model load error: {e}") 102 | else: logger.error("Skipping model load during startup due to config errors."); model_load_info = {"status": "error", "error": "Configuration failed"} 103 | yield 104 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down..."); global asr_model 105 | if asr_model: logger.info("Releasing ASR model resources..."); del asr_model; asr_model = None 106 | if effective_device.startswith("cuda"): 107 | try: torch.cuda.empty_cache(); logger.info("Cleared PyTorch CUDA cache.") 108 | except Exception as e: logger.warning(f"Could not clear CUDA cache during shutdown: {e}") 109 | logger.info("ASR Service shutdown complete.") 110 | 111 | # --- FastAPI App Initialization --- 112 | app = FastAPI(lifespan=lifespan, title="ASR Service", version="1.1.0") 113 | 114 | # --- Audio Preprocessing Helper --- 115 | async def preprocess_audio(audio_bytes: bytes, filename: str) -> np.ndarray: 116 | """Reads audio bytes, converts to WAV using pydub if necessary, then mono float32 numpy array at 16kHz.""" 117 | logger.debug(f"Preprocessing audio from '{filename}' ({len(audio_bytes)} bytes)") 118 | start_time = time.monotonic() 119 | input_stream = io.BytesIO(audio_bytes) 120 | processed_stream = input_stream # Start with the original stream 121 | 122 | # --- CONVERSION STEP using pydub --- 123 | if pydub_available: 124 | try: 125 | logger.debug("Attempting to load audio with pydub...") 126 | # Load audio segment using pydub, format='*' tells it to detect 127 | audio_segment = AudioSegment.from_file(input_stream) 128 | 129 | # Ensure minimum frame rate and export to WAV format in memory 130 | # Whisper expects 16kHz, so set frame rate here if conversion is happening 131 | if audio_segment.frame_rate < 16000: 132 | logger.warning(f"Input audio frame rate {audio_segment.frame_rate}Hz is low, setting to 16000Hz during conversion.") 133 | audio_segment = audio_segment.set_frame_rate(16000) 134 | 135 | wav_stream = io.BytesIO() 136 | audio_segment.export(wav_stream, format="wav") 137 | wav_stream.seek(0) # Rewind stream to the beginning 138 | processed_stream = wav_stream # Use the converted WAV stream 139 | logger.info(f"Successfully converted audio '{filename}' to WAV format using pydub.") 140 | 141 | except Exception as pydub_err: 142 | logger.warning(f"Pydub failed to load/convert '{filename}': {pydub_err}. Falling back to soundfile with original data.", exc_info=True) 143 | # Reset stream to original bytes if pydub fails 144 | processed_stream = io.BytesIO(audio_bytes) 145 | processed_stream.seek(0) # Ensure stream is at the beginning 146 | else: 147 | logger.warning("Pydub not available, attempting direct load with soundfile.") 148 | # --- END CONVERSION STEP --- 149 | 150 | try: 151 | # Now read using soundfile (should be WAV data if conversion succeeded) 152 | audio_data, samplerate = sf.read(processed_stream, dtype='float32', always_2d=True) 153 | logger.debug(f"Read audio via soundfile: SR={samplerate}Hz, Shape={audio_data.shape}, Duration={audio_data.shape[0]/samplerate:.2f}s") 154 | 155 | # Convert to mono by averaging channels if stereo 156 | if audio_data.shape[1] > 1: 157 | audio_data = np.mean(audio_data, axis=1) 158 | logger.debug(f"Converted stereo to mono. New shape: {audio_data.shape}") 159 | else: 160 | audio_data = audio_data[:, 0] # Squeeze mono channel dim 161 | 162 | # Resample to 16kHz if necessary (e.g., if pydub wasn't used or original WAV wasn't 16k) 163 | if samplerate != 16000: 164 | logger.info(f"Resampling audio from {samplerate}Hz to 16000Hz using torchaudio...") 165 | try: 166 | import torchaudio 167 | import torchaudio.transforms as T 168 | audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) 169 | resampler = T.Resample(orig_freq=samplerate, new_freq=16000).to('cpu') 170 | resampled_tensor = resampler(audio_tensor.cpu()) 171 | audio_data = resampled_tensor.squeeze(0).numpy() 172 | logger.info(f"Resampled audio to 16kHz. New shape: {audio_data.shape}") 173 | samplerate = 16000 174 | except ImportError: 175 | logger.error("Torchaudio is required for resampling but not found/installed.") 176 | raise HTTPException(status_code=501, detail="Audio resampling required but torchaudio not available.") 177 | except Exception as resample_err: 178 | logger.error(f"Error during resampling: {resample_err}", exc_info=True) 179 | raise HTTPException(status_code=500, detail=f"Failed to resample audio: {resample_err}") 180 | 181 | if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) 182 | 183 | preprocess_time = time.monotonic() - start_time 184 | logger.debug(f"Audio preprocessing completed in {preprocess_time:.3f} seconds.") 185 | return audio_data 186 | 187 | except sf.SoundFileError as sf_err: 188 | # This error likely means the original format was bad OR pydub failed AND the original was also bad 189 | logger.error(f"Soundfile error processing audio '{filename}' after potential conversion attempt: {sf_err}", exc_info=True) 190 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not read or decode audio file: {sf_err}") 191 | except Exception as e: 192 | logger.error(f"Unexpected error preprocessing audio '{filename}' after potential conversion: {e}", exc_info=True) 193 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal error processing audio: {e}") 194 | 195 | 196 | # --- API Endpoints --- 197 | @app.post("/transcribe") 198 | async def transcribe_audio_endpoint(audio: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, WebM, Ogg, etc.)")): 199 | if not asr_model or model_load_info.get("status") != "loaded": 200 | error_detail = model_load_info.get("error", "Model not available or failed to load.") 201 | logger.error(f"Transcription request failed: {error_detail}") 202 | raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"ASR model unavailable: {error_detail}") 203 | 204 | req_start_time = time.monotonic() 205 | logger.info(f"Received transcription request for file: {audio.filename} ({audio.content_type})") 206 | 207 | try: 208 | audio_bytes = await audio.read() 209 | if not audio_bytes: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Received empty audio file.") 210 | 211 | # Preprocess audio: Handles conversion, mono, 16kHz, float32 numpy array 212 | audio_np = await preprocess_audio(audio_bytes, audio.filename or "uploaded_audio") 213 | audio_duration_sec = len(audio_np) / 16000.0 214 | 215 | beam_size = asr_config.get('beam_size', 5) 216 | language_code = asr_config.get('language', None) # Let Whisper detect by default unless specified 217 | vad_filter = asr_config.get('vad_filter', True) 218 | vad_parameters = asr_config.get('vad_parameters', {"threshold": 0.5}) 219 | 220 | logger.info(f"Starting transcription (beam={beam_size}, lang={language_code or 'auto'}, vad={vad_filter})...") 221 | transcribe_start_time = time.monotonic() 222 | 223 | segments_generator, info = asr_model.transcribe( audio_np, beam_size=beam_size, language=language_code, vad_filter=vad_filter, vad_parameters=vad_parameters) 224 | 225 | transcribed_text_parts = []; segment_count = 0 226 | try: 227 | for segment in segments_generator: transcribed_text_parts.append(segment.text); segment_count += 1 228 | except Exception as seg_err: logger.error(f"Error processing transcription segment: {seg_err}", exc_info=True) 229 | 230 | transcribed_text = " ".join(transcribed_text_parts).strip() 231 | transcribe_time = time.monotonic() - transcribe_start_time 232 | total_req_time = time.monotonic() - req_start_time 233 | 234 | logger.info(f"Transcription completed in {transcribe_time:.3f}s ({segment_count} segments). Total request time: {total_req_time:.3f}s.") 235 | logger.info(f"Detected lang: {info.language} (Prob: {info.language_probability:.2f}), Audio duration: {info.duration:.2f}s (processed: {audio_duration_sec:.2f}s)") 236 | if len(transcribed_text) < 200: logger.debug(f"Transcription result: '{transcribed_text}'") 237 | else: logger.debug(f"Transcription result (truncated): '{transcribed_text[:100]}...{transcribed_text[-100:]}'") 238 | 239 | return JSONResponse( status_code=status.HTTP_200_OK, content={ "text": transcribed_text, "language": info.language, "language_probability": info.language_probability, "audio_duration_ms": round(info.duration * 1000), "processing_time_ms": round(total_req_time * 1000), "transcription_time_ms": round(transcribe_time * 1000), }) 240 | 241 | except HTTPException as http_exc: logger.warning(f"HTTP exception during transcription: {http_exc.status_code} - {http_exc.detail}"); raise http_exc 242 | except Exception as e: logger.error(f"Unexpected error during transcription request for {audio.filename}: {e}", exc_info=True); raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error: {e}") 243 | 244 | @app.get("/health", status_code=status.HTTP_200_OK) 245 | async def health_check(): 246 | current_status = model_load_info.get("status", "unknown") 247 | response_content = { "service": SERVICE_NAME, "status": "ok" if current_status == "loaded" else "error", "model_status": current_status, "model_name": model_load_info.get("model_name", asr_config.get('model_name', 'N/A')), "device": model_load_info.get("device", effective_device), "compute_type": model_load_info.get("compute_type", asr_config.get('compute_type', 'N/A')), "load_info": model_load_info } 248 | if current_status != "loaded": return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=response_content) 249 | return response_content 250 | 251 | # --- Main Execution Guard (for local debugging) --- 252 | if __name__ == "__main__": 253 | import uvicorn 254 | logger.info(f"Starting {SERVICE_NAME.upper()} service directly via __main__...") 255 | logger.info("Running startup sequence...") 256 | model_load_info = {"status": "initializing"} 257 | load_configuration() 258 | if asr_config: 259 | try: load_asr_model() 260 | except RuntimeError as e: logger.critical(f"Direct run failed: Model load error: {e}"); exit(1) 261 | else: logger.critical("Direct run failed: Configuration error."); exit(1) 262 | port = int(os.getenv('ASR_PORT', 5001)); log_level_param = LOG_LEVEL.lower() 263 | logger.info(f"Launching Uvicorn on port {port} with log level {log_level_param}...") 264 | uvicorn.run("app:app", host="0.0.0.0", port=port, log_level=log_level_param, reload=False) 265 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down (direct run)...") -------------------------------------------------------------------------------- /backend/asr/requirements.txt: -------------------------------------------------------------------------------- 1 | # backend/asr/requirements.txt 2 | fastapi>=0.110.0,<0.112.0 3 | uvicorn[standard]>=0.29.0,<0.30.0 4 | python-dotenv>=1.0.0 5 | PyYAML>=6.0 6 | pydantic>=2.0.0,<3.0.0 7 | 8 | # ASR Core & Dependencies 9 | faster-whisper>=1.0.1,<1.1.0 10 | # Specify torch version compatible with TTS dependency (CSM requires 2.4.0) 11 | torch==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu121 12 | torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu121 13 | ctranslate2>=4.0.0,<4.4.0 14 | 15 | # Audio Handling 16 | soundfile>=0.12.1 17 | numpy>=1.24.0,<2.0.0 18 | pydub>=0.25.0 19 | 20 | # Model Downloading & Cache Management - Allow newer version required by transformers 4.49.0 21 | huggingface_hub>=0.20.0,<1.0 22 | 23 | # Health checks / internal comms (optional) 24 | httpx>=0.27.0 25 | 26 | # Match TTS requirement for consistency (indirect dependencies) 27 | transformers==4.49.0 28 | -------------------------------------------------------------------------------- /backend/llm/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.1-runtime-ubuntu22.04 2 | 3 | ARG PYTHON_VERSION=3.10 4 | ARG USER=appuser 5 | ARG GROUP=appgroup 6 | ARG UID=1000 7 | ARG GID=1000 8 | 9 | WORKDIR /app 10 | 11 | ENV PYTHONUNBUFFERED=1 12 | ENV DEBIAN_FRONTEND=noninteractive 13 | ENV LLM_PORT=5002 14 | ENV CACHE_DIR=/cache 15 | ENV CONFIG_PATH=/app/config.yaml 16 | ENV LOG_FILE_BASE=/app/logs/service 17 | ENV LOG_LEVEL=info 18 | ENV USE_GPU=auto 19 | ENV CMAKE_ARGS="-DLLAMA_CUBLAS=on" 20 | ENV FORCE_CMAKE=1 21 | ENV HF_HOME=${CACHE_DIR}/huggingface 22 | ENV TRANSFORMERS_CACHE=${CACHE_DIR}/huggingface 23 | ENV PATH="/app/.venv/bin:$PATH" 24 | 25 | RUN groupadd -g ${GID} ${GROUP} && \ 26 | useradd -u ${UID} -g ${GID} -ms /bin/bash ${USER} 27 | 28 | RUN apt-get update && apt-get install -y --no-install-recommends \ 29 | python${PYTHON_VERSION} \ 30 | python${PYTHON_VERSION}-venv \ 31 | python${PYTHON_VERSION}-dev \ 32 | build-essential \ 33 | cmake \ 34 | pkg-config \ 35 | curl \ 36 | && apt-get clean && rm -rf /var/lib/apt/lists/* 37 | 38 | RUN python${PYTHON_VERSION} -m venv /app/.venv 39 | 40 | RUN mkdir -p ${CACHE_DIR}/huggingface && chown ${USER}:${GROUP} ${CACHE_DIR} ${CACHE_DIR}/huggingface 41 | 42 | COPY requirements.txt . 43 | RUN . /app/.venv/bin/activate && \ 44 | pip install --no-cache-dir --upgrade pip && \ 45 | echo "Attempting pip install with CMAKE_ARGS='${CMAKE_ARGS}' FORCE_CMAKE='${FORCE_CMAKE}'" && \ 46 | (CMAKE_ARGS="${CMAKE_ARGS}" FORCE_CMAKE="${FORCE_CMAKE}" pip install --no-cache-dir -r /app/requirements.txt) || \ 47 | (echo "WARNING: llama-cpp-python GPU build failed. Falling back to CPU-only build..." && \ 48 | CMAKE_ARGS="" FORCE_CMAKE=0 pip install --no-cache-dir -r /app/requirements.txt) 49 | 50 | COPY app.py . 51 | 52 | RUN mkdir -p /app/logs && chown ${USER}:${GROUP} /app/logs 53 | RUN chown -R ${USER}:${GROUP} /app 54 | 55 | USER ${USER} 56 | 57 | EXPOSE ${LLM_PORT} 58 | 59 | HEALTHCHECK --interval=30s --timeout=15s --start-period=180s --retries=3 \ 60 | CMD curl --fail http://localhost:${LLM_PORT}/health || exit 1 61 | 62 | CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${LLM_PORT} --log-level ${LOG_LEVEL}"] -------------------------------------------------------------------------------- /backend/llm/app.py: -------------------------------------------------------------------------------- 1 | # backend/llm/app.py 2 | import os 3 | import logging 4 | import time 5 | import yaml 6 | from contextlib import asynccontextmanager 7 | from pathlib import Path 8 | 9 | from fastapi import FastAPI, HTTPException, status 10 | from fastapi.responses import JSONResponse 11 | from pydantic import BaseModel, Field 12 | from typing import List, Dict, Optional, Any 13 | from llama_cpp import Llama, LlamaGrammar # type: ignore 14 | from huggingface_hub import hf_hub_download, login, logout 15 | from dotenv import load_dotenv 16 | 17 | # --- Constants & Environment Loading --- 18 | load_dotenv() 19 | 20 | CONFIG_PATH = os.getenv('CONFIG_PATH', '/app/config.yaml') 21 | CACHE_DIR = Path(os.getenv('CACHE_DIR', '/cache')) 22 | # Specific subdirectory within the cache for downloaded GGUF models 23 | LLM_CACHE_DIR = CACHE_DIR / "llm" 24 | 25 | LOG_FILE_BASE = os.getenv('LOG_FILE_BASE', '/app/logs/service') 26 | LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper() 27 | USE_GPU_ENV = os.getenv('USE_GPU', 'auto').lower() 28 | HF_TOKEN = os.getenv('HUGGING_FACE_TOKEN') 29 | 30 | SERVICE_NAME = "llm" 31 | LOG_PATH = f"{LOG_FILE_BASE}_{SERVICE_NAME}.log" 32 | 33 | # --- Logging Setup --- 34 | os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) 35 | # Ensure LLM cache directory exists 36 | LLM_CACHE_DIR.mkdir(parents=True, exist_ok=True) 37 | 38 | logging.basicConfig( 39 | level=LOG_LEVEL, 40 | format="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s", 41 | handlers=[ 42 | logging.StreamHandler(), 43 | logging.FileHandler(LOG_PATH) 44 | ] 45 | ) 46 | logger = logging.getLogger(SERVICE_NAME) 47 | 48 | # --- Global Variables --- 49 | llm_model: Optional[Llama] = None 50 | llm_config: Dict[str, Any] = {} 51 | effective_n_gpu_layers: int = 0 52 | model_load_info: Dict[str, Any] = {"status": "pending"} # Track loading status 53 | gguf_model_path: Optional[Path] = None # Store the resolved path to the GGUF file 54 | 55 | # --- Configuration Loading --- 56 | def load_configuration(): 57 | """Loads LLM settings from the YAML config file.""" 58 | global llm_config, effective_n_gpu_layers 59 | try: 60 | logger.info(f"Loading configuration from: {CONFIG_PATH}") 61 | if not os.path.exists(CONFIG_PATH): 62 | raise FileNotFoundError(f"Config file not found at {CONFIG_PATH}") 63 | with open(CONFIG_PATH, 'r') as f: 64 | config = yaml.safe_load(f) 65 | if not config or 'llm' not in config: 66 | raise ValueError("Config file is empty or missing 'llm' section.") 67 | 68 | llm_config = config['llm'] 69 | # Validate essential LLM config keys for downloading/loading 70 | if not llm_config.get('model_repo_id') or not llm_config.get('model_filename'): 71 | raise ValueError("Missing 'model_repo_id' or 'model_filename' in llm configuration.") 72 | 73 | # Determine effective GPU layers based on config and environment 74 | config_n_gpu_layers = llm_config.get('n_gpu_layers', 0) 75 | logger.info(f"Configured n_gpu_layers: {config_n_gpu_layers}") 76 | logger.info(f"USE_GPU environment variable: '{USE_GPU_ENV}'") 77 | 78 | if USE_GPU_ENV == 'false': 79 | effective_n_gpu_layers = 0 80 | logger.info("GPU usage explicitly disabled via environment variable (n_gpu_layers=0).") 81 | elif USE_GPU_ENV == 'auto' or USE_GPU_ENV == 'true': 82 | effective_n_gpu_layers = config_n_gpu_layers 83 | if effective_n_gpu_layers != 0: 84 | logger.info(f"GPU usage enabled/auto. Using configured n_gpu_layers: {effective_n_gpu_layers}. Availability checked at load time.") 85 | else: 86 | logger.info("GPU usage enabled/auto, but n_gpu_layers=0. Using CPU.") 87 | else: # Unrecognized USE_GPU value 88 | effective_n_gpu_layers = 0 89 | logger.warning(f"Unrecognized USE_GPU value '{USE_GPU_ENV}'. Assuming CPU (n_gpu_layers=0).") 90 | 91 | llm_config['effective_n_gpu_layers'] = effective_n_gpu_layers # Store effective value 92 | logger.info(f"LLM effective n_gpu_layers set to: {effective_n_gpu_layers}") 93 | 94 | except (FileNotFoundError, ValueError) as e: 95 | logger.critical(f"Configuration error: {e}. LLM service cannot start correctly.", exc_info=True) 96 | llm_config = {} # Prevent partial config use 97 | model_load_info.update({"status": "error", "error": f"Configuration error: {e}"}) 98 | except Exception as e: 99 | logger.critical(f"Unexpected error loading configuration: {e}. LLM service cannot start correctly.", exc_info=True) 100 | llm_config = {} 101 | model_load_info.update({"status": "error", "error": f"Unexpected config error: {e}"}) 102 | 103 | 104 | # --- Model Downloading --- 105 | def download_gguf_model_if_needed() -> Optional[Path]: 106 | """Checks for the GGUF model file and downloads it if missing.""" 107 | global gguf_model_path, model_load_info 108 | if not llm_config: # Config failed 109 | return None 110 | 111 | repo_id = llm_config['model_repo_id'] 112 | filename = llm_config['model_filename'] 113 | # Define the target path within our dedicated LLM cache subdir 114 | target_path = LLM_CACHE_DIR / filename 115 | gguf_model_path = target_path # Store globally for loading later 116 | 117 | if target_path.exists(): 118 | logger.info(f"GGUF model '{filename}' found locally at {target_path}.") 119 | model_load_info["download_status"] = "cached" 120 | return target_path 121 | 122 | logger.info(f"GGUF model '{filename}' not found locally. Attempting download from repo '{repo_id}'.") 123 | model_load_info.update({"status": "downloading", "repo_id": repo_id, "filename": filename}) 124 | start_time = time.monotonic() 125 | 126 | try: 127 | # Login to Hugging Face Hub if token is provided 128 | if HF_TOKEN: 129 | logger.info("Logging into Hugging Face Hub using provided token for download.") 130 | login(token=HF_TOKEN) 131 | 132 | # Download the specific file to our designated cache directory 133 | logger.info(f"Downloading {filename} from {repo_id} to {LLM_CACHE_DIR}...") 134 | downloaded_path_str = hf_hub_download( 135 | repo_id=repo_id, 136 | filename=filename, 137 | cache_dir=LLM_CACHE_DIR, # Use specific cache subdir 138 | local_dir=LLM_CACHE_DIR, # Force download into this dir 139 | local_dir_use_symlinks=False, # Avoid symlinks if problematic 140 | resume_download=True, 141 | # token=HF_TOKEN, # Handled by login() 142 | ) 143 | download_time = time.monotonic() - start_time 144 | downloaded_path = Path(downloaded_path_str) 145 | 146 | # Verify download path matches expected target path after potential internal caching by hf_hub 147 | if downloaded_path.resolve() != target_path.resolve(): 148 | # This case might occur if hf_hub places it in a 'snapshots' subdir within cache_dir 149 | logger.warning(f"Downloaded file path '{downloaded_path}' differs from target '{target_path}'. Ensuring file exists at target.") 150 | if not target_path.exists() and downloaded_path.exists(): 151 | # If target doesn't exist but download does, move it 152 | target_path.parent.mkdir(parents=True, exist_ok=True) 153 | downloaded_path.rename(target_path) 154 | elif not target_path.exists() and not downloaded_path.exists(): 155 | raise FileNotFoundError("Download reported success but target file not found.") 156 | # If target *does* exist, assume hf_hub handled it correctly (e.g., hard link or direct placement) 157 | 158 | logger.info(f"Successfully downloaded GGUF model '{filename}' to {target_path} in {download_time:.2f} seconds.") 159 | model_load_info["download_status"] = "downloaded" 160 | return target_path 161 | 162 | except Exception as e: 163 | logger.critical(f"FATAL: Failed to download GGUF model '{filename}' from '{repo_id}': {e}", exc_info=True) 164 | gguf_model_path = None # Ensure path is None on failure 165 | model_load_info.update({"status": "error", "download_status": "failed", "error": f"Download failed: {e}"}) 166 | # Re-raise critical error for Fail Fast strategy 167 | raise RuntimeError(f"GGUF model download failed: {e}") from e 168 | finally: 169 | if HF_TOKEN: 170 | try: logout() 171 | except Exception: pass # Ignore logout errors 172 | 173 | # --- Model Loading --- 174 | def load_llm_model(model_path: Path): 175 | """Loads the GGUF model using llama-cpp-python.""" 176 | global llm_model, model_load_info 177 | if not llm_config or not model_path or not model_path.exists(): 178 | error_msg = f"Cannot load LLM model - configuration missing, model path invalid, or file not found at '{model_path}'." 179 | logger.error(error_msg) 180 | model_load_info.update({"status": "error", "error": error_msg}) 181 | raise ValueError(error_msg) # Raise error to signal failure 182 | 183 | n_ctx = llm_config.get('n_ctx', 2048) 184 | chat_format = llm_config.get('chat_format', 'llama-3') # Default to llama-3 format 185 | n_gpu_layers_to_load = llm_config.get('effective_n_gpu_layers', 0) 186 | 187 | logger.info(f"Attempting to load LLM model from: {model_path}") 188 | logger.info(f"Parameters: n_gpu_layers={n_gpu_layers_to_load}, n_ctx={n_ctx}, chat_format={chat_format}") 189 | model_load_info["status"] = "loading_model" 190 | start_time = time.monotonic() 191 | 192 | try: 193 | llm_model = Llama( 194 | model_path=str(model_path), # llama-cpp expects string path 195 | n_gpu_layers=n_gpu_layers_to_load, 196 | n_ctx=n_ctx, 197 | chat_format=chat_format, 198 | verbose=LOG_LEVEL == 'DEBUG', 199 | # seed=1337, # Optional for reproducibility 200 | # n_batch=512, # Adjust based on VRAM/performance 201 | ) 202 | load_time = time.monotonic() - start_time 203 | # Check actual GPU layers used after load 204 | actual_gpu_layers = -999 # Placeholder 205 | try: 206 | # Accessing internal context details might change between versions 207 | if llm_model and llm_model.ctx and hasattr(llm_model.ctx, "n_gpu_layers"): 208 | actual_gpu_layers = llm_model.ctx.n_gpu_layers 209 | elif llm_model and hasattr(llm_model, 'model') and hasattr(llm_model.model, 'n_gpu_layers'): # Older attribute access 210 | actual_gpu_layers = llm_model.model.n_gpu_layers() 211 | except Exception as e: 212 | logger.warning(f"Could not determine actual GPU layers used: {e}") 213 | 214 | 215 | offload_status = f"requested={n_gpu_layers_to_load}, actual={actual_gpu_layers if actual_gpu_layers != -999 else 'unknown'}" 216 | logger.info(f"LLM Model '{model_path.name}' loaded successfully in {load_time:.2f}s. GPU Layer Status: {offload_status}") 217 | model_load_info.update({"status": "loaded", "load_time_s": round(load_time, 2), "actual_gpu_layers": actual_gpu_layers if actual_gpu_layers != -999 else None}) 218 | 219 | except Exception as e: 220 | logger.critical(f"FATAL: Failed to load LLM model from '{model_path}': {e}", exc_info=True) 221 | llm_model = None 222 | model_load_info.update({"status": "error", "error": f"Model load failed: {e}"}) 223 | # Re-raise critical error for Fail Fast strategy 224 | raise RuntimeError(f"LLM model loading failed: {e}") from e 225 | 226 | # --- FastAPI Lifespan Event Handler --- 227 | @asynccontextmanager 228 | async def lifespan(app: FastAPI): 229 | # Startup Sequence 230 | logger.info(f"{SERVICE_NAME.upper()} Service starting up...") 231 | model_load_info = {"status": "initializing"} 232 | load_configuration() # Step 1: Load config 233 | 234 | if llm_config: # Proceed only if config loaded okay 235 | downloaded_path = None 236 | try: 237 | downloaded_path = download_gguf_model_if_needed() # Step 2: Download if needed 238 | if downloaded_path: 239 | load_llm_model(downloaded_path) # Step 3: Load model 240 | else: 241 | # This case should be caught by exceptions in download function 242 | logger.critical("Model path not available after download check, cannot load model.") 243 | model_load_info.update({"status": "error", "error": "Model file unavailable after download check"}) 244 | raise RuntimeError("GGUF model path unavailable after download check") 245 | 246 | except RuntimeError as e: 247 | # Critical error during download or load, logged already. 248 | logger.critical(f"Lifespan startup failed due to critical error: {e}") 249 | # Let FastAPI start, healthcheck will fail 250 | else: 251 | logger.error("Skipping model download/load during startup due to config errors.") 252 | model_load_info = {"status": "error", "error": "Configuration failed"} 253 | 254 | yield # Application runs here 255 | 256 | # Shutdown Sequence 257 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down...") 258 | global llm_model 259 | if llm_model: 260 | logger.info("Releasing LLM model resources...") 261 | # Explicitly delete the object to trigger llama.cpp cleanup if implemented (__del__) 262 | del llm_model 263 | llm_model = None 264 | import gc 265 | gc.collect() # Encourage garbage collection 266 | logger.info("LLM Service shutdown complete.") 267 | 268 | 269 | # --- FastAPI App Initialization --- 270 | app = FastAPI(lifespan=lifespan, title="LLM Service", version="1.1.0") 271 | 272 | # --- Pydantic Models for API --- 273 | class Message(BaseModel): 274 | role: str = Field(..., pattern="^(system|user|assistant)$") 275 | content: str 276 | 277 | class GenerateRequest(BaseModel): 278 | messages: List[Message] 279 | temperature: Optional[float] = Field(None, gt=0.0, le=2.0) 280 | max_tokens: Optional[int] = Field(None, gt=0) 281 | top_p: Optional[float] = Field(None, gt=0.0, lt=1.0) 282 | # stream: Optional[bool] = False # Future use 283 | 284 | class GenerateResponse(BaseModel): 285 | role: str = "assistant" 286 | content: str 287 | model: str # Model repo/filename used 288 | usage: Dict[str, int] # Token usage stats 289 | 290 | # --- API Endpoints --- 291 | @app.post("/generate", response_model=GenerateResponse) 292 | async def generate_completion(request: GenerateRequest): 293 | """Generates chat completion using the loaded Llama GGUF model.""" 294 | if not llm_model or model_load_info.get("status") != "loaded": 295 | error_detail = model_load_info.get("error", "Model not available or failed to load.") 296 | logger.error(f"Generation request failed: {error_detail}") 297 | raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"LLM model unavailable: {error_detail}") 298 | 299 | logger.info(f"Received generation request with {len(request.messages)} messages.") 300 | if logger.isEnabledFor(logging.DEBUG): 301 | logger.debug(f"Request messages: {[msg.model_dump() for msg in request.messages]}") 302 | 303 | req_start_time = time.monotonic() 304 | 305 | # Get generation parameters from request or config defaults 306 | temperature = request.temperature if request.temperature is not None else llm_config.get('temperature', 0.7) 307 | max_tokens = request.max_tokens if request.max_tokens is not None else llm_config.get('max_tokens', 512) 308 | top_p = request.top_p if request.top_p is not None else llm_config.get('top_p', 0.9) 309 | stream = False # For non-streaming response 310 | 311 | messages_dict_list = [msg.model_dump() for msg in request.messages] 312 | 313 | try: 314 | logger.info(f"Generating chat completion (temp={temperature}, max_tokens={max_tokens}, top_p={top_p})...") 315 | generation_start_time = time.monotonic() 316 | 317 | completion = llm_model.create_chat_completion( 318 | messages=messages_dict_list, 319 | temperature=temperature, 320 | max_tokens=max_tokens, 321 | top_p=top_p, 322 | stream=stream, 323 | # stop=["<|eot_id|>"] # Usually handled by chat_format="llama-3" 324 | ) 325 | 326 | generation_time = time.monotonic() - generation_start_time 327 | total_req_time = time.monotonic() - req_start_time 328 | 329 | if not completion or 'choices' not in completion or not completion['choices']: 330 | logger.error("LLM generation returned empty/invalid completion object.") 331 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="LLM returned empty response") 332 | 333 | response_message = completion['choices'][0]['message'] 334 | response_content = response_message.get('content', '').strip() 335 | response_role = response_message.get('role', 'assistant') 336 | model_identifier = f"{llm_config.get('model_repo_id', '?')}/{llm_config.get('model_filename', '?')}" 337 | usage_stats = completion.get('usage', {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}) 338 | 339 | logger.info(f"Generation successful in {generation_time:.3f}s. Total request time: {total_req_time:.3f}s.") 340 | logger.info(f"Usage - Prompt: {usage_stats.get('prompt_tokens', 0)}, Completion: {usage_stats.get('completion_tokens', 0)}, Total: {usage_stats.get('total_tokens', 0)}") 341 | # Limit logging long responses unless DEBUG 342 | logger.debug(f"Generated response content (first 100 chars): '{response_content[:100]}...'") 343 | 344 | return GenerateResponse( 345 | role=response_role, 346 | content=response_content, 347 | model=model_identifier, 348 | usage=usage_stats 349 | ) 350 | 351 | except Exception as e: 352 | logger.error(f"LLM generation failed unexpectedly: {type(e).__name__} - {e}", exc_info=True) 353 | raise HTTPException( 354 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 355 | detail=f"Internal server error during LLM generation: {e}" 356 | ) 357 | 358 | @app.get("/health", status_code=status.HTTP_200_OK) 359 | async def health_check(): 360 | """Provides LLM service health status.""" 361 | current_status = model_load_info.get("status", "unknown") 362 | response_content = { 363 | "service": SERVICE_NAME, 364 | "status": "ok" if current_status == "loaded" else "error", 365 | "model_status": current_status, 366 | "model_repo_id": llm_config.get('model_repo_id', 'N/A'), 367 | "model_filename": llm_config.get('model_filename', 'N/A'), 368 | "model_file_path": str(gguf_model_path) if gguf_model_path else 'N/A', 369 | "gpu_layers_effective": llm_config.get('effective_n_gpu_layers', 'N/A'), 370 | "load_info": model_load_info # Detailed status/error/timing 371 | } 372 | 373 | if current_status != "loaded": 374 | return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=response_content) 375 | 376 | return response_content 377 | 378 | 379 | # --- Main Execution Guard (for local debugging) --- 380 | if __name__ == "__main__": 381 | import uvicorn 382 | logger.info(f"Starting {SERVICE_NAME.upper()} service directly via __main__...") 383 | # Manually run lifespan startup steps 384 | logger.info("Running startup sequence...") 385 | model_load_info = {"status": "initializing"} 386 | load_configuration() 387 | if llm_config: 388 | downloaded_path = None 389 | try: 390 | downloaded_path = download_gguf_model_if_needed() 391 | if downloaded_path: 392 | load_llm_model(downloaded_path) 393 | else: 394 | logger.critical("Direct run failed: Model file unavailable after download check.") 395 | exit(1) 396 | except RuntimeError as e: 397 | logger.critical(f"Direct run failed: Critical error during startup: {e}") 398 | exit(1) # Exit if model fails in direct run 399 | else: 400 | logger.critical("Direct run failed: Configuration error.") 401 | exit(1) 402 | 403 | # Launch server 404 | port = int(os.getenv('LLM_PORT', 5002)) 405 | log_level_param = LOG_LEVEL.lower() 406 | logger.info(f"Launching Uvicorn on port {port} with log level {log_level_param}...") 407 | uvicorn.run("app:app", host="0.0.0.0", port=port, log_level=log_level_param, reload=False) 408 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down (direct run)...") -------------------------------------------------------------------------------- /backend/llm/requirements.txt: -------------------------------------------------------------------------------- 1 | # backend/llm/requirements.txt 2 | fastapi>=0.110.0,<0.112.0 3 | uvicorn[standard]>=0.29.0,<0.30.0 4 | python-dotenv>=1.0.0 5 | PyYAML>=6.0 6 | pydantic>=2.0.0,<3.0.0 7 | 8 | # LLM Core - llama-cpp-python 9 | llama-cpp-python[server]>=0.2.75,<0.3.0 10 | 11 | # Model Downloading & Cache Management - Allow newer version needed for consistency 12 | huggingface_hub>=0.20.0,<1.0 13 | 14 | # Health checks / internal comms (optional) 15 | httpx>=0.27.0 -------------------------------------------------------------------------------- /backend/orchestrator/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | ARG PYTHON_VERSION=3.10 4 | ARG USER=appuser 5 | ARG GROUP=appgroup 6 | ARG UID=1000 7 | ARG GID=1000 8 | 9 | WORKDIR /app 10 | 11 | ENV PYTHONUNBUFFERED=1 12 | ENV DEBIAN_FRONTEND=noninteractive 13 | ENV ORCHESTRATOR_PORT=5000 14 | ENV ASR_SERVICE_URL="http://asr:5001" 15 | ENV LLM_SERVICE_URL="http://llm:5002" 16 | ENV TTS_SERVICE_URL="http://tts:5003" 17 | ENV CONFIG_PATH=/app/config.yaml 18 | ENV LOG_FILE_BASE=/app/logs/service 19 | ENV LOG_LEVEL=info 20 | ENV PATH="/app/.venv/bin:$PATH" 21 | 22 | RUN groupadd -g ${GID} ${GROUP} && \ 23 | useradd -u ${UID} -g ${GID} -ms /bin/bash ${USER} 24 | 25 | RUN apt-get update && apt-get install -y --no-install-recommends \ 26 | python3 \ 27 | python3-venv \ 28 | python3-dev \ 29 | build-essential \ 30 | curl \ 31 | && apt-get clean && rm -rf /var/lib/apt/lists/* 32 | 33 | RUN python3 -m venv /app/.venv 34 | 35 | COPY requirements.txt . 36 | RUN . /app/.venv/bin/activate && \ 37 | pip install --no-cache-dir --upgrade pip && \ 38 | pip install --no-cache-dir -r requirements.txt 39 | 40 | COPY orchestrator.py . 41 | 42 | RUN mkdir -p /app/logs && chown ${USER}:${GROUP} /app/logs 43 | RUN chown -R ${USER}:${GROUP} /app 44 | 45 | USER ${USER} 46 | 47 | EXPOSE ${ORCHESTRATOR_PORT} 48 | 49 | HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ 50 | CMD curl --fail http://localhost:${ORCHESTRATOR_PORT}/health || exit 1 51 | 52 | CMD ["sh", "-c", "uvicorn orchestrator:app --host 0.0.0.0 --port ${ORCHESTRATOR_PORT} --log-level ${LOG_LEVEL}"] -------------------------------------------------------------------------------- /backend/orchestrator/orchestrator.py: -------------------------------------------------------------------------------- 1 | # backend/orchestrator/orchestrator.py 2 | 3 | import os 4 | import logging 5 | import time 6 | import yaml 7 | import base64 8 | import asyncio 9 | import json # For websocket messages 10 | import uuid # For unique utterance IDs 11 | from contextlib import asynccontextmanager 12 | from collections import deque 13 | from urllib.parse import urlparse, urlunparse 14 | 15 | import httpx 16 | import websockets # Keep import for call_tts_service_ws 17 | from fastapi import FastAPI, HTTPException, UploadFile, File, status, Depends 18 | from fastapi.responses import JSONResponse 19 | from fastapi.middleware.cors import CORSMiddleware 20 | from pydantic import BaseModel, Field 21 | from dotenv import load_dotenv 22 | from typing import Optional, Dict, Any, List, Tuple 23 | 24 | # --- Constants & Environment Loading --- 25 | load_dotenv() 26 | 27 | CONFIG_PATH = os.getenv('CONFIG_PATH', '/app/config.yaml') 28 | LOG_FILE_BASE = os.getenv('LOG_FILE_BASE', '/app/logs/service') 29 | LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper() 30 | 31 | _ASR_BASE_URL_ENV = os.getenv('ASR_SERVICE_URL', f"http://asr:{os.getenv('ASR_PORT', 5001)}") 32 | _LLM_BASE_URL_ENV = os.getenv('LLM_SERVICE_URL', f"http://llm:{os.getenv('LLM_PORT', 5002)}") 33 | _TTS_BASE_URL_ENV = os.getenv('TTS_SERVICE_URL', f"http://tts:{os.getenv('TTS_PORT', 5003)}") 34 | 35 | SERVICE_NAME = "orchestrator" 36 | LOG_PATH = f"{LOG_FILE_BASE}_{SERVICE_NAME}.log" 37 | 38 | # --- Logging Setup --- 39 | os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) 40 | logging.basicConfig( 41 | level=LOG_LEVEL, 42 | format="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s", 43 | handlers=[ 44 | logging.StreamHandler(), 45 | logging.FileHandler(LOG_PATH) 46 | ] 47 | ) 48 | logger = logging.getLogger(SERVICE_NAME) 49 | 50 | # --- Global Variables & State --- 51 | orchestrator_config: Dict[str, Any] = {} 52 | api_endpoints: Dict[str, str] = {} 53 | conversation_history: deque = deque(maxlen=10) 54 | http_client: Optional[httpx.AsyncClient] = None 55 | 56 | # --- Configuration Loading --- 57 | def load_configuration(): 58 | """Loads Orchestrator and API endpoint settings from YAML.""" 59 | global orchestrator_config, api_endpoints, conversation_history 60 | default_api_endpoints = { 61 | 'asr': f"{_ASR_BASE_URL_ENV}/transcribe", 62 | 'llm': f"{_LLM_BASE_URL_ENV}/generate", 63 | 'tts_ws': f"ws://{urlparse(_TTS_BASE_URL_ENV).netloc}/synthesize_stream" # Keep for call_tts_service_ws 64 | } 65 | try: 66 | logger.info(f"Loading configuration from: {CONFIG_PATH}") 67 | # ... (rest of config loading is fine) ... 68 | if not os.path.exists(CONFIG_PATH): 69 | logger.warning(f"Config file not found at {CONFIG_PATH}. Using defaults.") 70 | config = {} 71 | else: 72 | with open(CONFIG_PATH, 'r') as f: 73 | config = yaml.safe_load(f) 74 | if not config: 75 | logger.warning("Config file is empty or invalid YAML. Using defaults.") 76 | config = {} 77 | 78 | orchestrator_config = config.get('orchestrator', {}) 79 | api_endpoints_cfg = config.get('api_endpoints', {}) 80 | api_endpoints['asr'] = api_endpoints_cfg.get('asr', default_api_endpoints['asr']) 81 | api_endpoints['llm'] = api_endpoints_cfg.get('llm', default_api_endpoints['llm']) 82 | api_endpoints['tts_ws'] = api_endpoints_cfg.get('tts_ws', default_api_endpoints['tts_ws']) # Resolve TTS WS URL 83 | 84 | logger.info(f"ASR Endpoint: {api_endpoints['asr']}") 85 | logger.info(f"LLM Endpoint: {api_endpoints['llm']}") 86 | logger.info(f"TTS WebSocket Endpoint Configured (for no-speech handling): {api_endpoints['tts_ws']}") 87 | max_hist = orchestrator_config.get('max_history_turns', 5) * 2 88 | if max_hist <= 0: max_hist = 2 89 | conversation_history = deque(maxlen=max_hist) 90 | logger.info(f"Conversation history max length set to {max_hist} messages ({max_hist // 2} turns).") 91 | if not orchestrator_config.get('system_prompt'): 92 | logger.warning("System prompt not found in config, using default.") 93 | orchestrator_config['system_prompt'] = "You are a helpful voice assistant." 94 | else: 95 | logger.info("Loaded system prompt from config.") 96 | except Exception as e: 97 | logger.error(f"Unexpected error loading configuration: {e}. Using defaults.", exc_info=True) 98 | api_endpoints = default_api_endpoints 99 | orchestrator_config = {'max_history_turns': 5, 'system_prompt': 'You are a helpful voice assistant.'} 100 | conversation_history = deque(maxlen=10) 101 | 102 | # --- Backend Service Interaction Helpers --- 103 | async def call_asr_service(audio_file: UploadFile) -> str: 104 | # ... (Keep existing implementation) ... 105 | if not http_client: raise RuntimeError("HTTP client not initialized") 106 | files = {'audio': (audio_file.filename, await audio_file.read(), audio_file.content_type)} 107 | request_url = api_endpoints['asr'] 108 | logger.info(f"Sending audio to ASR service at {request_url}") 109 | try: 110 | response = await http_client.post(request_url, files=files, timeout=30.0) 111 | response.raise_for_status() 112 | data = response.json() 113 | transcript = data.get('text', '').strip() 114 | if not transcript: logger.warning("ASR returned an empty transcript."); return "" 115 | logger.info(f"ASR Response: '{transcript[:100]}...'") 116 | return transcript 117 | except httpx.TimeoutException: logger.error(f"ASR request timed out to {request_url}"); raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="ASR service request timed out.") 118 | except httpx.RequestError as e: logger.error(f"ASR request error to {request_url}: {e}"); raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"ASR service request failed: {e}") 119 | except httpx.HTTPStatusError as e: 120 | logger.error(f"ASR service error: Status {e.response.status_code}, Response: {e.response.text[:500]}") 121 | detail = f"ASR service error ({e.response.status_code})"; backend_detail=None 122 | try: backend_detail = e.response.json().get("detail") 123 | except Exception: pass 124 | if backend_detail: detail += f": {backend_detail}" 125 | status_code = e.response.status_code if e.response.status_code >= 500 else status.HTTP_502_BAD_GATEWAY 126 | raise HTTPException(status_code=status_code, detail=detail) 127 | except Exception as e: logger.error(f"Unexpected error calling ASR: {e}", exc_info=True); raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal error communicating with ASR: {e}") 128 | 129 | 130 | async def call_llm_service(history: List[Dict[str, str]]) -> str: 131 | # ... (Keep existing implementation) ... 132 | if not http_client: raise RuntimeError("HTTP client not initialized") 133 | request_url = api_endpoints['llm'] 134 | payload = {"messages": history} 135 | logger.info(f"Sending {len(history)} messages to LLM service at {request_url}") 136 | logger.debug(f"LLM Payload: {payload}") 137 | try: 138 | response = await http_client.post(request_url, json=payload, timeout=60.0) 139 | response.raise_for_status() 140 | data = response.json() 141 | assistant_response = data.get('content', '').strip() 142 | if not assistant_response: logger.warning("LLM returned an empty response."); return "Sorry, I seem to be speechless right now." 143 | logger.info(f"LLM Response: '{assistant_response[:100]}...'") 144 | return assistant_response 145 | except httpx.TimeoutException: logger.error(f"LLM request timed out to {request_url}"); raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="LLM service request timed out.") 146 | except httpx.RequestError as e: logger.error(f"LLM request error to {request_url}: {e}"); raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM service request failed: {e}") 147 | except httpx.HTTPStatusError as e: 148 | logger.error(f"LLM service error: Status {e.response.status_code}, Response: {e.response.text[:500]}") 149 | detail = f"LLM service error ({e.response.status_code})"; backend_detail=None 150 | try: backend_detail = e.response.json().get("detail") 151 | except Exception: pass 152 | if backend_detail: detail += f": {backend_detail}" 153 | status_code = e.response.status_code if e.response.status_code >= 500 else status.HTTP_502_BAD_GATEWAY 154 | raise HTTPException(status_code=status_code, detail=detail) 155 | except Exception as e: logger.error(f"Unexpected error calling LLM: {e}", exc_info=True); raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal error communicating with LLM: {e}") 156 | 157 | 158 | # --- This function is now ONLY used for the no-speech case --- 159 | async def call_tts_service_ws(text: str, max_audio_length_ms: Optional[float] = None) -> bytes: 160 | """Calls the TTS service WebSocket endpoint to synthesize audio stream. 161 | ONLY intended for short, fixed responses like the no-speech case.""" 162 | # --- Add specific warning --- 163 | logger.warning("Orchestrator making direct TTS call (expected only for no-speech handling).") 164 | # --- End warning --- 165 | ws_url = api_endpoints['tts_ws'] 166 | utterance_id = str(uuid.uuid4()) # Still generate unique ID 167 | logger.info(f"[No-Speech TTS] Connecting to TTS WebSocket at {ws_url} for utterance_id: {utterance_id}") 168 | 169 | message_payload = { 170 | "type": "generate_chunk", 171 | "text_chunk": text, 172 | "utterance_id": utterance_id 173 | } 174 | if max_audio_length_ms is not None: 175 | message_payload["max_audio_length_ms"] = max_audio_length_ms 176 | logger.info(f"[No-Speech TTS] Requesting max audio length: {max_audio_length_ms}ms for {utterance_id}") 177 | else: 178 | logger.info(f"[No-Speech TTS] Sending text: '{text[:50]}...' for {utterance_id}") 179 | 180 | all_audio_bytes = bytearray() 181 | receive_timeout_seconds = 30.0 # Shorter timeout for no-speech 182 | websocket_connection = None 183 | 184 | try: 185 | # Use slightly shorter timeouts for this specific, short call 186 | async with websockets.connect( 187 | ws_url, open_timeout=15, close_timeout=10, ping_interval=None # No ping needed 188 | ) as websocket: 189 | websocket_connection = websocket 190 | logger.info(f"[No-Speech TTS] WebSocket connection established for {utterance_id}.") 191 | await websocket.send(json.dumps(message_payload)) 192 | logger.info(f"[No-Speech TTS] Sent 'generate_chunk' request for {utterance_id}.") 193 | 194 | last_message_time = time.monotonic() 195 | while True: 196 | try: 197 | wait_time = max(0, receive_timeout_seconds - (time.monotonic() - last_message_time)) 198 | if wait_time <= 0: 199 | if all_audio_bytes: logger.warning(f"[No-Speech TTS] WS receive timed out (inactivity). Assuming stream ended."); break 200 | else: logger.error(f"[No-Speech TTS] WS receive timed out (inactivity) with NO audio."); raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="TTS service timed out (no-speech).") 201 | 202 | message_json = await asyncio.wait_for(websocket.recv(), timeout=wait_time) 203 | last_message_time = time.monotonic() 204 | message = json.loads(message_json) 205 | msg_type = message.get("type") 206 | 207 | if msg_type == "audio_chunk": 208 | audio_b64 = message.get("audio_b64", "") 209 | if audio_b64: 210 | try: all_audio_bytes.extend(base64.b64decode(audio_b64)) 211 | except Exception as decode_err: logger.warning(f"[No-Speech TTS] Failed to decode chunk: {decode_err}") 212 | elif msg_type == "error": 213 | error_msg = message.get("message", "Unknown TTS error") 214 | logger.error(f"[No-Speech TTS] Received error from TTS WS: {error_msg}"); raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"TTS service error: {error_msg}") 215 | elif msg_type == "stream_end": logger.info(f"[No-Speech TTS] Received explicit 'stream_end'."); break 216 | # Ignore context_cleared for this specific call 217 | elif msg_type != "context_cleared": logger.warning(f"[No-Speech TTS] Received unknown message type '{msg_type}'.") 218 | 219 | except asyncio.TimeoutError: 220 | if all_audio_bytes: logger.warning(f"[No-Speech TTS] WS receive timed out (asyncio). Assuming stream ended."); break 221 | else: logger.error(f"[No-Speech TTS] WS receive timed out (asyncio) with NO audio."); raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="TTS service timed out (no-speech).") 222 | except websockets.exceptions.ConnectionClosedOK: logger.info(f"[No-Speech TTS] WS closed normally by server."); break 223 | except websockets.exceptions.ConnectionClosedError as e: logger.error(f"[No-Speech TTS] WS closed with error: {e}"); break # Break if we got some audio 224 | 225 | logger.info(f"[No-Speech TTS] Processing complete. Bytes: {len(all_audio_bytes)}.") 226 | # No context clear needed here as it's a one-off call 227 | return bytes(all_audio_bytes) 228 | 229 | # Keep specific exception handling for this call 230 | except websockets.exceptions.InvalidURI as e: logger.error(f"Invalid TTS WS URI: {ws_url}. Error: {e}"); raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Invalid TTS WS URI configured: {ws_url}") 231 | except websockets.exceptions.WebSocketException as e: logger.error(f"TTS WS connection failed: {e}"); raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Failed to connect to TTS service WS: {e}") 232 | except HTTPException as http_exc: raise http_exc 233 | except Exception as e: logger.error(f"Unexpected error in call_tts_service_ws: {e}", exc_info=True); raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal error during TTS call: {e}") 234 | 235 | 236 | # --- Conversation History Management --- 237 | def update_history(user_text: str, assistant_text: str): 238 | if conversation_history.maxlen is None or conversation_history.maxlen <= 0: logger.warning("Conv history maxlen invalid."); return 239 | if user_text: conversation_history.append({"role": "user", "content": user_text}) 240 | if assistant_text: conversation_history.append({"role": "assistant", "content": assistant_text}) 241 | logger.debug(f"History updated. Len: {len(conversation_history)}/{conversation_history.maxlen}") 242 | 243 | def get_formatted_history() -> List[Dict[str, str]]: 244 | system_prompt = orchestrator_config.get('system_prompt', '') 245 | history = [] 246 | if system_prompt: history.append({"role": "system", "content": system_prompt}) 247 | history.extend(list(conversation_history)) 248 | return history 249 | 250 | # --- FastAPI Lifespan & App Setup --- 251 | @asynccontextmanager 252 | async def lifespan(app: FastAPI): 253 | logger.info(f"{SERVICE_NAME.upper()} Service starting up...") 254 | load_configuration() 255 | global http_client 256 | timeout = httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0) 257 | limits = httpx.Limits(max_keepalive_connections=20, max_connections=100) 258 | http_client = httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) 259 | logger.info("HTTP client initialized.") 260 | # Removed sleep, relying on docker-compose depends_on: condition: service_healthy 261 | # await asyncio.sleep(5) 262 | await check_backend_services() # Run checks after client init 263 | yield 264 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down...") 265 | if http_client: await http_client.aclose(); logger.info("HTTP client closed.") 266 | logger.info("Orchestrator Service shutdown complete.") 267 | 268 | # --- check_backend_services includes retries --- 269 | async def check_backend_services(): 270 | if not http_client: return 271 | # Check ASR, LLM, and TTS HTTP health endpoints 272 | services_to_check = { 273 | "ASR": api_endpoints['asr'].replace('/transcribe', '/health'), 274 | "LLM": api_endpoints['llm'].replace('/generate', '/health'), 275 | "TTS_HTTP": urlunparse(('http', urlparse(api_endpoints['tts_ws']).netloc, '/health', '', '', '')), 276 | } 277 | logger.info("Checking backend service connectivity (with retries)...") 278 | all_services_ok = True # Track overall status 279 | max_retries = 3 280 | delay = 2.0 281 | 282 | for name, url in services_to_check.items(): 283 | service_ok = False # Track status for this specific service 284 | for attempt in range(max_retries): 285 | logger.info(f"Checking {name} at {url} (Attempt {attempt+1}/{max_retries})...") 286 | try: 287 | response = await http_client.get(url, timeout=5.0) 288 | if response.status_code < 400: 289 | logger.info(f"Backend service {name} health check successful (Status {response.status_code}).") 290 | service_ok = True 291 | break # Success for this service, move to next service 292 | else: 293 | logger.warning(f"Backend service {name} health check attempt {attempt+1}/{max_retries} failed: Status {response.status_code}. URL: {url}. Response: {response.text[:200]}") 294 | if response.status_code == 503: logger.warning(f" -> {name} service might still be loading/initializing.") 295 | # Continue retrying 296 | except httpx.RequestError as e: 297 | logger.error(f"Failed to connect to backend service {name} (Attempt {attempt+1}/{max_retries}) at {url}: {e}") 298 | # Continue retrying 299 | except Exception as e: 300 | logger.error(f"Unexpected error during {name} health check (Attempt {attempt+1}/{max_retries}) at {url}: {e}", exc_info=True) 301 | service_ok = False # Mark as failed on unexpected error 302 | break # Stop retrying for this service on unexpected error 303 | 304 | if not service_ok and attempt + 1 < max_retries: 305 | logger.info(f"Waiting {delay}s before retrying {name}...") 306 | await asyncio.sleep(delay) 307 | # After retries for a specific service 308 | if not service_ok: 309 | logger.error(f"Backend service {name} failed health check after {max_retries} attempts.") 310 | all_services_ok = False # Mark overall failure if any service fails 311 | 312 | # Final overall status log 313 | if not all_services_ok: 314 | logger.error("One or more critical backend services could not be reached or failed health check during startup.") 315 | else: 316 | logger.info("Initial backend service connectivity and health checks passed.") 317 | 318 | 319 | # --- FastAPI App Creation and CORS --- 320 | app = FastAPI(lifespan=lifespan, title="Voice Assistant Orchestrator", version="1.1.0") 321 | origins = ["*"] # Allow all origins for simplicity in local dev 322 | app.add_middleware( 323 | CORSMiddleware, 324 | allow_origins=origins, 325 | allow_credentials=True, 326 | allow_methods=["*"], 327 | allow_headers=["*"], 328 | ) 329 | 330 | # --- Pydantic Models for API --- 331 | class AssistResponse(BaseModel): 332 | user_transcript: str 333 | assistant_response: str 334 | assistant_audio_b64: str # Kept for API schema consistency, will be empty 335 | response_time_ms: float 336 | 337 | # --- API Endpoints --- 338 | @app.post("/assist", response_model=AssistResponse) 339 | async def handle_assist_request(audio: UploadFile = File(..., description="User audio input (WAV, MP3, etc.)")): 340 | """ 341 | Handles voice input, gets transcription and LLM response. 342 | Audio synthesis is now handled by the frontend connecting directly to TTS. 343 | """ 344 | overall_start_time = time.monotonic() 345 | logger.info(f"Received /assist request for file: {audio.filename}, size: {audio.size}") 346 | 347 | try: 348 | # 1. Call ASR Service 349 | asr_start_time = time.monotonic() 350 | user_transcript = await call_asr_service(audio) 351 | asr_time = (time.monotonic() - asr_start_time) * 1000 352 | 353 | # Handle no speech (Calls deprecated local TTS function) 354 | if not user_transcript: 355 | logger.info("ASR returned no transcript. Generating no-speech response.") 356 | no_speech_response = "Sorry, I didn't hear anything." 357 | tts_no_speech_start_time = time.monotonic() 358 | try: 359 | # Call the specific function for this case 360 | no_speech_audio_bytes = await call_tts_service_ws(no_speech_response) 361 | tts_no_speech_time = (time.monotonic() - tts_no_speech_start_time) * 1000 362 | logger.info(f"No-speech TTS generation took {tts_no_speech_time:.0f}ms") 363 | no_speech_audio_b64 = base64.b64encode(no_speech_audio_bytes).decode('utf-8') if no_speech_audio_bytes else "" 364 | except Exception as no_speech_err: 365 | logger.error(f"Failed to generate no-speech audio via TTS: {no_speech_err}", exc_info=True) 366 | no_speech_audio_b64 = "" # Fallback to empty audio on error 367 | 368 | return AssistResponse( 369 | user_transcript="", 370 | assistant_response=no_speech_response, 371 | assistant_audio_b64=no_speech_audio_b64, # May contain short audio or be empty 372 | response_time_ms=(time.monotonic() - overall_start_time) * 1000 373 | ) 374 | 375 | # 2. Prepare history for LLM 376 | current_llm_input = get_formatted_history() 377 | current_llm_input.append({"role": "user", "content": user_transcript}) 378 | 379 | # 3. Call LLM Service 380 | llm_start_time = time.monotonic() 381 | assistant_response = await call_llm_service(current_llm_input) 382 | llm_time = (time.monotonic() - llm_start_time) * 1000 383 | 384 | # --- Step 4 REMOVED: No direct TTS call from orchestrator for normal responses --- 385 | 386 | # 5. Update History (Still relevant) 387 | update_history(user_text=user_transcript, assistant_text=assistant_response) 388 | 389 | # --- Step 6 REMOVED: No audio encoding needed here --- 390 | assistant_audio_b64 = "" # Explicitly set to empty string 391 | 392 | # 7. Log and return (Adjust log message) 393 | overall_time = (time.monotonic() - overall_start_time) * 1000 394 | logger.info(f"Assist request (ASR+LLM only) processed in {overall_time:.2f}ms (ASR: {asr_time:.0f}ms, LLM: {llm_time:.0f}ms)") 395 | 396 | # Return response without audio data 397 | return AssistResponse( 398 | user_transcript=user_transcript, 399 | assistant_response=assistant_response, 400 | assistant_audio_b64=assistant_audio_b64, # Will be empty 401 | response_time_ms=overall_time 402 | ) 403 | 404 | except HTTPException as http_exc: 405 | logger.error(f"Pipeline failed due to HTTPException: {http_exc.status_code} - {http_exc.detail}") 406 | # Log the detail coming from downstream services if available 407 | if http_exc.detail: logger.error(f" -> Detail: {http_exc.detail}") 408 | raise http_exc # Re-raise the exception to return proper HTTP error 409 | except Exception as e: 410 | logger.error(f"Unexpected error during /assist pipeline: {e}", exc_info=True) 411 | raise HTTPException( 412 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 413 | detail=f"An internal orchestration error occurred: {e}" 414 | ) 415 | 416 | # --- Health Check Endpoint --- 417 | @app.get("/health", status_code=status.HTTP_200_OK) 418 | async def health_check(): 419 | # Basic health check only confirms the orchestrator itself is running 420 | return {"service": SERVICE_NAME, "status": "ok", "details": "Orchestrator is running."} 421 | 422 | # --- Reset History Endpoint --- 423 | @app.post("/reset_history", status_code=status.HTTP_200_OK) 424 | async def reset_conversation_history(): 425 | global conversation_history 426 | try: 427 | max_hist = orchestrator_config.get('max_history_turns', 5) * 2 428 | if max_hist <= 0: max_hist = 2 429 | conversation_history = deque(maxlen=max_hist) 430 | logger.info(f"Conversation history reset via API (maxlen={max_hist}).") 431 | return {"message": "Conversation history cleared."} 432 | except Exception as e: 433 | logger.error(f"Error during history reset: {e}", exc_info=True) 434 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to reset conversation history.") 435 | 436 | 437 | # --- Main Execution Guard --- 438 | # ... (Keep existing __main__ block as is) ... 439 | if __name__ == "__main__": 440 | import uvicorn 441 | logger.info(f"Starting {SERVICE_NAME.upper()} service directly via __main__...") 442 | load_configuration() 443 | timeout = httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0) 444 | limits = httpx.Limits(max_keepalive_connections=20, max_connections=100) 445 | http_client = httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) 446 | port = int(os.getenv('ORCHESTRATOR_PORT', 5000)) 447 | log_level_param = LOG_LEVEL.lower() 448 | logger.info(f"Launching Uvicorn on host 0.0.0.0, port {port} with log level {log_level_param}...") 449 | try: 450 | # Manual checks are less critical with depends_on, but can be run for direct execution testing 451 | # async def run_checks(): 452 | # await asyncio.sleep(15) 453 | # await check_backend_services() 454 | # # asyncio.run(run_checks()) 455 | 456 | uvicorn.run("orchestrator:app", host="0.0.0.0", port=port, log_level=log_level_param, reload=False) 457 | except Exception as main_err: 458 | logger.critical(f"Failed to start Uvicorn directly: {main_err}", exc_info=True) 459 | exit(1) 460 | finally: 461 | if http_client and not http_client.is_closed: 462 | asyncio.run(http_client.aclose()) 463 | logger.info(f"{SERVICE_NAME.upper()} HTTP client closed (direct run shutdown).") 464 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down (direct run)...") 465 | 466 | # --- END OF FILE --- 467 | -------------------------------------------------------------------------------- /backend/orchestrator/requirements.txt: -------------------------------------------------------------------------------- 1 | # backend/orchestrator/requirements.txt 2 | fastapi>=0.110.0,<0.112.0 3 | uvicorn[standard]>=0.29.0,<0.30.0 4 | httpx>=0.27.0 5 | PyYAML>=6.0 6 | python-dotenv>=1.0.0 7 | pydantic>=2.0.0,<3.0.0 8 | websockets>=12.0 # Explicitly add websockets client library -------------------------------------------------------------------------------- /backend/tts/Dockerfile: -------------------------------------------------------------------------------- 1 | # backend/tts/Dockerfile 2 | ARG BASE_IMAGE=nvidia/cuda:12.1.1-runtime-ubuntu22.04 3 | FROM ${BASE_IMAGE} 4 | 5 | ARG PYTHON_VERSION=3.10 6 | ARG USER_ID=1000 7 | ARG GROUP_ID=1000 8 | ARG USERNAME=appuser 9 | ARG GROUPNAME=appgroup 10 | 11 | WORKDIR /app 12 | 13 | # Create non-root user and group 14 | RUN groupadd -g ${GROUP_ID} ${GROUPNAME} && \ 15 | useradd -u ${USER_ID} -g ${GROUP_ID} -ms /bin/bash ${USERNAME} 16 | 17 | # Install Python, venv, pip, essential build tools, git, AND curl 18 | RUN apt-get update && apt-get install -y --no-install-recommends \ 19 | python${PYTHON_VERSION} \ 20 | python${PYTHON_VERSION}-venv \ 21 | python${PYTHON_VERSION}-dev \ 22 | build-essential \ 23 | git \ 24 | curl \ 25 | && rm -rf /var/lib/apt/lists/* 26 | 27 | # --- START: Added websocat installation --- 28 | ARG WEBSOCAT_VERSION=1.13.0 29 | ARG TARGETARCH # Docker buildx automatically sets this (e.g., amd64, arm64) 30 | 31 | # Install wget and ca-certificates if needed for download 32 | RUN apt-get update && apt-get install -y --no-install-recommends wget ca-certificates && rm -rf /var/lib/apt/lists/* 33 | 34 | # Download and install websocat based on architecture 35 | RUN case ${TARGETARCH} in \ 36 | amd64) WEBSOCAT_ARCH="x86_64-unknown-linux-musl" ;; \ 37 | arm64) WEBSOCAT_ARCH="aarch64-unknown-linux-musl" ;; \ 38 | # Add other architectures if needed, e.g. arm/v7 \ 39 | *) echo "Unsupported architecture: ${TARGETARCH}"; exit 1 ;; \ 40 | esac && \ 41 | wget -q "https://github.com/vi/websocat/releases/download/v${WEBSOCAT_VERSION}/websocat.${WEBSOCAT_ARCH}" -O /usr/local/bin/websocat && \ 42 | chmod +x /usr/local/bin/websocat && \ 43 | # Verify installation (optional) 44 | websocat --version 45 | # --- END: Added websocat installation --- 46 | 47 | 48 | # Create and activate virtual environment 49 | ENV VENV_PATH=/app/.venv 50 | RUN python${PYTHON_VERSION} -m venv ${VENV_PATH} 51 | ENV PATH="${VENV_PATH}/bin:$PATH" 52 | 53 | # Set cache directory for Hugging Face (and potentially pip) 54 | ENV HF_HOME=/cache/huggingface 55 | ENV PIP_CACHE_DIR=/cache/pip 56 | RUN mkdir -p ${HF_HOME} ${PIP_CACHE_DIR} && \ 57 | chown ${USERNAME}:${GROUPNAME} /cache ${HF_HOME} ${PIP_CACHE_DIR} 58 | 59 | # Install FFmpeg libraries needed for PyAV (av package) 60 | RUN apt-get update && apt-get install -y --no-install-recommends \ 61 | ffmpeg \ 62 | libavcodec-dev \ 63 | libavformat-dev \ 64 | libavutil-dev \ 65 | libswresample-dev \ 66 | && rm -rf /var/lib/apt/lists/* 67 | 68 | # Copy and install Python requirements 69 | COPY requirements.txt . 70 | RUN . ${VENV_PATH}/bin/activate && \ 71 | pip install --no-cache-dir --upgrade pip setuptools wheel && \ 72 | pip install --no-cache-dir -r requirements.txt 73 | 74 | # Copy application code 75 | COPY csm_utils ./csm_utils 76 | COPY generator.py models.py ./ 77 | COPY app.py . 78 | 79 | # Create logs directory and set permissions 80 | ENV LOG_DIR=/app/logs 81 | RUN mkdir -p ${LOG_DIR} && chown ${USERNAME}:${GROUPNAME} ${LOG_DIR} 82 | 83 | # Change ownership of the app directory 84 | RUN chown -R ${USERNAME}:${GROUPNAME} /app 85 | 86 | # Switch to non-root user 87 | USER ${USERNAME}:${GROUPNAME} 88 | 89 | # Expose port (will be mapped by docker-compose) 90 | EXPOSE 5003 91 | 92 | # Default command (can be overridden by docker-compose) 93 | CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "5003"] 94 | 95 | -------------------------------------------------------------------------------- /backend/tts/app.py: -------------------------------------------------------------------------------- 1 | # backend/tts/app.py 2 | # --- PASTE THIS ENTIRE BLOCK INTO YOUR FILE --- 3 | 4 | import os 5 | import logging 6 | import io 7 | import time 8 | import base64 9 | import asyncio # Ensure asyncio is imported 10 | import json 11 | import numpy as np 12 | import soundfile as sf 13 | import torch 14 | import yaml 15 | from contextlib import asynccontextmanager 16 | from pathlib import Path 17 | from typing import Optional, Dict, Any, List 18 | 19 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, status as fastapi_status 20 | from fastapi.responses import JSONResponse 21 | from pydantic import BaseModel, Field 22 | from huggingface_hub import login, logout 23 | from dotenv import load_dotenv 24 | from starlette.websockets import WebSocketState 25 | 26 | # --- CSM Library Imports --- 27 | CSM_LOADED = False 28 | try: 29 | import generator 30 | import models 31 | load_csm_1b = generator.load_csm_1b 32 | Segment = generator.Segment 33 | Generator = generator.Generator 34 | CSM_LOADED = True 35 | logging.info("Successfully prepared direct imports for copied generator/models.") 36 | except ImportError as e: 37 | logging.error(f"FATAL: Failed direct import of copied 'generator'/'models'. Error: {e}", exc_info=True) 38 | # Define dummy versions so the rest of the code doesn't raise NameError immediately 39 | def load_csm_1b(**kwargs): raise NotImplementedError("CSM library import failed") 40 | class Segment: pass 41 | class Generator: pass 42 | 43 | # --- Constants & Environment Loading --- 44 | load_dotenv() 45 | CONFIG_PATH = os.getenv('CONFIG_PATH', '/app/config.yaml') 46 | CACHE_DIR = Path(os.getenv('CACHE_DIR', '/cache')) 47 | HF_HOME_ENV = os.getenv('HF_HOME') 48 | HF_CACHE_DIR = Path(HF_HOME_ENV) if HF_HOME_ENV else CACHE_DIR / "huggingface" 49 | os.environ['HF_HOME'] = str(HF_CACHE_DIR) 50 | LOG_FILE_BASE = os.getenv('LOG_FILE_BASE', '/app/logs/service') 51 | LOG_LEVEL = os.getenv('LOG_LEVEL', 'info').upper() 52 | USE_GPU_ENV = os.getenv('USE_GPU', 'auto').lower() 53 | HF_TOKEN = os.getenv('HUGGING_FACE_TOKEN') 54 | TTS_SPEAKER_ID = int(os.getenv('TTS_SPEAKER_ID', '4')) 55 | TTS_MODEL_REPO_ID_DEFAULT = 'senstella/csm-expressiva-1b' 56 | TTS_MODEL_REPO_ID_ENV = os.getenv('TTS_MODEL_REPO_ID', TTS_MODEL_REPO_ID_DEFAULT) 57 | SERVICE_NAME = "tts_streaming" 58 | LOG_PATH = f"{LOG_FILE_BASE}_{SERVICE_NAME}.log" 59 | 60 | # --- Logging Setup --- 61 | os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) 62 | HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) 63 | logging.basicConfig( 64 | level=LOG_LEVEL, 65 | format="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s", 66 | handlers=[logging.StreamHandler(), logging.FileHandler(LOG_PATH)] 67 | ) 68 | logger = logging.getLogger(SERVICE_NAME) 69 | 70 | # --- Global Variables --- 71 | tts_generator: Optional[Generator] = None 72 | tts_config: Dict[str, Any] = {} 73 | effective_device: str = "cpu" 74 | model_load_info: Dict[str, Any] = {"status": "pending"} 75 | utterance_context_cache: Dict[str, List[Segment]] = {} 76 | 77 | # --- Configuration Loading --- 78 | # ... (Keep existing load_configuration function as is) ... 79 | def load_configuration(): 80 | global tts_config, effective_device, TTS_SPEAKER_ID, model_load_info 81 | try: 82 | logger.info(f"Loading configuration from: {CONFIG_PATH}") 83 | if not os.path.exists(CONFIG_PATH): 84 | logger.warning(f"Config file not found at {CONFIG_PATH}. Using defaults/env vars.") 85 | config = {'tts': {}} 86 | else: 87 | with open(CONFIG_PATH, 'r') as f: config = yaml.safe_load(f) 88 | if not config or 'tts' not in config: 89 | logger.warning(f"Config file {CONFIG_PATH} is missing 'tts' section. Using defaults.") 90 | config['tts'] = {} 91 | 92 | tts_config = config.get('tts', {}) 93 | 94 | # Determine Model Repo ID (Env Var > Config > Default) 95 | config_model_id = tts_config.get('model_repo_id') 96 | if TTS_MODEL_REPO_ID_ENV != TTS_MODEL_REPO_ID_DEFAULT: 97 | tts_config['model_repo_id'] = TTS_MODEL_REPO_ID_ENV 98 | logger.info(f"Using TTS_MODEL_REPO_ID from environment: {TTS_MODEL_REPO_ID_ENV}") 99 | elif config_model_id: 100 | tts_config['model_repo_id'] = config_model_id 101 | logger.info(f"Using model_repo_id from config.yaml: {config_model_id}") 102 | else: 103 | tts_config['model_repo_id'] = TTS_MODEL_REPO_ID_DEFAULT 104 | logger.info(f"Using default TTS_MODEL_REPO_ID: {TTS_MODEL_REPO_ID_DEFAULT}") 105 | 106 | # Determine Speaker ID (Env Var > Config > Default) 107 | config_speaker_id_str = str(tts_config.get('speaker_id', '4')) # Default to 4 108 | env_speaker_id_str = os.getenv('TTS_SPEAKER_ID', config_speaker_id_str) 109 | try: 110 | TTS_SPEAKER_ID = int(env_speaker_id_str) 111 | except ValueError: 112 | logger.warning(f"Invalid Speaker ID '{env_speaker_id_str}'. Using default 4.") 113 | TTS_SPEAKER_ID = 4 114 | tts_config['effective_speaker_id'] = TTS_SPEAKER_ID 115 | 116 | logger.info(f"Final Effective TTS Model Repo ID: {tts_config['model_repo_id']}") 117 | logger.info(f"Final Effective TTS Speaker ID: {TTS_SPEAKER_ID}") 118 | 119 | # Determine Device (Env Var USE_GPU > Config 'device' > Auto) 120 | config_device = tts_config.get('device', 'auto').lower() 121 | cuda_available = torch.cuda.is_available() 122 | logger.info(f"CUDA available via torch check: {cuda_available}, Torch: {torch.__version__}") 123 | 124 | if cuda_available and (USE_GPU_ENV == 'true' or (USE_GPU_ENV == 'auto' and config_device != 'cpu')): 125 | if config_device.startswith("cuda"): 126 | effective_device = config_device # Use specific cuda device from config (e.g., "cuda:0") 127 | else: 128 | effective_device = "cuda" # Default to "cuda" 129 | logger.info(f"CUDA is available and requested/auto. Using CUDA device: '{effective_device}'.") 130 | else: 131 | effective_device = "cpu" 132 | # CSM requires CUDA, so treat this as a fatal error if CUDA was expected 133 | if USE_GPU_ENV == 'true' or (USE_GPU_ENV == 'auto' and config_device != 'cpu'): 134 | logger.critical(f"FATAL: CUDA requested/required but unavailable/disabled (USE_GPU={USE_GPU_ENV}, cuda_available={cuda_available}). Cannot load CSM model.") 135 | model_load_info.update({"status": "error", "error": "CUDA unavailable/disabled, required by CSM."}) 136 | else: 137 | logger.warning("CUDA not available or disabled. TTS service cannot load CSM model.") 138 | model_load_info.update({"status": "error", "error": "CUDA unavailable/disabled, required by CSM."}) 139 | 140 | tts_config['effective_device'] = effective_device 141 | logger.info(f"TTS effective device target: {effective_device}") 142 | 143 | # Store other config values with defaults 144 | tts_config['max_audio_length_ms'] = float(tts_config.get('max_audio_length_ms', 90000)) 145 | tts_config['temperature'] = float(tts_config.get('temperature', 0.9)) 146 | tts_config['top_k'] = int(tts_config.get('top_k', 50)) 147 | logger.info(f"Generation params: max_len={tts_config['max_audio_length_ms']}ms, temp={tts_config['temperature']}, top_k={tts_config['top_k']}") 148 | 149 | except Exception as e: 150 | logger.critical(f"Unexpected error loading configuration: {e}", exc_info=True) 151 | tts_config = {} # Reset config on error 152 | model_load_info.update({"status": "error", "error": f"Config error: {e}"}) 153 | 154 | # --- Model Loading --- 155 | # ... (Keep existing load_tts_model function as is) ... 156 | def load_tts_model(): 157 | global tts_generator, model_load_info, tts_config 158 | if not CSM_LOADED: 159 | error_msg = model_load_info.get("error", "CSM library components failed import.") 160 | logger.critical(f"Cannot load model: {error_msg}") 161 | raise RuntimeError(error_msg) 162 | 163 | if not tts_config or model_load_info.get("status") == "error": 164 | error_msg = model_load_info.get("error", "TTS configuration missing or invalid.") 165 | logger.critical(f"Cannot load model: {error_msg}") 166 | model_load_info.setdefault("status", "error") 167 | model_load_info.setdefault("error", error_msg) 168 | raise RuntimeError(error_msg) 169 | 170 | device_to_load = tts_config.get('effective_device') 171 | if not device_to_load or not device_to_load.startswith("cuda"): 172 | error_msg = model_load_info.get("error", f"CSM requires CUDA, but effective device is '{device_to_load}'.") 173 | logger.critical(f"FATAL: {error_msg}") 174 | model_load_info.update({"status": "error", "error": error_msg}) 175 | raise RuntimeError(error_msg) 176 | 177 | model_repo_id = tts_config.get('model_repo_id') 178 | if not model_repo_id: 179 | error_msg = "TTS model_repo_id missing from config." 180 | logger.critical(f"FATAL: {error_msg}") 181 | model_load_info.update({"status": "error", "error": error_msg}) 182 | raise RuntimeError(error_msg) 183 | 184 | logger.info(f"Attempting to load TTS model: {model_repo_id} on device: {device_to_load}") 185 | logger.info(f"Using cache directory (HF_HOME): {HF_CACHE_DIR}") 186 | model_load_info = {"status": "loading", "model_repo_id": model_repo_id, "device": device_to_load} 187 | start_time = time.monotonic() 188 | 189 | try: 190 | logger.info(f"Calling generator.load_csm_1b(model_id='{model_repo_id}', device='{device_to_load}')") 191 | tts_generator = load_csm_1b( model_id=model_repo_id, device=device_to_load ) 192 | 193 | if tts_generator is None: 194 | raise RuntimeError(f"generator.load_csm_1b returned None for model '{model_repo_id}'.") 195 | if not isinstance(tts_generator, Generator): 196 | raise TypeError(f"load_csm_1b did not return a Generator instance (got {type(tts_generator)}).") 197 | 198 | load_time = time.monotonic() - start_time 199 | logger.info(f"Model load call completed in {load_time:.2f}s.") 200 | 201 | actual_sample_rate = getattr(tts_generator, 'sample_rate', None) 202 | if actual_sample_rate: 203 | logger.info(f"Detected generator sample rate: {actual_sample_rate} Hz") 204 | tts_config['actual_sample_rate'] = actual_sample_rate 205 | else: 206 | logger.warning(f"Could not get sample rate from generator. Using default 24000 Hz for encoding.") 207 | tts_config['actual_sample_rate'] = 24000 # Fallback default 208 | 209 | model_load_info.update({"status": "loaded", "load_time_s": round(load_time, 2), "sample_rate": tts_config['actual_sample_rate']}) 210 | logger.info(f"TTS Model '{model_repo_id}' loaded successfully on {device_to_load}.") 211 | 212 | try: 213 | actual_device = next(tts_generator._model.parameters()).device 214 | logger.info(f"Model confirmed on device: {actual_device}") 215 | if str(actual_device) != device_to_load: 216 | logger.warning(f"Model loaded on {actual_device} but target was {device_to_load}.") 217 | except Exception as dev_check_err: 218 | logger.warning(f"Could not confirm model device post-load: {dev_check_err}") 219 | 220 | except Exception as e: 221 | logger.critical(f"FATAL: Model loading failed for '{model_repo_id}': {e}", exc_info=True) 222 | tts_generator = None 223 | load_time = time.monotonic() - start_time 224 | model_load_info.update({"status": "error", "error": f"Model loading failed: {e}", "load_time_s": round(load_time, 2)}) 225 | raise RuntimeError(f"TTS model loading failed: {e}") from e 226 | 227 | # --- FastAPI Lifespan --- 228 | # ... (Keep existing lifespan function as is) ... 229 | @asynccontextmanager 230 | async def lifespan(app: FastAPI): 231 | """Manages application startup and shutdown, including model loading and optional warm-up.""" 232 | logger.info(f"{SERVICE_NAME.upper()} Service starting up...") 233 | global model_load_info, tts_config, tts_generator 234 | model_load_info = {"status": "initializing"} 235 | startup_error = None 236 | 237 | try: 238 | # --- Startup sequence logic --- 239 | if not CSM_LOADED: 240 | startup_error = model_load_info.get("error", "CSM library import failed at startup.") 241 | model_load_info.update({"status": "error", "error": startup_error}) 242 | raise RuntimeError(startup_error) 243 | 244 | load_configuration() 245 | if model_load_info.get("status") == "error": 246 | startup_error = model_load_info.get("error", "Config loading failed.") 247 | raise RuntimeError(startup_error) 248 | 249 | if tts_config.get('effective_device','cpu').startswith('cuda'): 250 | load_tts_model() 251 | if model_load_info.get("status") != "loaded": 252 | startup_error = model_load_info.get("error", "Model loading status not 'loaded' after successful call.") 253 | logger.error(f"Inconsistent state: load_tts_model completed but status is {model_load_info.get('status')}") 254 | model_load_info["status"] = "error" 255 | model_load_info.setdefault("error", startup_error) 256 | raise RuntimeError(startup_error) 257 | else: 258 | startup_error = model_load_info.get("error", "CUDA device not configured or available.") 259 | logger.critical(f"Lifespan: {startup_error}") 260 | raise RuntimeError(startup_error) 261 | 262 | if tts_generator and model_load_info.get("status") == "loaded": 263 | logger.info("Attempting TTS model warm-up with dummy inference...") 264 | warmup_speaker_id = tts_config.get('effective_speaker_id', 4) 265 | try: 266 | await asyncio.to_thread( 267 | tts_generator.generate, 268 | text="Ready.", 269 | speaker=warmup_speaker_id, 270 | context=[], 271 | max_audio_length_ms=1000 272 | ) 273 | logger.info("TTS model warm-up inference completed successfully.") 274 | except AttributeError: 275 | logger.warning("tts_generator does not have a 'generate' method, skipping lifespan warmup.") 276 | except Exception as warmup_err: 277 | logger.warning(f"TTS model warm-up failed: {warmup_err}", exc_info=True) 278 | 279 | logger.info("Lifespan startup sequence completed successfully.") 280 | 281 | except Exception as e: 282 | startup_error = str(e) 283 | logger.critical(f"Lifespan startup failed: {startup_error}", exc_info=True) 284 | model_load_info["status"] = "error" 285 | model_load_info.setdefault("error", startup_error if startup_error else "Unknown startup error") 286 | raise RuntimeError(f"Critical startup error: {startup_error}") from e 287 | 288 | logger.info("Yielding control to FastAPI application...") 289 | yield 290 | logger.info("FastAPI application finished.") 291 | 292 | # --- Shutdown Logic --- 293 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down...") 294 | if 'utterance_context_cache' in globals(): 295 | logger.info(f"Clearing utterance context cache ({len(globals()['utterance_context_cache'])} items)...") 296 | globals()['utterance_context_cache'].clear() 297 | logger.info("Utterance context cache cleared.") 298 | 299 | if 'tts_generator' in globals() and globals()['tts_generator']: 300 | logger.info("Releasing TTS generator instance...") 301 | try: 302 | del globals()['tts_generator'] 303 | globals()['tts_generator'] = None 304 | logger.info("TTS generator instance deleted.") 305 | except Exception as del_err: 306 | logger.warning(f"Error deleting generator instance: {del_err}") 307 | 308 | if 'effective_device' in globals() and globals()['effective_device'] and globals()['effective_device'].startswith("cuda"): 309 | try: 310 | logger.info("Attempting to clear CUDA cache...") 311 | torch.cuda.empty_cache() 312 | logger.info("Cleared CUDA cache.") 313 | except Exception as e: 314 | logger.warning(f"CUDA cache clear failed during shutdown: {e}") 315 | 316 | logger.info("TTS Service shutdown complete.") 317 | 318 | # --- FastAPI App Initialization --- 319 | app = FastAPI(lifespan=lifespan, title="TTS Streaming Service (CSM - senstella)", version="1.3.1") 320 | 321 | 322 | # --- Helper Functions --- 323 | # Ensure PCM_16 fix is present 324 | def numpy_to_base64_wav(audio_np: np.ndarray, samplerate: int) -> str: 325 | """Converts a NumPy audio array to a base64 encoded WAV string using PCM_16.""" 326 | if not isinstance(audio_np, np.ndarray): 327 | logger.warning(f"Input not NumPy array (type: {type(audio_np)}). Returning empty string.") 328 | return "" 329 | if audio_np.size == 0: 330 | logger.warning("Attempted encode empty NumPy array. Returning empty string.") 331 | return "" 332 | try: 333 | if audio_np.dtype != np.float32: 334 | audio_np = audio_np.astype(np.float32) 335 | 336 | max_val = np.max(np.abs(audio_np)) 337 | if max_val > 1.0: 338 | logger.debug(f"Normalizing audio before writing WAV (max abs val: {max_val:.3f}).") 339 | audio_np = audio_np / max_val 340 | elif max_val < 1e-6: 341 | logger.warning(f"Audio data seems silent (max abs val: {max_val:.3e}).") 342 | 343 | buffer = io.BytesIO() 344 | # Use PCM_16 for better browser compatibility 345 | logger.debug(f"Writing audio to WAV buffer (PCM_16, {samplerate}Hz)...") 346 | sf.write(buffer, audio_np, samplerate, format='WAV', subtype='PCM_16') 347 | buffer.seek(0) 348 | wav_bytes = buffer.getvalue() 349 | b64_string = base64.b64encode(wav_bytes).decode('utf-8') 350 | logger.debug(f"Encoded {len(wav_bytes)} bytes WAV ({len(audio_np)/samplerate:.2f}s) to base64 string.") 351 | return b64_string 352 | except Exception as e: 353 | logger.error(f"Error encoding numpy audio to WAV base64: {e}", exc_info=True) 354 | return "" 355 | 356 | # --- Main WebSocket Endpoint --- 357 | @app.websocket("/synthesize_stream") 358 | async def synthesize_stream_endpoint(websocket: WebSocket): 359 | # ... (Keep existing synthesize_stream_endpoint function as is) ... 360 | """Handles WebSocket for streaming TTS synthesis.""" 361 | client_host = websocket.client.host if websocket.client else "unknown" 362 | client_port = websocket.client.port if websocket.client else "unknown" 363 | logger.info(f"WebSocket connection request from {client_host}:{client_port} to /synthesize_stream") 364 | 365 | await websocket.accept() 366 | logger.info(f"WebSocket connection accepted for {client_host}:{client_port} on /synthesize_stream") 367 | 368 | if not tts_generator or model_load_info.get("status") != "loaded": 369 | err_msg = model_load_info.get("error", "TTS model is not ready or failed to load.") 370 | logger.error(f"Rejecting WebSocket connection (post-accept): {err_msg}") 371 | try: 372 | await websocket.send_json({"type": "error", "message": err_msg}) 373 | except Exception as send_err: 374 | logger.warning(f"Could not send error message before closing WS: {send_err}") 375 | await websocket.close(code=fastapi_status.WS_1011_INTERNAL_ERROR, reason="TTS Service Not Ready") 376 | return 377 | 378 | current_utt = None 379 | loop = asyncio.get_running_loop() 380 | 381 | try: 382 | while True: 383 | try: 384 | raw_data = await websocket.receive_text() 385 | msg = json.loads(raw_data) 386 | logger.debug(f"Received WS message: {msg}") 387 | except WebSocketDisconnect: 388 | logger.info(f"WebSocket client {client_host}:{client_port} disconnected.") 389 | break 390 | except json.JSONDecodeError: 391 | logger.warning("Received invalid JSON over WebSocket.") 392 | try: await websocket.send_json({"type":"error", "message":"Invalid JSON format"}) 393 | except: pass 394 | continue 395 | except Exception as e: 396 | logger.error(f"Unexpected error receiving WebSocket message: {e}", exc_info=True) 397 | try: await websocket.send_json({"type":"error", "message":f"Server error receiving message: {e}"}) 398 | except: pass 399 | continue 400 | 401 | msg_type = msg.get("type") 402 | 403 | if msg_type == "clear_context": 404 | utterance_id_to_clear = msg.get("utterance_id") 405 | if utterance_id_to_clear and utterance_id_to_clear in utterance_context_cache: 406 | del utterance_context_cache[utterance_id_to_clear] 407 | logger.info(f"Cleared context cache for utterance_id: {utterance_id_to_clear}") 408 | try: await websocket.send_json({"type":"context_cleared", "utterance_id": utterance_id_to_clear}) 409 | except: pass 410 | else: 411 | logger.warning(f"Received clear_context for unknown/missing utterance_id: {utterance_id_to_clear}") 412 | continue 413 | 414 | if msg_type != "generate_chunk": 415 | logger.warning(f"Received unknown message type: {msg_type}") 416 | try: await websocket.send_json({"type":"error", "message":f"Unknown message type: {msg_type}"}) 417 | except: pass 418 | continue 419 | 420 | text_chunk = msg.get("text_chunk","").strip() 421 | utterance_id = msg.get("utterance_id") 422 | max_len_ms = float(msg.get("max_audio_length_ms", tts_config.get("max_audio_length_ms", 90000))) 423 | 424 | if not text_chunk or not utterance_id: 425 | logger.warning(f"Missing text_chunk or utterance_id in generate_chunk message.") 426 | try: await websocket.send_json({"type":"error","message":"Missing text_chunk or utterance_id"}) 427 | except: pass 428 | continue 429 | 430 | if utterance_id != current_utt: 431 | if current_utt and current_utt in utterance_context_cache: 432 | logger.warning(f"Switching utterance context from {current_utt} to {utterance_id} without explicit clear.") 433 | current_utt = utterance_id 434 | if current_utt not in utterance_context_cache: 435 | utterance_context_cache[current_utt] = [] 436 | logger.info(f"Initialized new context cache for utterance_id: {current_utt}") 437 | else: 438 | logger.info(f"Reusing existing context cache for utterance_id: {current_utt}") 439 | 440 | context_for_csm = utterance_context_cache.get(current_utt, []) 441 | effective_sr = tts_config.get("actual_sample_rate", 24000) 442 | temp = float(tts_config.get("temperature", 0.9)) 443 | topk = int(tts_config.get("top_k", 50)) 444 | logger.info(f"Generating chunk for utt_id={current_utt}, speaker={TTS_SPEAKER_ID}, temp={temp}, topk={topk}, context_len={len(context_for_csm)}") 445 | logger.debug(f"Text chunk: '{text_chunk[:50]}...'") 446 | 447 | def on_chunk_generated(chunk: torch.Tensor, utt_id=current_utt, txt_chunk=text_chunk): 448 | if not loop.is_running(): 449 | logger.warning(f"Event loop closed before sending chunk for {utt_id}. Skipping.") 450 | return 451 | try: 452 | if chunk is None or chunk.numel() == 0: 453 | logger.warning(f"Generator produced an empty chunk for {utt_id}. Skipping send.") 454 | return 455 | 456 | b64 = numpy_to_base64_wav(chunk.numpy(), effective_sr) 457 | if not b64: 458 | logger.error(f"Failed to encode audio chunk to base64 for {utt_id}") 459 | return 460 | 461 | coro = websocket.send_json({ 462 | "type": "audio_chunk", 463 | "utterance_id": utt_id, 464 | "audio_b64": b64, 465 | "text_chunk": txt_chunk, 466 | }) 467 | future = asyncio.run_coroutine_threadsafe(coro, loop) 468 | 469 | def log_send_exception(fut): 470 | try: fut.result(timeout=0) 471 | except Exception as send_exc: logger.error(f"Error sending audio chunk for {utt_id} via WebSocket: {send_exc}", exc_info=False) 472 | future.add_done_callback(log_send_exception) 473 | 474 | except Exception as e_inner: 475 | logger.error(f"Error in on_chunk_generated callback for {utt_id}: {e_inner}", exc_info=True) 476 | 477 | generation_exception = None 478 | try: 479 | logger.debug(f"Starting generate_stream in thread for {current_utt}...") 480 | await asyncio.to_thread( 481 | lambda: list( 482 | tts_generator.generate_stream( 483 | text=text_chunk, speaker=TTS_SPEAKER_ID, context=context_for_csm, 484 | max_audio_length_ms=max_len_ms, temperature=temp, topk=topk, 485 | on_chunk_generated=on_chunk_generated 486 | ) 487 | ) 488 | ) 489 | logger.debug(f"generate_stream thread finished for {current_utt}.") 490 | 491 | try: 492 | if websocket.client_state == WebSocketState.CONNECTED: 493 | await websocket.send_json({"type": "stream_end", "utterance_id": current_utt}) 494 | logger.info(f"Sent explicit stream_end for {current_utt}.") 495 | else: 496 | logger.warning(f"WebSocket closed before explicit stream_end could be sent for {current_utt}.") 497 | except Exception as send_exc: 498 | logger.warning(f"Could not send stream_end message for {current_utt}: {send_exc}") 499 | 500 | except Exception as gen_exc: 501 | logger.error(f"Error during TTS generation thread for {current_utt}: {gen_exc}", exc_info=True) 502 | generation_exception = gen_exc 503 | 504 | if generation_exception: 505 | try: 506 | if websocket.client_state == WebSocketState.CONNECTED: 507 | await websocket.send_json({"type":"error", "message": f"TTS generation failed: {generation_exception}", "utterance_id": current_utt}) 508 | except Exception as send_err_exc: 509 | logger.error(f"Failed to send generation error message to client for {current_utt}: {send_err_exc}") 510 | 511 | if not generation_exception: 512 | try: 513 | if current_utt in utterance_context_cache: 514 | utterance_context_cache[current_utt].append( 515 | Segment(text=text_chunk, speaker=TTS_SPEAKER_ID, audio=None) 516 | ) 517 | logger.debug(f"Appended text segment to context for {current_utt}. New context length: {len(utterance_context_cache[current_utt])}") 518 | else: 519 | logger.warning(f"Context {current_utt} was cleared before text segment could be appended.") 520 | except Exception as append_err: 521 | logger.error(f"Error appending segment to context cache for {current_utt}: {append_err}") 522 | 523 | logger.info(f"Exited main WebSocket loop for {client_host}:{client_port}.") 524 | 525 | except WebSocketDisconnect: 526 | logger.info(f"WebSocket client {client_host}:{client_port} disconnected (caught in outer block).") 527 | except Exception as e: 528 | logger.error(f"Unhandled error in WebSocket handler for {client_host}:{client_port}: {e}", exc_info=True) 529 | finally: 530 | logger.info(f"Performing final WebSocket cleanup for {client_host}:{client_port}, last utterance: {current_utt}") 531 | if current_utt and current_utt in utterance_context_cache: 532 | try: 533 | del utterance_context_cache[current_utt] 534 | logger.info(f"Cleared context cache for final utterance: {current_utt}") 535 | except KeyError: 536 | logger.info(f"Context cache already cleared for final utterance: {current_utt}") 537 | except Exception as cache_del_err: 538 | logger.error(f"Error deleting context cache for {current_utt}: {cache_del_err}") 539 | 540 | if websocket.client_state == WebSocketState.CONNECTED: 541 | try: 542 | await websocket.close(code=fastapi_status.WS_1000_NORMAL_CLOSURE) 543 | logger.info(f"Closed WebSocket connection in finally block for {client_host}:{client_port}") 544 | except Exception as close_err: 545 | logger.error(f"Error closing WebSocket in finally block: {close_err}") 546 | else: 547 | logger.info(f"WebSocket connection already closed for {client_host}:{client_port}") 548 | 549 | # --- WebSocket Health Check Endpoint --- 550 | @app.websocket("/ws_health") 551 | async def websocket_health_check(websocket: WebSocket): 552 | """Accepts a WebSocket connection and immediately closes it.""" 553 | await websocket.accept() 554 | # Optional: Log the health check connection attempt 555 | # logger.debug(f"WebSocket health check connection accepted from {websocket.client.host}:{websocket.client.port}") 556 | await websocket.close(code=fastapi_status.WS_1000_NORMAL_CLOSURE) 557 | # logger.debug("WebSocket health check connection closed.") 558 | 559 | # --- HTTP Health Check Endpoint --- 560 | @app.get("/health", status_code=fastapi_status.HTTP_200_OK) 561 | async def health_check(): 562 | # ... (Keep existing health_check implementation as is) ... 563 | current_status = model_load_info.get("status", "unknown") 564 | is_healthy = (current_status == "loaded" and tts_generator is not None) 565 | response_content = { 566 | "service": SERVICE_NAME, "status": "ok" if is_healthy else "error", 567 | "details": { "model_status": current_status, 568 | "model_repo_id": tts_config.get('model_repo_id', 'N/A'), 569 | "effective_device": tts_config.get('effective_device', 'N/A'), 570 | "target_speaker_id": tts_config.get('effective_speaker_id', 'N/A'), 571 | "actual_sample_rate": tts_config.get('actual_sample_rate', 'N/A'), 572 | "generator_object_present": (tts_generator is not None), 573 | "load_info": model_load_info } 574 | } 575 | status_code = fastapi_status.HTTP_200_OK if is_healthy else fastapi_status.HTTP_503_SERVICE_UNAVAILABLE 576 | 577 | if not is_healthy: 578 | logger.warning(f"Health check reporting unhealthy: Status='{current_status}', Generator Present={tts_generator is not None}") 579 | if model_load_info.get("error"): 580 | response_content["details"]["error_message"] = model_load_info.get("error") 581 | else: 582 | try: 583 | _ = tts_generator.device 584 | _ = tts_generator.sample_rate 585 | logger.debug(f"Health check: Generator instance accessed successfully (device: {tts_generator.device}).") 586 | except Exception as gen_check_err: 587 | logger.error(f"Health check failed accessing generator properties: {gen_check_err}", exc_info=True) 588 | response_content["status"] = "error" 589 | response_content["details"]["error_details"] = f"Generator access error: {gen_check_err}" 590 | model_load_info["status"] = "error_runtime_access" 591 | model_load_info["error"] = f"Generator access error: {gen_check_err}" 592 | status_code = fastapi_status.HTTP_503_SERVICE_UNAVAILABLE 593 | 594 | return JSONResponse(status_code=status_code, content=response_content) 595 | 596 | # --- Main Execution Guard --- 597 | # ... (Keep existing __main__ block as is) ... 598 | if __name__ == "__main__": 599 | logger.info(f"Starting {SERVICE_NAME.upper()} service directly via __main__...") 600 | main_startup_error = None 601 | try: 602 | logger.info("Running startup sequence (direct run)...") 603 | if not CSM_LOADED: 604 | main_startup_error = model_load_info.get("error", "CSM import failed.") 605 | raise RuntimeError(main_startup_error) 606 | load_configuration() 607 | if model_load_info.get("status") == "error": 608 | main_startup_error = model_load_info.get("error", "Config loading failed.") 609 | raise RuntimeError(main_startup_error) 610 | if tts_config.get('effective_device','cpu').startswith('cuda'): 611 | load_tts_model() # This will raise if it fails 612 | if model_load_info.get("status") != "loaded": 613 | main_startup_error = model_load_info.get("error", "Model loading status not 'loaded'.") 614 | raise RuntimeError(main_startup_error) 615 | else: 616 | main_startup_error = model_load_info.get("error", "CUDA device not configured or available.") 617 | model_load_info.update({"status": "error", "error": main_startup_error}) # Ensure status reflects this 618 | raise RuntimeError(main_startup_error) 619 | 620 | if tts_generator and model_load_info.get("status") == "loaded": 621 | logger.info("Attempting direct run warm-up...") 622 | try: 623 | tts_generator.generate(text="Ready.", speaker=tts_config.get('effective_speaker_id', 4), context=[], max_audio_length_ms=1000) 624 | logger.info("Direct run warm-up successful.") 625 | except Exception as warmup_err: 626 | logger.warning(f"Direct run warm-up failed: {warmup_err}") 627 | 628 | port = int(os.getenv('TTS_PORT', 5003)) 629 | host = os.getenv('TTS_HOST', "0.0.0.0") 630 | log_level_param = LOG_LEVEL.lower() 631 | logger.info(f"Direct run startup successful. Launching Uvicorn on {host}:{port}...") 632 | import uvicorn 633 | uvicorn.run("app:app", host=host, port=port, log_level=log_level_param, reload=False) 634 | 635 | except RuntimeError as e: 636 | logger.critical(f"Direct run failed during startup: {e}", exc_info=False) 637 | exit(1) 638 | except ImportError as e: 639 | logger.critical(f"Direct run failed: Missing dependency? {e}", exc_info=True) 640 | exit(1) 641 | except Exception as e: 642 | logger.critical(f"Unexpected error during direct run: {e}", exc_info=True) 643 | exit(1) 644 | finally: 645 | logger.info(f"{SERVICE_NAME.upper()} Service shutting down (direct run)...") 646 | if 'tts_generator' in globals() and globals()['tts_generator']: 647 | del globals()['tts_generator'] 648 | globals()['tts_generator'] = None 649 | if 'utterance_context_cache' in globals(): 650 | globals()['utterance_context_cache'].clear() 651 | if 'effective_device' in globals() and globals()['effective_device'] and globals()['effective_device'].startswith("cuda"): 652 | try: torch.cuda.empty_cache() 653 | except: pass 654 | logger.info(f"{SERVICE_NAME.upper()} Service shutdown complete (direct run).") 655 | 656 | # --- END OF FILE --- 657 | -------------------------------------------------------------------------------- /backend/tts/csm_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReisCook/VoiceAssistant/543f1f934555ee3f08ee74f3edbe78b9321b3ac9/backend/tts/csm_utils/__init__.py -------------------------------------------------------------------------------- /backend/tts/csm_utils/loader.py: -------------------------------------------------------------------------------- 1 | # backend/tts/csm_utils/loader.py 2 | import json 3 | import torch 4 | # No longer need hf_hub_download here as encodec_model_24khz handles it internally (?) 5 | # from huggingface_hub import hf_hub_download 6 | # No longer need safe_open 7 | # from safetensors import safe_open 8 | 9 | # --- Import changed to base encodec library --- 10 | # Using the correct import path based on encodec library structure 11 | from encodec.model import EncodecModel 12 | # ------------------------------------------------------------------- 13 | 14 | # Using the factory function, repo/filenames might not be needed explicitly 15 | # DEFAULT_AUDIO_TOKENIZER_REPO = "facebook/encodec_24khz" 16 | # DEFAULT_AUDIO_TOKENIZER_CHECKPOINT = "encodec_24khz-d7cc33bc.th" 17 | 18 | # --- load_ckpt function not needed for this loading method --- 19 | # def load_ckpt(path): ... 20 | 21 | def get_mimi(device="cpu"): 22 | """ 23 | Loads the Encodec audio tokenizer model using the base encodec library's 24 | recommended factory function encodec_model_24khz(). 25 | """ 26 | # The repo_id argument is removed as the factory function targets the specific model 27 | print(f"Initializing EncodecModel (base encodec library's 24khz factory) on device: {device}") 28 | try: 29 | # --- Use the standard factory function from the encodec library --- 30 | # This function typically returns the model architecture with pre-trained weights loaded. 31 | model = EncodecModel.encodec_model_24khz() 32 | print("Instantiated base EncodecModel (24khz factory).") 33 | 34 | # --- Optional: Set target bandwidth (common practice) --- 35 | # Bandwidths can be 1.5, 3.0, 6.0, 12.0, 24.0 kbps 36 | # 6.0 kbps is a common default balancing quality and size 37 | target_bandwidth = 6.0 38 | model.set_target_bandwidth(target_bandwidth) 39 | print(f"Set target bandwidth to {target_bandwidth} kbps.") 40 | 41 | # --- Move model to the target device and set to eval mode --- 42 | model = model.to(device) 43 | model.eval() 44 | print(f"Encodec model configured and moved to {device}.") 45 | return model 46 | except Exception as e: 47 | print(f"ERROR loading Encodec model using base 'encodec' library factory: {e}") 48 | raise 49 | 50 | # Add any other helper functions needed by generator.py if they were originally in moshi.models.loaders -------------------------------------------------------------------------------- /backend/tts/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchtune 7 | from huggingface_hub import PyTorchModelHubMixin 8 | from torchtune.models import llama3_2 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder: 13 | return llama3_2.llama3_2( 14 | vocab_size=128_256, 15 | num_layers=16, 16 | num_heads=32, 17 | num_kv_heads=8, 18 | embed_dim=2048, 19 | max_seq_len=2048, 20 | intermediate_dim=8192, 21 | attn_dropout=0.0, 22 | norm_eps=1e-5, 23 | rope_base=500_000, 24 | scale_factor=32, 25 | ) 26 | 27 | def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder: 28 | return llama3_2.llama3_2( 29 | vocab_size=128_256, 30 | num_layers=4, 31 | num_heads=8, 32 | num_kv_heads=2, 33 | embed_dim=1024, 34 | max_seq_len=2048, 35 | intermediate_dim=8192, 36 | attn_dropout=0.0, 37 | norm_eps=1e-5, 38 | rope_base=500_000, 39 | scale_factor=32, 40 | ) 41 | 42 | FLAVORS = { 43 | "llama-1B": llama3_2_1B, 44 | "llama-100M": llama3_2_100M, 45 | } 46 | 47 | def _prepare_transformer(model): 48 | embed_dim = model.tok_embeddings.embedding_dim 49 | model.tok_embeddings = nn.Identity() 50 | model.output = nn.Identity() 51 | return model, embed_dim 52 | 53 | def _create_causal_mask(seq_len: int, device: torch.device): 54 | return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)) 55 | 56 | def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): 57 | """ 58 | Args: 59 | mask: (max_seq_len, max_seq_len) 60 | input_pos: (batch_size, seq_len) 61 | 62 | Returns: 63 | (batch_size, seq_len, max_seq_len) 64 | """ 65 | r = mask[input_pos, :] 66 | return r 67 | 68 | def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization 69 | q = torch.empty_like(probs).exponential_(1) 70 | return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) 71 | 72 | def sample_topk(logits: torch.Tensor, topk: int, temperature: float): 73 | logits = logits / temperature 74 | 75 | filter_value: float = -float("Inf") 76 | indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] 77 | scores_processed = logits.masked_fill(indices_to_remove, filter_value) 78 | scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1) 79 | probs = torch.nn.functional.softmax(scores_processed, dim=-1) 80 | 81 | sample_token = _multinomial_sample_one_no_sync(probs) 82 | return sample_token 83 | 84 | @dataclass 85 | class ModelArgs: 86 | backbone_flavor: str 87 | decoder_flavor: str 88 | text_vocab_size: int 89 | audio_vocab_size: int 90 | audio_num_codebooks: int 91 | 92 | 93 | class Model( 94 | nn.Module, 95 | PyTorchModelHubMixin, 96 | repo_url="https://github.com/SesameAILabs/csm", 97 | pipeline_tag="text-to-speech", 98 | license="apache-2.0", 99 | ): 100 | def __init__(self, config: ModelArgs): 101 | super().__init__() 102 | self.config = config 103 | 104 | self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]()) 105 | self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]()) 106 | 107 | self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim) 108 | self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim) 109 | 110 | self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) 111 | self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False) 112 | self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) 113 | 114 | def setup_caches(self, max_batch_size: int) -> torch.Tensor: 115 | """Setup KV caches and return a causal mask.""" 116 | dtype = next(self.parameters()).dtype 117 | device = next(self.parameters()).device 118 | 119 | with device: 120 | self.backbone.setup_caches(max_batch_size, dtype) 121 | self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks) 122 | 123 | self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device)) 124 | self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device)) 125 | 126 | def generate_frame( 127 | self, 128 | tokens: torch.Tensor, 129 | tokens_mask: torch.Tensor, 130 | input_pos: torch.Tensor, 131 | temperature: float, 132 | topk: int, 133 | ) -> torch.Tensor: 134 | """ 135 | Args: 136 | tokens: (batch_size, seq_len, audio_num_codebooks+1) 137 | tokens_mask: (batch_size, seq_len, audio_num_codebooks+1) 138 | input_pos: (batch_size, seq_len) positions for each token 139 | mask: (batch_size, seq_len, max_seq_len 140 | 141 | Returns: 142 | (batch_size, audio_num_codebooks) sampled tokens 143 | """ 144 | dtype = next(self.parameters()).dtype 145 | b, s, _ = tokens.size() 146 | 147 | assert self.backbone.caches_are_enabled(), "backbone caches are not enabled" 148 | curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos) 149 | embeds = self._embed_tokens(tokens) 150 | masked_embeds = embeds * tokens_mask.unsqueeze(-1) 151 | h = masked_embeds.sum(dim=2) 152 | h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype) 153 | 154 | last_h = h[:, -1, :] 155 | c0_logits = self.codebook0_head(last_h) 156 | c0_sample = sample_topk(c0_logits, topk, temperature) 157 | c0_embed = self._embed_audio(0, c0_sample) 158 | 159 | curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1) 160 | curr_sample = c0_sample.clone() 161 | curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1) 162 | 163 | # Decoder caches must be reset every frame. 164 | self.decoder.reset_caches() 165 | for i in range(1, self.config.audio_num_codebooks): 166 | curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos) 167 | decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to( 168 | dtype=dtype 169 | ) 170 | ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1]) 171 | ci_sample = sample_topk(ci_logits, topk, temperature) 172 | ci_embed = self._embed_audio(i, ci_sample) 173 | 174 | curr_h = ci_embed 175 | curr_sample = torch.cat([curr_sample, ci_sample], dim=1) 176 | curr_pos = curr_pos[:, -1:] + 1 177 | 178 | return curr_sample 179 | 180 | def reset_caches(self): 181 | self.backbone.reset_caches() 182 | self.decoder.reset_caches() 183 | 184 | def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor: 185 | return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size) 186 | 187 | def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: 188 | text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2) 189 | 190 | audio_tokens = tokens[:, :, :-1] + ( 191 | self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device) 192 | ) 193 | audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape( 194 | tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1 195 | ) 196 | 197 | return torch.cat([audio_embeds, text_embeds], dim=-2) 198 | 199 | -------------------------------------------------------------------------------- /backend/tts/requirements.txt: -------------------------------------------------------------------------------- 1 | # backend/tts/requirements.txt 2 | 3 | torch==2.4.0 4 | torchaudio==2.4.0 5 | tokenizers==0.21.0 6 | transformers==4.49.0 7 | huggingface_hub==0.28.1 8 | moshi==0.2.2 9 | torchtune==0.4.0 10 | torchao==0.9.0 11 | fastapi>=0.110.0,<0.112.0 12 | uvicorn[standard]>=0.29.0,<0.30.0 13 | python-dotenv>=1.0.0 14 | PyYAML>=6.0 15 | pydantic>=2.0.0,<3.0.0 16 | websockets>=12.0 17 | encodec>=0.1.1 # Needed for Mimi/audio tokenizer 18 | einops>=0.7.0 19 | librosa>=0.10.0 20 | soundfile>=0.12.1 21 | numpy>=1.24.0,<2.0.0 22 | httpx>=0.27.0 23 | safetensors>=0.4.1 24 | av>=10.0.0 25 | 26 | # NOTE: Ensure your Python version is compatible (e.g., 3.10 as mentioned in CSM README) 27 | # NOTE: Ensure CUDA version matches (e.g., CUDA 12.1+ for torch 2.4.0+cu121) 28 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | # voice-assistant/docker-compose.yml 2 | services: 3 | asr: 4 | build: 5 | context: ./backend/asr 6 | dockerfile: Dockerfile 7 | container_name: voice-assistant-asr 8 | ports: 9 | - "${ASR_PORT}:${ASR_PORT}" 10 | volumes: 11 | - ./shared/config.yaml:/app/config.yaml:ro 12 | - ./shared/logs:/app/logs 13 | - model_cache:${CACHE_DIR:-/cache} 14 | environment: 15 | - ASR_PORT=${ASR_PORT} 16 | - CACHE_DIR=${CACHE_DIR:-/cache} 17 | - USE_GPU=${USE_GPU} 18 | - LOG_LEVEL=${LOG_LEVEL:-info} 19 | - PYTHONUNBUFFERED=${PYTHONUNBUFFERED:-1} 20 | - CONFIG_PATH=/app/config.yaml 21 | - LOG_FILE_BASE=/app/logs/service 22 | - HUGGING_FACE_TOKEN=${HUGGING_FACE_TOKEN} 23 | - NVIDIA_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-all} 24 | - NVIDIA_DRIVER_CAPABILITIES=compute,utility 25 | deploy: 26 | resources: 27 | reservations: 28 | devices: 29 | - driver: nvidia 30 | capabilities: [gpu] 31 | count: 1 32 | runtime: nvidia 33 | networks: 34 | - voice-assistant-net 35 | restart: unless-stopped 36 | healthcheck: 37 | test: ["CMD", "curl", "--fail", "http://localhost:${ASR_PORT}/health"] 38 | interval: 30s 39 | timeout: 10s 40 | retries: 5 41 | start_period: 120s 42 | 43 | llm: 44 | build: 45 | context: ./backend/llm 46 | dockerfile: Dockerfile 47 | container_name: voice-assistant-llm 48 | ports: 49 | - "${LLM_PORT}:${LLM_PORT}" 50 | volumes: 51 | - ./shared/config.yaml:/app/config.yaml:ro 52 | - ./shared/logs:/app/logs 53 | - model_cache:${CACHE_DIR:-/cache} 54 | environment: 55 | - LLM_PORT=${LLM_PORT} 56 | - CACHE_DIR=${CACHE_DIR:-/cache} 57 | - USE_GPU=${USE_GPU} 58 | - LOG_LEVEL=${LOG_LEVEL:-info} 59 | - PYTHONUNBUFFERED=${PYTHONUNBUFFERED:-1} 60 | - CONFIG_PATH=/app/config.yaml 61 | - LOG_FILE_BASE=/app/logs/service 62 | - HUGGING_FACE_TOKEN=${HUGGING_FACE_TOKEN} 63 | - NVIDIA_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-all} 64 | - NVIDIA_DRIVER_CAPABILITIES=compute,utility 65 | - CMAKE_ARGS=${CMAKE_ARGS:--DLLAMA_CUBLAS=on} 66 | - FORCE_CMAKE=${FORCE_CMAKE:-1} 67 | deploy: 68 | resources: 69 | reservations: 70 | devices: 71 | - driver: nvidia 72 | capabilities: [gpu] 73 | count: 1 74 | runtime: nvidia 75 | networks: 76 | - voice-assistant-net 77 | restart: unless-stopped 78 | healthcheck: 79 | test: ["CMD", "curl", "--fail", "http://localhost:${LLM_PORT}/health"] 80 | interval: 30s 81 | timeout: 15s 82 | retries: 5 83 | start_period: 180s 84 | 85 | tts: 86 | build: 87 | context: ./backend/tts 88 | # Ensure the Dockerfile includes websocat installation 89 | dockerfile: Dockerfile 90 | container_name: voice-assistant-tts 91 | ports: 92 | - "${TTS_PORT}:${TTS_PORT}" 93 | volumes: 94 | - ./shared/config.yaml:/app/config.yaml:ro 95 | - ./shared/logs:/app/logs 96 | - model_cache:${CACHE_DIR:-/cache} 97 | environment: 98 | - TTS_PORT=${TTS_PORT} 99 | - CACHE_DIR=${CACHE_DIR:-/cache} 100 | - USE_GPU=${USE_GPU} # Should be true or auto for CSM 101 | - LOG_LEVEL=${LOG_LEVEL:-info} 102 | - PYTHONUNBUFFERED=${PYTHONUNBUFFERED:-1} 103 | - CONFIG_PATH=/app/config.yaml 104 | - LOG_FILE_BASE=/app/logs/service 105 | - HUGGING_FACE_TOKEN=${HUGGING_FACE_TOKEN} # REQUIRED 106 | - TTS_SPEAKER_ID=${TTS_SPEAKER_ID:-4} # Default to 4 for expressiva model 107 | - NVIDIA_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-all} 108 | - NVIDIA_DRIVER_CAPABILITIES=compute,utility 109 | - HF_HOME=${CACHE_DIR:-/cache}/huggingface # Explicitly set HF_HOME 110 | deploy: 111 | resources: 112 | reservations: 113 | devices: 114 | - driver: nvidia 115 | capabilities: [gpu] 116 | count: 1 # CSM requires GPU 117 | runtime: nvidia 118 | networks: 119 | - voice-assistant-net 120 | restart: unless-stopped 121 | # --- UPDATED HEALTHCHECK SECTION for TTS using /ws_health --- 122 | healthcheck: 123 | test: > 124 | sh -c ' 125 | curl --fail -s http://localhost:${TTS_PORT}/health && 126 | websocat -q --one-message ws://localhost:${TTS_PORT}/ws_health 127 | ' 128 | # -q: quiet mode 129 | # --one-message: waits for exactly one message OR EOF (which /ws_health sends immediately after accept) 130 | interval: 15s # Check slightly more often 131 | timeout: 10s # Overall timeout for the sh -c command 132 | retries: 6 # Increase retries slightly 133 | start_period: 300s # Keep long start period for model loading 134 | # --- END OF UPDATED HEALTHCHECK --- 135 | 136 | orchestrator: 137 | build: 138 | context: ./backend/orchestrator 139 | dockerfile: Dockerfile 140 | container_name: voice-assistant-orchestrator 141 | ports: 142 | - "${ORCHESTRATOR_PORT}:${ORCHESTRATOR_PORT}" 143 | volumes: 144 | - ./shared/config.yaml:/app/config.yaml:ro 145 | - ./shared/logs:/app/logs 146 | environment: 147 | - ORCHESTRATOR_PORT=${ORCHESTRATOR_PORT} 148 | - ASR_SERVICE_URL=http://asr:${ASR_PORT} 149 | - LLM_SERVICE_URL=http://llm:${LLM_PORT} 150 | - TTS_SERVICE_URL=http://tts:${TTS_PORT} 151 | - LOG_LEVEL=${LOG_LEVEL:-info} 152 | - PYTHONUNBUFFERED=${PYTHONUNBUFFERED:-1} 153 | - CONFIG_PATH=/app/config.yaml 154 | - LOG_FILE_BASE=/app/logs/service 155 | depends_on: # This correctly waits for the updated healthcheck now 156 | asr: 157 | condition: service_healthy 158 | llm: 159 | condition: service_healthy 160 | tts: 161 | condition: service_healthy 162 | networks: 163 | - voice-assistant-net 164 | restart: unless-stopped 165 | healthcheck: 166 | test: ["CMD", "curl", "--fail", "http://localhost:${ORCHESTRATOR_PORT}/health"] 167 | interval: 30s 168 | timeout: 5s 169 | retries: 3 170 | start_period: 10s # Orchestrator itself starts fast 171 | 172 | networks: 173 | voice-assistant-net: 174 | driver: bridge 175 | 176 | volumes: 177 | model_cache: 178 | driver: local 179 | -------------------------------------------------------------------------------- /frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | pnpm-debug.log* 8 | lerna-debug.log* 9 | 10 | node_modules 11 | dist 12 | dist-ssr 13 | *.local 14 | 15 | # Editor directories and files 16 | .vscode/* 17 | !.vscode/extensions.json 18 | .idea 19 | .DS_Store 20 | *.suo 21 | *.ntvs* 22 | *.njsproj 23 | *.sln 24 | *.sw? 25 | -------------------------------------------------------------------------------- /frontend/.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": ["tauri-apps.tauri-vscode", "rust-lang.rust-analyzer"] 3 | } 4 | -------------------------------------------------------------------------------- /frontend/README.md: -------------------------------------------------------------------------------- 1 | # Tauri + React + Typescript 2 | 3 | This template should help get you started developing with Tauri, React and Typescript in Vite. 4 | 5 | ## Recommended IDE Setup 6 | 7 | - [VS Code](https://code.visualstudio.com/) + [Tauri](https://marketplace.visualstudio.com/items?itemName=tauri-apps.tauri-vscode) + [rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer) 8 | -------------------------------------------------------------------------------- /frontend/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 6 | 7 |