├── .DS_Store ├── .env.example ├── .gitignore ├── Makefile ├── README.md ├── alembic.ini ├── app ├── .DS_Store ├── __init__.py ├── api │ ├── __init__.py │ └── routes │ │ └── v1 │ │ ├── endpoints.py │ │ └── voice.py ├── config.py ├── core │ └── voice.py ├── db │ ├── __init__.py │ ├── database.py │ ├── migrations │ │ ├── env.py │ │ ├── script.py.mako │ │ ├── utils.py │ │ └── versions │ │ │ ├── 20240101_add_rate_limits.py │ │ │ └── 20240101_initial.py │ └── models.py ├── dependencies.py ├── main.py ├── schemas │ ├── constants.py │ ├── models.py │ └── requests.py ├── services │ ├── __init__.py │ ├── audio.py │ ├── chat_state.py │ └── llm.py ├── utils │ ├── __init__.py │ ├── errors.py │ └── logger.py └── websocket │ ├── __init__.py │ ├── base_handler.py │ ├── connection.py │ ├── handlers │ ├── __init__.py │ ├── audio.py │ ├── conversation.py │ ├── main.py │ ├── response.py │ └── session.py │ ├── redis.py │ └── types.py ├── docker-compose.yml ├── document.json ├── echoollama.png ├── pilot ├── .DS_Store ├── Readme.md ├── echo.analyzer.ipynb ├── test_audio.wav └── vad_animation.gif ├── prometheus └── prometheus.yml ├── requirements.txt └── scripts └── set_env.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/.DS_Store -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Server Configuration 2 | PORT=8000 3 | HOST=0.0.0.0 4 | DEBUG=True 5 | ENVIRONMENT=development # development, staging, production 6 | SECRET_KEY=your-super-secret-key-here 7 | 8 | # Database Configuration 9 | DATABASE_URL=postgresql://user:password@localhost:5432/db_name 10 | DB_MAX_CONNECTIONS=20 11 | DB_TIMEOUT=30 12 | 13 | # Redis Configuration 14 | REDIS_HOST=localhost 15 | REDIS_PORT=6379 16 | REDIS_DB=0 17 | REDIS_PASSWORD=your-redis-password 18 | REDIS_SSL=False 19 | 20 | # WebSocket Configuration 21 | WS_HEARTBEAT_INTERVAL=30 # seconds 22 | WS_CONNECTION_TIMEOUT=60 # seconds 23 | WS_MAX_CONNECTIONS=1000 24 | WS_RATE_LIMIT=100 # requests per minute 25 | 26 | # Audio Processing 27 | WHISPER_MODEL_SIZE=base # tiny, base, small, medium, large 28 | WHISPER_DEVICE=cuda # cuda, cpu 29 | WHISPER_COMPUTE_TYPE=float16 # float16, float32 30 | WHISPER_LANGUAGE=en 31 | WHISPER_BEAM_SIZE=5 32 | WHISPER_VAD_FILTER=True 33 | 34 | # Text-to-Speech 35 | TTS_ENGINE=openai # openai, azure, elevenlabs 36 | TTS_MODEL=tts-1 # tts-1, tts-1-hd 37 | TTS_OPENAI_API_KEY=your-openai-api-key 38 | TTS_OPENAI_API_BASE_URL=https://api.openai.com/v1 39 | TTS_DEFAULT_VOICE=alloy # alloy, echo, fable, onyx, nova, shimmer 40 | 41 | # Cache Configuration 42 | SPEECH_CACHE_DIR=./cache/speech 43 | SPEECH_CACHE_MAX_SIZE=1024 # MB 44 | SPEECH_CACHE_TTL=86400 # seconds (24 hours) 45 | 46 | # Logging 47 | LOG_LEVEL=INFO # DEBUG, INFO, WARNING, ERROR, CRITICAL 48 | LOG_FORMAT=json # json, text 49 | LOG_FILE_PATH=./logs/app.log 50 | LOG_ROTATION=1d # 1d, 1w, 1m 51 | LOG_RETENTION=30d # how long to keep logs 52 | 53 | # Security 54 | CORS_ORIGINS=http://localhost:3000,https://yourdomain.com 55 | ALLOWED_HOSTS=localhost,127.0.0.1 56 | SSL_CERT_PATH=/path/to/cert.pem 57 | SSL_KEY_PATH=/path/to/key.pem 58 | 59 | # Rate Limiting 60 | RATE_LIMIT_REQUESTS=100 # requests per window 61 | RATE_LIMIT_WINDOW=60 # seconds 62 | RATE_LIMIT_TOKENS=1000 # tokens per minute -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # db files 59 | *.db 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # templates 135 | .github/templates/* 136 | .idea/ 137 | app.log 138 | data/cache/ 139 | app/prompts 140 | airley-ai 141 | postgres_data 142 | redis_data 143 | ollama_models 144 | speech_cache 145 | grafana_data 146 | pilot/models -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: migrations migrate rollback 2 | 3 | # Set Python path 4 | export PYTHONPATH := $(shell pwd) 5 | 6 | migrations: 7 | source scripts/set_env.sh && alembic revision --autogenerate -m "$(message)" 8 | 9 | migrate: 10 | source scripts/set_env.sh && alembic upgrade head 11 | 12 | rollback: 13 | source scripts/set_env.sh && alembic downgrade -1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦙 `echoOLlama`: Reverse-engineered OpenAI’s [Realtime API] 2 | > 🌟 Talk to your local LLM models in human voice and get responses in realtime! 3 | 4 | ![🦙 EchoOLlama Banner](https://github.com/user-attachments/assets/d2422917-b03a-48aa-88c8-d40f0884bd5e) 5 | 6 | > ⚠️ **Active Development Alert!** ⚠️ 7 | > 8 | > We're cooking up something amazing! While the core functionality is taking shape, some features are still in the oven. Perfect for experiments, but maybe hold off on that production deployment for now! 😉 9 | 10 | ## 🎯 What's `echoOLlama`? 11 | `echoOLlama` is a cool project that lets you talk to AI models using your voice, just like you'd talk to a real person! 🗣️ 12 | 13 | Here's what makes it special: 14 | 15 | - 🎤 You can speak naturally and the AI understands you 16 | - 🤖 It works with local AI models (through Ollama) so your data stays private 17 | - ⚡ Super fast responses in real-time 18 | - 🔊 The AI talks back to you with a natural voice 19 | - 🔄 Works just like OpenAI's API but with your own models 20 | 21 | Think of it like having a smart assistant that runs completely on your computer. You can have natural conversations with it, ask questions, get help with tasks - all through voice! And because it uses local AI models, you don't need to worry about your conversations being stored in the cloud. 22 | 23 | Perfect for developers who want to: 24 | - Build voice-enabled AI applications 25 | - Create custom AI assistants 26 | - Experiment with local language models 27 | - Have private AI conversations 28 | 29 | 30 | ### 🎉 What's Working Now: 31 | 32 | ![🦙 EchoOLlama Banner](https://github.com/user-attachments/assets/5ce20abf-6982-4b6b-a824-58f7d91ef7cd) 33 | 34 | - ✅ Connection handling and session management 35 | - ✅ Real-time event streaming 36 | - ✅ Redis-based session storage 37 | - ✅ Basic database interactions 38 | - ✅ OpenAI compatibility layer 39 | - ✅ Core WebSocket infrastructure 40 | 41 | ### 🚧 On the Roadmap: 42 | - 📝 Message processing pipeline (In Progress) 43 | - 🤖 Advanced response generation with client events 44 | - 🎯 Function calling implementation with client events 45 | - 🔊 Audio transcription service connection with client events 46 | - 🗣️ Text-to-speech integration with client events 47 | - 📊 Usage analytics dashboard 48 | - 🔐 Enhanced authentication system 49 | 50 | ## 🌟 Features & Capabilities 51 | 52 | ### 🎮 Core Services 53 | - **Real-time Chat** 💬 54 | - Streaming responses via websockets 55 | - Multi-model support via Ollama 56 | - Session persistence 57 | - 🎤 Audio Transcription (FASTER_Whisper) 58 | - 🗣️ Text-to-Speech (OpenedAI/Speech) 59 | 60 | - **Coming Soon** 🔜 61 | - 🔧 Function Calling System 62 | - 📊 Advanced Analytics 63 | 64 | ### 🛠️ Technical Goodies 65 | - ⚡ Lightning-fast response times 66 | - 🔒 Built-in rate limiting 67 | - 📈 Usage tracking ready 68 | - ⚖️ Load balancing for scale 69 | - 🎯 100% OpenAI API compatibility 70 | 71 | 72 | 73 | ## 🏗️ System Architecture 74 | 75 | 76 | echoOLlama 77 | 78 | 79 | > Click on the image to view the interactive version on Excalidraw! 80 | 81 | ## 💻 Tech Stack Spotlight 82 | ### 🎯 Backend Champions 83 | - 🚀 FastAPI - Lightning-fast API framework 84 | - 📝 Redis - Blazing-fast caching & session management 85 | - 🐘 PostgreSQL - Rock-solid data storage 86 | 87 | ### 🤖 AI Powerhouse 88 | - 🦙 Ollama - Local LLM inference 89 | - 🎤 faster_whisper - Speech recognition (coming soon) 90 | - 🗣️ OpenedAI TTS - Voice synthesis (coming soon) 91 | 92 | ## 🚀 Get Started in 3, 2, 1... 93 | 94 | 1. **Clone & Setup** 📦 95 | ```bash 96 | git clone https://github.com/iamharshdev/EchoOLlama.git 97 | cd EchoOLlama 98 | python -m venv .venv 99 | source .venv/bin/activate # or `.venv\Scripts\activate` on Windows 100 | pip install -r requirements.txt 101 | ``` 102 | 103 | 2. **Environment Setup** ⚙️ 104 | ```bash 105 | cp .env.example .env 106 | # Update .env with your config - check .env.example for all options! 107 | make migrate # create db and apply migrations 108 | ``` 109 | 110 | 3. **Launch Time** 🚀 111 | ```bash 112 | # Fire up the services 113 | docker-compose up -d 114 | 115 | # Start the API server 116 | uvicorn app.main:app --reload 117 | ``` 118 | 119 | ## 🤝 Join the EchoOLlama Family 120 | Got ideas? Found a bug? Want to contribute? Check out our [CONTRIBUTING.md](CONTRIBUTING.md) guide and become part of something awesome! We love pull requests! 🎉 121 | 122 | ## 💡 Project Status Updates 123 | - 🟢 **Working**: Connection handling, session management, event streaming 124 | - 🟡 **In Progress**: Message processing, response generation 125 | - 🔴 **Planned**: Audio services, function calling, analytics 126 | 127 | ## 📜 License 128 | MIT Licensed - Go wild! See [LICENSE](LICENSE) for the legal stuff. 129 | 130 | --- 131 | *Built with 💖 by the community, for the community* 132 | 133 | *PS: Star ⭐ us on GitHub if you like what we're building!* 134 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | [alembic] 2 | script_location = app/db/migrations 3 | sqlalchemy.url = driver://user:pass@localhost/dbname 4 | file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s 5 | 6 | [loggers] 7 | keys = root,sqlalchemy,alembic 8 | 9 | [handlers] 10 | keys = console 11 | 12 | [formatters] 13 | keys = generic 14 | 15 | [logger_root] 16 | level = WARN 17 | handlers = console 18 | qualname = 19 | 20 | [logger_sqlalchemy] 21 | level = WARN 22 | handlers = 23 | qualname = sqlalchemy.engine 24 | 25 | [logger_alembic] 26 | level = INFO 27 | handlers = 28 | qualname = alembic 29 | 30 | [handler_console] 31 | class = StreamHandler 32 | args = (sys.stderr,) 33 | level = NOTSET 34 | formatter = generic 35 | 36 | [formatter_generic] 37 | format = %(levelname)-5.5s [%(name)s] %(message)s 38 | datefmt = %H:%M:%S -------------------------------------------------------------------------------- /app/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/.DS_Store -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/__init__.py -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/api/__init__.py -------------------------------------------------------------------------------- /app/api/routes/v1/endpoints.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from fastapi import APIRouter, Depends, HTTPException 4 | from fastapi.responses import StreamingResponse 5 | 6 | from app.dependencies import get_llm_service 7 | from app.schemas.requests import ( 8 | GenerateRequest, 9 | ChatRequest, 10 | PullRequest, 11 | GenerateResponse, 12 | ChatResponse, 13 | PullResponse 14 | ) 15 | from app.services.llm import LLMService, ModelProvider 16 | from app.utils.logger import logger 17 | 18 | router = APIRouter() 19 | 20 | 21 | @router.post("/generate") 22 | async def generate_response( 23 | request: GenerateRequest, 24 | llm_service: LLMService = Depends(get_llm_service) 25 | ): 26 | """ 27 | Generate a response using the specified model 28 | 📝 File: endpoints.py, Line: 27, Function: generate_response 29 | """ 30 | try: 31 | if request.stream: 32 | return StreamingResponse( 33 | llm_service.generate_response( 34 | messages=request.messages, 35 | temperature=request.temperature, 36 | stream=True, 37 | provider=ModelProvider.OLLAMA if "llama" in request.model.lower() else ModelProvider.OPENAI 38 | ), 39 | media_type="text/event-stream" 40 | ) 41 | 42 | response = await llm_service.generate_response( 43 | messages=request.messages, 44 | temperature=request.temperature, 45 | stream=False, 46 | provider=ModelProvider.OLLAMA if "llama" in request.model.lower() else ModelProvider.OPENAI 47 | ) 48 | return GenerateResponse( 49 | model=request.model, 50 | response=response.choices[0].message.content, 51 | done=True 52 | ) 53 | 54 | except Exception as e: 55 | logger.error(f"❌ endpoints.py: Generate response failed: {str(e)}") 56 | raise HTTPException(status_code=500, detail=str(e)) 57 | 58 | 59 | @router.post("/chat") 60 | async def chat_with_model( 61 | request: ChatRequest, 62 | llm_service: LLMService = Depends(get_llm_service) 63 | ): 64 | """ 65 | Have a conversation with the model 66 | 📝 File: endpoints.py, Line: 64, Function: chat_with_model 67 | """ 68 | try: 69 | if request.stream: 70 | return StreamingResponse( 71 | llm_service.chat_stream( 72 | request=request, 73 | provider=ModelProvider.OLLAMA if "llama" in request.model.lower() else ModelProvider.OPENAI 74 | ), 75 | media_type="text/event-stream" 76 | ) 77 | 78 | response = await llm_service.generate_response( 79 | messages=[m.model_dump() for m in request.messages], 80 | temperature=0.8, 81 | stream=False, 82 | provider=ModelProvider.OLLAMA if "llama" in request.model.lower() else ModelProvider.OPENAI 83 | ) 84 | 85 | return ChatResponse( 86 | model=request.model, 87 | message={ 88 | "role": "assistant", 89 | "content": response.choices[0].message.content 90 | }, 91 | done=True 92 | ) 93 | 94 | except Exception as e: 95 | logger.error(f"❌ endpoints.py: Chat with model failed: {str(e)}") 96 | raise HTTPException(status_code=500, detail=str(e)) 97 | 98 | 99 | @router.get("/models", response_model=Dict) 100 | async def list_models( 101 | llm_service: LLMService = Depends(get_llm_service) 102 | ): 103 | """ 104 | List all available models 105 | 📝 File: endpoints.py, Line: 105, Function: list_models 106 | """ 107 | try: 108 | # Get models from both providers 109 | ollama_models = await llm_service.ollama_client.list() 110 | return ollama_models 111 | 112 | except Exception as e: 113 | logger.error(f"❌ endpoints.py: List models failed: {str(e)}") 114 | raise HTTPException(status_code=500, detail=str(e)) 115 | 116 | 117 | @router.post("/model/pull", response_model=PullResponse) 118 | async def pull_model( 119 | request: PullRequest, 120 | llm_service: LLMService = Depends(get_llm_service) 121 | ): 122 | """ 123 | Pull a model from the provider 124 | 📝 File: endpoints.py, Line: 144, Function: pull_model 125 | """ 126 | try: 127 | if "ollama" in request.provider.lower(): 128 | # Only Ollama supports model pulling 129 | await llm_service.ollama_client.pull(request.name) 130 | return PullResponse(status="success", model=request.name) 131 | else: 132 | raise HTTPException(status_code=400, detail="Model pulling only supported for Ollama models") 133 | 134 | except Exception as e: 135 | logger.error(f"❌ endpoints.py: Pull model failed: {str(e)}") 136 | raise HTTPException(status_code=500, detail=str(e)) 137 | 138 | 139 | @router.delete("/model/{model_name}") 140 | async def delete_model( 141 | model_name: str, 142 | provider: str = "ollama", 143 | llm_service: LLMService = Depends(get_llm_service) 144 | ): 145 | """ 146 | Delete a model 147 | 📝 File: endpoints.py, Line: 166, Function: delete_model 148 | """ 149 | try: 150 | if "ollama" in provider.lower(): 151 | # Only Ollama supports model deletion 152 | await llm_service.ollama_client.delete(model_name) 153 | return {"status": "success", "message": f"Model {model_name} deleted"} 154 | else: 155 | raise HTTPException(status_code=400, detail="Model deletion only supported for Ollama models") 156 | 157 | except Exception as e: 158 | logger.error(f"❌ endpoints.py: Delete model failed: {str(e)}") 159 | raise HTTPException(status_code=500, detail=str(e)) 160 | -------------------------------------------------------------------------------- /app/api/routes/v1/voice.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from fastapi import File, UploadFile, APIRouter, HTTPException, Depends 3 | from fastapi.responses import FileResponse, StreamingResponse 4 | from app.services.audio import AudioService 5 | from app.dependencies import get_audio_service 6 | from app.utils.logger import logger 7 | 8 | router = APIRouter() 9 | 10 | 11 | @router.post("/transcribe") 12 | async def transcribe_audio( 13 | file: UploadFile = File(...), 14 | language: str = 'en', 15 | task: str = "transcribe", 16 | beam_size: int = 5, 17 | vad_filter: bool = True, 18 | audio_service: AudioService = Depends(get_audio_service) 19 | ) -> StreamingResponse: 20 | """ 21 | Transcribe audio file to text using Faster Whisper and stream the results 22 | 📝 File: voice.py, Line: 20, Function: transcribe_audio 23 | """ 24 | try: 25 | # Read file content and encode to base64 26 | content = await file.read() 27 | event_id = f"transcribe_{file.filename}" 28 | 29 | logger.info(f"🎤 voice.py: Starting transcription for file {file.filename}") 30 | 31 | return StreamingResponse( 32 | audio_service.transcribe_audio( 33 | audio_data=content, 34 | event_id=event_id, 35 | language=language, 36 | task=task, 37 | beam_size=beam_size, 38 | vad_filter=vad_filter 39 | ), 40 | media_type="text/event-stream" 41 | ) 42 | except Exception as e: 43 | logger.error(f"❌ voice.py: Error in transcription generation: {str(e)}") 44 | raise HTTPException(status_code=500, detail=str(e)) 45 | 46 | 47 | @router.get("/speech") 48 | async def speech( 49 | input: str, 50 | voice: Optional[str] = 'alloy', 51 | model: Optional[str] = 'tts-1', 52 | response_format: Optional[str] = 'mp3', 53 | audio_service: AudioService = Depends(get_audio_service) 54 | ): 55 | """ 56 | Generate speech from text using TTS service 57 | 📝 File: voice.py, Line: 50, Function: speech 58 | """ 59 | try: 60 | logger.info(f"🎤 voice.py: Starting speech generation for text: {input[:30]}...") 61 | 62 | file_path, cache_key = await audio_service.generate_speech( 63 | text=input, 64 | voice=voice, 65 | model=model, 66 | response_format=response_format 67 | ) 68 | 69 | return FileResponse( 70 | file_path, 71 | media_type="audio/mpeg", 72 | headers={ 73 | "Content-Disposition": f"attachment; filename={cache_key}.mp3", 74 | "Accept-Ranges": "bytes", 75 | "Cache-Control": "no-cache", 76 | } 77 | ) 78 | 79 | except Exception as e: 80 | logger.error(f"❌ voice.py: Error in speech generation: {str(e)}") 81 | raise HTTPException(status_code=500, detail=str(e)) 82 | -------------------------------------------------------------------------------- /app/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Tuple, Type 3 | 4 | from app.utils.logger import logger 5 | import os 6 | from pydantic_settings import BaseSettings, EnvSettingsSource, PydanticBaseSettingsSource, SettingsConfigDict 7 | from pydantic.fields import FieldInfo 8 | 9 | class CustomSource(EnvSettingsSource): 10 | def prepare_field_value( 11 | self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool 12 | ) -> Any: 13 | if field_name == 'RATE_LIMIT_REQUESTS': 14 | return int(value) if value else 0 15 | if field_name == 'RATE_LIMIT_TOKENS': 16 | return int(value) if value else 0 17 | 18 | if field_name == 'USE_CUDA': 19 | print(f"🔍 config.py: USE_CUDA: {value}") 20 | return value.lower() == 'true' if value else False 21 | 22 | if field_name == 'DEBUG': 23 | return value.lower() == 'true' if value else False 24 | 25 | if field_name == 'CORS_ORIGINS': 26 | return [origin.strip() for origin in value.split(',')] if value else [] 27 | 28 | return json.loads(value) if value else None 29 | 30 | class Settings(BaseSettings): 31 | model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") 32 | # API Settings 33 | APP_NAME: str = "OLLAMAGATE" 34 | API_VERSION: str = "v1" 35 | DEBUG: bool = True 36 | LOG_LEVEL: str = "INFO" 37 | CORS_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:3001"] 38 | 39 | # Database 40 | DB_HOST: str = 'localhost' 41 | DB_PORT: int = 5432 42 | DB_USER: str = "ollamagateuser" 43 | DB_PASSWORD: str = "ollamagate" 44 | DB_NAME: str = "ollamagate" 45 | 46 | # Redis 47 | REDIS_HOST: str = "localhost" 48 | REDIS_PORT: int = 6379 49 | REDIS_DB: int = 0 50 | 51 | # LLM Settings 52 | OLLAMA_API_BASE_URL: str = "http://localhost:11434" 53 | OPENAI_API_KEY: str = "sk-default" 54 | GPT_MODEL: str = "gpt-4-turbo-preview" 55 | OLLAMA_MODEL: str = "llama3.1" 56 | 57 | # Audio & Speech Settings 58 | AUDIO_STORAGE_PATH: str = "/tmp/audio_buffers" 59 | MAX_AUDIO_SIZE_MB: int = 10 60 | DATA_DIR: str = f"{os.getcwd()}/data" 61 | CACHE_DIR: str = os.path.join(DATA_DIR, "cache") 62 | SPEECH_CACHE_DIR: str = os.path.join(CACHE_DIR, "audio", "speech") 63 | 64 | # Speech-to-Text Settings 65 | STT_MODEL_CHOICE: str = "whisper" 66 | WHISPER_MODEL_SIZE: str = "base" 67 | WHISPER_DEVICE: str = "cpu" 68 | WHISPER_COMPUTE_TYPE: str = "float16" 69 | STT_MODEL: any = None 70 | USE_CUDA: bool = False 71 | 72 | # Text-to-Speech Settings 73 | TTS_ENGINE: str = "openai" 74 | TTS_MODEL: str = "tts-1" 75 | TTS_OPENAI_API_KEY: str = 'sk-111111111' 76 | TTS_OPENAI_API_BASE_URL: str = "http://localhost:8000/v1" 77 | 78 | # WebSocket Settings 79 | WS_HEARTBEAT_INTERVAL: int = 30 80 | 81 | # Rate Limiting 82 | RATE_LIMIT_REQUESTS: int = 1000 83 | RATE_LIMIT_TOKENS: int = 50000 84 | 85 | # Session 86 | SESSION_EXPIRATION_TIME: int = 60 * 60 # 1 hour 87 | 88 | def __init__(self, **kwargs): 89 | super().__init__(**kwargs) 90 | self._setup_cuda() 91 | self.setup_cache_dir() 92 | 93 | def _setup_cuda(self): 94 | """Setup CUDA if available and requested""" 95 | if self.USE_CUDA: 96 | try: 97 | import torch 98 | assert torch.cuda.is_available(), "CUDA not available" 99 | self.WHISPER_DEVICE = "cuda" 100 | self.WHISPER_COMPUTE_TYPE = "float16" 101 | logger.info("🚀 config.py: CUDA enabled successfully") 102 | except Exception as e: 103 | cuda_error = ( 104 | "Error when testing CUDA but USE_CUDA_DOCKER is true. " 105 | f"Resetting USE_CUDA_DOCKER to false: {e}" 106 | ) 107 | logger.warning(f"⚠️ config.py: {cuda_error}") 108 | os.environ["USE_CUDA_DOCKER"] = "false" 109 | self.USE_CUDA = "false" 110 | self.WHISPER_DEVICE = "cpu" 111 | self.WHISPER_COMPUTE_TYPE = "float32" 112 | else: 113 | self.WHISPER_DEVICE = "cpu" 114 | self.WHISPER_COMPUTE_TYPE = "float32" 115 | 116 | def setup_cache_dir(self): 117 | """Setup cache directories""" 118 | try: 119 | os.makedirs(self.SPEECH_CACHE_DIR, exist_ok=True) 120 | os.makedirs(self.AUDIO_STORAGE_PATH, exist_ok=True) 121 | logger.info("📁 config.py: Cache directories setup complete") 122 | except Exception as e: 123 | logger.error(f"❌ config.py: Error setting up cache directories: {str(e)}") 124 | raise e 125 | 126 | @property 127 | def is_cuda_enabled(self) -> bool: 128 | """Check if CUDA is enabled""" 129 | return self.WHISPER_DEVICE == "cuda" 130 | 131 | @property 132 | def cache_dirs(self) -> dict: 133 | """Get all cache directories""" 134 | return { 135 | "data": self.DATA_DIR, 136 | "cache": self.CACHE_DIR, 137 | "speech": self.SPEECH_CACHE_DIR, 138 | "audio": self.AUDIO_STORAGE_PATH 139 | } 140 | 141 | @classmethod 142 | def settings_customise_sources( 143 | cls, 144 | settings_cls: Type[BaseSettings], 145 | init_settings: PydanticBaseSettingsSource, 146 | env_settings: PydanticBaseSettingsSource, 147 | dotenv_settings: PydanticBaseSettingsSource, 148 | file_secret_settings: PydanticBaseSettingsSource, 149 | ) -> Tuple[PydanticBaseSettingsSource, ...]: 150 | return (CustomSource(settings_cls),) 151 | 152 | 153 | settings = Settings() 154 | 155 | 156 | logger.info(f"🔍 config.py: Settings: {settings.model_dump()}") -------------------------------------------------------------------------------- /app/core/voice.py: -------------------------------------------------------------------------------- 1 | import json 2 | from app.config import Settings 3 | from fastapi import APIRouter, HTTPException 4 | from faster_whisper import WhisperModel 5 | from typing import Generator 6 | import asyncio 7 | import os 8 | from app.utils.logger import logger 9 | 10 | router = APIRouter() 11 | 12 | config = Settings() 13 | 14 | 15 | def get_stt_model(): 16 | if config.STT_MODEL_CHOICE == "whisper": 17 | try: 18 | stt_model = WhisperModel( 19 | model_size_or_path=config.WHISPER_MODEL_SIZE, 20 | device=config.WHISPER_DEVICE, 21 | compute_type=config.WHISPER_COMPUTE_TYPE, 22 | cpu_threads=4, 23 | num_workers=2 24 | ) 25 | return stt_model 26 | except Exception as e: 27 | logger.error(f"❌ voice.py: Failed to initialize Whisper model: {str(e)}") 28 | raise 29 | else: 30 | raise HTTPException(status_code=500, detail="STT_MODEL_CHOICE not supported") 31 | 32 | 33 | async def generate_transcription(temp_path: str, language: str, task: str, beam_size: int, vad_filter: bool) -> \ 34 | Generator[str, None, None]: # type: ignore 35 | try: 36 | stt_model = get_stt_model() 37 | if stt_model is None: 38 | raise HTTPException(status_code=500, detail="STT_MODEL not initialized") 39 | 40 | # Use faster-whisper's streaming API 41 | segments, info = await asyncio.to_thread( 42 | stt_model.transcribe, 43 | temp_path, 44 | language=language, 45 | task=task, 46 | beam_size=beam_size, 47 | vad_filter=vad_filter, 48 | initial_prompt=None 49 | ) 50 | 51 | logger.info(f"ℹ️ voice.py: Detected language: {info.language} with probability {info.language_probability:.2f}") 52 | 53 | # Stream each segment as it's transcribed 54 | for segment in segments: 55 | yield f"{segment.text}\n" 56 | logger.info(f"🎯 voice.py: Transcribed segment: {segment.text[:30]}...") 57 | except Exception as e: 58 | logger.error(f"❌ voice.py: Error in transcription generation: {str(e)}") 59 | yield f"data: {json.dumps({'error': str(e)})}\n\n" 60 | finally: 61 | # Cleanup temporary file 62 | if os.path.exists(temp_path): 63 | os.remove(temp_path) 64 | logger.info(f"🧹 voice.py: Cleaned up temporary file {temp_path}") 65 | -------------------------------------------------------------------------------- /app/db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/db/__init__.py -------------------------------------------------------------------------------- /app/db/database.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Dict 2 | from sqlalchemy import update 3 | from sqlalchemy.orm import sessionmaker 4 | from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, AsyncEngine 5 | from sqlalchemy.future import select 6 | from app.config import settings 7 | from app.db.models import Session, Conversation, ConversationItem, Response, MessageRole, ResponseStatus, Base, \ 8 | RateLimit 9 | from app.utils.logger import logger 10 | import uuid 11 | 12 | 13 | class Database: 14 | def __init__(self): 15 | self.engine = None 16 | self.SessionLocal = None 17 | self._session: Optional[AsyncSession] = None 18 | 19 | async def connect(self): 20 | """Initialize database connection""" 21 | if not self.engine: 22 | logger.info("🔌 File: database.py, Function: connect; Connecting to database") 23 | print( 24 | f'postgresql+asyncpg://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}') 25 | self.engine: AsyncEngine = create_async_engine( 26 | f'postgresql+asyncpg://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}', 27 | echo=True, 28 | pool_size=20, 29 | max_overflow=0 30 | ) 31 | self.SessionLocal = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) 32 | 33 | async def disconnect(self): 34 | """Close database connection""" 35 | if self.engine: 36 | logger.info("🔌 File: database.py, Function: disconnect; Disconnecting from database") 37 | await self.engine.dispose() 38 | 39 | async def create_session(self, session_data: Dict) -> Session: 40 | """Create a new realtime session""" 41 | async with self.SessionLocal() as session: 42 | logger.info("📝 File: database.py, Function: create_session; Creating new session") 43 | new_session = Session(**session_data) 44 | session.add(new_session) 45 | await session.commit() 46 | await session.refresh(new_session) 47 | return new_session 48 | 49 | async def update_session(self, session_data: Dict) -> Session: 50 | """Update session data""" 51 | async with self.SessionLocal() as session: 52 | logger.info(f"🔄 File: database.py, Function: update_session; Updating session {session_data['id']}") 53 | await session.execute(update(Session).where(Session.id == session_data['id']).values(**session_data)) 54 | await session.commit() 55 | 56 | async def get_session(self, session_id: str) -> Optional[Session]: 57 | """Get session by ID""" 58 | async with self.SessionLocal() as session: 59 | logger.info(f"🔍 File: database.py, Function: get_session; Fetching session {session_id}") 60 | result = await session.execute(select(Session).where(Session.id == session_id)) 61 | return result.scalar_one_or_none() 62 | 63 | async def create_conversation(self, session_id: str) -> Conversation: 64 | """Create a new conversation for a session""" 65 | async with self.SessionLocal() as session: 66 | logger.info( 67 | f"💬 File: database.py, Function: create_conversation; Creating conversation for session {session_id}") 68 | conversation = Conversation(session_id=session_id) 69 | session.add(conversation) 70 | await session.commit() 71 | await session.refresh(conversation) 72 | return conversation 73 | 74 | async def create_conversation_item(self, conversation_id: str, role: MessageRole, 75 | content: Dict, audio_start_ms: int = None, 76 | audio_end_ms: int = None) -> ConversationItem: 77 | """Create a new conversation item""" 78 | async with self.SessionLocal() as session: 79 | logger.info( 80 | f"📝 File: database.py, Function: create_conversation_item; Adding item to conversation {conversation_id}") 81 | item = ConversationItem( 82 | conversation_id=conversation_id, 83 | role=role, 84 | content=content, 85 | audio_start_ms=audio_start_ms, 86 | audio_end_ms=audio_end_ms 87 | ) 88 | session.add(item) 89 | await session.commit() 90 | await session.refresh(item) 91 | return item 92 | 93 | async def create_response(self, conversation_id: str) -> Response: 94 | """Create a new response""" 95 | async with self.SessionLocal() as session: 96 | logger.info( 97 | f"🤖 File: database.py, Function: create_response; Creating response for conversation {conversation_id}") 98 | response = Response( 99 | conversation_id=conversation_id, 100 | status=ResponseStatus.IN_PROGRESS 101 | ) 102 | session.add(response) 103 | await session.commit() 104 | await session.refresh(response) 105 | return response 106 | 107 | async def update_response(self, response_id: str, 108 | status: ResponseStatus = None, 109 | usage_stats: Dict = None, 110 | status_details: Dict = None) -> Response: 111 | """Update response status and usage statistics""" 112 | async with self.SessionLocal() as session: 113 | logger.info(f"📊 File: database.py, Function: update_response; Updating response {response_id}") 114 | result = await session.execute(select(Response).where(Response.id == response_id)) 115 | response = result.scalar_one_or_none() 116 | 117 | if response: 118 | if status: 119 | response.status = status 120 | if usage_stats: 121 | response.total_tokens = usage_stats.get('total_tokens') 122 | response.input_tokens = usage_stats.get('input_tokens') 123 | response.output_tokens = usage_stats.get('output_tokens') 124 | response.input_token_details = usage_stats.get('input_token_details') 125 | response.output_token_details = usage_stats.get('output_token_details') 126 | if status_details: 127 | response.status_details = status_details 128 | 129 | await session.commit() 130 | await session.refresh(response) 131 | 132 | return response 133 | 134 | async def get_conversation_items(self, conversation_id: str) -> List[ConversationItem]: 135 | """Get all items in a conversation""" 136 | async with self.SessionLocal() as session: 137 | logger.info( 138 | f"📜 File: database.py, Function: get_conversation_items; Fetching items for conversation {conversation_id}") 139 | result = await session.execute( 140 | select(ConversationItem) 141 | .where(ConversationItem.conversation_id == conversation_id) 142 | .order_by(ConversationItem.id) 143 | ) 144 | return result.scalars().all() 145 | 146 | async def create_rate_limit(self, session_id: str, name: str, limit: int, 147 | remaining: int, reset_seconds: float) -> RateLimit: 148 | """Create or update a rate limit for a session""" 149 | async with self.SessionLocal() as session: 150 | logger.info( 151 | f"⚡ File: database.py, Function: create_rate_limit; Creating/updating rate limit for session {session_id}") 152 | 153 | # Check if rate limit exists 154 | result = await session.execute( 155 | select(RateLimit) 156 | .where(RateLimit.session_id == session_id) 157 | .where(RateLimit.name == name) 158 | ) 159 | rate_limit = result.scalar_one_or_none() 160 | 161 | if rate_limit: 162 | # Update existing rate limit 163 | rate_limit.limit = limit 164 | rate_limit.remaining = remaining 165 | rate_limit.reset_seconds = reset_seconds 166 | else: 167 | # Create new rate limit 168 | rate_limit = RateLimit( 169 | id=f"rl_{uuid.uuid4().hex}", 170 | session_id=session_id, 171 | name=name, 172 | limit=limit, 173 | remaining=remaining, 174 | reset_seconds=reset_seconds 175 | ) 176 | session.add(rate_limit) 177 | 178 | await session.commit() 179 | await session.refresh(rate_limit) 180 | return rate_limit 181 | 182 | async def get_session_rate_limits(self, session_id: str) -> List[RateLimit]: 183 | """Get all rate limits for a session""" 184 | async with self.SessionLocal() as session: 185 | logger.info( 186 | f"📊 File: database.py, Function: get_session_rate_limits; Fetching rate limits for session {session_id}") 187 | result = await session.execute( 188 | select(RateLimit) 189 | .where(RateLimit.session_id == session_id) 190 | .order_by(RateLimit.name) 191 | ) 192 | return result.scalars().all() 193 | 194 | async def reset_rate_limits(self, session_id: str, name: str) -> None: 195 | """Reset rate limits for a session""" 196 | async with self.SessionLocal() as session: 197 | logger.info( 198 | f"📊 File: database.py, Function: reset_rate_limits; Resetting rate limits for session {session_id}") 199 | await session.execute(await session.update(RateLimit).where(RateLimit.session_id == session_id).where(RateLimit.name == name).values(remaining=0)) 200 | await session.commit() 201 | 202 | async def update_rate_limits(self, session_id: str, rate_limits: List[Dict]) -> List[RateLimit]: 203 | """Update rate limits from server event""" 204 | async with self.SessionLocal() as session: 205 | logger.info( 206 | f"📈 File: database.py, Function: update_rate_limits; Updating rate limits for session {session_id}") 207 | 208 | updated_limits = [] 209 | 210 | for limit_data in rate_limits: 211 | rate_limit = await self.create_rate_limit( 212 | session_id=session_id, 213 | name=limit_data['name'], 214 | limit=limit_data['limit'], 215 | remaining=limit_data['remaining'], 216 | reset_seconds=limit_data['reset_seconds'] 217 | ) 218 | updated_limits.append(rate_limit) 219 | 220 | return updated_limits 221 | 222 | 223 | db = Database() 224 | -------------------------------------------------------------------------------- /app/db/migrations/env.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from logging.config import fileConfig 3 | from sqlalchemy.engine import Connection 4 | from alembic import context 5 | 6 | from app.config import settings 7 | from app.db.models import Base 8 | from app.db.database import db 9 | from app.utils.logger import logger 10 | 11 | # this is the Alembic Config object 12 | config = context.config 13 | 14 | # Setup logging 15 | fileConfig(config.config_file_name) 16 | 17 | # Set target metadata 18 | target_metadata = Base.metadata 19 | 20 | 21 | # Override sqlalchemy.url with our async database URL 22 | def get_url(): 23 | return f'postgresql+asyncpg://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}' 24 | 25 | 26 | config.set_main_option("sqlalchemy.url", get_url()) 27 | 28 | 29 | async def run_migrations_online() -> None: 30 | """Run migrations in 'online' mode using the Database instance.""" 31 | logger.info("🔄 File: env.py, Function: run_migrations_online; Running migrations") 32 | 33 | # Use the engine from our Database instance 34 | await db.connect() 35 | 36 | async with db.engine.connect() as connection: 37 | await connection.run_sync(do_run_migrations) 38 | 39 | await db.disconnect() 40 | 41 | 42 | def do_run_migrations(connection: Connection) -> None: 43 | """Execute migrations""" 44 | logger.info("🔄 File: env.py, Function: do_run_migrations; Executing migrations") 45 | 46 | context.configure( 47 | connection=connection, 48 | target_metadata=target_metadata, 49 | compare_type=True, 50 | include_schemas=True, 51 | # Add transaction support 52 | transaction_per_migration=True, 53 | # Add retry logic for better reliability 54 | retry_on_errors=True 55 | ) 56 | 57 | try: 58 | with context.begin_transaction(): 59 | context.run_migrations() 60 | logger.info("✅ File: env.py, Function: do_run_migrations; Migrations completed successfully") 61 | except Exception as e: 62 | logger.error(f"❌ File: env.py, Function: do_run_migrations; Migration failed: {str(e)}") 63 | raise 64 | 65 | 66 | def run_migrations_offline() -> None: 67 | """Run migrations in 'offline' mode.""" 68 | logger.info("🔄 File: env.py, Function: run_migrations_offline; Running offline migrations") 69 | 70 | url = config.get_main_option("sqlalchemy.url") 71 | context.configure( 72 | url=url, 73 | target_metadata=target_metadata, 74 | literal_binds=True, 75 | dialect_opts={"paramstyle": "named"}, 76 | compare_type=True, 77 | include_schemas=True 78 | ) 79 | 80 | try: 81 | with context.begin_transaction(): 82 | context.run_migrations() 83 | logger.info("✅ File: env.py, Function: run_migrations_offline; Offline migrations completed successfully") 84 | except Exception as e: 85 | logger.error(f"❌ File: env.py, Function: run_migrations_offline; Offline migration failed: {str(e)}") 86 | raise 87 | 88 | 89 | print( 90 | "🔧 File: env.py; Alembic configuration loaded. Running in offline mode:", context.is_offline_mode() 91 | ) 92 | if context.is_offline_mode(): 93 | run_migrations_offline() 94 | else: 95 | asyncio.run(run_migrations_online()) 96 | -------------------------------------------------------------------------------- /app/db/migrations/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from typing import Sequence, Union 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | ${imports if imports else ""} 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = ${repr(up_revision)} 16 | down_revision: Union[str, None] = ${repr(down_revision)} 17 | branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} 18 | depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} 19 | 20 | 21 | def upgrade() -> None: 22 | ${upgrades if upgrades else "pass"} 23 | 24 | 25 | def downgrade() -> None: 26 | ${downgrades if downgrades else "pass"} -------------------------------------------------------------------------------- /app/db/migrations/utils.py: -------------------------------------------------------------------------------- 1 | from alembic import command 2 | from alembic.config import Config 3 | 4 | from app.utils.logger import logger 5 | 6 | 7 | class MigrationManager: 8 | def __init__(self, alembic_cfg_path: str = "alembic.ini"): 9 | self.alembic_cfg = Config(alembic_cfg_path) 10 | 11 | async def create_migration(self, message: str): 12 | """Create a new migration""" 13 | logger.info(f"📝 File: utils.py, Function: create_migration; Creating migration: {message}") 14 | try: 15 | await command.revision(self.alembic_cfg, autogenerate=True, message=message) 16 | logger.info("✅ File: utils.py, Function: create_migration; Migration created successfully") 17 | except Exception as e: 18 | logger.error(f"❌ File: utils.py, Function: create_migration; Failed to create migration: {str(e)}") 19 | raise 20 | 21 | async def upgrade(self, revision: str = "head"): 22 | """Upgrade to a later version""" 23 | logger.info(f"⬆️ File: utils.py, Function: upgrade; Upgrading to {revision}") 24 | try: 25 | await command.upgrade(self.alembic_cfg, revision) 26 | logger.info("✅ File: utils.py, Function: upgrade; Upgrade completed successfully") 27 | except Exception as e: 28 | logger.error(f"❌ File: utils.py, Function: upgrade; Upgrade failed: {str(e)}") 29 | raise 30 | 31 | async def downgrade(self, revision: str = "-1"): 32 | """Revert to a previous version""" 33 | logger.info(f"⬇️ File: utils.py, Function: downgrade; Downgrading to {revision}") 34 | try: 35 | await command.downgrade(self.alembic_cfg, revision) 36 | logger.info("✅ File: utils.py, Function: downgrade; Downgrade completed successfully") 37 | except Exception as e: 38 | logger.error(f"❌ File: utils.py, Function: downgrade; Downgrade failed: {str(e)}") 39 | raise 40 | 41 | async def show_current(self): 42 | """Show current revision""" 43 | logger.info("ℹ️ File: utils.py, Function: show_current; Showing current revision") 44 | try: 45 | await command.current(self.alembic_cfg) 46 | except Exception as e: 47 | logger.error(f"❌ File: utils.py, Function: show_current; Failed to show current revision: {str(e)}") 48 | raise 49 | 50 | 51 | migration_manager = MigrationManager() 52 | -------------------------------------------------------------------------------- /app/db/migrations/versions/20240101_add_rate_limits.py: -------------------------------------------------------------------------------- 1 | """add rate limits table 2 | 3 | Revision ID: add_rate_limits 4 | Revises: initial 5 | Create Date: 2024-01-01 00:00:00.000000 6 | 7 | """ 8 | import uuid 9 | from alembic import op 10 | import sqlalchemy as sa 11 | 12 | # revision identifiers 13 | revision = 'add_rate_limits' 14 | down_revision = 'initial' # Update this to your previous migration 15 | branch_labels = None 16 | depends_on = None 17 | 18 | def upgrade() -> None: 19 | # Create rate_limits table 20 | op.create_table( 21 | 'rate_limits', 22 | sa.Column('id', sa.String(), default=lambda: str(uuid.uuid4()), nullable=False), 23 | sa.Column('name', sa.String(), nullable=False), 24 | sa.Column('limit', sa.Integer(), nullable=False), 25 | sa.Column('remaining', sa.Integer(), nullable=False), 26 | sa.Column('reset_seconds', sa.Float(), nullable=False), 27 | sa.Column('session_id', sa.String(), nullable=False), 28 | sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 29 | sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 30 | sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ondelete='CASCADE'), 31 | sa.PrimaryKeyConstraint('id'), 32 | sa.UniqueConstraint('session_id', 'name', name='uq_rate_limit_session_name'), 33 | if_not_exists=True, 34 | ) 35 | 36 | # Add indexes for rate_limits 37 | op.create_index( 38 | 'idx_rate_limits_session_id', 39 | 'rate_limits', 40 | ['session_id'], 41 | unique=False, 42 | if_not_exists=True 43 | ) 44 | op.create_index( 45 | 'idx_rate_limits_name', 46 | 'rate_limits', 47 | ['name'], 48 | unique=False, 49 | if_not_exists=True 50 | ) 51 | op.create_index( 52 | 'idx_rate_limits_session_name', 53 | 'rate_limits', 54 | ['session_id', 'name'], 55 | unique=True, 56 | if_not_exists=True 57 | ) 58 | 59 | def downgrade() -> None: 60 | # Drop index first 61 | op.drop_index('idx_rate_limits_session_id') 62 | op.drop_index('idx_rate_limits_name') 63 | op.drop_index('idx_rate_limits_session_name') 64 | 65 | # Drop table 66 | op.drop_table('rate_limits') -------------------------------------------------------------------------------- /app/db/migrations/versions/20240101_initial.py: -------------------------------------------------------------------------------- 1 | """initial migration 2 | 3 | Revision ID: initial 4 | Revises: 5 | Create Date: 2024-01-01 00:00:00.000000 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | from sqlalchemy.dialects import postgresql 11 | import uuid 12 | from app.db.models import MessageRole, ResponseStatus 13 | 14 | # revision identifiers 15 | revision = 'initial' 16 | down_revision = None 17 | branch_labels = None 18 | depends_on = None 19 | 20 | 21 | def upgrade() -> None: 22 | # Create enum types 23 | # Check if enums exist before creating them 24 | with op.get_context().autocommit_block(): 25 | op.execute('DROP TYPE IF EXISTS message_role CASCADE') 26 | op.execute('DROP TYPE IF EXISTS response_status CASCADE') 27 | 28 | # Create sessions table with updated defaults and columns 29 | op.create_table( 30 | 'sessions', 31 | sa.Column('id', sa.String(), default=lambda: str(uuid.uuid4()), nullable=False), 32 | sa.Column('object_type', sa.String(), server_default='realtime.session', nullable=True), 33 | sa.Column('model', sa.String(), server_default='llama3.1', nullable=True), 34 | sa.Column('modalities', postgresql.JSON(astext_type=sa.Text()), 35 | server_default='["text", "audio"]', nullable=True), 36 | sa.Column('instructions', sa.String(), server_default='', nullable=True), 37 | sa.Column('voice', sa.String(), server_default='alloy', nullable=True), 38 | sa.Column('input_audio_format', sa.String(), server_default='pcm16', nullable=True), 39 | sa.Column('output_audio_format', sa.String(), server_default='pcm16', nullable=True), 40 | sa.Column('input_audio_transcription', postgresql.JSON(astext_type=sa.Text()), 41 | server_default='{"model": "whisper-1", "language": "en"}', nullable=True), 42 | sa.Column('turn_detection', postgresql.JSON(astext_type=sa.Text()), 43 | server_default='''{"type": "server_vad", "threshold": 0.5, 44 | "prefix_padding_ms": 300, "silence_duration_ms": 500}''', 45 | nullable=True), 46 | sa.Column('tools', postgresql.JSON(astext_type=sa.Text()), nullable=True), 47 | sa.Column('tool_choice', sa.String(), server_default='auto', nullable=True), 48 | sa.Column('temperature', sa.Float(), server_default='0.7', nullable=True), 49 | sa.Column('max_response_output_tokens', sa.String(), server_default='inf', nullable=True), 50 | sa.Column('created_at', sa.TIMESTAMP(timezone=True), 51 | server_default=sa.text('now()'), nullable=False), 52 | sa.Column('updated_at', sa.TIMESTAMP(timezone=True), 53 | server_default=sa.text('now()'), nullable=False), 54 | sa.PrimaryKeyConstraint('id'), 55 | if_not_exists=True, 56 | ) 57 | 58 | # Add index for sessions lookup 59 | op.create_index( 60 | 'idx_sessions_created_at', 61 | 'sessions', 62 | ['created_at'], 63 | if_not_exists=True 64 | ) 65 | 66 | 67 | # Create conversations table 68 | op.create_table( 69 | 'conversations', 70 | sa.Column('id', sa.String(), default=lambda: str(uuid.uuid4()), nullable=False), 71 | sa.Column('object_type', sa.String(), nullable=True), 72 | sa.Column('session_id', sa.String(), nullable=True), 73 | sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ), 74 | sa.PrimaryKeyConstraint('id'), 75 | sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 76 | sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 77 | if_not_exists=True, 78 | ) 79 | 80 | # Create conversation_items table 81 | op.create_table( 82 | 'conversation_items', 83 | sa.Column('id', sa.String(), default=lambda: str(uuid.uuid4()), nullable=False), 84 | sa.Column('conversation_id', sa.String(), nullable=True), 85 | sa.Column('role', MessageRole.as_pg_enum(), nullable=True), 86 | sa.Column('content', postgresql.JSON(astext_type=sa.Text()), nullable=True), 87 | sa.Column('audio_start_ms', sa.Integer(), nullable=True), 88 | sa.Column('audio_end_ms', sa.Integer(), nullable=True), 89 | sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ), 90 | sa.PrimaryKeyConstraint('id'), 91 | sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 92 | sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 93 | if_not_exists=True, 94 | ) 95 | 96 | # Create responses table 97 | op.create_table( 98 | 'responses', 99 | sa.Column('id', sa.String(), default=lambda: str(uuid.uuid4()), nullable=False), 100 | sa.Column('object_type', sa.String(), nullable=True), 101 | sa.Column('status', ResponseStatus.as_pg_enum(), nullable=True), 102 | sa.Column('conversation_id', sa.String(), nullable=True), 103 | sa.Column('total_tokens', sa.Integer(), nullable=True), 104 | sa.Column('input_tokens', sa.Integer(), nullable=True), 105 | sa.Column('output_tokens', sa.Integer(), nullable=True), 106 | sa.Column('input_token_details', postgresql.JSON(astext_type=sa.Text()), nullable=True), 107 | sa.Column('output_token_details', postgresql.JSON(astext_type=sa.Text()), nullable=True), 108 | sa.Column('status_details', postgresql.JSON(astext_type=sa.Text()), nullable=True), 109 | sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ), 110 | sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 111 | sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), 112 | sa.PrimaryKeyConstraint('id'), 113 | if_not_exists=True, 114 | ) 115 | 116 | 117 | def downgrade() -> None: 118 | # Drop tables 119 | op.drop_table('responses') 120 | op.drop_table('conversation_items') 121 | op.drop_table('conversations') 122 | op.drop_table('sessions') 123 | 124 | # Drop enum types 125 | op.execute('DROP TYPE response_status') 126 | op.execute('DROP TYPE rate_limit_type') 127 | op.execute('DROP TYPE message_role') 128 | -------------------------------------------------------------------------------- /app/db/models.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Any, Type, TypeVar 3 | import uuid 4 | 5 | from pydantic import BaseModel 6 | from sqlalchemy import TIMESTAMP, Column, Float, String, Integer, ForeignKey, JSON, Enum, Boolean, func 7 | from sqlalchemy.orm import relationship 8 | from sqlalchemy.ext.declarative import declarative_base 9 | from app.utils.logger import logger 10 | from sqlalchemy.dialects import postgresql 11 | 12 | Base = declarative_base() 13 | 14 | 15 | class MessageRole(enum.Enum): 16 | SYSTEM = "system" 17 | USER = "user" 18 | ASSISTANT = "assistant" 19 | FUNCTION = "function" 20 | 21 | @classmethod 22 | def as_pg_enum(cls): 23 | return postgresql.ENUM( 24 | 'system', 'user', 'assistant', 'function', 25 | name='message_role', 26 | create_type=True 27 | ) 28 | 29 | 30 | class ResponseStatus(enum.Enum): 31 | COMPLETED = "completed" 32 | CANCELLED = "cancelled" 33 | FAILED = "failed" 34 | INCOMPLETE = "incomplete" 35 | IN_PROGRESS = "in_progress" 36 | 37 | @classmethod 38 | def as_pg_enum(cls): 39 | return postgresql.ENUM( 40 | 'completed', 'cancelled', 'failed', 'incomplete', 'in_progress', 41 | name='response_status', 42 | create_type=True # Important: don't create type in SQLAlchemy 43 | ) 44 | 45 | 46 | class Session(Base): 47 | __tablename__ = 'sessions' 48 | 49 | id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) 50 | object_type = Column(String, default="realtime.session") 51 | model = Column(String, default="llama3.1") 52 | modalities = Column(JSON, default=["text", "audio"]) # Array of strings 53 | instructions = Column(String, default="") 54 | voice = Column(String, default='alloy') 55 | input_audio_format = Column(String, default='pcm16') 56 | output_audio_format = Column(String, default='pcm16') 57 | input_audio_transcription = Column(JSON, default={ 58 | 'model': 'whisper-1', 59 | 'language': 'en', 60 | }) 61 | turn_detection = Column(JSON, nullable=True, default={ 62 | "type": "server_vad", 63 | "threshold": 0.5, 64 | "prefix_padding_ms": 300, 65 | "silence_duration_ms": 500 66 | }) 67 | tools = Column(JSON) # Array of tools 68 | tool_choice = Column(String, default="auto") 69 | temperature = Column(Float, default=0.7) 70 | max_response_output_tokens = Column(String, default="inf") 71 | created_at = Column(TIMESTAMP(timezone=True), server_default=func.now()) 72 | updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now()) 73 | 74 | # Relationships 75 | conversations = relationship("Conversation", back_populates="session") 76 | rate_limits = relationship( 77 | "RateLimit", 78 | back_populates="session", 79 | cascade="all, delete-orphan", 80 | lazy="selectin" # For better performance when loading related rate limits 81 | ) 82 | 83 | def __repr__(self): 84 | logger.info("🔍 File: models.py, Line: 45, Function: __repr__, Value:", f"Session(id={self.id})") 85 | return f"Session(id={self.id})" 86 | 87 | 88 | class Conversation(Base): 89 | __tablename__ = 'conversations' 90 | 91 | id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) 92 | object_type = Column(String, default="realtime.conversation") 93 | session_id = Column(String, ForeignKey('sessions.id')) 94 | 95 | # Relationships 96 | session = relationship("Session", back_populates="conversations") 97 | items = relationship("ConversationItem", back_populates="conversation") 98 | created_at = Column(TIMESTAMP(timezone=True), server_default=func.now()) 99 | updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now()) 100 | 101 | def __repr__(self): 102 | logger.info("🔍 File: models.py, Line: 60, Function: __repr__, Value:", f"Conversation(id={self.id})") 103 | return f"Conversation(id={self.id})" 104 | 105 | 106 | class ConversationItem(Base): 107 | __tablename__ = 'conversation_items' 108 | 109 | id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) 110 | conversation_id = Column(String, ForeignKey('conversations.id')) 111 | role = Column(MessageRole.as_pg_enum()) 112 | content = Column(JSON) # Can store both text and audio content 113 | audio_start_ms = Column(Integer, nullable=True) 114 | audio_end_ms = Column(Integer, nullable=True) 115 | created_at = Column(TIMESTAMP(timezone=True), server_default=func.now()) 116 | updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now()) 117 | 118 | # Relationships 119 | conversation = relationship("Conversation", back_populates="items") 120 | 121 | def __repr__(self): 122 | logger.info("🔍 File: models.py, Line: 77, Function: __repr__, Value:", f"ConversationItem(id={self.id})") 123 | return f"ConversationItem(id={self.id})" 124 | 125 | 126 | class Response(Base): 127 | __tablename__ = 'responses' 128 | 129 | id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) 130 | object_type = Column(String, default="realtime.response") 131 | status = Column(ResponseStatus.as_pg_enum()) 132 | conversation_id = Column(String, ForeignKey('conversations.id')) 133 | 134 | # Usage statistics 135 | total_tokens = Column(Integer) 136 | input_tokens = Column(Integer) 137 | output_tokens = Column(Integer) 138 | input_token_details = Column(JSON) 139 | output_token_details = Column(JSON) 140 | 141 | # Status details 142 | status_details = Column(JSON) 143 | created_at = Column(TIMESTAMP(timezone=True), server_default=func.now()) 144 | updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now()) 145 | 146 | def __repr__(self): 147 | logger.info("🔍 File: models.py, Line: 98, Function: __repr__, Value:", f"Response(id={self.id})") 148 | return f"Response(id={self.id})" 149 | 150 | 151 | class RateLimit(Base): 152 | __tablename__ = 'rate_limits' 153 | 154 | id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) 155 | name = Column(String, nullable=False) # 'requests', 'tokens' 156 | limit = Column(Integer, nullable=False) 157 | remaining = Column(Integer, nullable=False) 158 | reset_seconds = Column(Float, nullable=False) 159 | session_id = Column(String, ForeignKey('sessions.id'), nullable=False) 160 | created_at = Column(TIMESTAMP(timezone=True), server_default=func.now()) 161 | updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now()) 162 | 163 | # Relationships 164 | session = relationship("Session", back_populates="rate_limits") 165 | 166 | def __repr__(self): 167 | logger.info("🔍 File: models.py, Line: 130, Function: __repr__, Value:", 168 | f"RateLimit(id={self.id}, name={self.name}, remaining={self.remaining}/{self.limit})") 169 | return f"RateLimit(id={self.id}, name={self.name}, remaining={self.remaining}/{self.limit})" 170 | 171 | 172 | model_t = TypeVar('T', bound=BaseModel) 173 | 174 | def to_pydantic(db_object: Any, pydantic_model: Type[model_t]) -> model_t: 175 | return pydantic_model(**db_object.__dict__) -------------------------------------------------------------------------------- /app/dependencies.py: -------------------------------------------------------------------------------- 1 | from app.services.llm import LLMService 2 | from app.services.audio import AudioService 3 | 4 | _llm_service = None 5 | _audio_service = None 6 | 7 | def get_llm_service() -> LLMService: 8 | """ 9 | Get or create LLM service instance 10 | 📝 File: dependencies.py, Line: 9, Function: get_llm_service 11 | """ 12 | global _llm_service 13 | if _llm_service is None: 14 | _llm_service = LLMService() 15 | return _llm_service 16 | 17 | def get_audio_service() -> AudioService: 18 | """ 19 | Get or create Audio service instance 20 | 📝 File: dependencies.py, Line: 19, Function: get_audio_service 21 | """ 22 | global _audio_service 23 | if _audio_service is None: 24 | _audio_service = AudioService() 25 | return _audio_service -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect 3 | 4 | from app.db.database import db 5 | from app.api.routes.v1 import endpoints, voice 6 | from fastapi.middleware.cors import CORSMiddleware 7 | from app.config import settings 8 | import uvicorn 9 | 10 | from app.websocket.connection import WebSocketConnection 11 | from app.utils.logger import logger 12 | 13 | origins = settings.CORS_ORIGINS 14 | 15 | @asynccontextmanager 16 | async def lifespan(app: FastAPI): 17 | """Lifespan context manager for application startup/shutdown events""" 18 | try: 19 | await db.connect() 20 | logger.info("📁 File: main.py, Line: 10, Function: lifespan; Status: Application started") 21 | yield 22 | finally: 23 | await db.disconnect() 24 | logger.info("📁 File: main.py, Line: 14, Function: lifespan; Status: Application shutdown") 25 | 26 | 27 | app = FastAPI( 28 | title=settings.APP_NAME, 29 | description="A FastAPI wrapper for the Ollama API that matches the official API structure", 30 | version="1.0.0", 31 | docs_url="/api/docs", 32 | redoc_url="/api/redoc", 33 | debug=settings.LOG_LEVEL == "debug", 34 | lifespan=lifespan 35 | ) 36 | 37 | app.add_middleware( 38 | CORSMiddleware, 39 | allow_origins=origins, 40 | allow_credentials=True, 41 | allow_methods=["*"], 42 | allow_headers=["*"], 43 | ) 44 | app.include_router(endpoints.router, prefix="/api") 45 | app.include_router(voice.router, prefix="/api/voice") 46 | 47 | @app.websocket_route("/realtime") 48 | async def websocket_endpoint(websocket: WebSocket): 49 | """WebSocket endpoint for real-time chat""" 50 | 51 | # Get the requested protocols from the client 52 | requested_protocols = websocket.headers.get("Sec-WebSocket-Protocol", "").split(", ") 53 | 54 | logger.info(f"🔌 Requested protocols: {requested_protocols}") 55 | 56 | connection = WebSocketConnection(websocket, db, subprotocol=requested_protocols) 57 | try: 58 | logger.info("🔌 WebSocket client connected") 59 | connection.set_model(websocket.query_params["model"]) 60 | await connection.handle_connection() 61 | except WebSocketDisconnect as e: 62 | logger.info(f"🔌 {e.code} WebSocket client disconnected: {e.reason}") 63 | except Exception as e: 64 | logger.error(f"❌ WebSocket error: {str(e)}") 65 | finally: 66 | await connection.cleanup() 67 | 68 | 69 | @app.get("/health") 70 | async def health_check(): 71 | """Health check endpoint""" 72 | return {"status": "healthy"} 73 | 74 | 75 | if __name__ == "__main__": 76 | uvicorn.run(app, host="0.0.0.0", port=9000) 77 | -------------------------------------------------------------------------------- /app/schemas/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class ObjectTypes(str, Enum): 4 | SESSION = "realtime.session" 5 | CONVERSATION = "realtime.conversation" 6 | RESPONSE = "realtime.response" 7 | 8 | class DefaultValues: 9 | MODEL = "llama3.1" 10 | VOICE = "alloy" 11 | AUDIO_FORMAT = "pcm16" 12 | TOOL_CHOICE = "auto" 13 | TEMPERATURE = 0.7 14 | MAX_TOKENS = "inf" 15 | 16 | MODALITIES = ["text", "audio"] 17 | AUDIO_TRANSCRIPTION = { 18 | 'model': 'whisper-1', 19 | 'language': 'en', 20 | } 21 | TURN_DETECTION = { 22 | "type": "server_vad", 23 | "threshold": 0.5, 24 | "prefix_padding_ms": 300, 25 | "silence_duration_ms": 500 26 | } 27 | -------------------------------------------------------------------------------- /app/schemas/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from enum import Enum 5 | from typing import Any, Dict, Generic, List, Optional, TypeVar, Union, Literal 6 | from uuid import uuid4 7 | 8 | from pydantic import BaseModel, Field 9 | from pydantic.main import IncEx 10 | 11 | from app.schemas.constants import DefaultValues, ObjectTypes 12 | 13 | 14 | class BaseModelM(BaseModel): 15 | def model_dump(self, **kwargs): 16 | data = super().model_dump(**kwargs) 17 | data.pop('created_at', None) 18 | data.pop('updated_at', None) 19 | return data 20 | 21 | class Config: 22 | fields = {'created_at': {'exclude': True}, 'updated_at': {'exclude': True}} 23 | 24 | 25 | model_t = TypeVar('T', bound=BaseModelM) 26 | 27 | 28 | # Enums 29 | class MessageRole(str, Enum): 30 | SYSTEM = "system" 31 | USER = "user" 32 | ASSISTANT = "assistant" 33 | FUNCTION = "function" 34 | 35 | 36 | class ResponseStatus(str, Enum): 37 | COMPLETED = "completed" 38 | CANCELLED = "cancelled" 39 | FAILED = "failed" 40 | INCOMPLETE = "incomplete" 41 | IN_PROGRESS = "in_progress" 42 | 43 | 44 | # Base Models 45 | class TimestampedModel(BaseModelM): 46 | created_at: datetime 47 | updated_at: datetime 48 | 49 | class Config: 50 | from_attributes = True 51 | 52 | 53 | class IdentifiedModel(BaseModelM): 54 | id: str = Field(default_factory=lambda: str(uuid4())) 55 | 56 | class Config: 57 | from_attributes = True 58 | 59 | 60 | class BaseDBModel(IdentifiedModel, TimestampedModel): 61 | pass 62 | 63 | 64 | # Rate Limit Models 65 | class RateLimitBase(BaseModelM): 66 | name: str 67 | limit: int 68 | remaining: int 69 | reset_seconds: float 70 | 71 | 72 | class RateLimitCreate(RateLimitBase): 73 | session_id: str 74 | 75 | 76 | class RateLimit(RateLimitBase, BaseDBModel): 77 | session_id: str 78 | 79 | 80 | # Session Models 81 | class SessionBase(BaseModelM): 82 | object_type: str = ObjectTypes.SESSION 83 | model: str = DefaultValues.MODEL 84 | modalities: List[str] = Field(default_factory=lambda: DefaultValues.MODALITIES.copy()) 85 | instructions: str = "" 86 | voice: str = DefaultValues.VOICE 87 | input_audio_format: str = DefaultValues.AUDIO_FORMAT 88 | output_audio_format: str = DefaultValues.AUDIO_FORMAT 89 | input_audio_transcription: Dict[str, str] = Field(default_factory=dict) 90 | turn_detection: Dict[str, Union[str, float, int]] = Field( 91 | default_factory=lambda: DefaultValues.TURN_DETECTION.copy() 92 | ) 93 | tools: Optional[List[Dict[str, Any]]] = None 94 | tool_choice: str = DefaultValues.TOOL_CHOICE 95 | temperature: float = DefaultValues.TEMPERATURE 96 | max_response_output_tokens: str = DefaultValues.MAX_TOKENS 97 | 98 | 99 | class SessionCreate(SessionBase): 100 | pass 101 | 102 | 103 | class SessionSchema(SessionBase, BaseDBModel): 104 | rate_limits: List[RateLimit] = Field(default_factory=list) 105 | 106 | 107 | # Conversation Models 108 | class ConversationItemBase(BaseModelM): 109 | role: MessageRole 110 | content: Dict[str, Any] 111 | audio_start_ms: Optional[int] = None 112 | audio_end_ms: Optional[int] = None 113 | 114 | 115 | class ConversationItemCreate(ConversationItemBase): 116 | conversation_id: str 117 | 118 | 119 | class ConversationItem(ConversationItemBase, BaseDBModel): 120 | conversation_id: str 121 | 122 | 123 | class ConversationBase(BaseModelM): 124 | object_type: str = ObjectTypes.CONVERSATION 125 | session_id: str 126 | 127 | 128 | class ConversationCreate(ConversationBase): 129 | pass 130 | 131 | 132 | class Conversation(ConversationBase, BaseDBModel): 133 | items: List[ConversationItem] = Field(default_factory=list) 134 | 135 | 136 | # Response Models 137 | class ResponseBase(BaseModelM): 138 | object_type: str = ObjectTypes.RESPONSE 139 | status: ResponseStatus 140 | conversation_id: str 141 | total_tokens: Optional[int] = None 142 | input_tokens: Optional[int] = None 143 | output_tokens: Optional[int] = None 144 | input_token_details: Optional[Dict[str, Any]] = None 145 | output_token_details: Optional[Dict[str, Any]] = None 146 | status_details: Optional[Dict[str, Any]] = None 147 | 148 | 149 | class ResponseCreate(ResponseBase): 150 | pass 151 | 152 | 153 | class Response(ResponseBase, BaseDBModel): 154 | pass 155 | 156 | 157 | # Generic API Response Models 158 | T = TypeVar('T') 159 | 160 | 161 | class BaseAPIResponse(BaseModelM): 162 | message: str 163 | 164 | 165 | class DataResponse(BaseAPIResponse, Generic[T]): 166 | data: T 167 | 168 | 169 | # Specific API Response Models 170 | class SessionResponse(DataResponse[SessionSchema]): 171 | message: str = "Session created successfully" 172 | 173 | 174 | class ConversationResponse(DataResponse[Conversation]): 175 | message: str = "Conversation created successfully" 176 | 177 | 178 | class ResponseResponse(DataResponse[Response]): 179 | message: str = "Response created successfully" 180 | 181 | 182 | class RateLimitResponse(DataResponse[RateLimit]): 183 | message: str = "Rate limit updated successfully" 184 | 185 | 186 | # Error Models 187 | class ErrorDetail(BaseModelM): 188 | loc: Optional[List[str]] = None 189 | msg: str 190 | type: str 191 | 192 | 193 | class ErrorResponse(BaseModelM): 194 | error: str 195 | details: Optional[Union[List[ErrorDetail], Dict[str, Any]]] = None 196 | status_code: int = 400 197 | 198 | 199 | # WebSocket Models 200 | class WSMessage(BaseModelM): 201 | type: str 202 | data: Dict[str, Any] 203 | 204 | 205 | class WSResponse(BaseModelM): 206 | type: str 207 | status: str 208 | data: Optional[Dict[str, Any]] = None 209 | error: Optional[str] = None 210 | -------------------------------------------------------------------------------- /app/schemas/requests.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import List, Optional, Dict, Any 3 | 4 | 5 | class GenerateRequest(BaseModel): 6 | model: str 7 | prompt: str 8 | system: Optional[str] = None 9 | template: Optional[str] = None 10 | context: Optional[List[int]] = None 11 | options: Optional[Dict[str, Any]] = None 12 | format: Optional[str] = None 13 | tools: Optional[List[str]] = None 14 | stream: bool = False 15 | 16 | 17 | class GenerateResponse(BaseModel): 18 | model: str 19 | created_at: str = Field(..., alias="created_at") 20 | response: str 21 | done: bool 22 | context: Optional[List[int]] = None 23 | total_duration: Optional[int] = Field(None, alias="total_duration") 24 | load_duration: Optional[int] = Field(None, alias="load_duration") 25 | prompt_eval_duration: Optional[int] = Field(None, alias="prompt_eval_duration") 26 | eval_duration: Optional[int] = Field(None, alias="eval_duration") 27 | prompt_eval_count: Optional[int] = Field(None, alias="prompt_eval_count") 28 | eval_count: Optional[int] = Field(None, alias="eval_count") 29 | 30 | 31 | class ChatMessage(BaseModel): 32 | role: str 33 | content: str 34 | images: Optional[List[str]] = None 35 | 36 | 37 | class ChatRequest(BaseModel): 38 | model: str 39 | messages: List[ChatMessage] 40 | system: Optional[str] = None 41 | format: Optional[str] = None 42 | options: Optional[Dict[str, Any]] = None 43 | template: Optional[str] = None 44 | stream: bool = False 45 | tools: Optional[List[str]] = None 46 | 47 | 48 | class ChatResponse(BaseModel): 49 | model: str 50 | created_at: str = Field(..., alias="created_at") 51 | message: ChatMessage 52 | done: bool 53 | total_duration: Optional[int] = Field(None, alias="total_duration") 54 | load_duration: Optional[int] = Field(None, alias="load_duration") 55 | prompt_eval_duration: Optional[int] = Field(None, alias="prompt_eval_duration") 56 | eval_duration: Optional[int] = Field(None, alias="eval_duration") 57 | prompt_eval_count: Optional[int] = Field(None, alias="prompt_eval_count") 58 | eval_count: Optional[int] = Field(None, alias="eval_count") 59 | encode: Optional[str] = Field(None, alias="encode") 60 | 61 | 62 | class ModelInfo(BaseModel): 63 | name: str 64 | modified_at: Optional[str] = Field(..., alias="modified_at") 65 | size: int 66 | model: str 67 | digest: str 68 | details: Dict[str, Any] 69 | provider: str 70 | 71 | 72 | class PullRequest(BaseModel): 73 | name: str 74 | insecure: Optional[bool] = None 75 | stream: Optional[bool] = None 76 | provider: Optional[str] = 'ollama' 77 | 78 | 79 | class PullResponse(BaseModel): 80 | status: str 81 | digest: Optional[str] = None 82 | total: Optional[int] = None 83 | completed: Optional[int] = None 84 | 85 | 86 | class CreateRequest(BaseModel): 87 | name: str 88 | modelfile: str 89 | stream: Optional[bool] = None 90 | 91 | 92 | class CreateResponse(BaseModel): 93 | status: str 94 | 95 | 96 | class DeleteRequest(BaseModel): 97 | name: str 98 | 99 | 100 | class SpeechRequest(BaseModel): 101 | model: Optional[str] = Field('tts-1', alias="model") 102 | input: str = Field(..., alias="input") 103 | voice: Optional[str] = Field('airley', alias="voice") 104 | response_format: Optional[str] = Field('mp3', alias="response_format") 105 | -------------------------------------------------------------------------------- /app/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/services/__init__.py -------------------------------------------------------------------------------- /app/services/audio.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import hashlib 3 | import json 4 | import os 5 | import tempfile 6 | from datetime import datetime 7 | from typing import Optional, Tuple, Generator 8 | 9 | import openai 10 | from faster_whisper import WhisperModel 11 | 12 | from app.config import settings 13 | from app.core.voice import get_stt_model 14 | from app.utils.errors import AudioProcessingError 15 | from app.utils.logger import logger 16 | 17 | 18 | class AudioService: 19 | """Service for handling audio processing, STT and TTS operations""" 20 | 21 | def __init__(self): 22 | """ 23 | Initialize audio service with models and cache 24 | 📝 File: audio.py, Line: 22, Function: __init__ 25 | """ 26 | self.stt_model: Optional[WhisperModel] = None 27 | self.temp_dir = os.path.join(tempfile.gettempdir(), "audio_processing") 28 | self.openai_client = openai.OpenAI( 29 | api_key=settings.TTS_OPENAI_API_KEY, 30 | base_url=settings.TTS_OPENAI_API_BASE_URL, 31 | ) if settings.TTS_ENGINE == "openai" else None 32 | os.makedirs(self.temp_dir, exist_ok=True) 33 | 34 | logger.info("🎙️ audio.py: AudioService initialized") 35 | 36 | async def initialize_stt(self) -> None: 37 | """ 38 | Initialize STT model with proper configuration 39 | 📝 File: audio.py, Line: 37, Function: initialize_stt 40 | """ 41 | 42 | if not self.stt_model: 43 | try: 44 | self.stt_model = get_stt_model() 45 | logger.info("✅ audio.py: STT model initialized") 46 | except Exception as e: 47 | logger.error(f"❌ audio.py: STT model initialization failed: {str(e)}") 48 | raise AudioProcessingError("Failed to initialize STT model") 49 | 50 | async def transcribe_audio( 51 | self, 52 | audio_data: bytes, 53 | event_id: str, 54 | language: str = 'en', 55 | task: str = 'transcribe', 56 | beam_size: int = 5, 57 | vad_filter: bool = True 58 | ) -> Generator[str, None, None]: # type: ignore 59 | """ 60 | Transcribe audio with streaming support 61 | 📝 File: audio.py, Line: 60, Function: transcribe_audio 62 | """ 63 | try: 64 | await self.initialize_stt() 65 | temp_path = self._save_audio_buffer(audio_data, event_id) 66 | 67 | segments, info = await asyncio.to_thread( 68 | self.stt_model.transcribe, 69 | temp_path, 70 | language=language, 71 | task=task, 72 | beam_size=beam_size, 73 | vad_filter=vad_filter, 74 | initial_prompt=None, 75 | ) 76 | 77 | logger.info( 78 | f"ℹ️ audio.py: Detected language: {info.language} " 79 | f"with probability {info.language_probability:.2f}" 80 | ) 81 | 82 | for segment in segments: 83 | yield segment.text 84 | logger.info(f"🎯 audio.py: Transcribed segment: {segment.text[:30]}...") 85 | 86 | except Exception as e: 87 | logger.error(f"❌ audio.py: Transcription failed: {str(e)}") 88 | yield json.dumps({"error": str(e)}) 89 | finally: 90 | if os.path.exists(temp_path): 91 | os.remove(temp_path) 92 | 93 | async def generate_speech( 94 | self, 95 | text: str, 96 | voice: str = 'alloy', 97 | model: str = 'tts-1', 98 | response_format: str = 'mp3' 99 | ) -> Tuple[str, str]: 100 | """ 101 | Generate speech using OpenAI's TTS API with caching 102 | 📝 File: audio.py, Line: 98, Function: generate_speech 103 | """ 104 | try: 105 | # Create cache key 106 | body = { 107 | "model": model, 108 | "input": text, 109 | "voice": voice, 110 | "response_format": response_format 111 | } 112 | cache_key = hashlib.sha256(json.dumps(body).encode()).hexdigest() 113 | 114 | # Setup cache paths 115 | settings.setup_cache_dir() 116 | file_path = os.path.join(settings.SPEECH_CACHE_DIR, f"{cache_key}.mp3") 117 | file_body_path = os.path.join(settings.SPEECH_CACHE_DIR, f"{cache_key}.json") 118 | 119 | # Return cached file if exists 120 | if os.path.isfile(file_path): 121 | return file_path, cache_key 122 | 123 | # Generate new audio 124 | if settings.TTS_ENGINE == "openai": 125 | with self.openai_client.audio.speech.with_streaming_response.create( 126 | model=settings.TTS_MODEL, 127 | voice=voice, 128 | input=text 129 | ) as response: 130 | with open(file_path, "wb") as f: 131 | for chunk in response.iter_bytes(): 132 | f.write(chunk) 133 | 134 | # Save request body for cache 135 | with open(file_body_path, "w") as f: 136 | json.dump(body, f) 137 | 138 | return file_path, cache_key 139 | else: 140 | raise AudioProcessingError("Unsupported TTS engine") 141 | 142 | except Exception as e: 143 | logger.error(f"❌ audio.py: Speech generation failed: {str(e)}") 144 | raise AudioProcessingError(f"Failed to generate speech: {str(e)}") 145 | 146 | def _save_audio_buffer(self, audio_bytes: bytes, event_id: str) -> str: 147 | """ 148 | Save audio buffer to temporary file 149 | 📝 File: audio.py, Line: 146, Function: _save_audio_buffer 150 | """ 151 | try: 152 | temp_path = os.path.join( 153 | self.temp_dir, 154 | f"{event_id}_{datetime.now().timestamp()}.wav" 155 | ) 156 | 157 | with open(temp_path, "wb") as f: 158 | f.write(audio_bytes) 159 | 160 | return temp_path 161 | 162 | except Exception as e: 163 | logger.error(f"❌ audio.py: Failed to save audio buffer: {str(e)}") 164 | raise AudioProcessingError("Failed to save audio data") 165 | 166 | async def commit_audio_buffer(self, audio_data: bytes, event_id: str) -> None: 167 | pass 168 | 169 | async def cleanup(self) -> None: 170 | """ 171 | Cleanup resources and temporary files 172 | 📝 File: audio.py, Line: 167, Function: cleanup 173 | """ 174 | try: 175 | if os.path.exists(self.temp_dir): 176 | for file in os.listdir(self.temp_dir): 177 | os.remove(os.path.join(self.temp_dir, file)) 178 | os.rmdir(self.temp_dir) 179 | 180 | self.stt_model = None 181 | logger.info("🧹 audio.py: Audio service cleanup completed") 182 | 183 | except Exception as e: 184 | logger.error(f"❌ audio.py: Cleanup failed: {str(e)}") 185 | -------------------------------------------------------------------------------- /app/services/chat_state.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, Any 3 | from redis import Redis 4 | from app.db.database import Database 5 | from app.db.models import to_pydantic 6 | from app.schemas.models import SessionSchema 7 | from app.utils.logger import logger 8 | 9 | 10 | class ChatStateManager: 11 | def __init__(self, redis: Redis, db: Database): 12 | self.redis = redis 13 | self.db = db 14 | 15 | async def get_chat_state(self, session_id: str) -> Dict[str, Any]: 16 | """ 17 | Get chat state from Redis and DB 18 | 📝 File: chat_state.py, Line: 15, Function: get_chat_state 19 | """ 20 | try: 21 | # Try Redis first 22 | state = json.loads(self.redis.get(f"chat_state:{session_id}") or "{}") 23 | 24 | logger.info(f"📨 chat_state.py: Chat state: {state}") 25 | 26 | if not state or state == {}: 27 | # Fallback to DB 28 | session = await self.db.get_session(session_id) 29 | logger.info(f"📨 chat_state.py: Session: {session}") 30 | if session: 31 | state = to_pydantic(session, SessionSchema).model_dump() 32 | logger.info(f"📨 chat_state.py: Chat state from DB: {state}") 33 | # Cache in Redis 34 | self.redis.set( 35 | f"chat_state:{session_id}", 36 | json.dumps(state) 37 | ) 38 | self.redis.expire( 39 | f"chat_state:{session_id}", 40 | 3600 41 | ) 42 | 43 | return state 44 | 45 | except Exception as e: 46 | logger.error(f"❌ chat_state.py: Failed to get chat state: {str(e)}") 47 | raise 48 | 49 | async def update_chat_state( 50 | self, 51 | session_id: str, 52 | updates: Dict[str, Any] 53 | ) -> None: 54 | """ 55 | Update chat state in Redis and DB 56 | 📝 File: chat_state.py, Line: 54, Function: update_chat_state 57 | """ 58 | try: 59 | # Update Redis 60 | self.redis.hmset( 61 | f"chat_state:{session_id}", 62 | json.dumps(updates) 63 | ) 64 | 65 | # Update DB 66 | await self.db.update_session( 67 | session_id=session_id, 68 | updates=updates 69 | ) 70 | 71 | logger.info(f"✅ chat_state.py: Updated chat state for {session_id}") 72 | 73 | except Exception as e: 74 | logger.error(f"❌ chat_state.py: Failed to update chat state: {str(e)}") 75 | raise 76 | 77 | async def persist_state(self, session_id: str) -> None: 78 | """ 79 | Persist chat state to database 80 | 📝 File: chat_state.py, Line: 72, Function: persist_state 81 | """ 82 | try: 83 | # Get chat state from Redis 84 | state = await self.redis.hgetall(f"chat_state:{session_id}") 85 | 86 | # Update database 87 | await self.db.update_session( 88 | session_id=session_id, 89 | updates=state 90 | ) 91 | 92 | logger.info(f"✅ chat_state.py: Persisted chat state for {session_id}") 93 | 94 | except Exception as e: 95 | logger.error(f"❌ chat_state.py: Failed to persist chat state: {str(e)}") 96 | raise 97 | -------------------------------------------------------------------------------- /app/services/llm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional, AsyncGenerator, Union 2 | import json 3 | 4 | import ollama 5 | import openai 6 | from enum import Enum 7 | from fastapi import HTTPException 8 | from app.config import settings 9 | from app.schemas.requests import (ChatRequest, ChatResponse) 10 | from app.utils.logger import logger 11 | from typing import Generator 12 | 13 | 14 | class ModelProvider(Enum): 15 | """ 16 | Enum for model providers 17 | 📝 File: llm.py, Line: 15, Function: ModelProvider 18 | """ 19 | OLLAMA = "ollama" 20 | OPENAI = "openai" 21 | 22 | 23 | class LLMService: 24 | """ 25 | Unified service for handling both Ollama and OpenAI models using OpenAI SDK 26 | """ 27 | 28 | def __init__(self): 29 | """ 30 | Initialize LLM service with OpenAI clients for both providers 31 | 📝 File: llm.py, Line: 26, Function: __init__ 32 | """ 33 | # Initialize Ollama client using OpenAI SDK 34 | self.ollama_client = ollama.AsyncClient( 35 | host=settings.OLLAMA_API_BASE_URL 36 | ) 37 | self.model = settings.OLLAMA_MODEL 38 | 39 | logger.info("🤖 llm.py: LLM service initialized with unified OpenAI SDK") 40 | 41 | async def generate_response( 42 | self, 43 | messages: List[Dict[str, Any]], 44 | temperature: float = 0.8, 45 | tools: Optional[List[Dict]] = None, 46 | stream: bool = True, 47 | provider: ModelProvider = ModelProvider.OPENAI, 48 | model: Optional[str] = None 49 | ) -> Union[AsyncGenerator[Any, None], Any]: 50 | """ 51 | Generate response using specified provider 52 | 📝 File: llm.py, Line: 54, Function: generate_response 53 | """ 54 | try: 55 | client = self.ollama_client if provider == ModelProvider.OLLAMA else None 56 | if client is None: 57 | raise ValueError("Invalid provider") 58 | model = model if model is not None else self.model if provider == ModelProvider.OLLAMA else None 59 | 60 | if model is None: 61 | raise ValueError("Model not specified") 62 | 63 | response = await client.generate( 64 | model=model, 65 | prompt=messages, 66 | tools=tools, 67 | stream=stream 68 | ) 69 | 70 | logger.info(f"🤖 llm.py: Generated response with {len(messages)} messages using {provider.value}") 71 | return response 72 | 73 | except Exception as e: 74 | logger.error(f"❌ llm.py: LLM response generation failed: {str(e)}") 75 | raise HTTPException(status_code=500, detail=str(e)) 76 | 77 | async def chat_stream( 78 | self, 79 | request: ChatRequest, 80 | provider: ModelProvider = ModelProvider.OPENAI 81 | ) -> Generator[str, None, None]: 82 | """ 83 | Stream chat responses using OpenAI SDK 84 | 📝 File: llm.py, Line: 82, Function: chat_stream 85 | """ 86 | try: 87 | client = self.ollama_client if provider == ModelProvider.OLLAMA else None 88 | 89 | if client is None: 90 | raise ValueError("Invalid provider") 91 | 92 | streaming_allowed: bool = request.stream 93 | 94 | request_dict = { 95 | "model": request.model, 96 | "tools": request.tools, 97 | "messages": [m.model_dump() for m in request.messages] 98 | } 99 | if streaming_allowed: 100 | stream = await client.chat( 101 | **request_dict, 102 | stream=True 103 | ) 104 | 105 | async for chunk in stream: 106 | if chunk['message']['content']: 107 | yield chunk['message']['content'] 108 | else: 109 | response = await client.chat( 110 | **request_dict, 111 | stream=False 112 | ) 113 | yield response['message']['content'] 114 | 115 | except Exception as e: 116 | logger.error(f"❌ llm.py: Chat stream failed: {str(e)}") 117 | raise HTTPException(status_code=500, detail=str(e)) 118 | 119 | async def process_function_call( 120 | self, 121 | function_call: Dict[str, Any], 122 | available_functions: Dict[str, callable] 123 | ) -> Any: 124 | """ 125 | Process function calls from LLM 126 | 📝 File: llm.py, Line: 116, Function: process_function_call 127 | """ 128 | try: 129 | function_name = function_call["name"] 130 | function_args = json.loads(function_call["arguments"]) 131 | 132 | if function_name not in available_functions: 133 | raise ValueError(f"Unknown function: {function_name}") 134 | 135 | function_to_call = available_functions[function_name] 136 | function_response = await function_to_call(**function_args) 137 | 138 | logger.info(f"🔧 llm.py: Executed function {function_name}") 139 | return function_response 140 | 141 | except Exception as e: 142 | logger.error(f"❌ llm.py: Function call processing failed: {str(e)}") 143 | raise HTTPException(status_code=500, detail=str(e)) 144 | 145 | def set_default_model(self, model: str) -> None: 146 | """ 147 | Set default model for Ollama 148 | 📝 File: llm.py, Line: 137, Function: set_default_model 149 | """ 150 | self.model = model 151 | logger.info(f"🤖 llm.py: Default model set to {model}") -------------------------------------------------------------------------------- /app/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/utils/__init__.py -------------------------------------------------------------------------------- /app/utils/errors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any 2 | from fastapi import HTTPException 3 | 4 | 5 | class WebSocketError(Exception): 6 | """Base WebSocket error""" 7 | 8 | def __init__( 9 | self, 10 | message: str, 11 | code: int = 1000, 12 | data: Optional[Dict[str, Any]] = None 13 | ): 14 | self.message = message 15 | self.code = code 16 | self.data = data or {} 17 | super().__init__(self.message) 18 | 19 | 20 | class SessionError(WebSocketError): 21 | """Session-related errors""" 22 | pass 23 | 24 | 25 | class AudioError(WebSocketError): 26 | """Audio processing errors""" 27 | pass 28 | 29 | 30 | class LLMError(WebSocketError): 31 | """LLM-related errors""" 32 | pass 33 | 34 | 35 | class RateLimitError(WebSocketError): 36 | """Rate limiting errors""" 37 | pass 38 | 39 | 40 | class AudioProcessingError(WebSocketError): 41 | """Audio processing errors""" 42 | pass 43 | 44 | 45 | def handle_websocket_error(error: Exception) -> Dict[str, Any]: 46 | """Convert exceptions to WebSocket error responses""" 47 | if isinstance(error, WebSocketError): 48 | return { 49 | "type": "error", 50 | "code": error.code, 51 | "message": str(error), 52 | "data": error.data 53 | } 54 | elif isinstance(error, HTTPException): 55 | return { 56 | "type": "error", 57 | "code": error.status_code, 58 | "message": error.detail 59 | } 60 | else: 61 | return { 62 | "type": "error", 63 | "code": 500, 64 | "message": "Internal server error" 65 | } 66 | -------------------------------------------------------------------------------- /app/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from datetime import datetime 4 | 5 | class CustomFormatter(logging.Formatter): 6 | """Custom formatter with colors and emojis""" 7 | 8 | COLORS = { 9 | 'DEBUG': '\033[94m', # Blue 10 | 'INFO': '\033[92m', # Green 11 | 'WARNING': '\033[93m', # Yellow 12 | 'ERROR': '\033[91m', # Red 13 | 'CRITICAL': '\033[95m', # Magenta 14 | 'RESET': '\033[0m' # Reset 15 | } 16 | 17 | EMOJIS = { 18 | 'DEBUG': '🔍', 19 | 'INFO': '✨', 20 | 'WARNING': '⚠️', 21 | 'ERROR': '❌', 22 | 'CRITICAL': '💥' 23 | } 24 | 25 | def format(self, record): 26 | if not record.exc_info: 27 | level = record.levelname 28 | msg = record.getMessage() 29 | 30 | # Add timestamp 31 | timestamp = datetime.fromtimestamp(record.created).strftime('%Y-%m-%d %H:%M:%S') 32 | 33 | # Add file and line info 34 | file_info = f"{record.filename}:{record.lineno}" 35 | 36 | # Format with color and emoji 37 | color = self.COLORS.get(level, '') 38 | emoji = self.EMOJIS.get(level, '') 39 | reset = self.COLORS['RESET'] 40 | 41 | return f"{color}{timestamp} {emoji} [{level}] {file_info} - {msg}{reset}" 42 | return super().format(record) 43 | 44 | def setup_logger(): 45 | """Setup application logger""" 46 | logger = logging.getLogger() 47 | logger.setLevel(logging.INFO) 48 | 49 | # Console handler 50 | console_handler = logging.StreamHandler(sys.stdout) 51 | console_handler.setFormatter(CustomFormatter()) 52 | logger.addHandler(console_handler) 53 | 54 | # File handler 55 | file_handler = logging.FileHandler('app.log') 56 | file_handler.setFormatter(logging.Formatter( 57 | '%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s' 58 | )) 59 | logger.addHandler(file_handler) 60 | 61 | return logger 62 | 63 | logger = setup_logger() 64 | -------------------------------------------------------------------------------- /app/websocket/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/websocket/__init__.py -------------------------------------------------------------------------------- /app/websocket/base_handler.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Dict, Any 3 | from fastapi import WebSocket 4 | from redis import Redis 5 | 6 | from app.db.database import Database 7 | from app.services.llm import LLMService 8 | from app.utils.logger import logger 9 | 10 | 11 | class BaseHandler: 12 | """Base class for all handlers""" 13 | 14 | def __init__(self, websocket: WebSocket, redis: Redis, llm: LLMService, db: Database): 15 | self.websocket = websocket 16 | self.redis = redis 17 | self.llm = llm 18 | self.db = db 19 | 20 | async def send_event(self, event_type: str, data: Dict[str, Any]) -> None: 21 | """Send WebSocket event with logging""" 22 | try: 23 | await self.websocket.send_json({ 24 | "type": event_type, 25 | "data": data, 26 | "timestamp": datetime.now().isoformat() 27 | }) 28 | logger.info(f"📤 base_handler.py: Sent event {event_type}") 29 | except Exception as e: 30 | logger.error(f"❌ base_handler.py: Failed to send event {event_type}: {str(e)}") 31 | raise 32 | 33 | async def cleanup(self) -> None: 34 | """Cleanup resources""" 35 | pass 36 | -------------------------------------------------------------------------------- /app/websocket/connection.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Dict, Optional, Any 3 | from fastapi.websockets import WebSocket, WebSocketDisconnect, WebSocketState 4 | from app.db.models import to_pydantic 5 | from app.schemas.models import SessionSchema 6 | from app.services.chat_state import ChatStateManager 7 | from app.db.database import Database 8 | from app.utils.errors import WebSocketError, handle_websocket_error 9 | from app.config import settings 10 | from datetime import datetime, timedelta 11 | import json 12 | import uuid 13 | from app.utils.logger import logger 14 | from app.websocket.handlers.main import WebSocketHandler 15 | from app.websocket.types import WebSocketEvent, MessageType 16 | 17 | 18 | class WebSocketConnection: 19 | """ 20 | WebSocket connection manager with enhanced error handling, state management, 21 | and real-time audio/text processing capabilities 22 | """ 23 | 24 | def __init__(self, websocket: WebSocket, db: Database, subprotocol: Optional[str] = None): 25 | """ 26 | Initialize WebSocket connection with necessary services 27 | 📝 File: connection.py, Line: 22, Function: __init__ 28 | """ 29 | self.websocket = websocket 30 | self.db = db 31 | self.client_id = str(uuid.uuid4()) 32 | self.subprotocol = subprotocol 33 | self.handler = WebSocketHandler(websocket, db) 34 | self.chat_state = ChatStateManager(self.handler.redis, db) 35 | self.heartbeat_task: Optional[asyncio.Task] = None 36 | self.is_connected = False 37 | self.current_session_id: Optional[str] = None 38 | 39 | async def handle_connection(self) -> None: 40 | """ 41 | Handle WebSocket connection lifecycle with heartbeat and state management 42 | 📝 File: connection.py, Line: 35, Function: handle_connection 43 | """ 44 | try: 45 | await self.websocket.accept( 46 | subprotocol=self.subprotocol[0] if self.subprotocol[0] else None 47 | ) 48 | self.is_connected = True 49 | 50 | session = None 51 | 52 | # Initialize session with retry logic 53 | retry_count = 3 54 | for attempt in range(retry_count): 55 | try: 56 | self.current_session_id = await self._initialize_session() 57 | session = await self.db.get_session(self.current_session_id) 58 | break 59 | except Exception as e: 60 | if attempt == retry_count - 1: 61 | raise 62 | await asyncio.sleep(1) 63 | 64 | logger.info( 65 | f"🔌 connection.py: New WebSocket connection established - " 66 | f"Client ID: {self.client_id}, Session: {self.current_session_id}" 67 | ) 68 | 69 | # Start heartbeat 70 | self.heartbeat_task = asyncio.create_task(self._heartbeat()) 71 | 72 | # Send connection confirmation 73 | await self._send_connection_confirmed(to_pydantic(session, SessionSchema).model_dump()) 74 | 75 | while self.is_connected and self.websocket.client_state == WebSocketState.CONNECTED: 76 | try: 77 | message = await self._receive_message() 78 | if message: 79 | await self.handle_message(message) 80 | except asyncio.TimeoutError: 81 | continue 82 | except Exception as e: 83 | logger.error(f"❌ connection.py: Message processing error: {str(e)}") 84 | await self._send_error(str(e)) 85 | 86 | except WebSocketDisconnect as e: 87 | logger.error(f"❌ connection.py: Client {self.client_id} disconnected: {e}") 88 | except Exception as e: 89 | logger.error(f"❌ connection.py: Connection error: {str(e)}") 90 | finally: 91 | await self._cleanup() 92 | 93 | def set_model(self, model: str) -> None: 94 | """ 95 | Set model for the current session 96 | 📝 File: connection.py, Line: 64, Function: set_model 97 | """ 98 | self.handler.set_model(model) 99 | 100 | async def handle_message(self, message: Dict[str, Any]) -> None: 101 | """ 102 | Route messages to appropriate handlers with rate limiting and validation 103 | 📝 File: connection.py, Line: 71, Function: handle_message 104 | """ 105 | try: 106 | # Validate message structure and session 107 | self._validate_message(message) 108 | if not self.current_session_id: 109 | raise WebSocketError("No active session", code=4003) 110 | 111 | message_type = message["type"] 112 | logger.info(f"📨 connection.py: Handling message type: {message_type}") 113 | 114 | # Check rate limits 115 | await self._check_rate_limits() 116 | 117 | # Enrich message with session data 118 | enriched_message = await self._enrich_message(message) 119 | 120 | 121 | # Create WebSocketEvent instance 122 | event = WebSocketEvent( 123 | event_id=enriched_message.get("event_id", f"event_{str(uuid.uuid4())}"), 124 | type=MessageType(message_type), 125 | data=enriched_message 126 | ) 127 | 128 | # Handle the message through main handler 129 | logger.info(f"📨 connection.py: Handling message: {event}") 130 | # await self.handler.handle_message(event) 131 | 132 | except WebSocketError as e: 133 | await self._send_error(e.message, e.code) 134 | except Exception as e: 135 | logger.error(f"❌ connection.py: Message handling error: {str(e)}") 136 | await self._send_error("Internal server error", 500) 137 | 138 | async def _initialize_session(self) -> str: 139 | """ 140 | Initialize session in database and cache with enhanced configuration 141 | 📝 File: connection.py, Line: 124, Function: _initialize_session 142 | """ 143 | session_id = str(uuid.uuid4()) 144 | session_data = { 145 | "id": session_id, 146 | "model": self.handler.llm_service.model, 147 | "modalities": ['text'] 148 | } 149 | 150 | # Add audio modality if enabled 151 | if settings.TTS_ENGINE: 152 | session_data["modalities"].append("audio") 153 | 154 | await self.db.create_session(session_data) 155 | return session_id 156 | 157 | async def _receive_message(self) -> Optional[Dict[str, Any]]: 158 | """ 159 | Receive and parse WebSocket message with timeout and validation 160 | 📝 File: connection.py, Line: 140, Function: _receive_message 161 | """ 162 | if self.websocket.client_state != WebSocketState.CONNECTED: 163 | return None 164 | try: 165 | message = await self.websocket.receive_json() 166 | # Basic JSON schema validation 167 | if not isinstance(message, dict): 168 | raise WebSocketError("Invalid message format", code=4000) 169 | 170 | return message 171 | except json.JSONDecodeError as e: 172 | logger.error(f"❌ connection.py: Invalid JSON received: {str(e)}") 173 | await self._send_error("Invalid JSON format") 174 | return None 175 | except asyncio.TimeoutError: 176 | return None 177 | except Exception as e: 178 | logger.error(f"❌ connection.py: Message receive error: {str(e)}") 179 | return None 180 | 181 | async def _heartbeat(self) -> None: 182 | """ 183 | Send periodic heartbeat with enhanced monitoring 184 | 📝 File: connection.py, Line: 157, Function: _heartbeat 185 | """ 186 | while self.is_connected: 187 | try: 188 | current_time = datetime.now() 189 | await self.websocket.send_text(json.dumps({ 190 | "type": "heartbeat", 191 | "timestamp": current_time.isoformat(), 192 | "session_id": self.current_session_id 193 | })) 194 | 195 | await asyncio.sleep(settings.WS_HEARTBEAT_INTERVAL) 196 | except Exception as e: 197 | logger.error(f"❌ connection.py: Heartbeat error: {str(e)}") 198 | break 199 | 200 | async def _check_rate_limits(self) -> None: 201 | """ 202 | Check and update rate limits 203 | 📝 File: connection.py, Line: 172, Function: _check_rate_limits 204 | """ 205 | rate_limits = await self.db.get_session_rate_limits(self.client_id) 206 | logger.info(f"📨 connection.py: Rate limits: {rate_limits}") 207 | if len(rate_limits) == 0: 208 | return 209 | for limit_type, limit in rate_limits.items(): 210 | if limit["remaining"] <= 0: 211 | raise WebSocketError(f"Rate limit exceeded for {limit_type} ({limit['reset_seconds']} seconds)", code=4029) 212 | 213 | await self.db.update_rate_limits(self.client_id, self.current_session_id) 214 | 215 | def _validate_message(self, message: Dict[str, Any]) -> None: 216 | """ 217 | Validate message structure 218 | 📝 File: connection.py, Line: 182, Function: _validate_message 219 | """ 220 | if not isinstance(message, dict) or "type" not in message: 221 | raise WebSocketError("Message type is required", code=4001) 222 | 223 | async def _enrich_message(self, message: Dict[str, Any]) -> Dict[str, Any]: 224 | """ 225 | Enrich message with session and state data 226 | 📝 File: connection.py, Line: 192, Function: _enrich_message 227 | """ 228 | state = await self.chat_state.get_chat_state(self.current_session_id) 229 | return { 230 | **message, 231 | "session_id": self.current_session_id, 232 | "client_id": self.client_id, 233 | "state": state 234 | } 235 | 236 | async def _send_connection_confirmed(self, session: Dict[str, Any]) -> None: 237 | """ 238 | Send connection confirmation with session details 239 | 📝 File: connection.py, Line: 205, Function: _send_connection_confirmed 240 | """ 241 | logger.info(f"📨 connection.py: Sending connection confirmed message") 242 | logger.info(f"📨 connection.py: WebSocket client state: {self.websocket.client_state}") 243 | if self.websocket.client_state != WebSocketState.CONNECTED: 244 | return 245 | try: 246 | await self.websocket.send_json({ 247 | "type": "session.created", 248 | "event_id": f"event_{str(uuid.uuid4())}", 249 | "session": {**session, "expires_at": (datetime.now() + timedelta(seconds=settings.SESSION_EXPIRATION_TIME)).isoformat()} 250 | }) 251 | except WebSocketDisconnect as e: 252 | logger.error(f"❌ connection.py: Line 250: {e.code} WebSocket disconnected: {e.reason}") 253 | except Exception as e: 254 | logger.error(f"❌ connection.py: Line 252: Failed to send connection confirmed message: {str(e)}") 255 | 256 | async def _send_error(self, message: str, code: int = 400) -> None: 257 | """ 258 | Send error message to client 259 | 📝 File: connection.py, Line: 227, Function: _send_error 260 | """ 261 | try: 262 | error_response = handle_websocket_error(WebSocketError(message, code)) 263 | await self.websocket.send_json(error_response) 264 | except Exception as e: 265 | logger.error(f"❌ connection.py: Failed to send error message: {str(e)}") 266 | 267 | async def _cleanup(self) -> None: 268 | """ 269 | Clean up resources on connection close 270 | 📝 File: connection.py, Line: 238, Function: _cleanup 271 | """ 272 | try: 273 | self.is_connected = False 274 | if self.heartbeat_task: 275 | self.heartbeat_task.cancel() 276 | await self.handler.cleanup() 277 | logger.info(f"🧹 connection.py: Cleanup completed for client {self.client_id}") 278 | except Exception as e: 279 | logger.error(f"❌ connection.py: Cleanup failed: {str(e)}") 280 | 281 | async def cleanup(self): 282 | return await self._cleanup() 283 | -------------------------------------------------------------------------------- /app/websocket/handlers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/app/websocket/handlers/__init__.py -------------------------------------------------------------------------------- /app/websocket/handlers/audio.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from fastapi import HTTPException 3 | 4 | from app.services.audio import AudioService 5 | from app.websocket.types import MessageType 6 | from app.websocket.base_handler import BaseHandler 7 | from app.utils.errors import AudioProcessingError 8 | 9 | from app.utils.logger import logger 10 | 11 | 12 | class AudioHandler(BaseHandler): 13 | """Handles audio-related WebSocket events""" 14 | 15 | def __init__(self, audio_service: AudioService, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.audio_service = audio_service 18 | 19 | async def handle_audio_append(self, message: Dict[str, Any]) -> None: 20 | """ 21 | Handle audio buffer append with transcription 22 | 📝 File: audio.py, Line: 20, Function: handle_audio_append 23 | """ 24 | try: 25 | audio_data = message.get("audio") 26 | event_id = message.get("event_id", "default") 27 | 28 | if not audio_data: 29 | raise ValueError("No audio data provided") 30 | 31 | # Process audio through service 32 | async for transcription in self.audio_service.transcribe_audio( 33 | audio_data=audio_data, 34 | event_id=event_id 35 | ): 36 | if isinstance(transcription, dict) and "error" in transcription: 37 | raise AudioProcessingError(transcription["error"]) 38 | 39 | await self.redis.rpush( 40 | f"transcriptions:{event_id}", 41 | transcription 42 | ) 43 | 44 | await self.send_event(MessageType.AUDIO_TRANSCRIBED.value, { 45 | "event_id": event_id, 46 | "status": "completed" 47 | }) 48 | 49 | except Exception as e: 50 | logger.error(f"❌ audio.py: Audio processing failed: {str(e)}") 51 | raise HTTPException(status_code=400, detail=str(e)) 52 | 53 | async def handle_speech_generate(self, message: Dict[str, Any]) -> None: 54 | """ 55 | Handle speech generation request 56 | 📝 File: audio.py, Line: 54, Function: handle_speech_generate 57 | """ 58 | try: 59 | text = message.get("text") 60 | voice = message.get("voice", "alloy") 61 | event_id = message.get("event_id", "default") 62 | 63 | if not text: 64 | raise ValueError("No text provided") 65 | 66 | # Generate speech 67 | file_path, cache_key = await self.audio_service.generate_speech( 68 | text=text, 69 | voice=voice 70 | ) 71 | 72 | # Send response with file path 73 | await self.send_event(MessageType.SPEECH_GENERATED.value, { 74 | "event_id": event_id, 75 | "file_path": file_path, 76 | "cache_key": cache_key 77 | }) 78 | 79 | except Exception as e: 80 | logger.error(f"❌ audio.py: Speech generation failed: {str(e)}") 81 | raise HTTPException(status_code=400, detail=str(e)) 82 | 83 | async def handle_audio_commit(self, message: Dict[str, Any]) -> None: 84 | """ 85 | Handle audio buffer commit 86 | 📝 File: audio.py, Line: 74, Function: handle_audio_commit 87 | """ 88 | try: 89 | event_id = message.get("event_id", "default") 90 | 91 | # Commit audio buffer 92 | await self.audio_service.commit_audio(event_id) 93 | 94 | except Exception as e: 95 | logger.error(f"❌ audio.py: Audio commit failed: {str(e)}") 96 | raise HTTPException(status_code=400, detail=str(e)) 97 | 98 | async def cleanup(self) -> None: 99 | """ 100 | Cleanup audio service resources 101 | 📝 File: audio.py, Line: 82, Function: cleanup 102 | """ 103 | try: 104 | await self.audio_service.cleanup() 105 | logger.info("🧹 audio.py: Audio handler cleanup completed") 106 | except Exception as e: 107 | logger.error(f"❌ audio.py: Cleanup failed: {str(e)}") 108 | -------------------------------------------------------------------------------- /app/websocket/handlers/conversation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from fastapi import HTTPException 3 | from datetime import datetime 4 | import json 5 | from app.websocket.types import MessageType 6 | from app.websocket.base_handler import BaseHandler 7 | from app.utils.logger import logger 8 | 9 | 10 | class ConversationHandler(BaseHandler): 11 | """Handles conversation-related events""" 12 | 13 | async def handle_conversation_create(self, message: Dict[str, Any]) -> None: 14 | """ 15 | Handle conversation item creation 16 | 📝 File: conversation.py, Line: 15, Function: handle_conversation_create 17 | """ 18 | try: 19 | item = message.get("item", {}) 20 | event_id = message.get("event_id", "default") 21 | 22 | # Validate conversation item 23 | self._validate_conversation_item(item) 24 | 25 | # Add metadata 26 | item_with_metadata = self._add_item_metadata(item, event_id) 27 | 28 | # Store conversation item 29 | await self._store_conversation_item(item_with_metadata, event_id) 30 | 31 | await self.send_event(MessageType.CONVERSATION_CREATE.value, { 32 | "event_id": event_id, 33 | "item": item_with_metadata 34 | }) 35 | 36 | logger.info(f"💬 conversation.py: Item created for conversation {event_id}") 37 | 38 | except Exception as e: 39 | logger.error(f"❌ conversation.py: Item creation failed: {str(e)}") 40 | raise HTTPException(status_code=400, detail=str(e)) 41 | 42 | async def handle_conversation_truncate(self, message: Dict[str, Any]) -> None: 43 | """ 44 | Handle conversation truncation 45 | 📝 File: conversation.py, Line: 42, Function: handle_conversation_truncate 46 | """ 47 | try: 48 | event_id = message.get("event_id", "default") 49 | before_id = message.get("before_id") 50 | 51 | if not before_id: 52 | raise ValueError("before_id is required for truncation") 53 | 54 | conv_key = f"conversation:{event_id}" 55 | items = await self.redis.lrange(conv_key, 0, -1) 56 | 57 | # Find index of before_id 58 | truncate_index = None 59 | for i, item in enumerate(items): 60 | item_data = json.loads(item) 61 | if item_data.get("id") == before_id: 62 | truncate_index = i 63 | break 64 | 65 | if truncate_index is None: 66 | raise ValueError(f"Item with id {before_id} not found") 67 | 68 | # Remove items after truncate_index 69 | await self.redis.ltrim(conv_key, 0, truncate_index - 1) 70 | 71 | await self.send_event(MessageType.CONVERSATION_TRUNCATE.value, { 72 | "event_id": event_id, 73 | "before_id": before_id 74 | }) 75 | 76 | logger.info(f"✂️ conversation.py: Conversation {event_id} truncated before {before_id}") 77 | 78 | except Exception as e: 79 | logger.error(f"❌ conversation.py: Truncation failed: {str(e)}") 80 | raise HTTPException(status_code=400, detail=str(e)) 81 | 82 | def _validate_conversation_item(self, item: Dict[str, Any]) -> None: 83 | """Validate conversation item structure""" 84 | required_fields = ["type", "role", "content"] 85 | if not all(k in item for k in required_fields): 86 | raise ValueError(f"Missing required fields: {required_fields}") 87 | 88 | valid_roles = ["user", "assistant", "system"] 89 | if item["role"] not in valid_roles: 90 | raise ValueError(f"Invalid role. Must be one of: {valid_roles}") 91 | 92 | def _add_item_metadata(self, item: Dict[str, Any], event_id: str) -> Dict[str, Any]: 93 | """Add metadata to conversation item""" 94 | return { 95 | **item, 96 | "id": f"msg_{datetime.now().timestamp()}", 97 | "created_at": datetime.now().isoformat(), 98 | "event_id": event_id 99 | } 100 | 101 | async def _store_conversation_item(self, item: Dict[str, Any], event_id: str) -> None: 102 | """Store conversation item in Redis""" 103 | conv_key = f"conversation:{event_id}" 104 | await self.redis.rpush(conv_key, json.dumps(item)) 105 | await self.redis.expire(conv_key, 86400) # 24 hour TTL 106 | 107 | def set_model(self, model): 108 | self.llm.set_default_model(model) 109 | -------------------------------------------------------------------------------- /app/websocket/handlers/main.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from fastapi import WebSocket 3 | from redis import Redis 4 | from datetime import datetime 5 | import asyncio 6 | 7 | from app.websocket.types import MessageType, WebSocketEvent 8 | from app.websocket.handlers.session import SessionHandler 9 | from app.websocket.handlers.audio import AudioHandler 10 | from app.websocket.handlers.conversation import ConversationHandler 11 | from app.websocket.handlers.response import ResponseHandler 12 | from app.db.database import Database 13 | from app.services.chat_state import ChatStateManager 14 | from app.services.llm import LLMService 15 | from app.services.audio import AudioService 16 | from app.utils.errors import WebSocketError 17 | from app.config import settings 18 | 19 | from app.utils.logger import logger 20 | 21 | 22 | class WebSocketHandler: 23 | """Main WebSocket handler that orchestrates all sub-handlers""" 24 | 25 | def __init__(self, websocket: WebSocket, db: Database): 26 | """ 27 | Initialize WebSocket handler with all necessary sub-handlers and services 28 | 📝 File: main.py, Line: 28, Function: __init__ 29 | """ 30 | self.websocket = websocket 31 | self.db = db 32 | self.redis = Redis( 33 | host=settings.REDIS_HOST, 34 | port=settings.REDIS_PORT, 35 | db=settings.REDIS_DB, 36 | decode_responses=True 37 | ) 38 | 39 | # Initialize services 40 | self.llm_service = LLMService() 41 | self.audio_service = AudioService() 42 | 43 | # Initialize sub-handlers 44 | self.session_handler = SessionHandler(websocket, self.redis, self.llm_service, 45 | db) 46 | self.audio_handler = AudioHandler(self.audio_service, websocket, self.redis, self.llm_service, db) 47 | self.conversation_handler = ConversationHandler(websocket, self.redis, self.llm_service, 48 | db) 49 | self.response_handler = ResponseHandler( 50 | websocket, 51 | self.redis, 52 | self.llm_service, 53 | db 54 | ) 55 | 56 | # Initialize chat state manager 57 | self.chat_state = ChatStateManager(self.redis, db) 58 | 59 | logger.info("✨ main.py: WebSocket handler initialized with all services") 60 | 61 | async def handle_message(self, event: WebSocketEvent) -> None: 62 | """ 63 | Main message handling method with validation and routing 64 | 📝 File: main.py, Line: 61, Function: handle_message 65 | """ 66 | try: 67 | # Validate session and rate limits 68 | await self._validate_session(event.data.get("session_id")) 69 | await self._check_rate_limits(event.data.get("client_id")) 70 | 71 | # Get appropriate handler 72 | handler = self.handlers.get(event.type.value) 73 | if not handler: 74 | raise WebSocketError(f"Unknown event type: {event.type}", code=4000) 75 | 76 | # Handle the message 77 | await handler(event) 78 | 79 | logger.info(f"✅ main.py: Successfully handled {event.type.value} event") 80 | 81 | except Exception as e: 82 | logger.error(f"❌ main.py: Error handling message: {str(e)}") 83 | raise 84 | 85 | async def cleanup(self) -> None: 86 | """ 87 | Enhanced cleanup with state persistence 88 | 📝 File: main.py, Line: 138, Function: cleanup 89 | """ 90 | try: 91 | # Save final state 92 | if hasattr(self, 'current_session_id'): 93 | await self.chat_state.persist_state(self.current_session_id) 94 | 95 | # Cleanup handlers 96 | cleanup_tasks = [ 97 | self.session_handler.cleanup(), 98 | self.audio_handler.cleanup(), 99 | self.conversation_handler.cleanup(), 100 | self.response_handler.cleanup() 101 | ] 102 | 103 | # Run cleanup tasks concurrently 104 | await asyncio.gather(*cleanup_tasks, return_exceptions=True) 105 | 106 | # Close connections 107 | self.redis.close() 108 | await self.db.disconnect() 109 | 110 | logger.info("🧹 main.py: All handlers and connections cleaned up successfully") 111 | 112 | except Exception as e: 113 | logger.error(f"❌ main.py: Cleanup failed: {str(e)}") 114 | raise 115 | 116 | async def _validate_session(self, session_id: str) -> None: 117 | """ 118 | Enhanced session validation with caching 119 | 📝 File: main.py, Line: 167, Function: _validate_session 120 | """ 121 | cache_key = f"session_valid:{session_id}" 122 | 123 | # Check cache first 124 | is_valid = await self.redis.get(cache_key) 125 | if is_valid: 126 | return 127 | 128 | # Validate from database 129 | session = await self.db.get_session(session_id) 130 | if not session: 131 | raise WebSocketError("Session not found", code=4004) 132 | if session["status"] != "active": 133 | raise WebSocketError("Session is not active", code=4005) 134 | 135 | # Cache validation result 136 | await self.redis.setex(cache_key, 300, "1") # Cache for 5 minutes 137 | 138 | async def _check_rate_limits(self, client_id: str) -> None: 139 | """ 140 | Enhanced rate limiting with token bucket algorithm 141 | 📝 File: main.py, Line: 188, Function: _check_rate_limits 142 | """ 143 | rate_limits = await self.db.get_session_rate_limits(client_id) 144 | current_time = datetime.now().timestamp() 145 | 146 | for limit_type, limit in rate_limits.items(): 147 | if current_time >= limit["reset_seconds"]: 148 | # Reset limits 149 | await self.db.reset_rate_limits(client_id, limit_type) 150 | elif limit["remaining"] <= 0: 151 | raise WebSocketError( 152 | f"Rate limit exceeded for {limit_type}", 153 | code=4029, 154 | data={"reset_in": limit["reset_seconds"] - current_time} 155 | ) 156 | 157 | @property 158 | def handlers(self) -> Dict[str, callable]: 159 | """Handler mapping with type checking""" 160 | return { 161 | MessageType.SESSION_UPDATE.value: self.session_handler.handle_session_update, 162 | MessageType.AUDIO_APPEND.value: self.audio_handler.handle_audio_append, 163 | MessageType.AUDIO_COMMIT.value: self.audio_handler.handle_audio_commit, 164 | MessageType.AUDIO_CLEAR.value: self.handle_input_audio_buffer_clear, 165 | MessageType.CONVERSATION_CREATE.value: self.conversation_handler.handle_conversation_create, 166 | MessageType.CONVERSATION_TRUNCATE.value: self.handle_conversation_item_truncate, 167 | MessageType.CONVERSATION_DELETE.value: self.handle_conversation_item_delete, 168 | MessageType.RESPONSE_CREATE.value: self.response_handler.handle_response_create, 169 | MessageType.RESPONSE_CANCEL.value: self.response_handler.handle_response_cancel, 170 | } 171 | 172 | def set_model(self, model): 173 | self.conversation_handler.set_model(model) 174 | -------------------------------------------------------------------------------- /app/websocket/handlers/response.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | 3 | from fastapi import WebSocket 4 | from redis import Redis 5 | from app.services.llm import LLMService 6 | from app.db.database import Database 7 | from app.websocket.base_handler import BaseHandler 8 | from app.websocket.types import ContentPart 9 | from app.utils.logger import logger 10 | 11 | 12 | class ResponseHandler(BaseHandler): 13 | def __init__(self, websocket: WebSocket, redis: Redis, llm: LLMService, db: Database): 14 | super().__init__(websocket, redis, llm, db) 15 | 16 | def handle_response_create(self, handler): 17 | pass 18 | 19 | def handle_response_cancel(self, handler): 20 | pass 21 | 22 | async def _process_response( 23 | self, 24 | event_id: str, 25 | response_id: str, 26 | config: Dict[str, Any] 27 | ) -> None: 28 | """ 29 | Process response with LLM integration 30 | 📝 File: response.py, Line: 20, Function: _process_response 31 | """ 32 | try: 33 | # Get session from database 34 | session = await self.db.get_session(event_id) 35 | if not session: 36 | raise ValueError("Session not found") 37 | 38 | # Get conversation history 39 | messages = await self._get_conversation_history(event_id) 40 | 41 | # Generate response 42 | async for chunk in await self.llm.generate_response( 43 | messages=messages, 44 | temperature=config.get("temperature", 0.8), 45 | tools=config.get("tools", []), 46 | stream=True 47 | ): 48 | if chunk.choices[0].delta.content: 49 | # Store in database 50 | await self.db.create_message( 51 | session_id=event_id, 52 | role="assistant", 53 | content=chunk.choices[0].delta.content, 54 | content_type="text", 55 | metadata={"response_id": response_id} 56 | ) 57 | 58 | # Send to websocket 59 | await self._send_content_part( 60 | event_id, 61 | response_id, 62 | ContentPart( 63 | type="text", 64 | text=chunk.choices[0].delta.content 65 | ) 66 | ) 67 | 68 | elif chunk.choices[0].delta.function_call: 69 | # Handle function calls 70 | await self._handle_function_call( 71 | event_id, 72 | response_id, 73 | chunk.choices[0].delta.function_call 74 | ) 75 | 76 | # Update rate limits 77 | await self._update_rate_limits(event_id) 78 | 79 | except Exception as e: 80 | logger.error(f"❌ response.py: Response processing failed: {str(e)}") 81 | await self._handle_error(event_id, response_id, str(e)) 82 | 83 | async def _get_conversation_history( 84 | self, 85 | session_id: str, 86 | limit: int = 50 87 | ) -> List[Dict[str, Any]]: 88 | """ 89 | Get conversation history from database 90 | 📝 File: response.py, Line: 82, Function: _get_conversation_history 91 | """ 92 | messages = await self.db.get_session_messages( 93 | session_id=session_id, 94 | limit=limit 95 | ) 96 | 97 | return [ 98 | { 99 | "role": msg["role"], 100 | "content": msg["content"], 101 | **({"function_call": msg["function_call"]} 102 | if msg["function_call"] else {}) 103 | } 104 | for msg in messages 105 | ] 106 | -------------------------------------------------------------------------------- /app/websocket/handlers/session.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from fastapi import HTTPException 3 | from app.websocket.types import SessionConfig, MessageType, WebSocketEvent 4 | from app.websocket.base_handler import BaseHandler 5 | 6 | from app.utils.logger import logger 7 | 8 | 9 | class SessionHandler(BaseHandler): 10 | """Handles session-related events""" 11 | 12 | async def handle_session_update(self, event: WebSocketEvent) -> None: 13 | """ 14 | Handle session.update event 15 | 📝 File: session.py, Line: 13, Function: handle_session_update 16 | """ 17 | try: 18 | session_data = event.data.get("state", {}) 19 | event_id = event.data.get("event_id", "default") 20 | 21 | # Validate and create session config 22 | config = SessionConfig(**session_data) 23 | 24 | logger.info(f"🔄 session.py: Updating session {config.id}") 25 | 26 | # Store session config with TTL 27 | session_key = f"session:{config.id}" 28 | await self.redis.set(session_key, config.__dict__) 29 | await self.redis.expire(session_key, 3600) # 1 hour TTL 30 | 31 | await self.send_event(MessageType.SESSION_UPDATED.value, { 32 | "event_id": event_id, 33 | "type": MessageType.SESSION_UPDATED.value, 34 | "session": config.__dict__ 35 | }) 36 | 37 | self.db.update_session(config.__dict__) 38 | 39 | except Exception as e: 40 | logger.error(f"❌ session.py: Session update failed: {str(e)}") 41 | raise HTTPException(status_code=400, detail=str(e)) 42 | -------------------------------------------------------------------------------- /app/websocket/redis.py: -------------------------------------------------------------------------------- 1 | # websocket/redis.py 2 | import redis 3 | 4 | redis_client = redis.Redis(host="localhost", port=6379, db=0) -------------------------------------------------------------------------------- /app/websocket/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from dataclasses import dataclass 3 | from typing import Dict, Any, List, Optional, Union 4 | from datetime import datetime 5 | 6 | class MessageType(Enum): 7 | """WebSocket message types from documentation""" 8 | # Session events 9 | SESSION_UPDATE = "session.update" 10 | SESSION_CREATED = "session.created" 11 | SESSION_UPDATED = "session.updated" 12 | 13 | # Audio buffer events 14 | AUDIO_APPEND = "input_audio_buffer.append" 15 | AUDIO_COMMIT = "input_audio_buffer.commit" 16 | AUDIO_CLEAR = "input_audio_buffer.clear" 17 | AUDIO_COMMITTED = "input_audio_buffer.committed" 18 | AUDIO_CLEARED = "input_audio_buffer.cleared" 19 | 20 | # Conversation events 21 | CONVERSATION_CREATE = "conversation.item.create" 22 | CONVERSATION_TRUNCATE = "conversation.item.truncate" 23 | CONVERSATION_DELETE = "conversation.item.delete" 24 | 25 | # Response events 26 | RESPONSE_CREATE = "response.create" 27 | RESPONSE_CANCEL = "response.cancel" 28 | RESPONSE_CONTENT_PART_ADDED = "response.content_part.added" 29 | RESPONSE_CONTENT_PART_DONE = "response.content_part.done" 30 | RESPONSE_FUNCTION_CALL_ARGS_DONE = "response.function_call_arguments.done" 31 | 32 | # Rate limit events 33 | RATE_LIMITS_UPDATED = "rate_limits.updated" 34 | 35 | @dataclass 36 | class SessionConfig: 37 | """Session configuration from documentation""" 38 | modalities: List[str] 39 | id: str 40 | voice: str 41 | instructions: Optional[str] = None 42 | input_audio_format: str = "pcm16" 43 | output_audio_format: str = "pcm16" 44 | input_audio_transcription: Optional[Dict] = None 45 | turn_detection: Optional[Dict] = None 46 | tools: List[Dict] = None 47 | tool_choice: str = "auto" 48 | temperature: float = 0.8 49 | max_response_output_tokens: Union[int, str] = "inf" 50 | 51 | @dataclass 52 | class WebSocketEvent: 53 | """Base WebSocket event structure""" 54 | event_id: str 55 | type: MessageType 56 | data: Dict[str, Any] 57 | timestamp: str = None 58 | 59 | def __post_init__(self): 60 | if not self.timestamp: 61 | self.timestamp = datetime.now().isoformat() 62 | 63 | @dataclass 64 | class ContentPart: 65 | """Content part structure for responses""" 66 | type: str # "text" or "audio" 67 | text: Optional[str] = None 68 | audio: Optional[str] = None # Base64 encoded 69 | transcript: Optional[str] = None -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | api: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | container_name: ollamagate-api 9 | ports: 10 | - "9009:9009" 11 | environment: 12 | - ENVIRONMENT=development 13 | - DEBUG=True 14 | - REDIS_HOST=redis 15 | - REDIS_PORT=6379 16 | - REDIS_DB=0 17 | - OPENAI_API_KEY=${OPENAI_API_KEY} 18 | - OPENAI_API_BASE_URL=https://api.openai.com/v1 19 | - OLLAMA_API_BASE_URL=http://ollama:11434/v1 20 | - GPT_MODEL=gpt-4-turbo-preview 21 | - OLLAMA_MODEL=llama2 22 | - LLM_TIMEOUT=60.0 23 | - LLM_MAX_RETRIES=3 24 | - TTS_ENGINE=openai 25 | - TTS_MODEL=tts-1 26 | - TTS_OPENAI_API_KEY=${OPENAI_API_KEY} 27 | - TTS_OPENAI_API_BASE_URL=https://api.openai.com/v1 28 | - SPEECH_CACHE_DIR=/app/cache/speech 29 | - WHISPER_MODEL_SIZE=base 30 | - WHISPER_DEVICE=cuda 31 | - WHISPER_COMPUTE_TYPE=float16 32 | volumes: 33 | - ./app:/app 34 | - speech_cache:/app/cache/speech 35 | depends_on: 36 | - redis 37 | - ollama 38 | deploy: 39 | resources: 40 | reservations: 41 | devices: 42 | - driver: nvidia 43 | count: 1 44 | capabilities: [gpu] 45 | 46 | redis: 47 | image: redis:alpine 48 | container_name: ollamagate-redis 49 | ports: 50 | - "6379:6379" 51 | volumes: 52 | - redis_data:/data 53 | command: redis-server --appendonly yes 54 | healthcheck: 55 | test: ["CMD", "redis-cli", "ping"] 56 | interval: 10s 57 | timeout: 5s 58 | retries: 3 59 | 60 | ollama: 61 | image: ollama/ollama:latest 62 | container_name: ollamagate-ollama 63 | ports: 64 | - "11434:11434" 65 | volumes: 66 | - ollama_models:/root/.ollama 67 | deploy: 68 | resources: 69 | reservations: 70 | devices: 71 | - driver: nvidia 72 | count: 1 73 | capabilities: [gpu] 74 | healthcheck: 75 | test: ["CMD", "curl", "-f", "http://localhost:11434/v1/health"] 76 | interval: 30s 77 | timeout: 10s 78 | retries: 3 79 | 80 | prometheus: 81 | image: prom/prometheus:latest 82 | container_name: olamagate-prometheus 83 | ports: 84 | - "9090:9090" 85 | volumes: 86 | - ./prometheus:/etc/prometheus 87 | - prometheus_data:/prometheus 88 | command: 89 | - '--config.file=/etc/prometheus/prometheus.yml' 90 | - '--storage.tsdb.path=/prometheus' 91 | - '--web.console.libraries=/usr/share/prometheus/console_libraries' 92 | - '--web.console.templates=/usr/share/prometheus/consoles' 93 | 94 | grafana: 95 | image: grafana/grafana:latest 96 | container_name: ollama-grafana 97 | ports: 98 | - "3000:3000" 99 | volumes: 100 | - grafana_data:/var/lib/grafana 101 | environment: 102 | - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:-admin} 103 | depends_on: 104 | - prometheus 105 | 106 | postgres: 107 | image: postgres:16 108 | container_name: ollama-postgres 109 | ports: 110 | - "5432:5432" 111 | environment: 112 | - POSTGRES_USER=ollamagateuser 113 | - POSTGRES_PASSWORD=ollamagate 114 | - POSTGRES_DB=ollamagate 115 | volumes: 116 | - ./postgres_data:/var/lib/postgresql/data 117 | 118 | openedai-speech: 119 | image: ghcr.io/matatonic/openedai-speech-min:dev 120 | container_name: openedai-speech 121 | ports: 122 | - "8000:8000" 123 | 124 | volumes: 125 | redis_data: 126 | ollama_models: 127 | speech_cache: 128 | prometheus_data: 129 | grafana_data: 130 | postgres_data: 131 | 132 | networks: 133 | default: 134 | name: ai_chat_network 135 | driver: bridge -------------------------------------------------------------------------------- /echoollama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/echoollama.png -------------------------------------------------------------------------------- /pilot/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/pilot/.DS_Store -------------------------------------------------------------------------------- /pilot/Readme.md: -------------------------------------------------------------------------------- 1 | # Pilot runs 2 | 3 | This folder contains all the jupyter notebooks used to run the pilot runs. 4 | 5 | ## Pilot run 1 6 | 7 | Audio transcription with Whisper and decisions that can be made with the transcription. Like whether to reply or not. -------------------------------------------------------------------------------- /pilot/test_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/pilot/test_audio.wav -------------------------------------------------------------------------------- /pilot/vad_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theboringhumane/echoOLlama/2936731e53e6c08231ba67141e4c0419a4809559/pilot/vad_animation.gif -------------------------------------------------------------------------------- /prometheus/prometheus.yml: -------------------------------------------------------------------------------- 1 | global: 2 | scrape_interval: 15s 3 | evaluation_interval: 15s 4 | 5 | scrape_configs: 6 | - job_name: 'api' 7 | static_configs: 8 | - targets: ['api:8000'] 9 | 10 | - job_name: 'redis' 11 | static_configs: 12 | - targets: ['redis:6379'] 13 | 14 | - job_name: 'ollama' 15 | static_configs: 16 | - targets: ['ollama:11434'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alembic==1.13.3 2 | annotated-types==0.7.0 3 | anyio==4.6.2.post1 4 | async-timeout==5.0.0 5 | asyncpg==0.30.0 6 | av==12.3.0 7 | certifi==2024.8.30 8 | charset-normalizer==3.4.0 9 | click==8.1.7 10 | coloredlogs==15.0.1 11 | ctranslate2==4.5.0 12 | distro==1.9.0 13 | dnspython==2.7.0 14 | email_validator==2.2.0 15 | exceptiongroup==1.2.2 16 | fastapi==0.115.4 17 | fastapi-cli==0.0.5 18 | faster-whisper==1.0.3 19 | filelock==3.16.1 20 | flatbuffers==24.3.25 21 | fsspec==2024.10.0 22 | greenlet==3.1.1 23 | h11==0.14.0 24 | httpcore==1.0.6 25 | httptools==0.6.4 26 | httpx==0.27.2 27 | huggingface-hub==0.26.2 28 | humanfriendly==10.0 29 | idna==3.10 30 | Jinja2==3.1.4 31 | jiter==0.6.1 32 | Mako==1.3.6 33 | markdown-it-py==3.0.0 34 | MarkupSafe==3.0.2 35 | mdurl==0.1.2 36 | mpmath==1.3.0 37 | networkx==3.2.1 38 | numpy==2.0.2 39 | ollama==0.3.3 40 | onnxruntime==1.19.2 41 | openai==1.53.0 42 | packaging==24.1 43 | protobuf==5.28.3 44 | pydantic==2.9.2 45 | pydantic-settings==2.6.0 46 | pydantic_core==2.23.4 47 | pydub==0.25.1 48 | Pygments==2.18.0 49 | python-dotenv==1.0.1 50 | python-multipart==0.0.16 51 | PyYAML==6.0.2 52 | redis==5.2.0 53 | requests==2.32.3 54 | rich==13.9.3 55 | shellingham==1.5.4 56 | sniffio==1.3.1 57 | SQLAlchemy==2.0.36 58 | starlette==0.41.2 59 | sympy==1.13.1 60 | tokenizers==0.20.1 61 | torch==2.5.1 62 | tqdm==4.66.6 63 | typer==0.12.5 64 | typing_extensions==4.12.2 65 | urllib3==2.2.3 66 | uvicorn==0.32.0 67 | uvloop==0.21.0 68 | watchfiles==0.24.0 69 | websockets==13.1 70 | -------------------------------------------------------------------------------- /scripts/set_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=$PYTHONPATH:$(pwd) --------------------------------------------------------------------------------