├── multimodal_rag ├── vector_stores │ ├── __init__.py │ ├── multimodal_chroma_store.py │ ├── faiss_store.py │ └── chroma_store.py ├── __init__.py ├── processors │ ├── __init__.py │ ├── image_processor.py │ ├── audio_processor.py │ └── document_processor.py ├── enhanced_system.py ├── base.py └── system.py ├── .gitignore ├── docker ├── .env.example ├── QUICKSTART.md ├── docker-compose.lite.yml ├── start.sh ├── start.bat ├── Dockerfile.lite ├── docker-compose.yml ├── nginx.conf ├── docker-compose.prod.yml ├── Dockerfile └── README.md ├── LICENSE ├── requirements.txt ├── config.yaml ├── start.py ├── setup.py ├── clean_db.py ├── README.md ├── config_examples.py └── config_schema.py /multimodal_rag/vector_stores/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vector stores package initialization. 3 | """ 4 | 5 | from .chroma_store import ChromaVectorStore 6 | 7 | # Optional imports with fallback 8 | try: 9 | from .faiss_store import FAISSVectorStore 10 | FAISS_AVAILABLE = True 11 | except ImportError: 12 | FAISSVectorStore = None 13 | FAISS_AVAILABLE = False 14 | 15 | __all__ = ['ChromaVectorStore'] 16 | if FAISS_AVAILABLE: 17 | __all__.append('FAISSVectorStore') -------------------------------------------------------------------------------- /multimodal_rag/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Initialization file for the multimodal_rag package. 3 | """ 4 | 5 | from .base import ( 6 | DocumentChunk, 7 | ProcessingResult, 8 | RetrievalResult, 9 | BaseProcessor, 10 | BaseVectorStore, 11 | BaseEmbedding, 12 | BaseLLM, 13 | QueryRequest, 14 | QueryResponse 15 | ) 16 | 17 | __version__ = "1.0.0" 18 | __author__ = "SmartRAG Team" 19 | 20 | __all__ = [ 21 | "DocumentChunk", 22 | "ProcessingResult", 23 | "RetrievalResult", 24 | "BaseProcessor", 25 | "BaseVectorStore", 26 | "BaseEmbedding", 27 | "BaseLLM", 28 | "QueryRequest", 29 | "QueryResponse" 30 | ] -------------------------------------------------------------------------------- /multimodal_rag/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Processor package initialization. 3 | """ 4 | 5 | from .document_processor import ( 6 | TextProcessor, 7 | PDFProcessor, 8 | DOCXProcessor, 9 | DocumentProcessorManager 10 | ) 11 | 12 | from .image_processor import ( 13 | ImageProcessor, 14 | ImageProcessorManager 15 | ) 16 | 17 | from .audio_processor import ( 18 | AudioProcessor, 19 | AudioProcessorManager 20 | ) 21 | 22 | __all__ = [ 23 | 'TextProcessor', 24 | 'PDFProcessor', 25 | 'DOCXProcessor', 26 | 'DocumentProcessorManager', 27 | 'ImageProcessor', 28 | 'ImageProcessorManager', 29 | 'AudioProcessor', 30 | 'AudioProcessorManager' 31 | ] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | env/ 8 | venv/ 9 | ENV/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # Virtual environments 27 | venv/ 28 | ENV/ 29 | env/ 30 | 31 | # IDE 32 | .vscode/ 33 | .idea/ 34 | *.swp 35 | *.swo 36 | *~ 37 | 38 | # OS 39 | .DS_Store 40 | Thumbs.db 41 | 42 | # Project specific 43 | vector_db/ 44 | file_storage.db 45 | *.db 46 | *.sqlite3 47 | temp_uploads/ 48 | user_data/streamlit_uploaded_files.json 49 | logs/ 50 | *.log 51 | 52 | # Test files 53 | test_*.png 54 | test_*.jpg 55 | test_*.pdf 56 | test_*.wav 57 | test_*.txt 58 | 59 | # Streamlit 60 | .streamlit/ 61 | 62 | # Model cache 63 | models/ 64 | .cache/ 65 | huggingface/ 66 | 67 | # Data files 68 | data/ 69 | *.h5 70 | *.pkl 71 | *.pickle 72 | 73 | # Docker 74 | .dockerignore 75 | docker-compose.override.yml 76 | -------------------------------------------------------------------------------- /docker/.env.example: -------------------------------------------------------------------------------- 1 | # Production Deployment Environment Variables 2 | # Copy this file to .env and update with your production values 3 | 4 | # Application 5 | SMARTRAG_VERSION=1.0.0 6 | ENVIRONMENT=production 7 | 8 | # Database (PostgreSQL) 9 | DB_PASSWORD=change_me_in_production_use_strong_password 10 | POSTGRES_DB=smartrag 11 | POSTGRES_USER=smartrag 12 | 13 | # Redis Cache 14 | REDIS_PASSWORD=change_me_in_production_use_strong_password 15 | 16 | # Security 17 | MAX_FILE_SIZE_MB=50 18 | MAX_UPLOAD_SIZE=52428800 19 | 20 | # Ollama Configuration 21 | OLLAMA_HOST=http://localhost:11434 22 | SMARTRAG_LLM_MODEL=llama3.1:8b 23 | SMARTRAG_EMBEDDING_MODEL=nomic-embed-text 24 | 25 | # Logging 26 | LOG_LEVEL=INFO 27 | SMARTRAG_DEBUG=false 28 | 29 | # Resource Limits 30 | MEMORY_LIMIT=8G 31 | CPU_LIMIT=4 32 | 33 | # SSL/TLS (for production) 34 | # SSL_CERT_PATH=/path/to/cert.pem 35 | # SSL_KEY_PATH=/path/to/key.pem 36 | 37 | # Optional: External Services 38 | # SENTRY_DSN=your_sentry_dsn_here 39 | # PROMETHEUS_ENABLED=true 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SmartRAG Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core RAG and ML libraries 2 | chromadb>=0.4.0 3 | transformers>=4.35.0 # For BLIP image captioning and model utilities 4 | torch>=2.0.0 5 | torchvision>=0.15.0 6 | sentence-transformers>=2.2.0 # For text embeddings 7 | accelerate>=0.20.0 # For model optimization 8 | 9 | # Configuration management 10 | pydantic>=2.0.0 # For config validation and schema 11 | pydantic-settings>=2.0.0 # For environment variable support 12 | 13 | # Traditional system requirements 14 | ollama>=0.1.0 # For Llama 3.1 8B local inference 15 | pytesseract>=0.3.10 # For OCR with Tesseract 16 | opencv-python>=4.8.0 # For image processing 17 | 18 | # Image and vision 19 | Pillow>=10.0.0 20 | requests>=2.31.0 21 | 22 | # Document processing 23 | PyPDF2>=3.0.1 24 | python-docx>=1.0.0 25 | pdfplumber>=0.9.0 26 | python-pptx>=0.6.21 27 | 28 | # Audio processing (keeping Whisper) 29 | openai-whisper>=20231117 30 | pydub>=0.25.1 31 | librosa>=0.10.0 32 | 33 | # Web interface 34 | fastapi>=0.104.0 35 | uvicorn>=0.24.0 36 | python-multipart>=0.0.6 37 | streamlit>=1.28.0 38 | 39 | # Utilities 40 | numpy>=1.24.0 41 | tqdm>=4.65.0 42 | pyyaml>=6.0.0 43 | requests>=2.31.0 -------------------------------------------------------------------------------- /docker/QUICKSTART.md: -------------------------------------------------------------------------------- 1 | # 🐳 SmartRAG Docker Quick Start 2 | 3 | This folder contains all Docker-related files for SmartRAG deployment. 4 | 5 | ## 📁 Contents 6 | 7 | ``` 8 | docker/ 9 | ├── README.md # Comprehensive Docker deployment guide 10 | ├── Dockerfile # Full stack image (Streamlit + Ollama) 11 | ├── Dockerfile.lite # Lightweight image (Streamlit only) 12 | ├── docker-compose.yml # Full stack orchestration 13 | ├── docker-compose.lite.yml # Lightweight orchestration 14 | ├── .dockerignore # Docker build ignore rules 15 | ├── start.bat # Windows quick start script 16 | └── start.sh # Linux/Mac quick start script 17 | ``` 18 | 19 | ## 🚀 Quick Start 20 | 21 | ### Windows 22 | 23 | ```powershell 24 | cd docker 25 | .\start.bat 26 | ``` 27 | 28 | ### Linux/Mac 29 | 30 | ```bash 31 | cd docker 32 | chmod +x start.sh 33 | ./start.sh 34 | ``` 35 | 36 | ### Manual Start 37 | 38 | ```bash 39 | cd docker 40 | docker-compose up -d 41 | ``` 42 | 43 | Access the application at: **http://localhost:8501** 44 | 45 | ## 📚 Full Documentation 46 | 47 | See [README.md](README.md) in this folder for: 48 | 49 | - Detailed deployment options 50 | - Configuration guide 51 | - Troubleshooting 52 | - Production setup 53 | - Backup and restore 54 | 55 | ## ⚡ Quick Commands 56 | 57 | ```bash 58 | # Start services 59 | docker-compose up -d 60 | 61 | # View logs 62 | docker-compose logs -f 63 | 64 | # Stop services 65 | docker-compose down 66 | 67 | # Stop and remove volumes 68 | docker-compose down -v 69 | 70 | # Rebuild 71 | docker-compose up -d --build 72 | ``` 73 | 74 | ## 🔧 Requirements 75 | 76 | - Docker Desktop installed 77 | - 8GB+ RAM available 78 | - 20GB free disk space 79 | - Ports 8501 and 11434 available 80 | 81 | --- 82 | 83 | **For complete instructions, see [README.md](README.md)** 84 | -------------------------------------------------------------------------------- /multimodal_rag/vector_stores/multimodal_chroma_store.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multimodal ChromaDB vector store with enhanced capabilities. 3 | """ 4 | 5 | import logging 6 | from typing import Dict, Any, List, Optional, Union 7 | from pathlib import Path 8 | 9 | from ..base import DocumentChunk, RetrievalResult 10 | from .chroma_store import ChromaVectorStore 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class MultimodalRetrievalResult(RetrievalResult): 16 | """Enhanced retrieval result with multimodal capabilities.""" 17 | 18 | def __init__(self, *args, **kwargs): 19 | self.visual_matches = kwargs.pop('visual_matches', []) 20 | super().__init__(*args, **kwargs) 21 | 22 | 23 | class MultimodalChromaVectorStore(ChromaVectorStore): 24 | """Enhanced ChromaDB vector store with multimodal support.""" 25 | 26 | def __init__(self, config: Dict[str, Any]): 27 | super().__init__(config) 28 | logger.info("Multimodal ChromaDB vector store initialized") 29 | 30 | def similarity_search_with_visual(self, query: str, visual_query_features: Optional[List[float]] = None, 31 | top_k: int = 5) -> MultimodalRetrievalResult: 32 | """Enhanced similarity search with visual features support.""" 33 | 34 | # For now, just use the regular similarity search 35 | # In a full implementation, this would combine text and visual embeddings 36 | regular_result = self.similarity_search(query, top_k) 37 | 38 | # Convert to multimodal result 39 | return MultimodalRetrievalResult( 40 | chunks=regular_result.chunks, 41 | scores=regular_result.scores, 42 | query=query, 43 | total_results=regular_result.total_results, 44 | retrieval_time=regular_result.retrieval_time, 45 | visual_matches=[] # Placeholder for visual matches 46 | ) -------------------------------------------------------------------------------- /multimodal_rag/enhanced_system.py: -------------------------------------------------------------------------------- 1 | """ 2 | Enhanced system - simplified version that uses SimpleRAGSystem. 3 | """ 4 | 5 | import logging 6 | from typing import Dict, Any, Union 7 | from .base import QueryRequest, QueryResponse 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class MultimodalQueryRequest(QueryRequest): 13 | """Extended query request with multimodal capabilities.""" 14 | 15 | def __init__(self, query: str = "", *args, **kwargs): 16 | self.search_type: str = kwargs.pop('search_type', 'hybrid') 17 | self.visual_query_image: str = kwargs.pop('visual_query_image', None) 18 | 19 | if not query and args: 20 | query = args[0] 21 | args = args[1:] 22 | 23 | super().__init__(query, *args, **kwargs) 24 | 25 | 26 | class EnhancedMultimodalRAGSystem: 27 | """Enhanced multimodal RAG system - simplified version.""" 28 | 29 | def __init__(self, config_dict: Dict[str, Any]): 30 | """Initialize the enhanced system.""" 31 | from .system import SimpleRAGSystem 32 | 33 | self.config = config_dict 34 | self._simple_system = SimpleRAGSystem(config_dict) 35 | logger.info("Enhanced Multimodal RAG System (simplified) initialized") 36 | 37 | def is_available(self) -> bool: 38 | """Check if the system is available.""" 39 | return self._simple_system.is_available() 40 | 41 | def ingest_file(self, file_path): 42 | """Ingest a file into the system.""" 43 | return self._simple_system.ingest_file(file_path) 44 | 45 | def query(self, query: Union[str, QueryRequest]) -> QueryResponse: 46 | """Process a query and return response.""" 47 | return self._simple_system.query(query) 48 | 49 | def get_system_status(self) -> Dict[str, Any]: 50 | """Get current system status.""" 51 | status = self._simple_system.get_system_status() 52 | status['system_type'] = 'enhanced_traditional_simplified' 53 | return status 54 | -------------------------------------------------------------------------------- /docker/docker-compose.lite.yml: -------------------------------------------------------------------------------- 1 | # Docker Compose for Lightweight setup (Ollama on host) 2 | # Optimized for production with security hardening 3 | version: '3.8' 4 | 5 | services: 6 | smartrag-lite: 7 | build: 8 | context: .. 9 | dockerfile: docker/Dockerfile.lite 10 | args: 11 | - BUILDKIT_INLINE_CACHE=1 12 | image: smartrag:lite 13 | container_name: smartrag-lite 14 | ports: 15 | - "8501:8501" 16 | volumes: 17 | # Use named volumes for better management 18 | - smartrag_lite_vector_db:/app/vector_db 19 | - smartrag_lite_user_data:/app/user_data 20 | - smartrag_lite_uploads:/app/temp_uploads 21 | - smartrag_lite_logs:/app/logs 22 | - smartrag_lite_db:/app/db 23 | environment: 24 | - STREAMLIT_SERVER_PORT=8501 25 | - STREAMLIT_SERVER_ADDRESS=0.0.0.0 26 | - OLLAMA_HOST=http://host.docker.internal:11434 27 | - PYTHONUNBUFFERED=1 28 | - MAX_FILE_SIZE_MB=50 29 | # Optional: Add custom config 30 | # - SMARTRAG_LLM_MODEL=llama3.1:8b 31 | # - SMARTRAG_TEMPERATURE=0.7 32 | extra_hosts: 33 | - "host.docker.internal:host-gateway" 34 | restart: unless-stopped 35 | healthcheck: 36 | test: | 37 | curl -f http://localhost:8501/_stcore/health || exit 1 38 | interval: 30s 39 | timeout: 10s 40 | retries: 3 41 | start_period: 40s 42 | security_opt: 43 | - no-new-privileges:true 44 | deploy: 45 | resources: 46 | limits: 47 | cpus: '2' 48 | memory: 4G 49 | reservations: 50 | cpus: '1' 51 | memory: 2G 52 | logging: 53 | driver: "json-file" 54 | options: 55 | max-size: "10m" 56 | max-file: "3" 57 | networks: 58 | - smartrag-lite-network 59 | 60 | volumes: 61 | smartrag_lite_vector_db: 62 | driver: local 63 | smartrag_lite_user_data: 64 | driver: local 65 | smartrag_lite_uploads: 66 | driver: local 67 | smartrag_lite_logs: 68 | driver: local 69 | smartrag_lite_db: 70 | driver: local 71 | 72 | networks: 73 | smartrag-lite-network: 74 | driver: bridge 75 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | system: 2 | name: "SmartRAG System" 3 | version: "2.0.0" 4 | offline_mode: true 5 | debug: false 6 | log_level: "INFO" 7 | 8 | models: 9 | # LLM for response generation (offline with Ollama) 10 | llm_type: "ollama" 11 | llm_model: "llama3.1:8b" # Llama 3.1 8B model (available in Ollama) 12 | ollama_host: "http://localhost:11434" 13 | 14 | # Text embedding model (using Ollama's embedding model) 15 | embedding_model: "nomic-embed-text" 16 | embedding_dimension: 768 # nomic-embed-text produces 768-dimensional vectors 17 | 18 | # Vision model for image understanding (BLIP) 19 | vision_model: "Salesforce/blip-image-captioning-base" 20 | 21 | # Speech-to-text model (keeping audio processing) 22 | whisper_model: "base" 23 | 24 | vector_store: 25 | type: "chromadb" 26 | persist_directory: "./vector_db" 27 | collection_name: "multimodal_documents" 28 | embedding_dimension: 768 # Must match models.embedding_dimension 29 | ollama_host: "http://localhost:11434" 30 | 31 | processing: 32 | # Text chunking settings 33 | chunk_size: 1000 34 | chunk_overlap: 200 35 | 36 | # Traditional image processing settings (with OCR via Tesseract) 37 | max_image_size: [1024, 1024] # Larger size for better OCR 38 | ocr_enabled: true # Enable Tesseract OCR 39 | store_original_images: true 40 | image_preprocessing: "resize" # resize, crop, none 41 | 42 | # Audio processing settings 43 | audio_sample_rate: 16000 44 | max_audio_duration: 300 # seconds 45 | 46 | # Batch processing 47 | batch_size: 32 # Higher for traditional processing 48 | 49 | retrieval: 50 | top_k: 5 51 | similarity_threshold: 0.7 52 | rerank_enabled: false 53 | 54 | generation: 55 | # LLM generation parameters 56 | max_tokens: 2048 57 | temperature: 0.7 58 | top_p: 0.9 59 | top_k: 50 60 | do_sample: true 61 | max_new_tokens: 1024 62 | 63 | ui: 64 | title: "SmartRAG - Multimodal AI Assistant" 65 | page_icon: "🤖" 66 | layout: "wide" 67 | theme: "dark" 68 | show_recent_uploads: true 69 | max_upload_size_mb: 200 70 | 71 | supported_formats: 72 | documents: [".pdf", ".docx", ".doc", ".txt", ".md", ".rtf"] 73 | images: [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"] 74 | audio: [".mp3", ".wav", ".m4a", ".ogg", ".flac"] 75 | 76 | storage: 77 | data_directory: "./data" 78 | logs_directory: "./logs" 79 | cache_directory: "./cache" 80 | temp_uploads_directory: "./temp_uploads" 81 | -------------------------------------------------------------------------------- /docker/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # SmartRAG Docker Quick Start Script for Linux/Mac 3 | 4 | set -e 5 | 6 | echo "========================================" 7 | echo "SmartRAG Docker Deployment" 8 | echo "========================================" 9 | echo 10 | 11 | # Check if Docker is installed 12 | if ! command -v docker &> /dev/null; then 13 | echo "ERROR: Docker is not installed" 14 | echo "Please install Docker from https://docs.docker.com/get-docker/" 15 | exit 1 16 | fi 17 | 18 | echo "Docker is installed. Checking Docker Compose..." 19 | if ! command -v docker-compose &> /dev/null; then 20 | echo "ERROR: Docker Compose is not installed" 21 | echo "Please install Docker Compose from https://docs.docker.com/compose/install/" 22 | exit 1 23 | fi 24 | 25 | echo 26 | echo "Select deployment option:" 27 | echo "1. Full Stack (App + Ollama in Docker) - Recommended" 28 | echo "2. Lightweight (App in Docker, Ollama on host)" 29 | echo 30 | read -p "Enter your choice (1 or 2): " choice 31 | 32 | if [ "$choice" == "1" ]; then 33 | echo 34 | echo "Starting Full Stack deployment..." 35 | echo "This will take 10-15 minutes on first run (downloading models)" 36 | echo 37 | docker-compose up -d 38 | 39 | echo 40 | echo "========================================" 41 | echo "SmartRAG is starting!" 42 | echo "========================================" 43 | echo 44 | echo "Access the application at: http://localhost:8501" 45 | echo 46 | echo "To view logs: docker-compose logs -f" 47 | echo "To stop: docker-compose down" 48 | echo 49 | 50 | elif [ "$choice" == "2" ]; then 51 | echo 52 | echo "Starting Lightweight deployment..." 53 | echo 54 | echo "IMPORTANT: Make sure Ollama is running on your host machine!" 55 | echo "If not running, open another terminal and run: ollama serve" 56 | echo 57 | read -p "Press Enter to continue..." 58 | 59 | docker-compose -f docker-compose.lite.yml up -d 60 | 61 | echo 62 | echo "========================================" 63 | echo "SmartRAG Lightweight is running!" 64 | echo "========================================" 65 | echo 66 | echo "Access the application at: http://localhost:8501" 67 | echo 68 | echo "To view logs: docker-compose -f docker-compose.lite.yml logs -f" 69 | echo "To stop: docker-compose -f docker-compose.lite.yml down" 70 | echo 71 | 72 | else 73 | echo "Invalid choice. Please run the script again." 74 | exit 1 75 | fi 76 | -------------------------------------------------------------------------------- /docker/start.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | REM SmartRAG Docker Quick Start Script for Windows 3 | 4 | echo ======================================== 5 | echo SmartRAG Docker Deployment 6 | echo ======================================== 7 | echo. 8 | 9 | REM Check if Docker is installed 10 | docker --version >nul 2>&1 11 | if %errorlevel% neq 0 ( 12 | echo ERROR: Docker is not installed or not in PATH 13 | echo Please install Docker Desktop from https://www.docker.com/products/docker-desktop 14 | pause 15 | exit /b 1 16 | ) 17 | 18 | echo Docker is installed. Checking Docker Compose... 19 | docker-compose --version >nul 2>&1 20 | if %errorlevel% neq 0 ( 21 | echo ERROR: Docker Compose is not installed 22 | echo Please install Docker Compose 23 | pause 24 | exit /b 1 25 | ) 26 | 27 | echo. 28 | echo Select deployment option: 29 | echo 1. Full Stack (App + Ollama in Docker) - Recommended 30 | echo 2. Lightweight (App in Docker, Ollama on host) 31 | echo. 32 | set /p choice="Enter your choice (1 or 2): " 33 | 34 | if "%choice%"=="1" ( 35 | echo. 36 | echo Starting Full Stack deployment... 37 | echo This will take 10-15 minutes on first run (downloading models) 38 | echo. 39 | docker-compose up -d 40 | if %errorlevel% equ 0 ( 41 | echo. 42 | echo ======================================== 43 | echo SmartRAG is starting! 44 | echo ======================================== 45 | echo. 46 | echo Access the application at: http://localhost:8501 47 | echo. 48 | echo To view logs: docker-compose logs -f 49 | echo To stop: docker-compose down 50 | echo. 51 | ) else ( 52 | echo. 53 | echo ERROR: Failed to start SmartRAG 54 | echo Check the logs with: docker-compose logs 55 | ) 56 | ) else if "%choice%"=="2" ( 57 | echo. 58 | echo Starting Lightweight deployment... 59 | echo. 60 | echo IMPORTANT: Make sure Ollama is running on your host machine! 61 | echo If not running, open another terminal and run: ollama serve 62 | echo. 63 | pause 64 | docker-compose -f docker-compose.lite.yml up -d 65 | if %errorlevel% equ 0 ( 66 | echo. 67 | echo ======================================== 68 | echo SmartRAG Lightweight is running! 69 | echo ======================================== 70 | echo. 71 | echo Access the application at: http://localhost:8501 72 | echo. 73 | echo To view logs: docker-compose -f docker-compose.lite.yml logs -f 74 | echo To stop: docker-compose -f docker-compose.lite.yml down 75 | echo. 76 | ) else ( 77 | echo. 78 | echo ERROR: Failed to start SmartRAG Lightweight 79 | echo Check the logs with: docker-compose -f docker-compose.lite.yml logs 80 | ) 81 | ) else ( 82 | echo Invalid choice. Please run the script again. 83 | pause 84 | exit /b 1 85 | ) 86 | 87 | pause 88 | -------------------------------------------------------------------------------- /start.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | SmartRAG - Application Launcher 4 | Validates environment and starts the SmartRAG web interface 5 | """ 6 | 7 | import subprocess 8 | import sys 9 | import os 10 | from pathlib import Path 11 | 12 | def check_environment(): 13 | """Pre-flight checks before starting.""" 14 | errors = [] 15 | 16 | # Check if we're in the right directory 17 | if not Path("chatbot_app.py").exists(): 18 | errors.append("❌ chatbot_app.py not found. Run this script from the project root directory.") 19 | 20 | # Check if config exists 21 | if not Path("config.yaml").exists(): 22 | errors.append("❌ config.yaml not found. Configuration file is required.") 23 | 24 | # Check if requirements are installed 25 | try: 26 | import streamlit 27 | except ImportError: 28 | errors.append("❌ Streamlit not installed. Run: pip install -r requirements.txt") 29 | 30 | return errors 31 | 32 | def main(): 33 | """Start the SmartRAG web application.""" 34 | print("🚀 Starting SmartRAG Multimodal ChatBot...") 35 | print("=" * 60) 36 | 37 | # Run pre-flight checks 38 | errors = check_environment() 39 | if errors: 40 | print("\n⚠️ Pre-flight checks failed:\n") 41 | for error in errors: 42 | print(f" {error}") 43 | print("\n� Fix the issues above and try again.") 44 | sys.exit(1) 45 | 46 | print("✅ Environment validated") 47 | print(f"�📁 Working directory: {os.getcwd()}") 48 | print("🌐 Web interface: http://localhost:8501") 49 | print("🤖 Model: Llama 3.1 8B via Ollama (local)") 50 | print("🔒 Privacy: Completely OFFLINE") 51 | print("") 52 | print("� Instructions:") 53 | print(" 1. The app will open in your default browser") 54 | print(" 2. Upload documents using the sidebar") 55 | print(" 3. Start chatting with your documents!") 56 | print(" 4. Press Ctrl+C to stop the server") 57 | print("") 58 | print("=" * 60) 59 | print("") 60 | 61 | try: 62 | # Start the Streamlit chatbot app 63 | subprocess.run( 64 | [sys.executable, "-m", "streamlit", "run", "chatbot_app.py", 65 | "--server.port=8501", 66 | "--server.address=localhost"], 67 | check=True 68 | ) 69 | except KeyboardInterrupt: 70 | print("\n\n🛑 SmartRAG stopped by user") 71 | except subprocess.CalledProcessError as e: 72 | print(f"\n❌ Error starting application: {e}") 73 | print("\n🔧 Troubleshooting:") 74 | print(" 1. Check if Ollama is running: ollama list") 75 | print(" 2. Verify Python packages: pip install -r requirements.txt") 76 | print(" 3. Check logs in ./logs/ for detailed errors") 77 | sys.exit(1) 78 | except Exception as e: 79 | print(f"\n❌ Unexpected error: {e}") 80 | print("\n🔧 Make sure you have:") 81 | print(" 1. Installed requirements: pip install -r requirements.txt") 82 | print(" 2. Ollama running with required models") 83 | sys.exit(1) 84 | 85 | if __name__ == "__main__": 86 | main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup script for SmartRAG multimodal RAG system. 3 | """ 4 | 5 | from setuptools import setup, find_packages 6 | from pathlib import Path 7 | 8 | # Read README file 9 | readme_file = Path(__file__).parent / "README.md" 10 | long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists() else "" 11 | 12 | # Read requirements 13 | requirements_file = Path(__file__).parent / "requirements.txt" 14 | if requirements_file.exists(): 15 | requirements = requirements_file.read_text().strip().split('\n') 16 | requirements = [req.strip() for req in requirements if req.strip() and not req.startswith('#')] 17 | else: 18 | requirements = [ 19 | "langchain>=0.1.0", 20 | "chromadb>=0.4.0", 21 | "sentence-transformers>=2.2.0", 22 | "transformers>=4.30.0", 23 | "torch>=2.0.0", 24 | "PyPDF2>=3.0.1", 25 | "python-docx>=1.0.0", 26 | "Pillow>=10.0.0", 27 | "requests>=2.31.0", 28 | "numpy>=1.24.0", 29 | "pandas>=2.0.0", 30 | "tqdm>=4.65.0", 31 | "pyyaml>=6.0.0", 32 | "click>=8.1.0" 33 | ] 34 | 35 | setup( 36 | name="smartrag", 37 | version="1.0.0", 38 | author="SmartRAG Team", 39 | author_email="team@smartrag.com", 40 | description="Multimodal Retrieval-Augmented Generation system for documents, images, and audio", 41 | long_description=long_description, 42 | long_description_content_type="text/markdown", 43 | url="https://github.com/your-org/smartrag", 44 | packages=find_packages(), 45 | classifiers=[ 46 | "Development Status :: 4 - Beta", 47 | "Intended Audience :: Developers", 48 | "Intended Audience :: Science/Research", 49 | "License :: OSI Approved :: MIT License", 50 | "Operating System :: OS Independent", 51 | "Programming Language :: Python :: 3", 52 | "Programming Language :: Python :: 3.8", 53 | "Programming Language :: Python :: 3.9", 54 | "Programming Language :: Python :: 3.10", 55 | "Programming Language :: Python :: 3.11", 56 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 57 | "Topic :: Software Development :: Libraries :: Python Modules", 58 | "Topic :: Text Processing :: Indexing", 59 | ], 60 | python_requires=">=3.8", 61 | install_requires=requirements, 62 | extras_require={ 63 | "all": [ 64 | "pdfplumber>=0.9.0", 65 | "pytesseract>=0.3.10", 66 | "opencv-python>=4.8.0", 67 | "openai-whisper>=20231117", 68 | "pydub>=0.25.1", 69 | "librosa>=0.10.0", 70 | "faiss-cpu>=1.7.4" 71 | ], 72 | "audio": [ 73 | "openai-whisper>=20231117", 74 | "pydub>=0.25.1", 75 | "librosa>=0.10.0", 76 | "SpeechRecognition>=3.10.0" 77 | ], 78 | "image": [ 79 | "pytesseract>=0.3.10", 80 | "opencv-python>=4.8.0" 81 | ], 82 | "pdf": [ 83 | "pdfplumber>=0.9.0" 84 | ], 85 | "dev": [ 86 | "pytest>=7.4.0", 87 | "black>=23.0.0", 88 | "flake8>=6.0.0", 89 | "coverage>=7.0.0" 90 | ] 91 | }, 92 | entry_points={ 93 | "console_scripts": [ 94 | "smartrag=cli:main", 95 | ], 96 | }, 97 | include_package_data=True, 98 | package_data={ 99 | "smartrag": ["config/*.yaml"], 100 | }, 101 | keywords=[ 102 | "rag", "retrieval", "generation", "multimodal", "llm", 103 | "document-processing", "semantic-search", "ai", "nlp" 104 | ], 105 | project_urls={ 106 | "Bug Reports": "https://github.com/your-org/smartrag/issues", 107 | "Source": "https://github.com/your-org/smartrag", 108 | "Documentation": "https://github.com/your-org/smartrag/wiki", 109 | }, 110 | ) -------------------------------------------------------------------------------- /docker/Dockerfile.lite: -------------------------------------------------------------------------------- 1 | # Lightweight Dockerfile for SmartRAG (without Ollama in container) 2 | # Use this if you want to run Ollama separately on the host 3 | # Optimized with multi-stage build and security hardening 4 | 5 | # ============================================ 6 | # Stage 1: Builder 7 | # ============================================ 8 | FROM python:3.10-slim as builder 9 | 10 | WORKDIR /build 11 | 12 | # Install build dependencies 13 | RUN apt-get update && apt-get install -y --no-install-recommends \ 14 | gcc \ 15 | g++ \ 16 | && rm -rf /var/lib/apt/lists/* 17 | 18 | # Copy and install Python dependencies 19 | COPY requirements.txt . 20 | RUN pip install --no-cache-dir --user -r requirements.txt 21 | 22 | # Pre-download models 23 | RUN mkdir -p /build/.cache/huggingface 24 | ENV HF_HOME=/build/.cache/huggingface 25 | RUN python3 -c "from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel; \ 26 | BlipProcessor.from_pretrained('Salesforce/blip-image-captioning-base'); \ 27 | BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base'); \ 28 | CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32'); \ 29 | CLIPModel.from_pretrained('openai/clip-vit-base-patch32'); \ 30 | print('✓ Models cached')" 31 | 32 | RUN python3 -c "import whisper; whisper.load_model('base'); print('✓ Whisper cached')" 33 | 34 | # ============================================ 35 | # Stage 2: Runtime 36 | # ============================================ 37 | FROM python:3.10-slim 38 | 39 | LABEL maintainer="SmartRAG Team" 40 | LABEL version="1.0-lite" 41 | LABEL description="Lightweight SmartRAG (Ollama on host)" 42 | 43 | # Install runtime system dependencies only 44 | RUN apt-get update && apt-get install -y --no-install-recommends \ 45 | tesseract-ocr \ 46 | tesseract-ocr-eng \ 47 | ffmpeg \ 48 | libsm6 \ 49 | libxext6 \ 50 | libxrender-dev \ 51 | libgomp1 \ 52 | curl \ 53 | ca-certificates \ 54 | tini \ 55 | && rm -rf /var/lib/apt/lists/* \ 56 | && apt-get clean 57 | 58 | # Create non-root user 59 | RUN groupadd -r smartrag --gid=1000 && \ 60 | useradd -r -g smartrag --uid=1000 --create-home --shell /bin/bash smartrag 61 | 62 | WORKDIR /app 63 | 64 | # Copy packages and models from builder 65 | COPY --from=builder --chown=smartrag:smartrag /root/.local /home/smartrag/.local 66 | COPY --from=builder --chown=smartrag:smartrag /build/.cache /home/smartrag/.cache 67 | 68 | ENV PATH=/home/smartrag/.local/bin:$PATH 69 | ENV HF_HOME=/home/smartrag/.cache/huggingface 70 | ENV TRANSFORMERS_CACHE=/home/smartrag/.cache/huggingface 71 | 72 | # Copy application code 73 | COPY --chown=smartrag:smartrag . . 74 | 75 | # Create necessary directories with proper permissions 76 | RUN mkdir -p vector_db temp_uploads user_data logs data && \ 77 | chown -R smartrag:smartrag vector_db temp_uploads user_data logs data 78 | 79 | # Expose Streamlit port 80 | EXPOSE 8501 81 | 82 | # Set environment variables 83 | ENV STREAMLIT_SERVER_PORT=8501 \ 84 | STREAMLIT_SERVER_ADDRESS=0.0.0.0 \ 85 | STREAMLIT_SERVER_HEADLESS=true \ 86 | STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \ 87 | OLLAMA_HOST=http://host.docker.internal:11434 \ 88 | PYTHONUNBUFFERED=1 \ 89 | PYTHONDONTWRITEBYTECODE=1 \ 90 | MAX_FILE_SIZE_MB=50 \ 91 | MAX_UPLOAD_SIZE=52428800 92 | 93 | # Switch to non-root user 94 | USER smartrag 95 | 96 | # Health check with Ollama connectivity test 97 | HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \ 98 | CMD curl -f http://localhost:8501/_stcore/health || exit 1 99 | 100 | # Use tini for proper signal handling 101 | ENTRYPOINT ["/usr/bin/tini", "--"] 102 | 103 | # Run Streamlit with production settings 104 | CMD ["streamlit", "run", "chatbot_app.py", \ 105 | "--server.port=8501", \ 106 | "--server.address=0.0.0.0", \ 107 | "--server.headless=true", \ 108 | "--server.fileWatcherType=none", \ 109 | "--browser.gatherUsageStats=false"] 110 | -------------------------------------------------------------------------------- /docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | smartrag: 5 | build: 6 | context: .. 7 | dockerfile: docker/Dockerfile 8 | args: 9 | - BUILDKIT_INLINE_CACHE=1 10 | image: smartrag:latest 11 | container_name: smartrag-app 12 | ports: 13 | - "8501:8501" # Streamlit UI 14 | - "11434:11434" # Ollama API 15 | volumes: 16 | # Persist vector database 17 | - smartrag_vector_db:/app/vector_db 18 | # Persist user data 19 | - smartrag_user_data:/app/user_data 20 | # Persist uploaded files 21 | - smartrag_uploads:/app/temp_uploads 22 | # Persist logs 23 | - smartrag_logs:/app/logs 24 | # Persist SQLite database 25 | - smartrag_db:/app/db 26 | # Ollama models persistence (updated path for non-root user) 27 | - ollama_models:/home/smartrag/.ollama 28 | environment: 29 | - STREAMLIT_SERVER_PORT=8501 30 | - STREAMLIT_SERVER_ADDRESS=0.0.0.0 31 | - OLLAMA_HOST=http://localhost:11434 32 | - PYTHONUNBUFFERED=1 33 | - MAX_FILE_SIZE_MB=50 34 | # Optional: Add your custom config overrides 35 | # - SMARTRAG_LLM_MODEL=llama3.1:8b 36 | # - SMARTRAG_DEBUG=false 37 | restart: unless-stopped 38 | healthcheck: 39 | test: | 40 | curl -f http://localhost:8501/_stcore/health && \ 41 | curl -f http://localhost:11434/api/tags > /dev/null || exit 1 42 | interval: 30s 43 | timeout: 10s 44 | retries: 3 45 | start_period: 120s # Increased for model downloads 46 | shm_size: '2gb' # Shared memory for ML models 47 | security_opt: 48 | - no-new-privileges:true 49 | deploy: 50 | resources: 51 | limits: 52 | cpus: '4' 53 | memory: 8G # Adjust based on your system 54 | reservations: 55 | cpus: '2' 56 | memory: 4G 57 | logging: 58 | driver: "json-file" 59 | options: 60 | max-size: "10m" 61 | max-file: "3" 62 | networks: 63 | - smartrag-network 64 | 65 | # Optional: Separate PostgreSQL for production (uncomment if needed) 66 | # postgres: 67 | # image: postgres:15-alpine 68 | # container_name: smartrag-db 69 | # environment: 70 | # - POSTGRES_DB=smartrag 71 | # - POSTGRES_USER=smartrag 72 | # - POSTGRES_PASSWORD=changeme # Use secrets in production! 73 | # volumes: 74 | # - postgres_data:/var/lib/postgresql/data 75 | # ports: 76 | # - "5432:5432" 77 | # restart: unless-stopped 78 | # healthcheck: 79 | # test: ["CMD-SHELL", "pg_isready -U smartrag"] 80 | # interval: 10s 81 | # timeout: 5s 82 | # retries: 5 83 | # networks: 84 | # - smartrag-network 85 | 86 | # Optional: Redis for caching (uncomment if needed) 87 | # redis: 88 | # image: redis:7-alpine 89 | # container_name: smartrag-cache 90 | # command: redis-server --appendonly yes --maxmemory 1gb --maxmemory-policy allkeys-lru 91 | # volumes: 92 | # - redis_data:/data 93 | # ports: 94 | # - "6379:6379" 95 | # restart: unless-stopped 96 | # healthcheck: 97 | # test: ["CMD", "redis-cli", "ping"] 98 | # interval: 10s 99 | # timeout: 3s 100 | # retries: 5 101 | # networks: 102 | # - smartrag-network 103 | 104 | # Optional: Nginx reverse proxy (uncomment for production) 105 | # nginx: 106 | # image: nginx:alpine 107 | # container_name: smartrag-proxy 108 | # ports: 109 | # - "80:80" 110 | # - "443:443" 111 | # volumes: 112 | # - ./nginx.conf:/etc/nginx/nginx.conf:ro 113 | # - ./ssl:/etc/nginx/ssl:ro 114 | # depends_on: 115 | # - smartrag 116 | # restart: unless-stopped 117 | # networks: 118 | # - smartrag-network 119 | 120 | volumes: 121 | ollama_models: 122 | driver: local 123 | smartrag_vector_db: 124 | driver: local 125 | smartrag_user_data: 126 | driver: local 127 | smartrag_uploads: 128 | driver: local 129 | smartrag_logs: 130 | driver: local 131 | smartrag_db: 132 | driver: local 133 | # postgres_data: 134 | # driver: local 135 | # redis_data: 136 | # driver: local 137 | 138 | networks: 139 | smartrag-network: 140 | driver: bridge 141 | -------------------------------------------------------------------------------- /clean_db.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clean up old test data from vector database 3 | """ 4 | 5 | import shutil 6 | from pathlib import Path 7 | import sys 8 | 9 | def check_running_processes(): 10 | """Check for processes that might be using the database.""" 11 | try: 12 | import psutil 13 | 14 | print("🔍 Checking for running processes...") 15 | python_processes = [] 16 | 17 | for proc in psutil.process_iter(['pid', 'name', 'cmdline']): 18 | try: 19 | if proc.info['name'] and 'python' in proc.info['name'].lower(): 20 | if proc.info['cmdline']: 21 | cmdline_str = ' '.join(proc.info['cmdline']).lower() 22 | if any(keyword in cmdline_str for keyword in ['enhanced_app', 'smartrag', 'fastapi']): 23 | python_processes.append(proc) 24 | except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): 25 | pass 26 | 27 | if python_processes: 28 | print(f"⚠️ Found {len(python_processes)} Python processes that might be using the database:") 29 | for proc in python_processes: 30 | try: 31 | cmdline = ' '.join(proc.info['cmdline'][:3]) if proc.info['cmdline'] else 'Unknown' 32 | print(f" PID: {proc.pid} - {cmdline}") 33 | except: 34 | print(f" PID: {proc.pid} - Process info unavailable") 35 | 36 | print("\n🛑 Please stop the web app first, then run this script again.") 37 | print(" You can stop it by pressing Ctrl+C in the terminal running the enhanced_app.py") 38 | return False 39 | 40 | print("✅ No conflicting processes found") 41 | return True 42 | 43 | except ImportError: 44 | print("⚠️ psutil not available - skipping process check") 45 | print(" Make sure to stop any running web apps manually") 46 | return True 47 | 48 | def clean_vector_database(): 49 | """Remove the entire vector database to start fresh.""" 50 | print("🧹 Cleaning Vector Database") 51 | print("=" * 40) 52 | 53 | # Check for running processes first 54 | if not check_running_processes(): 55 | return False 56 | 57 | vector_db_path = Path("./vector_db") 58 | 59 | if vector_db_path.exists(): 60 | try: 61 | # Remove the entire vector database directory 62 | shutil.rmtree(vector_db_path) 63 | print("✅ Successfully removed vector database directory") 64 | print("📁 Path removed:", vector_db_path.absolute()) 65 | 66 | # Verify it's gone 67 | if not vector_db_path.exists(): 68 | print("✅ Cleanup confirmed - vector database is now empty") 69 | else: 70 | print("❌ Cleanup failed - directory still exists") 71 | return False 72 | 73 | except Exception as e: 74 | print(f"❌ Error during cleanup: {e}") 75 | print("\n🔧 Manual cleanup steps:") 76 | print("1. Stop all Python processes") 77 | print("2. Delete the 'vector_db' folder manually") 78 | print("3. Restart the Streamlit app") 79 | return False 80 | else: 81 | print("ℹ️ Vector database directory doesn't exist - already clean") 82 | 83 | # Also offer to clean user data 84 | user_data_path = Path("user_data") 85 | if user_data_path.exists(): 86 | try: 87 | response = input("\n🤔 Do you also want to clear user data (uploaded files history and query history)? (y/N): ") 88 | if response.lower() in ['y', 'yes']: 89 | shutil.rmtree(user_data_path) 90 | print("✅ User data cleared!") 91 | else: 92 | print("ℹ️ User data preserved") 93 | except: 94 | print("ℹ️ User data preserved") 95 | 96 | print("\n🎯 Next steps:") 97 | print("1. Start your Streamlit app: python run_streamlit.py") 98 | print("2. Upload your files") 99 | print("3. The system will create a fresh, clean vector database") 100 | return True 101 | 102 | if __name__ == "__main__": 103 | clean_vector_database() -------------------------------------------------------------------------------- /docker/nginx.conf: -------------------------------------------------------------------------------- 1 | # Nginx reverse proxy configuration for SmartRAG 2 | # Production-ready with security headers and rate limiting 3 | 4 | events { 5 | worker_connections 1024; 6 | } 7 | 8 | http { 9 | # Basic settings 10 | sendfile on; 11 | tcp_nopush on; 12 | tcp_nodelay on; 13 | keepalive_timeout 65; 14 | types_hash_max_size 2048; 15 | client_max_body_size 50M; # Match MAX_FILE_SIZE_MB 16 | 17 | # Security headers 18 | add_header X-Frame-Options "SAMEORIGIN" always; 19 | add_header X-Content-Type-Options "nosniff" always; 20 | add_header X-XSS-Protection "1; mode=block" always; 21 | add_header Referrer-Policy "no-referrer-when-downgrade" always; 22 | 23 | # Rate limiting 24 | limit_req_zone $binary_remote_addr zone=smartrag_limit:10m rate=10r/s; 25 | limit_conn_zone $binary_remote_addr zone=addr:10m; 26 | 27 | # Upstream Streamlit app 28 | upstream smartrag_app { 29 | server smartrag:8501; 30 | keepalive 32; 31 | } 32 | 33 | # Upstream Ollama API 34 | upstream ollama_api { 35 | server smartrag:11434; 36 | keepalive 32; 37 | } 38 | 39 | # HTTP server (redirect to HTTPS in production) 40 | server { 41 | listen 80; 42 | server_name _; 43 | 44 | # Uncomment for HTTPS redirect in production 45 | # return 301 https://$host$request_uri; 46 | 47 | # For development, serve directly 48 | location / { 49 | limit_req zone=smartrag_limit burst=20 nodelay; 50 | limit_conn addr 10; 51 | 52 | proxy_pass http://smartrag_app; 53 | proxy_http_version 1.1; 54 | 55 | proxy_set_header Host $host; 56 | proxy_set_header X-Real-IP $remote_addr; 57 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 58 | proxy_set_header X-Forwarded-Proto $scheme; 59 | 60 | # WebSocket support for Streamlit 61 | proxy_set_header Upgrade $http_upgrade; 62 | proxy_set_header Connection "upgrade"; 63 | 64 | proxy_read_timeout 86400; 65 | proxy_buffering off; 66 | } 67 | 68 | # Health check endpoint (no rate limit) 69 | location /_stcore/health { 70 | proxy_pass http://smartrag_app/_stcore/health; 71 | access_log off; 72 | } 73 | 74 | # Ollama API (optional, comment out for security) 75 | location /api/ollama/ { 76 | limit_req zone=smartrag_limit burst=5 nodelay; 77 | 78 | rewrite ^/api/ollama/(.*) /api/$1 break; 79 | proxy_pass http://ollama_api; 80 | proxy_http_version 1.1; 81 | 82 | proxy_set_header Host $host; 83 | proxy_set_header X-Real-IP $remote_addr; 84 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 85 | } 86 | } 87 | 88 | # HTTPS server (uncomment for production with SSL) 89 | # server { 90 | # listen 443 ssl http2; 91 | # server_name your-domain.com; 92 | # 93 | # ssl_certificate /etc/nginx/ssl/cert.pem; 94 | # ssl_certificate_key /etc/nginx/ssl/key.pem; 95 | # ssl_protocols TLSv1.2 TLSv1.3; 96 | # ssl_ciphers HIGH:!aNULL:!MD5; 97 | # ssl_prefer_server_ciphers on; 98 | # 99 | # # HSTS (uncomment after testing) 100 | # # add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always; 101 | # 102 | # location / { 103 | # limit_req zone=smartrag_limit burst=20 nodelay; 104 | # limit_conn addr 10; 105 | # 106 | # proxy_pass http://smartrag_app; 107 | # proxy_http_version 1.1; 108 | # 109 | # proxy_set_header Host $host; 110 | # proxy_set_header X-Real-IP $remote_addr; 111 | # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 112 | # proxy_set_header X-Forwarded-Proto https; 113 | # 114 | # proxy_set_header Upgrade $http_upgrade; 115 | # proxy_set_header Connection "upgrade"; 116 | # 117 | # proxy_read_timeout 86400; 118 | # proxy_buffering off; 119 | # } 120 | # 121 | # location /_stcore/health { 122 | # proxy_pass http://smartrag_app/_stcore/health; 123 | # access_log off; 124 | # } 125 | # } 126 | } 127 | -------------------------------------------------------------------------------- /docker/docker-compose.prod.yml: -------------------------------------------------------------------------------- 1 | # Production Docker Compose with full stack 2 | # Includes: SmartRAG, PostgreSQL, Redis, Nginx 3 | version: '3.8' 4 | 5 | services: 6 | # Main application 7 | smartrag: 8 | build: 9 | context: .. 10 | dockerfile: docker/Dockerfile 11 | args: 12 | - BUILDKIT_INLINE_CACHE=1 13 | image: smartrag:production 14 | container_name: smartrag-prod 15 | expose: 16 | - "8501" 17 | - "11434" 18 | volumes: 19 | - smartrag_vector_db:/app/vector_db 20 | - smartrag_user_data:/app/user_data 21 | - smartrag_uploads:/app/temp_uploads 22 | - smartrag_logs:/app/logs 23 | - ollama_models:/home/smartrag/.ollama 24 | environment: 25 | - STREAMLIT_SERVER_PORT=8501 26 | - STREAMLIT_SERVER_ADDRESS=0.0.0.0 27 | - OLLAMA_HOST=http://localhost:11434 28 | - PYTHONUNBUFFERED=1 29 | - MAX_FILE_SIZE_MB=50 30 | # Database connection (if using PostgreSQL) 31 | # - DATABASE_URL=postgresql://smartrag:${DB_PASSWORD}@postgres:5432/smartrag 32 | # Redis cache (if using Redis) 33 | # - REDIS_URL=redis://redis:6379/0 34 | restart: unless-stopped 35 | healthcheck: 36 | test: | 37 | curl -f http://localhost:8501/_stcore/health && \ 38 | curl -f http://localhost:11434/api/tags > /dev/null || exit 1 39 | interval: 30s 40 | timeout: 10s 41 | retries: 3 42 | start_period: 120s 43 | shm_size: '2gb' 44 | security_opt: 45 | - no-new-privileges:true 46 | deploy: 47 | resources: 48 | limits: 49 | cpus: '4' 50 | memory: 8G 51 | reservations: 52 | cpus: '2' 53 | memory: 4G 54 | logging: 55 | driver: "json-file" 56 | options: 57 | max-size: "10m" 58 | max-file: "5" 59 | networks: 60 | - smartrag-prod-network 61 | depends_on: 62 | postgres: 63 | condition: service_healthy 64 | redis: 65 | condition: service_healthy 66 | 67 | # PostgreSQL database 68 | postgres: 69 | image: postgres:15-alpine 70 | container_name: smartrag-postgres 71 | environment: 72 | - POSTGRES_DB=smartrag 73 | - POSTGRES_USER=smartrag 74 | - POSTGRES_PASSWORD=${DB_PASSWORD:-changeme_in_production} 75 | - POSTGRES_INITDB_ARGS=--encoding=UTF-8 --lc-collate=en_US.UTF-8 --lc-ctype=en_US.UTF-8 76 | volumes: 77 | - postgres_data:/var/lib/postgresql/data 78 | expose: 79 | - "5432" 80 | restart: unless-stopped 81 | healthcheck: 82 | test: ["CMD-SHELL", "pg_isready -U smartrag"] 83 | interval: 10s 84 | timeout: 5s 85 | retries: 5 86 | security_opt: 87 | - no-new-privileges:true 88 | networks: 89 | - smartrag-prod-network 90 | logging: 91 | driver: "json-file" 92 | options: 93 | max-size: "10m" 94 | max-file: "3" 95 | 96 | # Redis cache 97 | redis: 98 | image: redis:7-alpine 99 | container_name: smartrag-redis 100 | command: > 101 | redis-server 102 | --appendonly yes 103 | --maxmemory 1gb 104 | --maxmemory-policy allkeys-lru 105 | --requirepass ${REDIS_PASSWORD:-changeme_in_production} 106 | volumes: 107 | - redis_data:/data 108 | expose: 109 | - "6379" 110 | restart: unless-stopped 111 | healthcheck: 112 | test: ["CMD", "redis-cli", "--raw", "incr", "ping"] 113 | interval: 10s 114 | timeout: 3s 115 | retries: 5 116 | security_opt: 117 | - no-new-privileges:true 118 | networks: 119 | - smartrag-prod-network 120 | logging: 121 | driver: "json-file" 122 | options: 123 | max-size: "10m" 124 | max-file: "3" 125 | 126 | # Nginx reverse proxy 127 | nginx: 128 | image: nginx:alpine 129 | container_name: smartrag-nginx 130 | ports: 131 | - "80:80" 132 | - "443:443" 133 | volumes: 134 | - ./nginx.conf:/etc/nginx/nginx.conf:ro 135 | # Uncomment for SSL certificates 136 | # - ./ssl:/etc/nginx/ssl:ro 137 | depends_on: 138 | - smartrag 139 | restart: unless-stopped 140 | security_opt: 141 | - no-new-privileges:true 142 | networks: 143 | - smartrag-prod-network 144 | logging: 145 | driver: "json-file" 146 | options: 147 | max-size: "10m" 148 | max-file: "3" 149 | 150 | # Optional: Prometheus for monitoring 151 | # prometheus: 152 | # image: prom/prometheus:latest 153 | # container_name: smartrag-prometheus 154 | # volumes: 155 | # - ./prometheus.yml:/etc/prometheus/prometheus.yml:ro 156 | # - prometheus_data:/prometheus 157 | # command: 158 | # - '--config.file=/etc/prometheus/prometheus.yml' 159 | # - '--storage.tsdb.path=/prometheus' 160 | # expose: 161 | # - "9090" 162 | # restart: unless-stopped 163 | # networks: 164 | # - smartrag-prod-network 165 | 166 | volumes: 167 | ollama_models: 168 | driver: local 169 | smartrag_vector_db: 170 | driver: local 171 | smartrag_user_data: 172 | driver: local 173 | smartrag_uploads: 174 | driver: local 175 | smartrag_logs: 176 | driver: local 177 | postgres_data: 178 | driver: local 179 | redis_data: 180 | driver: local 181 | # prometheus_data: 182 | # driver: local 183 | 184 | networks: 185 | smartrag-prod-network: 186 | driver: bridge 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SmartRAG - Intelligent Multimodal RAG System 2 | 3 | A production-ready RAG system enabling intelligent conversations with documents, images, and audio files. Built with local-first AI models for complete privacy and offline operation. 4 | 5 | ![SmartRAG Interface](https://github.com/user-attachments/assets/7b413c33-3208-405b-a4f9-b18381807216) 6 | 7 | ## Quick Start 8 | 9 | ```bash 10 | # Standard deployment 11 | docker-compose up -d 12 | 13 | # Access at http://localhost:8501 14 | ``` 15 | 16 | ## Core Features 17 | 18 | **Multimodal Processing** 19 | 20 | - Documents: PDF, DOCX, TXT, MD with intelligent chunking 21 | - Images: OCR + visual understanding via BLIP 22 | - Audio: Automatic transcription with Whisper 23 | 24 | **Local AI Stack** 25 | 26 | - Ollama (Llama 3.1 8B) for generation 27 | - Nomic Embed Text (768-dim) for embeddings 28 | - ChromaDB for vector storage 29 | - Complete offline operation 30 | 31 | **Production Ready** 32 | 33 | - Docker deployment with multi-stage builds 34 | - Non-root user execution 35 | - Health checks and auto-healing 36 | - Resource management and monitoring 37 | - Security hardening included 38 | 39 | ## Technology Stack 40 | 41 | | Component | Technology | 42 | | -------------- | --------------------------- | 43 | | **LLM** | Llama 3.1 8B via Ollama | 44 | | **Embeddings** | Nomic Embed Text (768-dim) | 45 | | **Vector DB** | ChromaDB / FAISS | 46 | | **Vision** | BLIP + CLIP + Tesseract OCR | 47 | | **Audio** | OpenAI Whisper (base) | 48 | | **UI** | Streamlit | 49 | | **Storage** | SQLite3 | 50 | 51 | ## Architecture 52 | 53 | image 54 | 55 | ## Installation 56 | 57 | ### Docker (Recommended) 58 | 59 | ```bash 60 | git clone https://github.com/itanishqshelar/SmartRAG.git 61 | cd SmartRAG/docker 62 | 63 | # Development 64 | docker-compose up -d 65 | 66 | # Production with full stack (PostgreSQL, Redis, Nginx) 67 | docker-compose -f docker-compose.prod.yml up -d 68 | ``` 69 | 70 | ### Local Setup 71 | 72 | ```bash 73 | # Install dependencies 74 | pip install -r requirements.txt 75 | 76 | # Install Ollama and models 77 | ollama pull llama3.1:8b 78 | ollama pull nomic-embed-text 79 | 80 | # Install system dependencies 81 | # macOS: brew install tesseract ffmpeg 82 | # Ubuntu: apt-get install tesseract-ocr ffmpeg 83 | # Windows: Download from GitHub releases 84 | 85 | # Run application 86 | streamlit run chatbot_app.py 87 | ``` 88 | 89 | ## Configuration 90 | 91 | SmartRAG uses a single `config.yaml` with Pydantic validation: 92 | 93 | ```yaml 94 | models: 95 | llm_model: "llama3.1:8b" 96 | embedding_model: "nomic-embed-text" 97 | vision_model: "Salesforce/blip-image-captioning-base" 98 | whisper_model: "base" 99 | 100 | vector_store: 101 | type: "chromadb" 102 | embedding_dimension: 768 103 | 104 | processing: 105 | chunk_size: 1000 106 | chunk_overlap: 200 107 | ocr_enabled: true 108 | 109 | generation: 110 | temperature: 0.7 111 | max_tokens: 2000 112 | context_window: 4096 113 | ``` 114 | 115 | Override via environment variables: 116 | 117 | ```bash 118 | export SMARTRAG_LLM_MODEL=llama2:7b 119 | export SMARTRAG_TEMPERATURE=0.5 120 | ``` 121 | 122 | ## Usage 123 | 124 | **Web Interface** 125 | 126 | 1. Upload files via drag-and-drop 127 | 2. Ask questions about your content 128 | 3. View source documents inline 129 | 4. Manage chat history and files 130 | 131 | **Python API** 132 | 133 | ```python 134 | from multimodal_rag.system import MultimodalRAGSystem 135 | 136 | system = MultimodalRAGSystem() 137 | 138 | # Ingest content 139 | system.ingest_file("document.pdf") 140 | system.ingest_file("screenshot.png") 141 | system.ingest_file("recording.mp3") 142 | 143 | # Query with context 144 | response = system.query("Summarize the key points") 145 | print(response.answer) 146 | ``` 147 | 148 | **Batch Processing** 149 | 150 | ```python 151 | # Process directories 152 | results = system.ingest_directory("./docs/", recursive=True) 153 | print(f"Processed {len(results)} files") 154 | ``` 155 | 156 | ## Project Structure 157 | 158 | ``` 159 | smartrag/ 160 | ├── chatbot_app.py # Streamlit application 161 | ├── config.yaml # Configuration 162 | ├── requirements.txt # Dependencies 163 | ├── multimodal_rag/ 164 | │ ├── system.py # RAG orchestrator 165 | │ ├── processors/ # File type handlers 166 | │ │ ├── document_processor.py 167 | │ │ ├── image_processor.py 168 | │ │ └── audio_processor.py 169 | │ └── vector_stores/ # DB implementations 170 | │ ├── chroma_store.py 171 | │ └── faiss_store.py 172 | ├── docker/ # Production deployment 173 | │ ├── Dockerfile 174 | │ ├── docker-compose.yml 175 | │ └── docker-compose.prod.yml 176 | └── tests/ # Test suite 177 | ``` 178 | 179 | ## Deployment Options 180 | 181 | **Standard** - All-in-one container with Ollama 182 | 183 | ```bash 184 | docker-compose up -d 185 | ``` 186 | 187 | **Lightweight** - External Ollama on host 188 | 189 | ```bash 190 | docker-compose -f docker-compose.lite.yml up -d 191 | ``` 192 | 193 | **Production** - Full stack with PostgreSQL, Redis, Nginx 194 | 195 | ```bash 196 | docker-compose -f docker-compose.prod.yml up -d 197 | ``` 198 | 199 | ## Development 200 | 201 | ```bash 202 | # Run tests 203 | pytest tests/ 204 | 205 | # Code formatting 206 | black multimodal_rag/ tests/ 207 | 208 | # Linting 209 | flake8 multimodal_rag/ tests/ 210 | ``` 211 | 212 | ## Performance 213 | 214 | - **Image size**: 4.2GB 215 | - **Memory**: 4-8GB recommended 216 | - **CPU**: 2-4 cores recommended 217 | - **Startup time**: ~90s (includes model downloads) 218 | - **Query latency**: <3s typical 219 | 220 | ## Security 221 | 222 | - Local inference - no external API calls 223 | - Non-root container execution 224 | - File size limits enforced (50MB default) 225 | - No privilege escalation 226 | - Security headers in production setup 227 | 228 | ## License 229 | 230 | MIT License - see [LICENSE](LICENSE) file for details. 231 | 232 | ## Acknowledgments 233 | 234 | Built with ChromaDB, Ollama, Hugging Face Transformers, OpenAI Whisper, and Tesseract OCR. 235 | 236 | --- 237 | 238 | **SmartRAG** - Local-first multimodal AI for document intelligence. 239 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # SmartRAG Docker Image - Production Ready 2 | # Multi-stage build for optimized image size and security 3 | 4 | # ============================================ 5 | # Stage 1: Builder - Install and cache models 6 | # ============================================ 7 | FROM python:3.10-slim as builder 8 | 9 | WORKDIR /build 10 | 11 | # Install build dependencies 12 | RUN apt-get update && apt-get install -y --no-install-recommends \ 13 | gcc \ 14 | g++ \ 15 | git \ 16 | curl \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | # Copy and install Python dependencies 20 | COPY requirements.txt . 21 | RUN pip install --no-cache-dir --user -r requirements.txt 22 | 23 | # Pre-download Hugging Face models to cache directory 24 | RUN mkdir -p /build/.cache/huggingface 25 | ENV HF_HOME=/build/.cache/huggingface 26 | RUN python3 -c "from transformers import BlipProcessor, BlipForConditionalGeneration; \ 27 | processor = BlipProcessor.from_pretrained('Salesforce/blip-image-captioning-base'); \ 28 | model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base'); \ 29 | print('✓ BLIP model cached successfully')" 30 | 31 | # Pre-download Whisper model 32 | RUN python3 -c "import whisper; \ 33 | model = whisper.load_model('base'); \ 34 | print('✓ Whisper model cached successfully')" 35 | 36 | # Pre-download CLIP model for image embeddings 37 | RUN python3 -c "from transformers import CLIPProcessor, CLIPModel; \ 38 | processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32'); \ 39 | model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32'); \ 40 | print('✓ CLIP model cached successfully')" 41 | 42 | # ============================================ 43 | # Stage 2: Runtime - Final production image 44 | # ============================================ 45 | FROM python:3.10-slim 46 | 47 | # Add metadata labels 48 | LABEL maintainer="SmartRAG Team" 49 | LABEL version="1.0" 50 | LABEL description="Multimodal RAG System with Ollama" 51 | 52 | # Install runtime system dependencies only 53 | RUN apt-get update && apt-get install -y --no-install-recommends \ 54 | tesseract-ocr \ 55 | tesseract-ocr-eng \ 56 | ffmpeg \ 57 | libsm6 \ 58 | libxext6 \ 59 | libxrender-dev \ 60 | libgomp1 \ 61 | curl \ 62 | ca-certificates \ 63 | tini \ 64 | && rm -rf /var/lib/apt/lists/* \ 65 | && apt-get clean 66 | 67 | # Install Ollama 68 | RUN curl -fsSL https://ollama.com/install.sh | sh 69 | 70 | # Create non-root user for security 71 | RUN groupadd -r smartrag --gid=1000 && \ 72 | useradd -r -g smartrag --uid=1000 --create-home --shell /bin/bash smartrag 73 | 74 | # Set working directory 75 | WORKDIR /app 76 | 77 | # Copy Python packages from builder 78 | COPY --from=builder --chown=smartrag:smartrag /root/.local /home/smartrag/.local 79 | 80 | # Copy cached models from builder 81 | COPY --from=builder --chown=smartrag:smartrag /build/.cache /home/smartrag/.cache 82 | 83 | # Update PATH to use user-installed packages 84 | ENV PATH=/home/smartrag/.local/bin:$PATH 85 | ENV HF_HOME=/home/smartrag/.cache/huggingface 86 | ENV TRANSFORMERS_CACHE=/home/smartrag/.cache/huggingface 87 | 88 | # Copy application code 89 | COPY --chown=smartrag:smartrag . . 90 | 91 | # Create necessary directories with proper permissions 92 | RUN mkdir -p vector_db temp_uploads user_data logs data && \ 93 | chown -R smartrag:smartrag vector_db temp_uploads user_data logs data 94 | 95 | # Create directory for Ollama data 96 | RUN mkdir -p /home/smartrag/.ollama && \ 97 | chown -R smartrag:smartrag /home/smartrag/.ollama 98 | 99 | # Expose ports 100 | EXPOSE 8501 11434 101 | 102 | # Create production-ready startup script with proper error handling 103 | RUN echo '#!/bin/bash\n\ 104 | set -euo pipefail\n\ 105 | \n\ 106 | # Trap signals for graceful shutdown\n\ 107 | trap "echo Shutting down...; kill -TERM $OLLAMA_PID 2>/dev/null; exit 0" SIGTERM SIGINT\n\ 108 | \n\ 109 | echo "========================================"\n\ 110 | echo "SmartRAG Production Startup"\n\ 111 | echo "========================================"\n\ 112 | \n\ 113 | # Start Ollama service\n\ 114 | echo "[1/5] Starting Ollama service..."\n\ 115 | OLLAMA_HOST=0.0.0.0:11434 ollama serve &\n\ 116 | OLLAMA_PID=$!\n\ 117 | \n\ 118 | # Wait for Ollama to be ready with timeout\n\ 119 | echo "[2/5] Waiting for Ollama to be ready..."\n\ 120 | MAX_RETRIES=30\n\ 121 | RETRY_COUNT=0\n\ 122 | while ! curl -s http://localhost:11434/api/tags >/dev/null 2>&1; do\n\ 123 | RETRY_COUNT=$((RETRY_COUNT + 1))\n\ 124 | if [ $RETRY_COUNT -ge $MAX_RETRIES ]; then\n\ 125 | echo "ERROR: Ollama failed to start after ${MAX_RETRIES} seconds"\n\ 126 | exit 1\n\ 127 | fi\n\ 128 | echo " Waiting for Ollama... (${RETRY_COUNT}/${MAX_RETRIES})"\n\ 129 | sleep 1\n\ 130 | done\n\ 131 | echo " ✓ Ollama is ready"\n\ 132 | \n\ 133 | # Pull required models with retry logic\n\ 134 | echo "[3/5] Pulling required Ollama models..."\n\ 135 | pull_model() {\n\ 136 | local model=$1\n\ 137 | local max_attempts=3\n\ 138 | local attempt=1\n\ 139 | \n\ 140 | while [ $attempt -le $max_attempts ]; do\n\ 141 | echo " Pulling $model (attempt $attempt/$max_attempts)..."\n\ 142 | if ollama pull $model; then\n\ 143 | echo " ✓ $model downloaded successfully"\n\ 144 | return 0\n\ 145 | fi\n\ 146 | attempt=$((attempt + 1))\n\ 147 | [ $attempt -le $max_attempts ] && sleep 5\n\ 148 | done\n\ 149 | \n\ 150 | echo " ⚠ Warning: Failed to pull $model after $max_attempts attempts"\n\ 151 | return 1\n\ 152 | }\n\ 153 | \n\ 154 | pull_model "llama3.1:8b"\n\ 155 | pull_model "nomic-embed-text"\n\ 156 | \n\ 157 | # Verify Hugging Face models\n\ 158 | echo "[4/5] Verifying cached models..."\n\ 159 | python3 -c "\n\ 160 | from transformers import BlipProcessor, CLIPProcessor\n\ 161 | import whisper\n\ 162 | try:\n\ 163 | BlipProcessor.from_pretrained('\''Salesforce/blip-image-captioning-base'\'')\n\ 164 | print('\'' ✓ BLIP model verified'\'')\n\ 165 | CLIPProcessor.from_pretrained('\''openai/clip-vit-base-patch32'\'')\n\ 166 | print('\'' ✓ CLIP model verified'\'')\n\ 167 | whisper.load_model('\''base'\'')\n\ 168 | print('\'' ✓ Whisper model verified'\'')\n\ 169 | except Exception as e:\n\ 170 | print(f'\'' ⚠ Warning: Model verification failed: {e}'\'')\n\ 171 | " || echo " ⚠ Warning: Model verification failed"\n\ 172 | \n\ 173 | # Start SmartRAG application\n\ 174 | echo "[5/5] Starting SmartRAG application..."\n\ 175 | echo "========================================"\n\ 176 | echo "Application ready at http://localhost:8501"\n\ 177 | echo "Ollama API available at http://localhost:11434"\n\ 178 | echo "========================================"\n\ 179 | \n\ 180 | exec streamlit run chatbot_app.py \\\n\ 181 | --server.port=8501 \\\n\ 182 | --server.address=0.0.0.0 \\\n\ 183 | --server.headless=true \\\n\ 184 | --server.fileWatcherType=none \\\n\ 185 | --browser.gatherUsageStats=false\n\ 186 | ' > /app/start.sh && chmod +x /app/start.sh 187 | 188 | # Set environment variables 189 | ENV STREAMLIT_SERVER_PORT=8501 \ 190 | STREAMLIT_SERVER_ADDRESS=0.0.0.0 \ 191 | STREAMLIT_SERVER_HEADLESS=true \ 192 | STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \ 193 | OLLAMA_HOST=http://localhost:11434 \ 194 | OLLAMA_MODELS=/home/smartrag/.ollama/models \ 195 | PYTHONUNBUFFERED=1 \ 196 | PYTHONDONTWRITEBYTECODE=1 197 | 198 | # Security: Set file size limits via environment 199 | ENV MAX_FILE_SIZE_MB=50 \ 200 | MAX_UPLOAD_SIZE=52428800 201 | 202 | # Switch to non-root user 203 | USER smartrag 204 | 205 | # Comprehensive health check 206 | HEALTHCHECK --interval=30s --timeout=10s --start-period=90s --retries=3 \ 207 | CMD curl -f http://localhost:8501/_stcore/health && \ 208 | curl -f http://localhost:11434/api/tags > /dev/null || exit 1 209 | 210 | # Use tini as init system for proper signal handling 211 | ENTRYPOINT ["/usr/bin/tini", "--"] 212 | 213 | # Run startup script 214 | CMD ["/bin/bash", "/app/start.sh"] 215 | -------------------------------------------------------------------------------- /config_examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | SmartRAG Configuration Examples 3 | Demonstrates how to use the new configuration system 4 | """ 5 | 6 | from config_schema import ( 7 | load_config, 8 | SmartRAGConfig, 9 | ConfigLoader, 10 | save_config 11 | ) 12 | from pathlib import Path 13 | 14 | print("=" * 80) 15 | print("SmartRAG Configuration System - Usage Examples") 16 | print("=" * 80) 17 | 18 | # ============================================================================ 19 | # Example 1: Load default configuration 20 | # ============================================================================ 21 | print("\n📋 Example 1: Load with defaults") 22 | print("-" * 80) 23 | 24 | config = load_config() 25 | print(f"✓ System: {config.system.name}") 26 | print(f"✓ LLM: {config.models.llm_type.value} - {config.models.llm_model}") 27 | print(f"✓ Embedding: {config.models.embedding_model} ({config.models.embedding_dimension}d)") 28 | print(f"✓ Vector Store: {config.vector_store.type.value}") 29 | print(f"✓ Chunk Size: {config.processing.chunk_size}") 30 | print(f"✓ Top K: {config.retrieval.top_k}") 31 | 32 | 33 | # ============================================================================ 34 | # Example 2: Load from config.yaml 35 | # ============================================================================ 36 | print("\n📋 Example 2: Load from config.yaml") 37 | print("-" * 80) 38 | 39 | try: 40 | config = load_config("config.yaml") 41 | print(f"✓ Loaded from config.yaml") 42 | print(f" System: {config.system.name} v{config.system.version}") 43 | print(f" Collection: {config.vector_store.collection_name}") 44 | except Exception as e: 45 | print(f"✗ Failed to load: {e}") 46 | 47 | 48 | # ============================================================================ 49 | # Example 3: Override specific parameters 50 | # ============================================================================ 51 | print("\n📋 Example 3: Runtime overrides with double underscore notation") 52 | print("-" * 80) 53 | 54 | config = load_config( 55 | "config.yaml", 56 | models__llm_model="llama2:7b", # Override LLM model 57 | generation__temperature=0.5, # Lower temperature 58 | retrieval__top_k=10, # More results 59 | processing__chunk_size=500 # Smaller chunks 60 | ) 61 | 62 | print(f"✓ LLM Model: {config.models.llm_model} (overridden)") 63 | print(f"✓ Temperature: {config.generation.temperature} (overridden)") 64 | print(f"✓ Top K: {config.retrieval.top_k} (overridden)") 65 | print(f"✓ Chunk Size: {config.processing.chunk_size} (overridden)") 66 | 67 | 68 | # ============================================================================ 69 | # Example 4: Direct ConfigLoader usage with explicit overrides 70 | # ============================================================================ 71 | print("\n📋 Example 4: Explicit override dictionary") 72 | print("-" * 80) 73 | 74 | overrides = { 75 | 'models': { 76 | 'llm_model': 'llama3.1:70b', 77 | 'ollama_host': 'http://remote-server:11434' 78 | }, 79 | 'vector_store': { 80 | 'type': 'faiss', 81 | 'collection_name': 'production_docs' 82 | } 83 | } 84 | 85 | config = ConfigLoader.load( 86 | config_path="config.yaml", 87 | override_params=overrides 88 | ) 89 | 90 | print(f"✓ LLM Model: {config.models.llm_model}") 91 | print(f"✓ Ollama Host: {config.models.ollama_host}") 92 | print(f"✓ Vector Store: {config.vector_store.type.value}") 93 | print(f"✓ Collection: {config.vector_store.collection_name}") 94 | 95 | 96 | # ============================================================================ 97 | # Example 5: Environment variable overrides 98 | # ============================================================================ 99 | print("\n📋 Example 5: Environment variable overrides") 100 | print("-" * 80) 101 | print("Set environment variables like:") 102 | print(" export SMARTRAG_LLM_MODEL=llama2:7b") 103 | print(" export SMARTRAG_TEMPERATURE=0.3") 104 | print(" export SMARTRAG_TOP_K=20") 105 | print("\nThese will automatically override config.yaml values!") 106 | print("Priority: CLI args > Env vars > YAML file > Defaults") 107 | 108 | 109 | # ============================================================================ 110 | # Example 6: Access typed configuration 111 | # ============================================================================ 112 | print("\n📋 Example 6: Type-safe configuration access") 113 | print("-" * 80) 114 | 115 | config = load_config("config.yaml") 116 | 117 | # All fields are type-checked and validated 118 | print(f"✓ Temperature (float): {config.generation.temperature}") 119 | print(f"✓ Top K (int): {config.retrieval.top_k}") 120 | print(f"✓ OCR Enabled (bool): {config.processing.ocr_enabled}") 121 | print(f"✓ Max Image Size (list): {config.processing.max_image_size}") 122 | print(f"✓ LLM Type (enum): {config.models.llm_type}") 123 | 124 | 125 | # ============================================================================ 126 | # Example 7: Validation errors 127 | # ============================================================================ 128 | print("\n📋 Example 7: Configuration validation") 129 | print("-" * 80) 130 | 131 | try: 132 | # This should fail validation (chunk_overlap >= chunk_size) 133 | from config_schema import ProcessingConfig, SmartRAGConfig 134 | 135 | bad_config = SmartRAGConfig( 136 | processing=ProcessingConfig( 137 | chunk_size=1000, 138 | chunk_overlap=1500 # Invalid: overlap > size 139 | ) 140 | ) 141 | except ValueError as e: 142 | print(f"✓ Validation caught error: {e}") 143 | 144 | 145 | try: 146 | # This should fail validation (invalid temperature range) 147 | from config_schema import GenerationConfig 148 | 149 | bad_gen = GenerationConfig(temperature=3.0) # Invalid: > 2.0 150 | except ValueError as e: 151 | print(f"✓ Validation caught error: {e}") 152 | 153 | 154 | # ============================================================================ 155 | # Example 8: Save configuration 156 | # ============================================================================ 157 | print("\n📋 Example 8: Save configuration to file") 158 | print("-" * 80) 159 | 160 | config = load_config("config.yaml") 161 | output_path = Path("config_backup.yaml") 162 | save_config(config, output_path) 163 | print(f"✓ Configuration saved to {output_path}") 164 | 165 | 166 | # ============================================================================ 167 | # Example 9: Convert to dictionary for backward compatibility 168 | # ============================================================================ 169 | print("\n📋 Example 9: Convert to dictionary") 170 | print("-" * 80) 171 | 172 | config = load_config("config.yaml") 173 | config_dict = config.to_dict() 174 | 175 | print(f"✓ Converted to dict with {len(config_dict)} top-level keys:") 176 | print(f" Keys: {', '.join(config_dict.keys())}") 177 | 178 | 179 | # ============================================================================ 180 | # Example 10: Production usage pattern 181 | # ============================================================================ 182 | print("\n📋 Example 10: Production initialization pattern") 183 | print("-" * 80) 184 | 185 | def initialize_smartrag(env: str = "production"): 186 | """Production-grade configuration loading""" 187 | 188 | # Different configs for different environments 189 | config_files = { 190 | 'development': 'config.dev.yaml', 191 | 'staging': 'config.staging.yaml', 192 | 'production': 'config.yaml' 193 | } 194 | 195 | config_file = config_files.get(env, 'config.yaml') 196 | 197 | try: 198 | # Load with validation 199 | config = load_config( 200 | config_file, 201 | # Override debug mode based on environment 202 | system__debug=(env == 'development'), 203 | system__log_level='DEBUG' if env == 'development' else 'INFO' 204 | ) 205 | 206 | print(f"✓ SmartRAG initialized for {env} environment") 207 | print(f" Config: {config_file}") 208 | print(f" Debug: {config.system.debug}") 209 | print(f" Log Level: {config.system.log_level}") 210 | 211 | return config 212 | 213 | except Exception as e: 214 | print(f"✗ Failed to initialize: {e}") 215 | raise 216 | 217 | # Simulate production initialization 218 | config = initialize_smartrag('production') 219 | 220 | 221 | print("\n" + "=" * 80) 222 | print("✅ All examples completed successfully!") 223 | print("=" * 80) 224 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker Deployment Guide for SmartRAG 2 | 3 | This guide provides instructions for running SmartRAG using Docker. 4 | 5 | ## 📦 Deployment Options 6 | 7 | ### Option 1: Full Stack (Recommended for Production) 8 | 9 | Includes both Streamlit app and Ollama in one container. 10 | 11 | ### Option 2: Lightweight (Recommended for Development) 12 | 13 | Streamlit in Docker, Ollama runs on host machine. 14 | 15 | --- 16 | 17 | ## 🚀 Quick Start 18 | 19 | ### Prerequisites 20 | 21 | - Docker Desktop installed 22 | - Docker Compose installed 23 | - At least 8GB RAM available 24 | - 20GB free disk space (for models) 25 | 26 | ### Option 1: Full Stack Deployment 27 | 28 | ```bash 29 | # Build and start the container 30 | docker-compose up -d 31 | 32 | # View logs 33 | docker-compose logs -f 34 | 35 | # Access the application 36 | # Open http://localhost:8501 in your browser 37 | ``` 38 | 39 | **First Run Notes:** 40 | 41 | - Initial Docker build takes 10-15 minutes (downloading & caching models) 42 | - Container startup takes 8-10 minutes first time (downloading Ollama models) 43 | - Subsequent runs start in ~30 seconds 44 | 45 | **Models Downloaded During Build:** 46 | 47 | - BLIP Image Captioning: ~1.0GB (cached in image) 48 | - Whisper Base: ~140MB (cached in image) 49 | 50 | **Models Downloaded on First Run:** 51 | 52 | - Llama 3.1 8B model: ~4.7GB (cached in Docker volume) 53 | - Nomic Embed Text model: ~274MB (cached in Docker volume) 54 | 55 | **Total Storage:** ~7-8GB for Docker image + ~5GB for Ollama models = **~13GB** 56 | 57 | ### Option 2: Lightweight Deployment 58 | 59 | ```bash 60 | # First, ensure Ollama is running on your host 61 | ollama serve 62 | 63 | # In another terminal, pull required models 64 | ollama pull llama3.1:8b 65 | ollama pull nomic-embed-text 66 | 67 | # Build and start the lightweight container 68 | docker-compose -f docker-compose.lite.yml up -d 69 | 70 | # Access the application 71 | # Open http://localhost:8501 in your browser 72 | ``` 73 | 74 | --- 75 | 76 | ## � Model Caching Strategy 77 | 78 | SmartRAG uses multiple AI models that are pre-downloaded to ensure fast startup and offline capability. 79 | 80 | ### Models Cached During Docker Build 81 | 82 | These models are downloaded once during `docker build` and stored in the Docker image: 83 | 84 | 1. **BLIP Image Captioning** (`Salesforce/blip-image-captioning-base`) 85 | 86 | - Size: ~1.0GB 87 | - Purpose: Generate captions for images 88 | - Location: `~/.cache/huggingface/` (inside image) 89 | 90 | 2. **Whisper Base** (OpenAI speech-to-text) 91 | - Size: ~140MB 92 | - Purpose: Audio transcription 93 | - Location: `~/.cache/whisper/` (inside image) 94 | 95 | ### Models Downloaded on First Container Run 96 | 97 | These models are downloaded when you first start the container and cached in Docker volumes: 98 | 99 | 3. **Llama 3.1 8B** (via Ollama) 100 | 101 | - Size: ~4.7GB 102 | - Purpose: Main language model for RAG 103 | - Location: `ollama_models` Docker volume 104 | 105 | 4. **Nomic Embed Text** (via Ollama) 106 | - Size: ~274MB 107 | - Purpose: Text embeddings for semantic search 108 | - Location: `ollama_models` Docker volume 109 | 110 | ### Benefits 111 | 112 | ✅ **Faster Startup** - No Hugging Face downloads during app runtime 113 | ✅ **Offline Ready** - All models cached after initial setup 114 | ✅ **Predictable Builds** - Same model versions every time 115 | ✅ **Better UX** - No waiting for downloads during image processing 116 | 117 | ### Verification 118 | 119 | The startup script automatically verifies all models are available: 120 | 121 | ```bash 122 | # During container startup you'll see: 123 | Starting Ollama service... 124 | Waiting for Ollama to start... 125 | Pulling Llama 3.1 8B model... 126 | Pulling Nomic Embed Text model... 127 | Verifying Hugging Face models... 128 | BLIP model cached successfully 129 | Starting SmartRAG application... 130 | ``` 131 | 132 | --- 133 | 134 | ## �🔧 Configuration 135 | 136 | ### Environment Variables 137 | 138 | You can customize the deployment by setting environment variables in `docker-compose.yml`: 139 | 140 | ```yaml 141 | environment: 142 | - STREAMLIT_SERVER_PORT=8501 143 | - STREAMLIT_SERVER_ADDRESS=0.0.0.0 144 | - OLLAMA_HOST=http://localhost:11434 145 | ``` 146 | 147 | ### Volume Mounts 148 | 149 | Data persistence is handled through Docker volumes: 150 | 151 | - `./vector_db` - ChromaDB vector database 152 | - `./user_data` - User session data 153 | - `./temp_uploads` - Temporary file uploads 154 | - `./logs` - Application logs 155 | - `./file_storage.db` - SQLite database 156 | - `ollama_models` - Ollama model cache (Docker volume) 157 | 158 | --- 159 | 160 | ## 🛠️ Build from Source 161 | 162 | ### Build Full Stack Image 163 | 164 | ```bash 165 | docker build -t smartrag:latest -f Dockerfile . 166 | ``` 167 | 168 | ### Build Lightweight Image 169 | 170 | ```bash 171 | docker build -t smartrag:lite -f Dockerfile.lite . 172 | ``` 173 | 174 | ### Run Manually 175 | 176 | ```bash 177 | # Full stack 178 | docker run -d \ 179 | -p 8501:8501 \ 180 | -p 11434:11434 \ 181 | -v $(pwd)/vector_db:/app/vector_db \ 182 | -v $(pwd)/user_data:/app/user_data \ 183 | --name smartrag \ 184 | smartrag:latest 185 | 186 | # Lightweight (with host Ollama) 187 | docker run -d \ 188 | -p 8501:8501 \ 189 | -v $(pwd)/vector_db:/app/vector_db \ 190 | -v $(pwd)/user_data:/app/user_data \ 191 | --add-host host.docker.internal:host-gateway \ 192 | --name smartrag-lite \ 193 | smartrag:lite 194 | ``` 195 | 196 | --- 197 | 198 | ## 📊 Resource Requirements 199 | 200 | ### Minimum Requirements 201 | 202 | - **CPU**: 4 cores 203 | - **RAM**: 8GB 204 | - **Disk**: 20GB free space 205 | 206 | ### Recommended Requirements 207 | 208 | - **CPU**: 8+ cores 209 | - **RAM**: 16GB 210 | - **Disk**: 50GB free space (for multiple models and data) 211 | - **GPU**: Optional (for faster inference) 212 | 213 | ### Storage Breakdown 214 | 215 | **Docker Image Size:** 216 | 217 | ``` 218 | Base Python 3.10: ~900 MB 219 | System Dependencies: ~300 MB 220 | Python Packages: ~800 MB 221 | BLIP Model (cached): ~1000 MB 222 | Whisper Model (cached): ~140 MB 223 | -------------------------------- 224 | Total Image Size: ~3.1 GB 225 | ``` 226 | 227 | **Runtime Storage (Docker Volumes):** 228 | 229 | ``` 230 | Llama 3.1 8B: ~4.7 GB 231 | Nomic Embed Text: ~274 MB 232 | ChromaDB Vectors: ~100 MB per 1000 documents 233 | User Uploads: Varies 234 | Logs: ~10 MB 235 | -------------------------------- 236 | Minimum Runtime: ~5.1 GB 237 | Recommended: ~10 GB 238 | ``` 239 | 240 | **Total Required:** ~13GB minimum, ~20GB recommended 241 | 242 | ### Memory Allocation 243 | 244 | Adjust memory limits in `docker-compose.yml`: 245 | 246 | ```yaml 247 | deploy: 248 | resources: 249 | limits: 250 | memory: 8G # Maximum memory 251 | reservations: 252 | memory: 4G # Guaranteed memory 253 | ``` 254 | 255 | --- 256 | 257 | ## 🔍 Monitoring & Debugging 258 | 259 | ### View Logs 260 | 261 | ```bash 262 | # Follow logs in real-time 263 | docker-compose logs -f 264 | 265 | # View specific service logs 266 | docker-compose logs -f smartrag 267 | 268 | # View last 100 lines 269 | docker-compose logs --tail=100 270 | ``` 271 | 272 | ### Check Container Status 273 | 274 | ```bash 275 | # List running containers 276 | docker-compose ps 277 | 278 | # Check health status 279 | docker inspect --format='{{json .State.Health}}' smartrag-app 280 | ``` 281 | 282 | ### Access Container Shell 283 | 284 | ```bash 285 | # Execute bash inside container 286 | docker-compose exec smartrag bash 287 | 288 | # Check if Ollama is running 289 | docker-compose exec smartrag curl http://localhost:11434/api/tags 290 | ``` 291 | 292 | ### Common Issues 293 | 294 | **Issue: Ollama models not downloading** 295 | 296 | ```bash 297 | # Enter container and manually pull models 298 | docker-compose exec smartrag bash 299 | ollama pull llama3.1:8b 300 | ollama pull nomic-embed-text 301 | ``` 302 | 303 | **Issue: Out of memory** 304 | 305 | ```bash 306 | # Increase memory limits in docker-compose.yml 307 | # Or reduce model size to llama3.1:7b 308 | ``` 309 | 310 | **Issue: Connection refused to Ollama** 311 | 312 | ```bash 313 | # Check if Ollama service is running 314 | docker-compose exec smartrag ps aux | grep ollama 315 | 316 | # Restart the container 317 | docker-compose restart 318 | ``` 319 | 320 | --- 321 | 322 | ## 🔄 Updates & Maintenance 323 | 324 | ### Update Application 325 | 326 | ```bash 327 | # Pull latest changes 328 | git pull origin main 329 | 330 | # Rebuild and restart 331 | docker-compose down 332 | docker-compose build --no-cache 333 | docker-compose up -d 334 | ``` 335 | 336 | ### Backup Data 337 | 338 | ```bash 339 | # Backup vector database and user data 340 | tar -czf smartrag-backup-$(date +%Y%m%d).tar.gz \ 341 | vector_db/ user_data/ file_storage.db 342 | 343 | # Backup Ollama models (optional) 344 | docker run --rm \ 345 | -v smartrag_ollama_models:/models \ 346 | -v $(pwd):/backup \ 347 | alpine tar -czf /backup/ollama-models.tar.gz -C /models . 348 | ``` 349 | 350 | ### Restore Data 351 | 352 | ```bash 353 | # Restore from backup 354 | tar -xzf smartrag-backup-YYYYMMDD.tar.gz 355 | docker-compose up -d 356 | ``` 357 | 358 | --- 359 | 360 | ## 🌐 Production Deployment 361 | 362 | ### Using Nginx as Reverse Proxy 363 | 364 | ```nginx 365 | server { 366 | listen 80; 367 | server_name smartrag.yourdomain.com; 368 | 369 | location / { 370 | proxy_pass http://localhost:8501; 371 | proxy_http_version 1.1; 372 | proxy_set_header Upgrade $http_upgrade; 373 | proxy_set_header Connection "upgrade"; 374 | proxy_set_header Host $host; 375 | proxy_set_header X-Real-IP $remote_addr; 376 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 377 | proxy_set_header X-Forwarded-Proto $scheme; 378 | } 379 | } 380 | ``` 381 | 382 | ### Enable HTTPS with Let's Encrypt 383 | 384 | ```bash 385 | # Install certbot 386 | sudo apt-get install certbot python3-certbot-nginx 387 | 388 | # Get SSL certificate 389 | sudo certbot --nginx -d smartrag.yourdomain.com 390 | ``` 391 | 392 | ### Docker Compose for Production 393 | 394 | ```yaml 395 | version: "3.8" 396 | 397 | services: 398 | smartrag: 399 | build: . 400 | restart: always 401 | environment: 402 | - STREAMLIT_SERVER_PORT=8501 403 | deploy: 404 | resources: 405 | limits: 406 | memory: 16G 407 | reservations: 408 | memory: 8G 409 | logging: 410 | driver: "json-file" 411 | options: 412 | max-size: "10m" 413 | max-file: "3" 414 | ``` 415 | 416 | --- 417 | 418 | ## 🧪 Testing Deployment 419 | 420 | ```bash 421 | # Health check 422 | curl http://localhost:8501/_stcore/health 423 | 424 | # Test Ollama 425 | curl http://localhost:11434/api/tags 426 | 427 | # Test application 428 | # Upload a test file through the web interface at http://localhost:8501 429 | ``` 430 | 431 | --- 432 | 433 | ## 🛑 Stopping & Cleanup 434 | 435 | ### Stop Containers 436 | 437 | ```bash 438 | # Stop services 439 | docker-compose down 440 | 441 | # Stop and remove volumes (WARNING: deletes all data) 442 | docker-compose down -v 443 | ``` 444 | 445 | ### Remove Images 446 | 447 | ```bash 448 | # Remove SmartRAG images 449 | docker rmi smartrag:latest smartrag:lite 450 | 451 | # Clean up unused images 452 | docker image prune -a 453 | ``` 454 | 455 | ### Complete Cleanup 456 | 457 | ```bash 458 | # Remove everything (containers, images, volumes) 459 | docker-compose down -v --rmi all 460 | docker system prune -a --volumes 461 | ``` 462 | 463 | --- 464 | 465 | ## 📝 Tips & Best Practices 466 | 467 | 1. **First Run**: Allow 15-20 minutes for initial model downloads 468 | 2. **Persistence**: Always use volume mounts for data persistence 469 | 3. **Resources**: Monitor memory usage with `docker stats` 470 | 4. **Security**: Run in production behind a reverse proxy with SSL 471 | 5. **Backups**: Regularly backup vector_db and user_data directories 472 | 6. **Updates**: Pull new models periodically for improved performance 473 | 7. **Logs**: Set up log rotation to prevent disk space issues 474 | 475 | --- 476 | 477 | ## 🔗 Additional Resources 478 | 479 | - [Docker Documentation](https://docs.docker.com/) 480 | - [Ollama Documentation](https://ollama.ai/docs) 481 | - [Streamlit Documentation](https://docs.streamlit.io/) 482 | - [SmartRAG GitHub](https://github.com/itanishqshelar/SmartRAG) 483 | -------------------------------------------------------------------------------- /multimodal_rag/processors/image_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image processor for visual content with OCR and vision model support. 3 | """ 4 | 5 | import logging 6 | import time 7 | from pathlib import Path 8 | from typing import Union, List, Dict, Any, Optional 9 | 10 | try: 11 | from PIL import Image 12 | import pytesseract 13 | except ImportError: 14 | Image = None 15 | pytesseract = None 16 | 17 | try: 18 | import cv2 19 | except ImportError: 20 | cv2 = None 21 | 22 | try: 23 | from transformers import BlipProcessor, BlipForConditionalGeneration 24 | import torch 25 | except ImportError: 26 | BlipProcessor = None 27 | BlipForConditionalGeneration = None 28 | torch = None 29 | 30 | from ..base import BaseProcessor, ProcessingResult, DocumentChunk 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class ImageProcessor(BaseProcessor): 36 | """Processor for image files with OCR and vision capabilities.""" 37 | 38 | def __init__(self, config: Dict[str, Any]): 39 | super().__init__(config) 40 | self.supported_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'] 41 | self.processor_type = "image" 42 | self.chunk_size = config.get('processing', {}).get('chunk_size', 1000) 43 | self.chunk_overlap = config.get('processing', {}).get('chunk_overlap', 200) 44 | 45 | # Configuration 46 | self.max_image_size = tuple(config.get('processing', {}).get('max_image_size', [1024, 1024])) 47 | self.ocr_enabled = config.get('processing', {}).get('ocr_enabled', True) 48 | self.vision_model_name = config.get('models', {}).get('vision_model', 'Salesforce/blip-image-captioning-base') 49 | 50 | # Check availability 51 | self.pil_available = Image is not None 52 | self.tesseract_available = pytesseract is not None 53 | self.cv2_available = cv2 is not None 54 | self.vision_model_available = all([BlipProcessor, BlipForConditionalGeneration, torch]) 55 | 56 | if not self.pil_available: 57 | logger.warning("PIL not available. Install with: pip install Pillow") 58 | 59 | # Configure Tesseract path if available 60 | if self.tesseract_available: 61 | import os 62 | tesseract_path = r"C:\Program Files\Tesseract-OCR\tesseract.exe" 63 | if os.path.exists(tesseract_path): 64 | pytesseract.pytesseract.tesseract_cmd = tesseract_path 65 | logger.info(f"Configured Tesseract path: {tesseract_path}") 66 | else: 67 | logger.warning("Tesseract executable not found at expected path") 68 | 69 | if self.ocr_enabled and not self.tesseract_available: 70 | logger.warning("Tesseract not available. Install with: pip install pytesseract") 71 | 72 | # Initialize vision model 73 | self.vision_processor = None 74 | self.vision_model = None 75 | if self.vision_model_available: 76 | try: 77 | self._load_vision_model() 78 | except Exception as e: 79 | logger.warning(f"Failed to load vision model: {str(e)}") 80 | self.vision_model_available = False 81 | 82 | def _load_vision_model(self): 83 | """Load the vision model for image captioning.""" 84 | try: 85 | self.vision_processor = BlipProcessor.from_pretrained(self.vision_model_name) 86 | self.vision_model = BlipForConditionalGeneration.from_pretrained(self.vision_model_name) 87 | 88 | # Move to GPU if available 89 | if torch.cuda.is_available(): 90 | self.vision_model = self.vision_model.cuda() 91 | 92 | logger.info(f"Loaded vision model: {self.vision_model_name}") 93 | except Exception as e: 94 | logger.error(f"Failed to load vision model: {str(e)}") 95 | raise 96 | 97 | def can_process(self, file_path: Union[str, Path]) -> bool: 98 | """Check if file is a supported image format.""" 99 | path = Path(file_path) 100 | return (path.suffix.lower() in self.supported_extensions and 101 | self.pil_available) 102 | 103 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 104 | """Extract content from image file.""" 105 | start_time = time.time() 106 | 107 | try: 108 | path = Path(file_path) 109 | 110 | # Load and preprocess image 111 | image = Image.open(path) 112 | original_size = image.size 113 | 114 | # Convert to RGB if necessary 115 | if image.mode != 'RGB': 116 | image = image.convert('RGB') 117 | 118 | # Resize if too large 119 | if (image.width > self.max_image_size[0] or 120 | image.height > self.max_image_size[1]): 121 | image.thumbnail(self.max_image_size, Image.Resampling.LANCZOS) 122 | 123 | # Get file metadata 124 | metadata = self._get_file_metadata(path) 125 | metadata['original_size'] = original_size 126 | metadata['processed_size'] = image.size 127 | metadata['image_mode'] = image.mode 128 | 129 | content_parts = [] 130 | 131 | # Extract text using OCR 132 | if self.ocr_enabled and self.tesseract_available: 133 | try: 134 | ocr_text = pytesseract.image_to_string(image, config='--psm 3') 135 | if ocr_text.strip(): 136 | content_parts.append(f"OCR Text: {ocr_text.strip()}") 137 | metadata['ocr_confidence'] = self._get_ocr_confidence(image) 138 | except Exception as e: 139 | logger.warning(f"OCR failed for {path}: {str(e)}") 140 | 141 | # Generate image caption using vision model 142 | if self.vision_model_available: 143 | try: 144 | caption = self._generate_caption(image) 145 | if caption: 146 | content_parts.append(f"Image Description: {caption}") 147 | metadata['has_caption'] = True 148 | except Exception as e: 149 | logger.warning(f"Caption generation failed for {path}: {str(e)}") 150 | 151 | # Extract EXIF data if available 152 | try: 153 | exif_data = self._extract_exif_data(Image.open(path)) 154 | if exif_data: 155 | metadata['exif_data'] = exif_data 156 | except Exception as e: 157 | logger.debug(f"EXIF extraction failed: {str(e)}") 158 | 159 | if not content_parts: 160 | return ProcessingResult( 161 | chunks=[], 162 | success=False, 163 | error_message="No text or description could be extracted from image" 164 | ) 165 | 166 | content = '\n'.join(content_parts) 167 | metadata['extraction_methods'] = len(content_parts) 168 | 169 | # Create chunks 170 | chunks = self._create_chunks(content, metadata, self.chunk_size, self.chunk_overlap) 171 | 172 | processing_time = time.time() - start_time 173 | 174 | return ProcessingResult( 175 | chunks=chunks, 176 | success=True, 177 | processing_time=processing_time, 178 | metadata={'chunks_created': len(chunks)} 179 | ) 180 | 181 | except Exception as e: 182 | logger.error(f"Error processing image file {file_path}: {str(e)}") 183 | return ProcessingResult( 184 | chunks=[], 185 | success=False, 186 | error_message=f"Failed to process image: {str(e)}", 187 | processing_time=time.time() - start_time 188 | ) 189 | 190 | def _generate_caption(self, image) -> Optional[str]: 191 | """Generate a caption for the image using the vision model.""" 192 | try: 193 | inputs = self.vision_processor(image, return_tensors="pt") 194 | 195 | # Move inputs to same device as model 196 | if torch.cuda.is_available() and next(self.vision_model.parameters()).is_cuda: 197 | inputs = {k: v.cuda() for k, v in inputs.items()} 198 | 199 | with torch.no_grad(): 200 | # Try different generation parameters for better captions 201 | outputs = self.vision_model.generate( 202 | **inputs, 203 | max_length=100, # Increased for more detailed captions 204 | num_beams=8, # More beams for better quality 205 | early_stopping=True, 206 | do_sample=True, 207 | temperature=0.7, 208 | top_p=0.9 209 | ) 210 | 211 | caption = self.vision_processor.decode(outputs[0], skip_special_tokens=True) 212 | 213 | # Filter out poor quality captions 214 | if caption and len(caption.strip()) > 10: 215 | # Check for repetitive patterns 216 | words = caption.split() 217 | if len(set(words)) < len(words) * 0.5: # Too many repeated words 218 | logger.warning(f"Repetitive caption detected, using fallback: {caption}") 219 | return f"Image content: Screenshot or document image (automatic caption failed)" 220 | return f"Visual content: {caption.strip()}" 221 | else: 222 | return f"Image content: Screenshot or document image" 223 | 224 | except Exception as e: 225 | logger.error(f"Caption generation error: {str(e)}") 226 | return f"Image content: Screenshot or document image (caption generation failed)" 227 | 228 | def _get_ocr_confidence(self, image) -> Optional[float]: 229 | """Get OCR confidence score.""" 230 | try: 231 | data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) 232 | confidences = [float(conf) for conf in data['conf'] if int(conf) > 0] 233 | return sum(confidences) / len(confidences) if confidences else None 234 | except Exception: 235 | return None 236 | 237 | def _extract_exif_data(self, image) -> Optional[Dict[str, Any]]: 238 | """Extract EXIF metadata from image.""" 239 | try: 240 | exif_dict = {} 241 | if hasattr(image, '_getexif') and image._getexif(): 242 | exif = image._getexif() 243 | for tag_id, value in exif.items(): 244 | tag = Image.ExifTags.TAGS.get(tag_id, tag_id) 245 | exif_dict[tag] = value 246 | return exif_dict if exif_dict else None 247 | except Exception: 248 | return None 249 | 250 | 251 | class ImageProcessorManager: 252 | """Manager for image processing operations.""" 253 | 254 | def __init__(self, config: Dict[str, Any]): 255 | self.config = config 256 | self.processor = ImageProcessor(config) 257 | 258 | def can_process(self, file_path: Union[str, Path]) -> bool: 259 | """Check if this processor can handle the given file.""" 260 | return self.processor.can_process(file_path) 261 | 262 | def process_file(self, file_path: Union[str, Path]) -> ProcessingResult: 263 | """Process an image file.""" 264 | return self.extract_content(file_path) 265 | 266 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 267 | """Extract content from an image file.""" 268 | if not self.processor.can_process(file_path): 269 | return ProcessingResult( 270 | chunks=[], 271 | success=False, 272 | error_message=f"Cannot process file: {file_path}" 273 | ) 274 | 275 | return self.processor.extract_content(file_path) 276 | 277 | def get_supported_extensions(self) -> List[str]: 278 | """Get supported image extensions.""" 279 | return self.processor.supported_extensions 280 | 281 | def is_available(self) -> bool: 282 | """Check if image processing is available.""" 283 | return self.processor.pil_available -------------------------------------------------------------------------------- /multimodal_rag/processors/audio_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Audio processor for speech-to-text conversion and audio content extraction. 3 | """ 4 | 5 | import logging 6 | import time 7 | from pathlib import Path 8 | from typing import Union, List, Dict, Any, Optional 9 | 10 | try: 11 | import whisper 12 | except ImportError: 13 | whisper = None 14 | 15 | try: 16 | from pydub import AudioSegment 17 | from pydub.utils import which 18 | except ImportError: 19 | AudioSegment = None 20 | which = None 21 | 22 | try: 23 | import speech_recognition as sr 24 | except ImportError: 25 | sr = None 26 | 27 | try: 28 | import librosa 29 | import numpy as np 30 | except ImportError: 31 | librosa = None 32 | np = None 33 | 34 | from ..base import BaseProcessor, ProcessingResult, DocumentChunk 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | class AudioProcessor(BaseProcessor): 40 | """Processor for audio files with speech-to-text capabilities.""" 41 | 42 | def __init__(self, config: Dict[str, Any]): 43 | super().__init__(config) 44 | self.supported_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.flac', '.aac'] 45 | self.processor_type = "audio" 46 | self.chunk_size = config.get('processing', {}).get('chunk_size', 1000) 47 | self.chunk_overlap = config.get('processing', {}).get('chunk_overlap', 200) 48 | 49 | # Configuration 50 | self.whisper_model_name = config.get('models', {}).get('whisper_model', 'base') 51 | self.max_audio_duration = config.get('processing', {}).get('max_audio_duration', 300) # seconds 52 | self.sample_rate = config.get('processing', {}).get('audio_sample_rate', 16000) 53 | 54 | # Check availability 55 | self.whisper_available = whisper is not None 56 | self.pydub_available = AudioSegment is not None 57 | self.speech_recognition_available = sr is not None 58 | self.librosa_available = librosa is not None 59 | 60 | if not self.whisper_available: 61 | logger.warning("Whisper not available. Install with: pip install openai-whisper") 62 | 63 | if not self.pydub_available: 64 | logger.warning("pydub not available. Install with: pip install pydub") 65 | 66 | # Initialize Whisper model 67 | self.whisper_model = None 68 | if self.whisper_available: 69 | try: 70 | self._load_whisper_model() 71 | except Exception as e: 72 | logger.warning(f"Failed to load Whisper model: {str(e)}") 73 | self.whisper_available = False 74 | 75 | # Initialize speech recognition 76 | self.speech_recognizer = None 77 | if self.speech_recognition_available: 78 | self.speech_recognizer = sr.Recognizer() 79 | 80 | def _load_whisper_model(self): 81 | """Load the Whisper model for speech-to-text.""" 82 | try: 83 | self.whisper_model = whisper.load_model(self.whisper_model_name) 84 | logger.info(f"Loaded Whisper model: {self.whisper_model_name}") 85 | except Exception as e: 86 | logger.error(f"Failed to load Whisper model: {str(e)}") 87 | raise 88 | 89 | def can_process(self, file_path: Union[str, Path]) -> bool: 90 | """Check if file is a supported audio format.""" 91 | path = Path(file_path) 92 | return (path.suffix.lower() in self.supported_extensions and 93 | (self.whisper_available or self.speech_recognition_available)) 94 | 95 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 96 | """Extract content from audio file.""" 97 | start_time = time.time() 98 | 99 | try: 100 | path = Path(file_path) 101 | 102 | # Get audio metadata 103 | metadata = self._get_file_metadata(path) 104 | audio_info = self._get_audio_info(path) 105 | metadata.update(audio_info) 106 | 107 | # Check duration limit 108 | if audio_info.get('duration', 0) > self.max_audio_duration: 109 | return ProcessingResult( 110 | chunks=[], 111 | success=False, 112 | error_message=f"Audio too long: {audio_info.get('duration', 0)}s > {self.max_audio_duration}s" 113 | ) 114 | 115 | # Convert audio to text 116 | transcript = None 117 | confidence = None 118 | 119 | if self.whisper_available: 120 | transcript, confidence = self._transcribe_with_whisper(path) 121 | elif self.speech_recognition_available: 122 | transcript, confidence = self._transcribe_with_speech_recognition(path) 123 | 124 | if not transcript or not transcript.strip(): 125 | return ProcessingResult( 126 | chunks=[], 127 | success=False, 128 | error_message="No speech could be detected or transcribed from audio" 129 | ) 130 | 131 | # Add transcription metadata 132 | metadata['transcript_length'] = len(transcript) 133 | metadata['word_count'] = len(transcript.split()) 134 | if confidence is not None: 135 | metadata['transcription_confidence'] = confidence 136 | 137 | # Create chunks from transcript 138 | chunks = self._create_chunks(transcript, metadata, self.chunk_size, self.chunk_overlap) 139 | 140 | processing_time = time.time() - start_time 141 | 142 | return ProcessingResult( 143 | chunks=chunks, 144 | success=True, 145 | processing_time=processing_time, 146 | metadata={'chunks_created': len(chunks)} 147 | ) 148 | 149 | except Exception as e: 150 | logger.error(f"Error processing audio file {file_path}: {str(e)}") 151 | return ProcessingResult( 152 | chunks=[], 153 | success=False, 154 | error_message=f"Failed to process audio: {str(e)}", 155 | processing_time=time.time() - start_time 156 | ) 157 | 158 | def _get_audio_info(self, path: Path) -> Dict[str, Any]: 159 | """Extract audio metadata.""" 160 | audio_info = {} 161 | 162 | try: 163 | if self.pydub_available: 164 | audio = AudioSegment.from_file(path) 165 | audio_info.update({ 166 | 'duration': len(audio) / 1000.0, # Convert to seconds 167 | 'sample_rate': audio.frame_rate, 168 | 'channels': audio.channels, 169 | 'sample_width': audio.sample_width, 170 | 'frame_count': audio.frame_count() 171 | }) 172 | elif self.librosa_available: 173 | # Use librosa as fallback 174 | duration = librosa.get_duration(filename=str(path)) 175 | audio_info['duration'] = duration 176 | 177 | except Exception as e: 178 | logger.warning(f"Failed to extract audio info: {str(e)}") 179 | 180 | return audio_info 181 | 182 | def _transcribe_with_whisper(self, path: Path) -> tuple[Optional[str], Optional[float]]: 183 | """Transcribe audio using Whisper.""" 184 | try: 185 | result = self.whisper_model.transcribe(str(path)) 186 | 187 | transcript = result.get('text', '').strip() 188 | 189 | # Calculate average confidence from segments 190 | segments = result.get('segments', []) 191 | if segments: 192 | confidences = [seg.get('avg_logprob', 0) for seg in segments if 'avg_logprob' in seg] 193 | avg_confidence = sum(confidences) / len(confidences) if confidences else None 194 | else: 195 | avg_confidence = None 196 | 197 | return transcript, avg_confidence 198 | 199 | except Exception as e: 200 | logger.error(f"Whisper transcription failed: {str(e)}") 201 | return None, None 202 | 203 | def _transcribe_with_speech_recognition(self, path: Path) -> tuple[Optional[str], Optional[float]]: 204 | """Transcribe audio using SpeechRecognition library.""" 205 | try: 206 | # Convert to WAV if necessary 207 | if self.pydub_available and path.suffix.lower() != '.wav': 208 | audio = AudioSegment.from_file(path) 209 | # Convert to temporary WAV 210 | temp_wav_path = path.with_suffix('.temp.wav') 211 | audio.export(temp_wav_path, format="wav") 212 | wav_path = temp_wav_path 213 | else: 214 | wav_path = path 215 | 216 | # Transcribe using speech_recognition 217 | with sr.AudioFile(str(wav_path)) as source: 218 | audio_data = self.speech_recognizer.record(source) 219 | 220 | try: 221 | # Try Google Speech Recognition (requires internet) 222 | transcript = self.speech_recognizer.recognize_google(audio_data) 223 | confidence = None # Google API doesn't return confidence 224 | except sr.RequestError: 225 | # Fallback to offline recognition if available 226 | try: 227 | transcript = self.speech_recognizer.recognize_sphinx(audio_data) 228 | confidence = None 229 | except (sr.RequestError, sr.UnknownValueError): 230 | transcript = None 231 | confidence = None 232 | 233 | # Clean up temporary file 234 | if 'temp_wav_path' in locals() and temp_wav_path.exists(): 235 | temp_wav_path.unlink() 236 | 237 | return transcript, confidence 238 | 239 | except Exception as e: 240 | logger.error(f"Speech recognition failed: {str(e)}") 241 | return None, None 242 | 243 | def _segment_long_audio(self, path: Path, segment_duration: int = 30) -> List[Path]: 244 | """Split long audio into smaller segments.""" 245 | if not self.pydub_available: 246 | return [path] 247 | 248 | try: 249 | audio = AudioSegment.from_file(path) 250 | segment_length_ms = segment_duration * 1000 251 | 252 | segments = [] 253 | for i, start_ms in enumerate(range(0, len(audio), segment_length_ms)): 254 | end_ms = min(start_ms + segment_length_ms, len(audio)) 255 | segment = audio[start_ms:end_ms] 256 | 257 | segment_path = path.with_name(f"{path.stem}_segment_{i}{path.suffix}") 258 | segment.export(segment_path, format=path.suffix[1:]) 259 | segments.append(segment_path) 260 | 261 | return segments 262 | 263 | except Exception as e: 264 | logger.warning(f"Audio segmentation failed: {str(e)}") 265 | return [path] 266 | 267 | 268 | class AudioProcessorManager: 269 | """Manager for audio processing operations.""" 270 | 271 | def __init__(self, config: Dict[str, Any]): 272 | self.config = config 273 | self.processor = AudioProcessor(config) 274 | 275 | def process_file(self, file_path: Union[str, Path]) -> ProcessingResult: 276 | """Process an audio file.""" 277 | return self.extract_content(file_path) 278 | 279 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 280 | """Extract content from an audio file.""" 281 | if not self.processor.can_process(file_path): 282 | return ProcessingResult( 283 | chunks=[], 284 | success=False, 285 | error_message=f"Cannot process file: {file_path}" 286 | ) 287 | 288 | return self.processor.extract_content(file_path) 289 | 290 | def get_supported_extensions(self) -> List[str]: 291 | """Get supported audio extensions.""" 292 | return self.processor.supported_extensions 293 | 294 | def is_available(self) -> bool: 295 | """Check if audio processing is available.""" 296 | return self.processor.whisper_available or self.processor.speech_recognition_available 297 | 298 | def can_process(self, file_path: Union[str, Path]) -> bool: 299 | """Check if the audio processor can handle this file.""" 300 | return self.processor.can_process(file_path) -------------------------------------------------------------------------------- /multimodal_rag/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base classes and interfaces for the multimodal RAG system. 3 | """ 4 | 5 | import logging 6 | from abc import ABC, abstractmethod 7 | from dataclasses import dataclass, field 8 | from typing import List, Dict, Any, Optional, Union 9 | from pathlib import Path 10 | import uuid 11 | from datetime import datetime 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @dataclass 17 | class DocumentChunk: 18 | """Represents a chunk of processed content from any modality.""" 19 | 20 | content: str 21 | metadata: Dict[str, Any] = field(default_factory=dict) 22 | embedding: Optional[List[float]] = None 23 | document_type: str = "text" 24 | chunk_id: str = field(default_factory=lambda: str(uuid.uuid4())) 25 | source_file: Optional[str] = None 26 | page_number: Optional[int] = None 27 | timestamp: datetime = field(default_factory=datetime.now) 28 | 29 | def __post_init__(self): 30 | """Ensure metadata has required fields.""" 31 | if 'chunk_index' not in self.metadata: 32 | self.metadata['chunk_index'] = 0 33 | if 'total_chunks' not in self.metadata: 34 | self.metadata['total_chunks'] = 1 35 | 36 | 37 | @dataclass 38 | class ProcessingResult: 39 | """Result of processing a file through any processor.""" 40 | 41 | chunks: List[DocumentChunk] 42 | success: bool = True 43 | error_message: Optional[str] = None 44 | processing_time: Optional[float] = None 45 | metadata: Dict[str, Any] = field(default_factory=dict) 46 | 47 | 48 | @dataclass 49 | class RetrievalResult: 50 | """Result of a semantic search query.""" 51 | 52 | chunks: List[DocumentChunk] 53 | scores: List[float] 54 | query: str 55 | total_results: int 56 | retrieval_time: Optional[float] = None 57 | 58 | 59 | class BaseProcessor(ABC): 60 | """Abstract base class for all content processors.""" 61 | 62 | def __init__(self, config: Dict[str, Any]): 63 | self.config = config 64 | self.supported_extensions = [] 65 | self.processor_type = "base" 66 | 67 | @abstractmethod 68 | def can_process(self, file_path: Union[str, Path]) -> bool: 69 | """Check if this processor can handle the given file.""" 70 | pass 71 | 72 | @abstractmethod 73 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 74 | """Extract and chunk content from the file.""" 75 | pass 76 | 77 | def _get_file_metadata(self, file_path: Union[str, Path]) -> Dict[str, Any]: 78 | """Extract common file metadata.""" 79 | path = Path(file_path) 80 | return { 81 | 'filename': path.name, 82 | 'file_extension': path.suffix.lower(), 83 | 'file_size': path.stat().st_size, 84 | 'created_date': datetime.fromtimestamp(path.stat().st_ctime), 85 | 'modified_date': datetime.fromtimestamp(path.stat().st_mtime), 86 | 'absolute_path': str(path.absolute()) 87 | } 88 | 89 | def _create_chunks(self, content: str, metadata: Dict[str, Any], 90 | chunk_size: int = 1000, overlap: int = 200) -> List[DocumentChunk]: 91 | """Create overlapping chunks from content.""" 92 | if not content.strip(): 93 | return [] 94 | 95 | chunks = [] 96 | content_length = len(content) 97 | 98 | if content_length <= chunk_size: 99 | # Content fits in one chunk - extract page number if available 100 | page_number = self._extract_page_number_from_content(content) 101 | chunk = DocumentChunk( 102 | content=content, 103 | metadata={**metadata, 'chunk_index': 0, 'total_chunks': 1}, 104 | document_type=self.processor_type, 105 | source_file=metadata.get('absolute_path'), 106 | page_number=page_number 107 | ) 108 | chunks.append(chunk) 109 | else: 110 | # Split into multiple chunks 111 | total_chunks = (content_length - overlap) // (chunk_size - overlap) + 1 112 | 113 | for i in range(0, content_length, chunk_size - overlap): 114 | chunk_content = content[i:i + chunk_size] 115 | if not chunk_content.strip(): 116 | continue 117 | 118 | # Extract page number from chunk content 119 | page_number = self._extract_page_number_from_content(chunk_content) 120 | 121 | chunk_metadata = { 122 | **metadata, 123 | 'chunk_index': len(chunks), 124 | 'total_chunks': total_chunks, 125 | 'start_index': i, 126 | 'end_index': min(i + chunk_size, content_length) 127 | } 128 | 129 | # Add page number to metadata if found 130 | if page_number: 131 | chunk_metadata['page_number'] = page_number 132 | 133 | chunk = DocumentChunk( 134 | content=chunk_content, 135 | metadata=chunk_metadata, 136 | document_type=self.processor_type, 137 | source_file=metadata.get('absolute_path'), 138 | page_number=page_number 139 | ) 140 | chunks.append(chunk) 141 | 142 | return chunks 143 | 144 | def _extract_page_number_from_content(self, content: str) -> Optional[int]: 145 | """Extract page number from content that contains page markers.""" 146 | import re 147 | # Look for page markers like "--- Page 2 ---" 148 | page_match = re.search(r'---\s*Page\s+(\d+)\s*---', content) 149 | if page_match: 150 | return int(page_match.group(1)) 151 | return None 152 | 153 | 154 | class BaseVectorStore(ABC): 155 | """Abstract base class for vector storage backends.""" 156 | 157 | def __init__(self, config: Dict[str, Any]): 158 | self.config = config 159 | self.collection_name = config.get('collection_name', 'default') 160 | 161 | @abstractmethod 162 | def add_documents(self, chunks: List[DocumentChunk]) -> bool: 163 | """Add document chunks to the vector store.""" 164 | pass 165 | 166 | @abstractmethod 167 | def similarity_search(self, query: str, k: int = 5, 168 | filter_dict: Optional[Dict[str, Any]] = None) -> RetrievalResult: 169 | """Perform similarity search and return results.""" 170 | pass 171 | 172 | @abstractmethod 173 | def delete_documents(self, chunk_ids: List[str]) -> bool: 174 | """Delete documents by chunk IDs.""" 175 | pass 176 | 177 | @abstractmethod 178 | def get_collection_stats(self) -> Dict[str, Any]: 179 | """Get statistics about the collection.""" 180 | pass 181 | 182 | 183 | class BaseEmbedding(ABC): 184 | """Abstract base class for embedding models.""" 185 | 186 | def __init__(self, config: Dict[str, Any]): 187 | self.config = config 188 | self.model_name = config.get('embedding_model', 'nomic-embed-text') 189 | 190 | @abstractmethod 191 | def embed_text(self, text: str) -> List[float]: 192 | """Generate embedding for a single text.""" 193 | pass 194 | 195 | @abstractmethod 196 | def embed_batch(self, texts: List[str]) -> List[List[float]]: 197 | """Generate embeddings for a batch of texts.""" 198 | pass 199 | 200 | @abstractmethod 201 | def get_embedding_dimension(self) -> int: 202 | """Get the dimension of embeddings produced by this model.""" 203 | pass 204 | 205 | 206 | class BaseLLM(ABC): 207 | """Abstract base class for language models.""" 208 | 209 | def __init__(self, config: Dict[str, Any]): 210 | self.config = config 211 | self.model_name = config.get('llm_model', 'microsoft/DialoGPT-medium') 212 | 213 | @abstractmethod 214 | def generate_response(self, prompt: str, context: str = "", 215 | max_tokens: int = 512) -> str: 216 | """Generate a response given a prompt and context.""" 217 | pass 218 | 219 | @abstractmethod 220 | def is_available(self) -> bool: 221 | """Check if the model is available and loaded.""" 222 | pass 223 | 224 | 225 | @dataclass 226 | class QueryRequest: 227 | """Represents a user query request.""" 228 | 229 | query: str 230 | query_type: str = "general" # general, document, image, audio 231 | filters: Dict[str, Any] = field(default_factory=dict) 232 | top_k: int = 5 233 | include_metadata: bool = True 234 | rerank: bool = False 235 | generation_params: Dict[str, Any] = field(default_factory=dict) 236 | 237 | 238 | @dataclass 239 | class QueryResponse: 240 | """Represents the response to a user query.""" 241 | 242 | answer: str 243 | sources: List[DocumentChunk] 244 | query: str 245 | confidence_score: Optional[float] = None 246 | processing_time: Optional[float] = None 247 | metadata: Dict[str, Any] = field(default_factory=dict) 248 | 249 | 250 | class OllamaLLM(BaseLLM): 251 | """Ollama LLM implementation.""" 252 | 253 | def __init__(self, config: Dict[str, Any]): 254 | super().__init__(config) 255 | try: 256 | import ollama 257 | self.ollama = ollama 258 | except ImportError: 259 | raise ImportError("Ollama not available. Install with: pip install ollama") 260 | 261 | self.model_name = config.get('models', {}).get('llm_model', 'llama3.1:8b') 262 | self.host = config.get('models', {}).get('ollama_host', 'http://localhost:11434') 263 | 264 | def is_available(self) -> bool: 265 | """Check if the LLM is available.""" 266 | try: 267 | # Try to connect to Ollama server 268 | logger.info(f"Checking LLM availability for model: {self.model_name}") 269 | response = self.ollama.list() 270 | # Check if our model is available 271 | available_models = [model.model for model in response.models] 272 | logger.info(f"Available models: {available_models}") 273 | 274 | # Check for exact match or partial match (handles version suffixes) 275 | model_found = False 276 | for available_model in available_models: 277 | if self.model_name == available_model or self.model_name in available_model or available_model in self.model_name: 278 | model_found = True 279 | logger.info(f"Model match found: {available_model} matches {self.model_name}") 280 | break 281 | 282 | logger.info(f"Model {self.model_name} found: {model_found}") 283 | return model_found 284 | except Exception as e: 285 | logger.error(f"Error checking LLM availability: {str(e)}") 286 | return False 287 | 288 | def generate_response(self, prompt: str, context: str = "", **kwargs) -> str: 289 | """Generate response using Ollama.""" 290 | try: 291 | # Use the prompt as-is if it already includes system instructions 292 | # Otherwise, add context if provided 293 | full_prompt = prompt 294 | if context and context.strip() and "Context from documents:" not in prompt: 295 | full_prompt = f"Context: {context}\n\nQuestion: {prompt}" 296 | 297 | response = self.ollama.chat( 298 | model=self.model_name, 299 | messages=[{'role': 'user', 'content': full_prompt}], 300 | options={ 301 | 'temperature': kwargs.get('temperature', 0.7), 302 | 'top_p': kwargs.get('top_p', 0.9), 303 | 'max_tokens': kwargs.get('max_tokens', 2048) 304 | } 305 | ) 306 | return response['message']['content'] 307 | except Exception as e: 308 | raise Exception(f"Error generating response with Ollama: {str(e)}") 309 | 310 | def generate_streaming_response(self, prompt: str, **kwargs): 311 | """Generate streaming response using Ollama.""" 312 | try: 313 | stream = self.ollama.chat( 314 | model=self.model_name, 315 | messages=[{'role': 'user', 'content': prompt}], 316 | stream=True, 317 | options={ 318 | 'temperature': kwargs.get('temperature', 0.7), 319 | 'top_p': kwargs.get('top_p', 0.9), 320 | 'max_tokens': kwargs.get('max_tokens', 2048) 321 | } 322 | ) 323 | for chunk in stream: 324 | if 'message' in chunk and 'content' in chunk['message']: 325 | yield chunk['message']['content'] 326 | except Exception as e: 327 | raise Exception(f"Error generating streaming response with Ollama: {str(e)}") -------------------------------------------------------------------------------- /multimodal_rag/processors/document_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Document processors for text-based formats (PDF, DOCX, TXT, etc.). 3 | """ 4 | 5 | import logging 6 | import time 7 | from pathlib import Path 8 | from typing import Union, List, Dict, Any, Optional 9 | 10 | try: 11 | import PyPDF2 12 | from PyPDF2 import PdfReader 13 | except ImportError: 14 | PyPDF2 = None 15 | 16 | try: 17 | from docx import Document as DocxDocument 18 | except ImportError: 19 | DocxDocument = None 20 | 21 | try: 22 | import pdfplumber 23 | except ImportError: 24 | pdfplumber = None 25 | 26 | from ..base import BaseProcessor, ProcessingResult, DocumentChunk 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class TextProcessor(BaseProcessor): 32 | """Processor for plain text files.""" 33 | 34 | def __init__(self, config: Dict[str, Any]): 35 | super().__init__(config) 36 | self.supported_extensions = ['.txt', '.md', '.rtf'] 37 | self.processor_type = "text" 38 | self.chunk_size = config.get('processing', {}).get('chunk_size', 1000) 39 | self.chunk_overlap = config.get('processing', {}).get('chunk_overlap', 200) 40 | 41 | def can_process(self, file_path: Union[str, Path]) -> bool: 42 | """Check if file is a supported text format.""" 43 | path = Path(file_path) 44 | return path.suffix.lower() in self.supported_extensions 45 | 46 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 47 | """Extract content from text file.""" 48 | start_time = time.time() 49 | 50 | try: 51 | path = Path(file_path) 52 | 53 | # Read text content 54 | with open(path, 'r', encoding='utf-8', errors='ignore') as file: 55 | content = file.read() 56 | 57 | if not content.strip(): 58 | return ProcessingResult( 59 | chunks=[], 60 | success=False, 61 | error_message="File is empty or contains no readable text" 62 | ) 63 | 64 | # Get file metadata 65 | metadata = self._get_file_metadata(path) 66 | metadata['word_count'] = len(content.split()) 67 | metadata['char_count'] = len(content) 68 | 69 | # Create chunks 70 | chunks = self._create_chunks(content, metadata, self.chunk_size, self.chunk_overlap) 71 | 72 | processing_time = time.time() - start_time 73 | 74 | return ProcessingResult( 75 | chunks=chunks, 76 | success=True, 77 | processing_time=processing_time, 78 | metadata={'chunks_created': len(chunks)} 79 | ) 80 | 81 | except Exception as e: 82 | logger.error(f"Error processing text file {file_path}: {str(e)}") 83 | return ProcessingResult( 84 | chunks=[], 85 | success=False, 86 | error_message=f"Failed to process text file: {str(e)}", 87 | processing_time=time.time() - start_time 88 | ) 89 | 90 | 91 | class PDFProcessor(BaseProcessor): 92 | """Processor for PDF files.""" 93 | 94 | def __init__(self, config: Dict[str, Any]): 95 | super().__init__(config) 96 | self.supported_extensions = ['.pdf'] 97 | self.processor_type = "pdf" 98 | self.chunk_size = config.get('processing', {}).get('chunk_size', 1000) 99 | self.chunk_overlap = config.get('processing', {}).get('chunk_overlap', 200) 100 | 101 | # Check for available PDF libraries 102 | self.use_pdfplumber = pdfplumber is not None 103 | self.use_pypdf2 = PyPDF2 is not None 104 | 105 | if not (self.use_pdfplumber or self.use_pypdf2): 106 | logger.warning("No PDF processing libraries available. Install pdfplumber or PyPDF2.") 107 | 108 | def can_process(self, file_path: Union[str, Path]) -> bool: 109 | """Check if file is a PDF.""" 110 | path = Path(file_path) 111 | return (path.suffix.lower() == '.pdf' and 112 | (self.use_pdfplumber or self.use_pypdf2)) 113 | 114 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 115 | """Extract content from PDF file.""" 116 | start_time = time.time() 117 | 118 | try: 119 | path = Path(file_path) 120 | 121 | if self.use_pdfplumber: 122 | content, page_info = self._extract_with_pdfplumber(path) 123 | elif self.use_pypdf2: 124 | content, page_info = self._extract_with_pypdf2(path) 125 | else: 126 | return ProcessingResult( 127 | chunks=[], 128 | success=False, 129 | error_message="No PDF processing library available" 130 | ) 131 | 132 | if not content.strip(): 133 | return ProcessingResult( 134 | chunks=[], 135 | success=False, 136 | error_message="PDF contains no extractable text" 137 | ) 138 | 139 | # Get file metadata 140 | metadata = self._get_file_metadata(path) 141 | metadata.update(page_info) 142 | metadata['word_count'] = len(content.split()) 143 | metadata['char_count'] = len(content) 144 | 145 | # Create chunks 146 | chunks = self._create_chunks(content, metadata, self.chunk_size, self.chunk_overlap) 147 | 148 | processing_time = time.time() - start_time 149 | 150 | return ProcessingResult( 151 | chunks=chunks, 152 | success=True, 153 | processing_time=processing_time, 154 | metadata={'chunks_created': len(chunks)} 155 | ) 156 | 157 | except Exception as e: 158 | logger.error(f"Error processing PDF file {file_path}: {str(e)}") 159 | return ProcessingResult( 160 | chunks=[], 161 | success=False, 162 | error_message=f"Failed to process PDF: {str(e)}", 163 | processing_time=time.time() - start_time 164 | ) 165 | 166 | def _extract_with_pdfplumber(self, path: Path) -> tuple[str, Dict[str, Any]]: 167 | """Extract text using pdfplumber.""" 168 | content_parts = [] 169 | page_info = {'total_pages': 0, 'extracted_pages': 0} 170 | 171 | with pdfplumber.open(path) as pdf: 172 | page_info['total_pages'] = len(pdf.pages) 173 | 174 | for page_num, page in enumerate(pdf.pages, 1): 175 | try: 176 | page_text = page.extract_text() 177 | if page_text and page_text.strip(): 178 | content_parts.append(f"\n--- Page {page_num} ---\n{page_text}") 179 | page_info['extracted_pages'] += 1 180 | except Exception as e: 181 | logger.warning(f"Failed to extract text from page {page_num}: {str(e)}") 182 | 183 | return '\n'.join(content_parts), page_info 184 | 185 | def _extract_with_pypdf2(self, path: Path) -> tuple[str, Dict[str, Any]]: 186 | """Extract text using PyPDF2.""" 187 | content_parts = [] 188 | page_info = {'total_pages': 0, 'extracted_pages': 0} 189 | 190 | with open(path, 'rb') as file: 191 | pdf_reader = PdfReader(file) 192 | page_info['total_pages'] = len(pdf_reader.pages) 193 | 194 | for page_num, page in enumerate(pdf_reader.pages, 1): 195 | try: 196 | page_text = page.extract_text() 197 | if page_text and page_text.strip(): 198 | content_parts.append(f"\n--- Page {page_num} ---\n{page_text}") 199 | page_info['extracted_pages'] += 1 200 | except Exception as e: 201 | logger.warning(f"Failed to extract text from page {page_num}: {str(e)}") 202 | 203 | return '\n'.join(content_parts), page_info 204 | 205 | 206 | class DOCXProcessor(BaseProcessor): 207 | """Processor for DOCX files.""" 208 | 209 | def __init__(self, config: Dict[str, Any]): 210 | super().__init__(config) 211 | self.supported_extensions = ['.docx', '.doc'] 212 | self.processor_type = "docx" 213 | self.chunk_size = config.get('processing', {}).get('chunk_size', 1000) 214 | self.chunk_overlap = config.get('processing', {}).get('chunk_overlap', 200) 215 | 216 | if DocxDocument is None: 217 | logger.warning("python-docx not available. Install with: pip install python-docx") 218 | 219 | def can_process(self, file_path: Union[str, Path]) -> bool: 220 | """Check if file is a DOCX.""" 221 | path = Path(file_path) 222 | return (path.suffix.lower() in self.supported_extensions and 223 | DocxDocument is not None) 224 | 225 | def extract_content(self, file_path: Union[str, Path]) -> ProcessingResult: 226 | """Extract content from DOCX file.""" 227 | start_time = time.time() 228 | 229 | try: 230 | path = Path(file_path) 231 | 232 | # Only process .docx files (not .doc) 233 | if path.suffix.lower() == '.doc': 234 | return ProcessingResult( 235 | chunks=[], 236 | success=False, 237 | error_message="Legacy .doc format not supported. Convert to .docx first." 238 | ) 239 | 240 | doc = DocxDocument(path) 241 | 242 | # Extract paragraphs 243 | content_parts = [] 244 | for paragraph in doc.paragraphs: 245 | if paragraph.text.strip(): 246 | content_parts.append(paragraph.text) 247 | 248 | content = '\n'.join(content_parts) 249 | 250 | if not content.strip(): 251 | return ProcessingResult( 252 | chunks=[], 253 | success=False, 254 | error_message="DOCX contains no extractable text" 255 | ) 256 | 257 | # Get file metadata 258 | metadata = self._get_file_metadata(path) 259 | metadata['paragraph_count'] = len(content_parts) 260 | metadata['word_count'] = len(content.split()) 261 | metadata['char_count'] = len(content) 262 | 263 | # Add document properties if available 264 | try: 265 | core_props = doc.core_properties 266 | metadata['doc_title'] = core_props.title or 'Unknown' 267 | metadata['doc_author'] = core_props.author or 'Unknown' 268 | metadata['doc_subject'] = core_props.subject or 'Unknown' 269 | except Exception: 270 | pass 271 | 272 | # Create chunks 273 | chunks = self._create_chunks(content, metadata, self.chunk_size, self.chunk_overlap) 274 | 275 | processing_time = time.time() - start_time 276 | 277 | return ProcessingResult( 278 | chunks=chunks, 279 | success=True, 280 | processing_time=processing_time, 281 | metadata={'chunks_created': len(chunks)} 282 | ) 283 | 284 | except Exception as e: 285 | logger.error(f"Error processing DOCX file {file_path}: {str(e)}") 286 | return ProcessingResult( 287 | chunks=[], 288 | success=False, 289 | error_message=f"Failed to process DOCX: {str(e)}", 290 | processing_time=time.time() - start_time 291 | ) 292 | 293 | 294 | class DocumentProcessorManager: 295 | """Manager class to handle different document types.""" 296 | 297 | def __init__(self, config: Dict[str, Any]): 298 | self.config = config 299 | self.processors = [ 300 | TextProcessor(config), 301 | PDFProcessor(config), 302 | DOCXProcessor(config) 303 | ] 304 | 305 | def get_processor(self, file_path: Union[str, Path]) -> Optional[BaseProcessor]: 306 | """Get the appropriate processor for a file.""" 307 | for processor in self.processors: 308 | if processor.can_process(file_path): 309 | return processor 310 | return None 311 | 312 | def process_file(self, file_path: Union[str, Path]) -> ProcessingResult: 313 | """Process a file using the appropriate processor.""" 314 | processor = self.get_processor(file_path) 315 | if processor is None: 316 | return ProcessingResult( 317 | chunks=[], 318 | success=False, 319 | error_message=f"No processor available for file: {file_path}" 320 | ) 321 | 322 | return processor.extract_content(file_path) 323 | 324 | def get_supported_extensions(self) -> List[str]: 325 | """Get all supported file extensions.""" 326 | extensions = [] 327 | for processor in self.processors: 328 | extensions.extend(processor.supported_extensions) 329 | return sorted(list(set(extensions))) 330 | 331 | def is_available(self) -> bool: 332 | """Check if document processors are available.""" 333 | return len(self.processors) > 0 334 | 335 | def can_process(self, file_path: Union[str, Path]) -> bool: 336 | """Check if any processor can handle this file.""" 337 | return self.get_processor(file_path) is not None -------------------------------------------------------------------------------- /multimodal_rag/vector_stores/faiss_store.py: -------------------------------------------------------------------------------- 1 | """ 2 | FAISS vector store implementation. 3 | """ 4 | 5 | import logging 6 | import pickle 7 | import time 8 | import uuid 9 | from pathlib import Path 10 | from typing import List, Dict, Any, Optional 11 | 12 | try: 13 | import faiss 14 | import numpy as np 15 | except ImportError: 16 | faiss = None 17 | np = None 18 | 19 | try: 20 | from sentence_transformers import SentenceTransformer 21 | except ImportError: 22 | SentenceTransformer = None 23 | 24 | from ..base import BaseVectorStore, DocumentChunk, RetrievalResult 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class FAISSVectorStore(BaseVectorStore): 30 | """FAISS implementation of vector store.""" 31 | 32 | def __init__(self, config: Dict[str, Any]): 33 | super().__init__(config) 34 | 35 | if faiss is None or np is None: 36 | raise ImportError("FAISS or NumPy not available. Install with: pip install faiss-cpu numpy") 37 | 38 | if SentenceTransformer is None: 39 | raise ImportError("SentenceTransformers not available. Install with: pip install sentence-transformers") 40 | 41 | self.persist_directory = Path(config.get('persist_directory', './vector_db')) 42 | # Configuration 43 | self.embedding_model_name = config.get('embedding_model', 'nomic-embed-text') 44 | self.embedding_dimension = config.get('embedding_dimension', 384) 45 | 46 | # Initialize components 47 | self.embedding_model = None 48 | self.index = None 49 | self.document_store = {} # Store actual documents with metadata 50 | self.id_to_index_map = {} # Map chunk IDs to FAISS indices 51 | self.index_to_id_map = {} # Map FAISS indices to chunk IDs 52 | 53 | # Create persist directory 54 | self.persist_directory.mkdir(parents=True, exist_ok=True) 55 | 56 | self._initialize_embedding_model() 57 | self._initialize_or_load_index() 58 | 59 | def _initialize_embedding_model(self): 60 | """Initialize the embedding model.""" 61 | try: 62 | self.embedding_model = SentenceTransformer(self.embedding_model_name) 63 | logger.info(f"Initialized embedding model: {self.embedding_model_name}") 64 | except Exception as e: 65 | logger.error(f"Failed to initialize embedding model: {str(e)}") 66 | raise 67 | 68 | def _initialize_or_load_index(self): 69 | """Initialize or load existing FAISS index.""" 70 | index_path = self.persist_directory / f"{self.collection_name}.index" 71 | metadata_path = self.persist_directory / f"{self.collection_name}.metadata" 72 | 73 | if index_path.exists() and metadata_path.exists(): 74 | self._load_index() 75 | else: 76 | self._create_new_index() 77 | 78 | def _create_new_index(self): 79 | """Create a new FAISS index.""" 80 | try: 81 | # Create FAISS index (using Inner Product for cosine similarity) 82 | self.index = faiss.IndexFlatIP(self.embedding_dimension) 83 | 84 | # Initialize empty stores 85 | self.document_store = {} 86 | self.id_to_index_map = {} 87 | self.index_to_id_map = {} 88 | 89 | logger.info(f"Created new FAISS index with dimension {self.embedding_dimension}") 90 | 91 | except Exception as e: 92 | logger.error(f"Failed to create FAISS index: {str(e)}") 93 | raise 94 | 95 | def _load_index(self): 96 | """Load existing FAISS index and metadata.""" 97 | try: 98 | index_path = self.persist_directory / f"{self.collection_name}.index" 99 | metadata_path = self.persist_directory / f"{self.collection_name}.metadata" 100 | 101 | # Load FAISS index 102 | self.index = faiss.read_index(str(index_path)) 103 | 104 | # Load metadata 105 | with open(metadata_path, 'rb') as f: 106 | metadata = pickle.load(f) 107 | self.document_store = metadata['document_store'] 108 | self.id_to_index_map = metadata['id_to_index_map'] 109 | self.index_to_id_map = metadata['index_to_id_map'] 110 | 111 | logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors") 112 | 113 | except Exception as e: 114 | logger.error(f"Failed to load FAISS index: {str(e)}") 115 | # Fallback to creating new index 116 | self._create_new_index() 117 | 118 | def _save_index(self): 119 | """Save FAISS index and metadata to disk.""" 120 | try: 121 | index_path = self.persist_directory / f"{self.collection_name}.index" 122 | metadata_path = self.persist_directory / f"{self.collection_name}.metadata" 123 | 124 | # Save FAISS index 125 | faiss.write_index(self.index, str(index_path)) 126 | 127 | # Save metadata 128 | metadata = { 129 | 'document_store': self.document_store, 130 | 'id_to_index_map': self.id_to_index_map, 131 | 'index_to_id_map': self.index_to_id_map 132 | } 133 | 134 | with open(metadata_path, 'wb') as f: 135 | pickle.dump(metadata, f) 136 | 137 | logger.debug("Saved FAISS index and metadata to disk") 138 | 139 | except Exception as e: 140 | logger.error(f"Failed to save FAISS index: {str(e)}") 141 | 142 | def add_documents(self, chunks: List[DocumentChunk]) -> bool: 143 | """Add document chunks to the vector store.""" 144 | try: 145 | if not chunks: 146 | return True 147 | 148 | # Generate embeddings 149 | texts = [chunk.content for chunk in chunks] 150 | embeddings = self.embedding_model.encode(texts, convert_to_numpy=True) 151 | 152 | # Normalize for cosine similarity 153 | faiss.normalize_L2(embeddings) 154 | 155 | # Add to FAISS index 156 | start_index = self.index.ntotal 157 | self.index.add(embeddings) 158 | 159 | # Store documents and update mappings 160 | for i, chunk in enumerate(chunks): 161 | chunk_id = chunk.chunk_id or str(uuid.uuid4()) 162 | faiss_index = start_index + i 163 | 164 | # Store document 165 | self.document_store[chunk_id] = { 166 | 'content': chunk.content, 167 | 'metadata': chunk.metadata, 168 | 'document_type': chunk.document_type, 169 | 'source_file': chunk.source_file, 170 | 'timestamp': chunk.timestamp.isoformat() if chunk.timestamp else None 171 | } 172 | 173 | # Update mappings 174 | self.id_to_index_map[chunk_id] = faiss_index 175 | self.index_to_id_map[faiss_index] = chunk_id 176 | 177 | # Save to disk 178 | self._save_index() 179 | 180 | logger.info(f"Added {len(chunks)} chunks to FAISS index") 181 | return True 182 | 183 | except Exception as e: 184 | logger.error(f"Failed to add documents to FAISS: {str(e)}") 185 | return False 186 | 187 | def similarity_search(self, query: str, k: int = 5, 188 | filter_dict: Optional[Dict[str, Any]] = None) -> RetrievalResult: 189 | """Perform similarity search and return results.""" 190 | start_time = time.time() 191 | 192 | try: 193 | if self.index.ntotal == 0: 194 | return RetrievalResult( 195 | chunks=[], 196 | scores=[], 197 | query=query, 198 | total_results=0, 199 | retrieval_time=time.time() - start_time 200 | ) 201 | 202 | # Generate query embedding 203 | query_embedding = self.embedding_model.encode([query], convert_to_numpy=True) 204 | faiss.normalize_L2(query_embedding) 205 | 206 | # Perform search 207 | search_k = min(k * 2, self.index.ntotal) # Get more results for filtering 208 | scores, indices = self.index.search(query_embedding, search_k) 209 | 210 | # Convert results to DocumentChunks 211 | chunks = [] 212 | final_scores = [] 213 | 214 | for score, index in zip(scores[0], indices[0]): 215 | if index == -1: # FAISS returns -1 for invalid indices 216 | continue 217 | 218 | chunk_id = self.index_to_id_map.get(index) 219 | if not chunk_id or chunk_id not in self.document_store: 220 | continue 221 | 222 | doc_data = self.document_store[chunk_id] 223 | 224 | # Apply filters if specified 225 | if filter_dict and not self._matches_filter(doc_data, filter_dict): 226 | continue 227 | 228 | # Create DocumentChunk 229 | chunk = DocumentChunk( 230 | content=doc_data['content'], 231 | metadata=doc_data['metadata'], 232 | document_type=doc_data['document_type'], 233 | chunk_id=chunk_id, 234 | source_file=doc_data['source_file'] 235 | ) 236 | 237 | chunks.append(chunk) 238 | final_scores.append(float(score)) 239 | 240 | if len(chunks) >= k: 241 | break 242 | 243 | retrieval_time = time.time() - start_time 244 | 245 | return RetrievalResult( 246 | chunks=chunks, 247 | scores=final_scores, 248 | query=query, 249 | total_results=len(chunks), 250 | retrieval_time=retrieval_time 251 | ) 252 | 253 | except Exception as e: 254 | logger.error(f"Similarity search failed: {str(e)}") 255 | return RetrievalResult( 256 | chunks=[], 257 | scores=[], 258 | query=query, 259 | total_results=0, 260 | retrieval_time=time.time() - start_time 261 | ) 262 | 263 | def delete_documents(self, chunk_ids: List[str]) -> bool: 264 | """Delete documents by chunk IDs.""" 265 | try: 266 | # Note: FAISS doesn't support deletion, so we mark as deleted 267 | # and rebuild index if too many deletions accumulate 268 | deleted_count = 0 269 | 270 | for chunk_id in chunk_ids: 271 | if chunk_id in self.document_store: 272 | del self.document_store[chunk_id] 273 | 274 | if chunk_id in self.id_to_index_map: 275 | faiss_index = self.id_to_index_map[chunk_id] 276 | del self.id_to_index_map[chunk_id] 277 | del self.index_to_id_map[faiss_index] 278 | 279 | deleted_count += 1 280 | 281 | if deleted_count > 0: 282 | self._save_index() 283 | logger.info(f"Marked {deleted_count} documents as deleted") 284 | 285 | return True 286 | 287 | except Exception as e: 288 | logger.error(f"Failed to delete documents: {str(e)}") 289 | return False 290 | 291 | def get_collection_stats(self) -> Dict[str, Any]: 292 | """Get statistics about the collection.""" 293 | try: 294 | return { 295 | 'collection_name': self.collection_name, 296 | 'document_count': len(self.document_store), 297 | 'faiss_index_size': self.index.ntotal, 298 | 'embedding_dimension': self.embedding_dimension, 299 | 'embedding_model': self.embedding_model_name, 300 | 'persist_directory': str(self.persist_directory) 301 | } 302 | 303 | except Exception as e: 304 | logger.error(f"Failed to get collection stats: {str(e)}") 305 | return {} 306 | 307 | def _matches_filter(self, doc_data: Dict[str, Any], filter_dict: Dict[str, Any]) -> bool: 308 | """Check if document matches filter criteria.""" 309 | for key, value in filter_dict.items(): 310 | if key in doc_data: 311 | doc_value = doc_data[key] 312 | if isinstance(value, list): 313 | if doc_value not in value: 314 | return False 315 | elif doc_value != value: 316 | return False 317 | elif key in doc_data.get('metadata', {}): 318 | doc_value = doc_data['metadata'][key] 319 | if isinstance(value, list): 320 | if doc_value not in value: 321 | return False 322 | elif str(doc_value) != str(value): 323 | return False 324 | else: 325 | return False 326 | 327 | return True 328 | 329 | def rebuild_index(self) -> bool: 330 | """Rebuild FAISS index to remove deleted documents.""" 331 | try: 332 | if not self.document_store: 333 | self._create_new_index() 334 | return True 335 | 336 | # Collect all valid documents 337 | chunks = [] 338 | for chunk_id, doc_data in self.document_store.items(): 339 | chunk = DocumentChunk( 340 | content=doc_data['content'], 341 | metadata=doc_data['metadata'], 342 | document_type=doc_data['document_type'], 343 | chunk_id=chunk_id, 344 | source_file=doc_data['source_file'] 345 | ) 346 | chunks.append(chunk) 347 | 348 | # Create new index 349 | self._create_new_index() 350 | 351 | # Re-add all documents 352 | return self.add_documents(chunks) 353 | 354 | except Exception as e: 355 | logger.error(f"Failed to rebuild index: {str(e)}") 356 | return False -------------------------------------------------------------------------------- /multimodal_rag/vector_stores/chroma_store.py: -------------------------------------------------------------------------------- 1 | """ 2 | ChromaDB vector store implementation. 3 | """ 4 | 5 | import logging 6 | import time 7 | import uuid 8 | from typing import List, Dict, Any, Optional 9 | 10 | try: 11 | import chromadb 12 | from chromadb.config import Settings 13 | import ollama 14 | import requests 15 | except ImportError: 16 | chromadb = None 17 | ollama = None 18 | 19 | from ..base import BaseVectorStore, DocumentChunk, RetrievalResult 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class OllamaEmbeddingFunction: 25 | """Custom embedding function for Ollama embeddings.""" 26 | 27 | def __init__(self, model_name: str = "nomic-embed-text", host: str = "http://localhost:11434"): 28 | self.model_name = model_name 29 | self.host = host 30 | if ollama is None: 31 | raise ImportError("Ollama not available. Install with: pip install ollama") 32 | 33 | # Check if Ollama server is running 34 | self._check_ollama_connection() 35 | 36 | # Ensure the embedding model is pulled 37 | self._ensure_model_available() 38 | 39 | def name(self) -> str: 40 | """Return the name of this embedding function (required by ChromaDB).""" 41 | return f"ollama-{self.model_name}" 42 | 43 | def _check_ollama_connection(self): 44 | """Check if Ollama server is running.""" 45 | try: 46 | response = requests.get(f"{self.host}/api/tags", timeout=5) 47 | if response.status_code == 200: 48 | logger.info(f"Ollama server is running at {self.host}") 49 | else: 50 | raise ConnectionError(f"Ollama server returned status {response.status_code}") 51 | except requests.exceptions.RequestException as e: 52 | raise ConnectionError(f"Cannot connect to Ollama server at {self.host}. Make sure Ollama is running. Error: {e}") 53 | 54 | def _ensure_model_available(self): 55 | """Ensure the embedding model is available in Ollama.""" 56 | try: 57 | # Check if model exists 58 | models_response = ollama.list() 59 | # Extract model names from the response objects 60 | model_names = [model.model for model in models_response.models] 61 | 62 | # Check for exact match or with :latest tag 63 | model_found = False 64 | actual_model_name = self.model_name 65 | 66 | if self.model_name in model_names: 67 | model_found = True 68 | actual_model_name = self.model_name 69 | elif f"{self.model_name}:latest" in model_names: 70 | model_found = True 71 | actual_model_name = f"{self.model_name}:latest" 72 | elif any(self.model_name in name for name in model_names): 73 | # Find the actual model name 74 | matching_models = [name for name in model_names if self.model_name in name] 75 | actual_model_name = matching_models[0] 76 | model_found = True 77 | 78 | if not model_found: 79 | logger.info(f"Model {self.model_name} not found. Pulling model...") 80 | ollama.pull(self.model_name) 81 | logger.info(f"Successfully pulled model {self.model_name}") 82 | actual_model_name = self.model_name 83 | else: 84 | logger.info(f"Model {actual_model_name} is available") 85 | 86 | # Update the model name to use the actual available name 87 | self.model_name = actual_model_name 88 | 89 | except Exception as e: 90 | logger.error(f"Error checking/pulling model {self.model_name}: {e}") 91 | raise 92 | 93 | def __call__(self, input: List[str]) -> List[List[float]]: 94 | """Generate embeddings for input texts (ChromaDB interface).""" 95 | try: 96 | embeddings = [] 97 | 98 | for text in input: 99 | # Generate embedding for each text 100 | response = ollama.embeddings( 101 | model=self.model_name, 102 | prompt=text 103 | ) 104 | 105 | embeddings.append(response['embedding']) 106 | 107 | logger.debug(f"Generated embeddings for {len(input)} texts using Ollama") 108 | return embeddings 109 | 110 | except Exception as e: 111 | logger.error(f"Error generating Ollama embeddings: {str(e)}") 112 | raise 113 | 114 | def embed_query(self, *args, **kwargs) -> List[float]: 115 | """Generate embedding for a single query (ChromaDB interface).""" 116 | try: 117 | # Debug what ChromaDB is calling us with 118 | logger.debug(f"embed_query called with args: {args}, kwargs: {kwargs}") 119 | 120 | # Extract query from arguments 121 | if args: 122 | query = args[0] 123 | elif 'query' in kwargs: 124 | query = kwargs['query'] 125 | elif 'input' in kwargs: 126 | query = kwargs['input'] 127 | else: 128 | raise ValueError("No query provided to embed_query") 129 | 130 | # Handle both single string and list inputs for compatibility 131 | if isinstance(query, list): 132 | query = query[0] if query else "" 133 | 134 | response = ollama.embeddings( 135 | model=self.model_name, 136 | prompt=str(query) 137 | ) 138 | embedding = response['embedding'] 139 | logger.debug(f"Generated embedding type: {type(embedding)}, length: {len(embedding) if hasattr(embedding, '__len__') else 'N/A'}") 140 | return embedding 141 | except Exception as e: 142 | logger.error(f"Error generating Ollama query embedding: {str(e)}") 143 | raise 144 | 145 | 146 | class ChromaVectorStore(BaseVectorStore): 147 | """ChromaDB implementation of vector store.""" 148 | 149 | def __init__(self, config: Dict[str, Any]): 150 | super().__init__(config) 151 | 152 | if chromadb is None: 153 | raise ImportError("ChromaDB not available. Install with: pip install chromadb") 154 | 155 | if ollama is None: 156 | raise ImportError("Ollama not available. Install with: pip install ollama") 157 | 158 | self.persist_directory = config.get('persist_directory', './vector_db') 159 | self.embedding_function_name = config.get('embedding_model', 'nomic-embed-text') 160 | 161 | # Initialize ChromaDB client 162 | self.client = None 163 | self.collection = None 164 | self.embedding_function = None 165 | 166 | self._initialize_client() 167 | self._initialize_collection() 168 | 169 | def _initialize_client(self): 170 | """Initialize ChromaDB client.""" 171 | try: 172 | settings = Settings( 173 | persist_directory=self.persist_directory, 174 | anonymized_telemetry=False 175 | ) 176 | 177 | self.client = chromadb.PersistentClient( 178 | path=self.persist_directory, 179 | settings=settings 180 | ) 181 | 182 | logger.info(f"Initialized ChromaDB client with persist directory: {self.persist_directory}") 183 | 184 | except Exception as e: 185 | logger.error(f"Failed to initialize ChromaDB client: {str(e)}") 186 | raise 187 | 188 | def _initialize_collection(self): 189 | """Initialize or get collection.""" 190 | try: 191 | # Always use Ollama embedding function to avoid confusion 192 | # Extract Ollama host from config if available 193 | ollama_host = self.config.get('ollama_host', 'http://localhost:11434') 194 | self.embedding_function = OllamaEmbeddingFunction( 195 | model_name=self.embedding_function_name, 196 | host=ollama_host 197 | ) 198 | logger.info(f"Using Ollama embedding function: {self.embedding_function_name}") 199 | logger.info("Note: System configured to use ONLY Ollama embeddings to prevent embedding confusion") 200 | 201 | # Get or create collection 202 | try: 203 | self.collection = self.client.get_collection( 204 | name=self.collection_name, 205 | embedding_function=self.embedding_function 206 | ) 207 | logger.info(f"Retrieved existing collection: {self.collection_name}") 208 | except Exception: 209 | self.collection = self.client.create_collection( 210 | name=self.collection_name, 211 | embedding_function=self.embedding_function, 212 | metadata={"hnsw:space": "cosine"} 213 | ) 214 | logger.info(f"Created new collection: {self.collection_name}") 215 | 216 | except Exception as e: 217 | logger.error(f"Failed to initialize collection: {str(e)}") 218 | raise 219 | 220 | def add_documents(self, chunks: List[DocumentChunk]) -> bool: 221 | """Add document chunks to the vector store.""" 222 | try: 223 | if not chunks: 224 | return True 225 | 226 | # Prepare data for ChromaDB 227 | ids = [] 228 | documents = [] 229 | metadatas = [] 230 | 231 | for chunk in chunks: 232 | # Use existing chunk_id or generate new one 233 | chunk_id = chunk.chunk_id or str(uuid.uuid4()) 234 | ids.append(chunk_id) 235 | documents.append(chunk.content) 236 | 237 | # Prepare metadata (ChromaDB requires dict with string values) 238 | metadata = self._prepare_metadata(chunk.metadata) 239 | metadata.update({ 240 | 'document_type': chunk.document_type, 241 | 'source_file': chunk.source_file or 'unknown', 242 | 'timestamp': chunk.timestamp.isoformat() if chunk.timestamp else None 243 | }) 244 | 245 | # Remove None values 246 | metadata = {k: v for k, v in metadata.items() if v is not None} 247 | metadatas.append(metadata) 248 | 249 | # Add to collection 250 | self.collection.add( 251 | ids=ids, 252 | documents=documents, 253 | metadatas=metadatas 254 | ) 255 | 256 | logger.info(f"Added {len(chunks)} chunks to ChromaDB collection") 257 | return True 258 | 259 | except Exception as e: 260 | logger.error(f"Failed to add documents to ChromaDB: {str(e)}") 261 | return False 262 | 263 | def similarity_search(self, query: str, k: int = 5, 264 | filter_dict: Optional[Dict[str, Any]] = None) -> RetrievalResult: 265 | """Perform similarity search and return results.""" 266 | start_time = time.time() 267 | 268 | try: 269 | # Prepare where clause for filtering 270 | where_clause = None 271 | if filter_dict: 272 | where_clause = self._prepare_where_clause(filter_dict) 273 | 274 | # Generate query embedding manually to avoid ChromaDB interface issues 275 | try: 276 | query_embedding = self.embedding_function.embed_query(query) 277 | logger.debug(f"Generated query embedding length: {len(query_embedding)}") 278 | 279 | # Perform query with precomputed embedding 280 | results = self.collection.query( 281 | query_embeddings=[query_embedding], 282 | n_results=k, 283 | where=where_clause, 284 | include=['documents', 'metadatas', 'distances'] 285 | ) 286 | except Exception as e: 287 | logger.error(f"Error with manual embedding, trying query_texts: {str(e)}") 288 | # Fallback to query_texts if manual embedding fails 289 | results = self.collection.query( 290 | query_texts=[query], 291 | n_results=k, 292 | where=where_clause, 293 | include=['documents', 'metadatas', 'distances'] 294 | ) 295 | 296 | # Convert results to DocumentChunks 297 | chunks = [] 298 | scores = [] 299 | 300 | if results['documents'] and results['documents'][0]: 301 | for i, (doc, metadata, distance) in enumerate(zip( 302 | results['documents'][0], 303 | results['metadatas'][0], 304 | results['distances'][0] 305 | )): 306 | # Convert distance to similarity score (ChromaDB uses cosine distance) 307 | similarity_score = 1 - distance 308 | scores.append(similarity_score) 309 | 310 | # Create DocumentChunk 311 | chunk = DocumentChunk( 312 | content=doc, 313 | metadata=metadata, 314 | document_type=metadata.get('document_type', 'unknown'), 315 | chunk_id=results['ids'][0][i] if results['ids'] else str(uuid.uuid4()), 316 | source_file=metadata.get('source_file') 317 | ) 318 | chunks.append(chunk) 319 | 320 | retrieval_time = time.time() - start_time 321 | 322 | return RetrievalResult( 323 | chunks=chunks, 324 | scores=scores, 325 | query=query, 326 | total_results=len(chunks), 327 | retrieval_time=retrieval_time 328 | ) 329 | 330 | except Exception as e: 331 | logger.error(f"Similarity search failed: {str(e)}") 332 | return RetrievalResult( 333 | chunks=[], 334 | scores=[], 335 | query=query, 336 | total_results=0, 337 | retrieval_time=time.time() - start_time 338 | ) 339 | 340 | def delete_documents(self, chunk_ids: List[str]) -> bool: 341 | """Delete documents by chunk IDs.""" 342 | try: 343 | if not chunk_ids: 344 | return True 345 | 346 | self.collection.delete(ids=chunk_ids) 347 | logger.info(f"Deleted {len(chunk_ids)} documents from ChromaDB") 348 | return True 349 | 350 | except Exception as e: 351 | logger.error(f"Failed to delete documents: {str(e)}") 352 | return False 353 | 354 | def get_collection_stats(self) -> Dict[str, Any]: 355 | """Get statistics about the collection.""" 356 | try: 357 | count = self.collection.count() 358 | 359 | return { 360 | 'collection_name': self.collection_name, 361 | 'document_count': count, 362 | 'embedding_function': self.embedding_function_name, 363 | 'persist_directory': self.persist_directory 364 | } 365 | 366 | except Exception as e: 367 | logger.error(f"Failed to get collection stats: {str(e)}") 368 | return {} 369 | 370 | def _prepare_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: 371 | """Prepare metadata for ChromaDB (convert to string values).""" 372 | prepared = {} 373 | 374 | for key, value in metadata.items(): 375 | if value is None: 376 | continue 377 | elif isinstance(value, (str, int, float, bool)): 378 | prepared[key] = str(value) 379 | elif isinstance(value, (list, dict)): 380 | prepared[key] = str(value) 381 | else: 382 | prepared[key] = str(value) 383 | 384 | return prepared 385 | 386 | def _prepare_where_clause(self, filter_dict: Dict[str, Any]) -> Dict[str, Any]: 387 | """Prepare where clause for ChromaDB filtering.""" 388 | where_clause = {} 389 | 390 | for key, value in filter_dict.items(): 391 | if isinstance(value, str): 392 | where_clause[key] = value 393 | elif isinstance(value, list): 394 | where_clause[key] = {"$in": value} 395 | else: 396 | where_clause[key] = str(value) 397 | 398 | return where_clause 399 | 400 | def clear_collection(self) -> bool: 401 | """Clear all documents from the collection.""" 402 | try: 403 | # Get all document IDs 404 | results = self.collection.get() 405 | if results['ids']: 406 | self.collection.delete(ids=results['ids']) 407 | 408 | logger.info(f"Cleared collection: {self.collection_name}") 409 | return True 410 | 411 | except Exception as e: 412 | logger.error(f"Failed to clear collection: {str(e)}") 413 | return False -------------------------------------------------------------------------------- /multimodal_rag/system.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main multimodal RAG system implementation. 3 | Simple traditional system using Ollama + basic processors. 4 | """ 5 | 6 | import logging 7 | import time 8 | from pathlib import Path 9 | from typing import Dict, Any, List, Optional, Union 10 | 11 | # Import new config schema 12 | try: 13 | from config_schema import SmartRAGConfig, ConfigLoader, load_config 14 | USE_NEW_CONFIG = True 15 | except ImportError: 16 | USE_NEW_CONFIG = False 17 | logging.warning("config_schema not found, using legacy config loading") 18 | 19 | from .base import ( 20 | QueryRequest, QueryResponse, DocumentChunk, ProcessingResult, OllamaLLM 21 | ) 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class SimpleRAGSystem: 27 | """Simple RAG system using available components.""" 28 | 29 | def __init__(self, config: Union[Dict[str, Any], SmartRAGConfig]): 30 | # Handle both new SmartRAGConfig and legacy dict 31 | if USE_NEW_CONFIG and isinstance(config, SmartRAGConfig): 32 | self.config = config.to_dict() 33 | self._typed_config = config 34 | else: 35 | self.config = config 36 | self._typed_config = None 37 | 38 | logger.info("Initializing Simple RAG System...") 39 | 40 | try: 41 | # Initialize LLM 42 | self.llm = OllamaLLM(self.config) 43 | logger.info("LLM initialized") 44 | 45 | # Initialize processors 46 | from .processors import DocumentProcessorManager, ImageProcessorManager, AudioProcessorManager 47 | self.document_processor = DocumentProcessorManager(self.config) 48 | self.image_processor = ImageProcessorManager(self.config) 49 | self.audio_processor = AudioProcessorManager(self.config) 50 | logger.info("Processors initialized") 51 | 52 | # Initialize vector store 53 | from .vector_stores.chroma_store import ChromaVectorStore 54 | self.vector_store = ChromaVectorStore(config) 55 | logger.info("Vector store initialized") 56 | 57 | logger.info("Simple RAG System initialized successfully") 58 | 59 | except Exception as e: 60 | logger.error(f"Failed to initialize Simple RAG System: {e}") 61 | self.llm = None 62 | self.document_processor = None 63 | self.image_processor = None 64 | self.audio_processor = None 65 | self.vector_store = None 66 | 67 | def is_available(self) -> bool: 68 | """Check if the system is available.""" 69 | return (self.llm is not None and 70 | hasattr(self.llm, 'is_available') and 71 | self.llm.is_available()) 72 | 73 | def ingest_file(self, file_path: Union[str, Path]) -> ProcessingResult: 74 | """Ingest a file into the system.""" 75 | if not self.is_available(): 76 | return ProcessingResult( 77 | chunks=[], 78 | success=False, 79 | error_message="System not available" 80 | ) 81 | 82 | try: 83 | file_path = Path(file_path) 84 | logger.info(f"Ingesting file: {file_path}") 85 | 86 | # Determine processor based on file type 87 | if self.document_processor and self.document_processor.can_process(file_path): 88 | result = self.document_processor.process_file(file_path) 89 | elif self.image_processor and self.image_processor.can_process(file_path): 90 | result = self.image_processor.extract_content(file_path) 91 | elif self.audio_processor and self.audio_processor.can_process(file_path): 92 | result = self.audio_processor.extract_content(file_path) 93 | else: 94 | return ProcessingResult( 95 | chunks=[], 96 | success=False, 97 | error_message=f"Unsupported file type: {file_path.suffix}" 98 | ) 99 | 100 | # Store chunks in vector database 101 | if result.success and result.chunks and self.vector_store: 102 | self.vector_store.add_documents(result.chunks) 103 | logger.info(f"Added {len(result.chunks)} chunks to vector store") 104 | 105 | return result 106 | 107 | except Exception as e: 108 | logger.error(f"Error ingesting file {file_path}: {e}") 109 | return ProcessingResult( 110 | chunks=[], 111 | success=False, 112 | error_message=str(e) 113 | ) 114 | 115 | def query(self, query: Union[str, QueryRequest]) -> QueryResponse: 116 | """Process a query and return response.""" 117 | if not self.is_available(): 118 | return QueryResponse( 119 | answer="System not available", 120 | sources=[], 121 | query=query if isinstance(query, str) else query.query, 122 | confidence=0.0 123 | ) 124 | 125 | try: 126 | if isinstance(query, str): 127 | query_text = query 128 | query_obj = QueryRequest(query) 129 | else: 130 | query_text = query.query 131 | query_obj = query 132 | 133 | logger.info(f"Processing query: {query_text}") 134 | 135 | # Check if this is a conversational query 136 | is_conversational = self._is_conversational_query(query_text) 137 | 138 | context = "" 139 | sources = [] 140 | 141 | # Only retrieve documents for non-conversational queries 142 | if not is_conversational and self.vector_store: 143 | retrieval_result = self.vector_store.similarity_search( 144 | query_text, 145 | k=self.config.get('retrieval', {}).get('top_k', 5) 146 | ) 147 | 148 | # Filter to only truly relevant chunks 149 | relevant_chunks = self._filter_relevant_context(retrieval_result.chunks, query_text) 150 | 151 | # Generate context from relevant chunks 152 | if relevant_chunks: 153 | context = self._build_context(relevant_chunks) 154 | sources = relevant_chunks 155 | 156 | logger.info(f"Retrieved {len(relevant_chunks)} relevant chunks for query") 157 | else: 158 | logger.info(f"Treating as conversational query: {query_text}") 159 | 160 | # Generate response using LLM with appropriate prompting 161 | response = self._generate_contextual_response(query_text, context, is_conversational, **query_obj.generation_params) 162 | 163 | return QueryResponse( 164 | answer=response, 165 | sources=sources, 166 | query=query_text, 167 | confidence_score=0.8 if sources else 0.9 # Higher confidence for conversational 168 | ) 169 | 170 | except Exception as e: 171 | logger.error(f"Error processing query: {e}") 172 | return QueryResponse( 173 | answer=f"Error processing query: {str(e)}", 174 | sources=[], 175 | query=query_text if 'query_text' in locals() else str(query), 176 | confidence_score=0.0 177 | ) 178 | 179 | def _is_conversational_query(self, query_text: str) -> bool: 180 | """Determine if a query is conversational vs document-based.""" 181 | conversational_patterns = [ 182 | 'hi', 'hello', 'hey', 'good morning', 'good afternoon', 'good evening', 183 | 'how are you', 'what is your name', 'who are you', 'thanks', 'thank you', 184 | 'bye', 'goodbye', 'see you', 'nice to meet you', 'how do you do' 185 | ] 186 | 187 | query_lower = query_text.lower().strip() 188 | 189 | # Check for exact matches or if query starts with conversational patterns 190 | for pattern in conversational_patterns: 191 | if query_lower == pattern or query_lower.startswith(pattern): 192 | return True 193 | 194 | # Check if query is very short and likely conversational 195 | if len(query_lower) <= 10 and not any(c in query_lower for c in ['?', 'what', 'how', 'when', 'where', 'why']): 196 | return True 197 | 198 | return False 199 | 200 | def _filter_relevant_context(self, chunks: List[DocumentChunk], query_text: str, min_score: float = 0.1) -> List[DocumentChunk]: 201 | """Filter chunks to only include truly relevant ones.""" 202 | if not chunks: 203 | return [] 204 | 205 | # For now, use a simple filtering approach 206 | # In a real implementation, you'd use semantic similarity scores 207 | query_lower = query_text.lower() 208 | relevant_chunks = [] 209 | 210 | for chunk in chunks[:3]: # Limit to top 3 most relevant chunks 211 | # Simple relevance check - in production, use actual similarity scores 212 | chunk_lower = chunk.content.lower() 213 | if len(chunk_lower) > 50: # Skip very short chunks 214 | relevant_chunks.append(chunk) 215 | 216 | return relevant_chunks 217 | 218 | def _build_context(self, chunks: List[DocumentChunk]) -> str: 219 | """Build context string from retrieved chunks with source attribution.""" 220 | context_parts = [] 221 | for i, chunk in enumerate(chunks): 222 | # Extract source information 223 | source_file = chunk.source_file or chunk.metadata.get('filename', 'Unknown Document') 224 | if source_file and source_file != 'Unknown Document': 225 | # Get just the filename without path 226 | source_name = source_file.split('/')[-1].split('\\')[-1] 227 | else: 228 | source_name = chunk.metadata.get('filename', 'Unknown Document') 229 | 230 | # Include page number if available 231 | page_info = "" 232 | if chunk.page_number: 233 | page_info = f" (Page {chunk.page_number})" 234 | elif 'page_number' in chunk.metadata: 235 | page_info = f" (Page {chunk.metadata['page_number']})" 236 | 237 | # Build context entry with source attribution 238 | context_parts.append(f"[{i+1}] From '{source_name}'{page_info}:\n{chunk.content}") 239 | return "\n\n".join(context_parts) 240 | 241 | def _generate_contextual_response(self, query_text: str, context: str, is_conversational: bool, **kwargs) -> str: 242 | """Generate response with appropriate prompting based on query type.""" 243 | if is_conversational: 244 | # For conversational queries, use simple, friendly prompting 245 | system_prompt = "You are a helpful AI assistant. Respond naturally and conversationally to greetings and casual interactions. Keep responses brief and friendly." 246 | full_prompt = f"{system_prompt}\n\nUser: {query_text}\nAssistant:" 247 | else: 248 | # For document queries, use RAG-style prompting 249 | if context and context.strip(): 250 | full_prompt = f"""You are a helpful AI assistant. Use the provided context to answer the user's question. Always mention the specific document name(s) where you found the information. If the context doesn't contain relevant information, say so and provide what general help you can. 251 | 252 | Context from documents: 253 | {context} 254 | 255 | Question: {query_text} 256 | 257 | Answer (include specific document names in your response):""" 258 | else: 259 | full_prompt = f"""You are a helpful AI assistant. The user asked a question but no relevant documents were found. Provide a helpful general response. 260 | 261 | Question: {query_text} 262 | 263 | Answer:""" 264 | 265 | # Generate response using LLM 266 | response = self.llm.generate_response( 267 | prompt=full_prompt, 268 | context="", # Context already included in prompt 269 | **kwargs 270 | ) 271 | 272 | return response 273 | 274 | def get_system_status(self) -> Dict[str, Any]: 275 | """Get current system status.""" 276 | try: 277 | return { 278 | 'system_type': 'simple_traditional', 279 | 'llm_available': self.is_available(), 280 | 'processors_available': { 281 | 'documents': self.document_processor is not None, 282 | 'images': self.image_processor is not None, 283 | 'audio': self.audio_processor is not None 284 | }, 285 | 'vector_store_available': self.vector_store is not None, 286 | 'model_name': self.config.get('models', {}).get('llm_model', 'unknown') 287 | } 288 | except Exception as e: 289 | logger.error(f"Error getting system status: {e}") 290 | return { 291 | 'system_type': 'simple_traditional', 292 | 'error': str(e), 293 | 'llm_available': False 294 | } 295 | 296 | 297 | class MultimodalRAGSystem: 298 | """Main multimodal RAG system - simplified version with new config system.""" 299 | 300 | def __init__( 301 | self, 302 | config_path: Optional[Union[str, Path]] = None, 303 | config_dict: Optional[Dict[str, Any]] = None, 304 | **overrides 305 | ): 306 | """ 307 | Initialize the multimodal RAG system. 308 | 309 | Args: 310 | config_path: Path to config.yaml file 311 | config_dict: Dictionary with config values (legacy) 312 | **overrides: Explicit config overrides (e.g., models__llm_model="llama2:7b") 313 | """ 314 | 315 | # Load configuration using new single source of truth system 316 | if USE_NEW_CONFIG: 317 | try: 318 | # Use new config loader with priority chain 319 | typed_config = load_config(config_path=config_path, **overrides) 320 | config = typed_config.to_dict() 321 | logger.info("✅ Using validated configuration schema") 322 | except Exception as e: 323 | logger.warning(f"Failed to load with new config system: {e}, falling back to legacy") 324 | config = self._load_config_legacy(config_path, config_dict) 325 | else: 326 | # Fallback to legacy config loading 327 | config = self._load_config_legacy(config_path, config_dict) 328 | 329 | # Try enhanced system first, fallback to simple system 330 | logger.info("Initializing multimodal RAG system...") 331 | try: 332 | from .enhanced_system import EnhancedMultimodalRAGSystem 333 | self._system = EnhancedMultimodalRAGSystem(config) 334 | if self._system.is_available(): 335 | self.system_type = "traditional" 336 | logger.info("Enhanced Multimodal RAG System initialized successfully") 337 | else: 338 | logger.warning("Enhanced system initialized but LLM not available") 339 | self._system = None 340 | self.system_type = "none" 341 | except Exception as e: 342 | logger.error(f"Failed to initialize Enhanced system: {e}") 343 | try: 344 | # Fallback to simple system 345 | self._system = SimpleRAGSystem(config) 346 | if self._system.is_available(): 347 | self.system_type = "traditional" 348 | logger.info("Simple RAG System initialized as fallback") 349 | else: 350 | logger.error("Simple RAG System not available - LLM connection failed") 351 | self._system = None 352 | self.system_type = "none" 353 | except Exception as e2: 354 | logger.error(f"Failed to initialize fallback system: {e2}") 355 | self._system = None 356 | self.system_type = "none" 357 | 358 | def _load_config_legacy( 359 | self, 360 | config_path: Optional[Union[str, Path]], 361 | config_dict: Optional[Dict[str, Any]] 362 | ) -> Dict[str, Any]: 363 | """Legacy configuration loading (backward compatibility).""" 364 | import yaml 365 | 366 | if config_dict: 367 | return config_dict 368 | elif config_path: 369 | return self._load_config(config_path) 370 | else: 371 | return self._get_default_config() 372 | 373 | def _load_config(self, config_path: Union[str, Path]) -> Dict[str, Any]: 374 | """Load configuration from YAML file (legacy method).""" 375 | import yaml 376 | 377 | try: 378 | with open(config_path, 'r', encoding='utf-8') as file: 379 | config = yaml.safe_load(file) 380 | logger.info(f"Loaded configuration from {config_path}") 381 | return config 382 | except Exception as e: 383 | logger.error(f"Failed to load config from {config_path}: {str(e)}") 384 | return self._get_default_config() 385 | 386 | def _get_default_config(self) -> Dict[str, Any]: 387 | """Get default configuration matching config.yaml defaults.""" 388 | return { 389 | 'system': { 390 | 'name': 'SmartRAG System', 391 | 'offline_mode': True, 392 | 'debug': False, 393 | 'log_level': 'INFO' 394 | }, 395 | 'models': { 396 | 'llm_type': 'ollama', 397 | 'llm_model': 'llama3.1:8b', 398 | 'ollama_host': 'http://localhost:11434', 399 | 'embedding_model': 'nomic-embed-text', 400 | 'embedding_dimension': 768, 401 | 'vision_model': 'Salesforce/blip-image-captioning-base', 402 | 'whisper_model': 'base' 403 | }, 404 | 'vector_store': { 405 | 'type': 'chromadb', 406 | 'persist_directory': './vector_db', 407 | 'collection_name': 'multimodal_documents', 408 | 'embedding_dimension': 768, 409 | 'ollama_host': 'http://localhost:11434' 410 | }, 411 | 'processing': { 412 | 'chunk_size': 1000, 413 | 'chunk_overlap': 200, 414 | 'max_image_size': [1024, 1024], 415 | 'ocr_enabled': True, 416 | 'batch_size': 32, 417 | 'store_original_images': True, 418 | 'image_preprocessing': 'resize', 419 | 'audio_sample_rate': 16000, 420 | 'max_audio_duration': 300 421 | }, 422 | 'retrieval': { 423 | 'top_k': 5, 424 | 'similarity_threshold': 0.7, 425 | 'rerank_enabled': False 426 | }, 427 | 'generation': { 428 | 'max_tokens': 2048, 429 | 'temperature': 0.7, 430 | 'top_p': 0.9, 431 | 'top_k': 50, 432 | 'do_sample': True, 433 | 'max_new_tokens': 1024 434 | } 435 | } 436 | 437 | def _initialize_traditional_system(self, config: Dict[str, Any]): 438 | """Initialize traditional system as fallback.""" 439 | # Implementation for traditional system fallback 440 | self.config = config 441 | self._system = None # Placeholder for traditional system 442 | logger.warning("Traditional system fallback not fully implemented") 443 | 444 | # Delegate methods to the underlying system 445 | def ingest_file(self, file_path: Union[str, Path]) -> ProcessingResult: 446 | """Ingest file using the active system.""" 447 | if self._system: 448 | return self._system.ingest_file(file_path) 449 | else: 450 | logger.error("No active system available") 451 | return ProcessingResult(chunks=[], success=False, error_message="No active system") 452 | 453 | def is_available(self) -> bool: 454 | """Check if the system is available and ready.""" 455 | return self._system is not None and self._system.is_available() 456 | 457 | def query(self, query: Union[str, QueryRequest]) -> QueryResponse: 458 | """Query using the active system.""" 459 | if not self._system: 460 | logger.error("No active system available") 461 | return QueryResponse( 462 | answer="No active system available. Please check Ollama is running and llama3.1:8b model is available.", 463 | sources=[], 464 | query=str(query), 465 | confidence=0.0 466 | ) 467 | 468 | # Use the simple system query method 469 | return self._system.query(query) 470 | 471 | def get_system_stats(self) -> Dict[str, Any]: 472 | """Get system statistics.""" 473 | if self._system and hasattr(self._system, 'get_system_status'): 474 | stats = self._system.get_system_status() 475 | stats['wrapper_system_type'] = self.system_type 476 | return stats 477 | else: 478 | return { 479 | 'wrapper_system_type': self.system_type, 480 | 'active_system': None, 481 | 'error': 'No active system' 482 | } 483 | 484 | 485 | # Backward compatibility aliases 486 | MultimodalRAG = MultimodalRAGSystem -------------------------------------------------------------------------------- /config_schema.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration Schema and Management Pydantic validation 3 | """ 4 | import os 5 | import yaml 6 | from pathlib import Path 7 | from typing import Optional, List, Dict, Any, Union 8 | from pydantic import BaseModel, Field, field_validator, model_validator 9 | from enum import Enum 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class LLMType(str, Enum): 16 | """Supported LLM types""" 17 | OLLAMA = "ollama" 18 | OPENAI = "openai" 19 | HUGGINGFACE = "huggingface" 20 | 21 | 22 | class VectorStoreType(str, Enum): 23 | """Supported vector store types""" 24 | CHROMADB = "chromadb" 25 | FAISS = "faiss" 26 | 27 | 28 | class ImagePreprocessing(str, Enum): 29 | """Image preprocessing strategies""" 30 | RESIZE = "resize" 31 | CROP = "crop" 32 | NONE = "none" 33 | 34 | 35 | # ============================================================================ 36 | # Configuration Domain Models 37 | # ============================================================================ 38 | 39 | class SystemConfig(BaseModel): 40 | """System-level configuration""" 41 | name: str = Field(default="SmartRAG System", description="System name") 42 | version: str = Field(default="2.0.0", description="System version") 43 | offline_mode: bool = Field(default=True, description="Run in offline mode") 44 | debug: bool = Field(default=False, description="Enable debug logging") 45 | log_level: str = Field(default="INFO", description="Logging level") 46 | 47 | @field_validator('log_level') 48 | @classmethod 49 | def validate_log_level(cls, v: str) -> str: 50 | valid_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] 51 | v_upper = v.upper() 52 | if v_upper not in valid_levels: 53 | raise ValueError(f"log_level must be one of {valid_levels}, got '{v}'") 54 | return v_upper 55 | 56 | 57 | class ModelsConfig(BaseModel): 58 | """AI Models configuration""" 59 | # LLM Configuration 60 | llm_type: LLMType = Field(default=LLMType.OLLAMA, description="Type of LLM to use") 61 | llm_model: str = Field(default="llama3.1:8b", description="LLM model name") 62 | ollama_host: str = Field(default="http://localhost:11434", description="Ollama server URL") 63 | 64 | # Embedding Configuration 65 | embedding_model: str = Field(default="nomic-embed-text", description="Text embedding model") 66 | embedding_dimension: int = Field(default=768, description="Embedding vector dimension") 67 | 68 | # Vision Configuration 69 | vision_model: str = Field( 70 | default="Salesforce/blip-image-captioning-base", 71 | description="Vision/image captioning model" 72 | ) 73 | 74 | # Speech Configuration 75 | whisper_model: str = Field(default="base", description="Whisper model size for audio") 76 | 77 | @field_validator('embedding_dimension') 78 | @classmethod 79 | def validate_embedding_dim(cls, v: int) -> int: 80 | if v <= 0: 81 | raise ValueError(f"embedding_dimension must be positive, got {v}") 82 | if v > 4096: 83 | logger.warning(f"embedding_dimension {v} is very large, may impact performance") 84 | return v 85 | 86 | @field_validator('ollama_host') 87 | @classmethod 88 | def validate_ollama_host(cls, v: str) -> str: 89 | if not v.startswith(('http://', 'https://')): 90 | raise ValueError(f"ollama_host must start with http:// or https://, got '{v}'") 91 | return v.rstrip('/') # Remove trailing slash 92 | 93 | 94 | class VectorStoreConfig(BaseModel): 95 | """Vector store configuration""" 96 | type: VectorStoreType = Field(default=VectorStoreType.CHROMADB, description="Vector store type") 97 | persist_directory: Path = Field(default=Path("./vector_db"), description="Storage directory") 98 | collection_name: str = Field( 99 | default="multimodal_documents", 100 | description="Collection/index name" 101 | ) 102 | embedding_dimension: int = Field(default=768, description="Embedding dimension") 103 | ollama_host: str = Field(default="http://localhost:11434", description="Ollama host for embeddings") 104 | 105 | @field_validator('persist_directory') 106 | @classmethod 107 | def validate_persist_dir(cls, v: Path) -> Path: 108 | # Ensure path is absolute or resolve it 109 | if not v.is_absolute(): 110 | v = Path.cwd() / v 111 | return v 112 | 113 | @field_validator('collection_name') 114 | @classmethod 115 | def validate_collection_name(cls, v: str) -> str: 116 | if not v or not v.strip(): 117 | raise ValueError("collection_name cannot be empty") 118 | # Sanitize collection name (no special chars for some vector stores) 119 | if not v.replace('_', '').replace('-', '').isalnum(): 120 | raise ValueError(f"collection_name must be alphanumeric with _/-, got '{v}'") 121 | return v.strip() 122 | 123 | 124 | class ProcessingConfig(BaseModel): 125 | """Document and media processing configuration""" 126 | # Text chunking 127 | chunk_size: int = Field(default=1000, ge=100, le=10000, description="Text chunk size") 128 | chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Chunk overlap size") 129 | 130 | # Image processing 131 | max_image_size: List[int] = Field(default=[1024, 1024], description="Max image dimensions [width, height]") 132 | ocr_enabled: bool = Field(default=True, description="Enable Tesseract OCR") 133 | store_original_images: bool = Field(default=True, description="Store original image data") 134 | image_preprocessing: ImagePreprocessing = Field(default=ImagePreprocessing.RESIZE, description="Image preprocessing") 135 | 136 | # Audio processing 137 | audio_sample_rate: int = Field(default=16000, ge=8000, le=48000, description="Audio sample rate (Hz)") 138 | max_audio_duration: int = Field(default=300, ge=1, le=3600, description="Max audio length (seconds)") 139 | 140 | # Batch processing 141 | batch_size: int = Field(default=32, ge=1, le=128, description="Processing batch size") 142 | 143 | @field_validator('max_image_size') 144 | @classmethod 145 | def validate_image_size(cls, v: List[int]) -> List[int]: 146 | if len(v) != 2: 147 | raise ValueError(f"max_image_size must be [width, height], got {v}") 148 | if any(dim <= 0 or dim > 4096 for dim in v): 149 | raise ValueError(f"max_image_size dimensions must be in range [1, 4096], got {v}") 150 | return v 151 | 152 | @model_validator(mode='after') 153 | def validate_chunk_overlap(self) -> 'ProcessingConfig': 154 | if self.chunk_overlap >= self.chunk_size: 155 | raise ValueError( 156 | f"chunk_overlap ({self.chunk_overlap}) must be less than " 157 | f"chunk_size ({self.chunk_size})" 158 | ) 159 | return self 160 | 161 | 162 | class RetrievalConfig(BaseModel): 163 | """Retrieval and search configuration""" 164 | top_k: int = Field(default=5, ge=1, le=50, description="Number of results to retrieve") 165 | similarity_threshold: float = Field( 166 | default=0.7, 167 | ge=0.0, 168 | le=1.0, 169 | description="Minimum similarity score" 170 | ) 171 | rerank_enabled: bool = Field(default=False, description="Enable result reranking") 172 | 173 | @field_validator('similarity_threshold') 174 | @classmethod 175 | def validate_threshold(cls, v: float) -> float: 176 | if v < 0.0 or v > 1.0: 177 | raise ValueError(f"similarity_threshold must be in [0.0, 1.0], got {v}") 178 | return v 179 | 180 | 181 | class GenerationConfig(BaseModel): 182 | """LLM generation configuration""" 183 | max_tokens: int = Field(default=2048, ge=1, le=8192, description="Max output tokens") 184 | temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 185 | top_p: float = Field(default=0.9, ge=0.0, le=1.0, description="Nucleus sampling threshold") 186 | top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling") 187 | do_sample: bool = Field(default=True, description="Enable sampling") 188 | max_new_tokens: int = Field(default=1024, ge=1, le=4096, description="Max new tokens to generate") 189 | 190 | @model_validator(mode='after') 191 | def validate_sampling_params(self) -> 'GenerationConfig': 192 | if self.max_new_tokens > self.max_tokens: 193 | logger.warning( 194 | f"max_new_tokens ({self.max_new_tokens}) > max_tokens ({self.max_tokens}), " 195 | f"adjusting max_new_tokens to {self.max_tokens}" 196 | ) 197 | self.max_new_tokens = self.max_tokens 198 | return self 199 | 200 | 201 | class UIConfig(BaseModel): 202 | """User interface configuration""" 203 | title: str = Field(default="SmartRAG - Multimodal AI Assistant", description="App title") 204 | page_icon: str = Field(default="🤖", description="Page icon emoji") 205 | layout: str = Field(default="wide", description="Page layout") 206 | theme: str = Field(default="dark", description="UI theme") 207 | show_recent_uploads: bool = Field(default=True, description="Show recent uploads section") 208 | max_upload_size_mb: int = Field(default=200, ge=1, le=1000, description="Max file upload size (MB)") 209 | 210 | @field_validator('layout') 211 | @classmethod 212 | def validate_layout(cls, v: str) -> str: 213 | valid_layouts = ['wide', 'centered'] 214 | if v not in valid_layouts: 215 | raise ValueError(f"layout must be one of {valid_layouts}, got '{v}'") 216 | return v 217 | 218 | 219 | class StorageConfig(BaseModel): 220 | """Storage and persistence configuration""" 221 | data_directory: Path = Field(default=Path("./data"), description="Data storage directory") 222 | logs_directory: Path = Field(default=Path("./logs"), description="Logs directory") 223 | cache_directory: Path = Field(default=Path("./cache"), description="Cache directory") 224 | temp_uploads_directory: Path = Field(default=Path("./temp_uploads"), description="Temporary uploads") 225 | 226 | @field_validator('data_directory', 'logs_directory', 'cache_directory', 'temp_uploads_directory') 227 | @classmethod 228 | def validate_directory(cls, v: Path) -> Path: 229 | if not v.is_absolute(): 230 | v = Path.cwd() / v 231 | return v 232 | 233 | 234 | class SupportedFormatsConfig(BaseModel): 235 | """Supported file formats""" 236 | documents: List[str] = Field( 237 | default=[".pdf", ".docx", ".doc", ".txt", ".md", ".rtf"], 238 | description="Supported document formats" 239 | ) 240 | images: List[str] = Field( 241 | default=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"], 242 | description="Supported image formats" 243 | ) 244 | audio: List[str] = Field( 245 | default=[".mp3", ".wav", ".m4a", ".ogg", ".flac"], 246 | description="Supported audio formats" 247 | ) 248 | 249 | @field_validator('documents', 'images', 'audio') 250 | @classmethod 251 | def validate_extensions(cls, v: List[str]) -> List[str]: 252 | # Ensure all extensions start with a dot and are lowercase 253 | validated = [] 254 | for ext in v: 255 | ext = ext.strip().lower() 256 | if not ext.startswith('.'): 257 | ext = f'.{ext}' 258 | validated.append(ext) 259 | return validated 260 | 261 | 262 | # ============================================================================ 263 | # Main Configuration Model 264 | # ============================================================================ 265 | 266 | class SmartRAGConfig(BaseModel): 267 | """Complete SmartRAG configuration with all domains""" 268 | 269 | system: SystemConfig = Field(default_factory=SystemConfig) 270 | models: ModelsConfig = Field(default_factory=ModelsConfig) 271 | vector_store: VectorStoreConfig = Field(default_factory=VectorStoreConfig) 272 | processing: ProcessingConfig = Field(default_factory=ProcessingConfig) 273 | retrieval: RetrievalConfig = Field(default_factory=RetrievalConfig) 274 | generation: GenerationConfig = Field(default_factory=GenerationConfig) 275 | ui: UIConfig = Field(default_factory=UIConfig) 276 | storage: StorageConfig = Field(default_factory=StorageConfig) 277 | supported_formats: SupportedFormatsConfig = Field(default_factory=SupportedFormatsConfig) 278 | 279 | @model_validator(mode='after') 280 | def validate_cross_domain(self) -> 'SmartRAGConfig': 281 | """Cross-domain validation""" 282 | # Ensure embedding dimensions match across configs 283 | if self.models.embedding_dimension != self.vector_store.embedding_dimension: 284 | logger.warning( 285 | f"Embedding dimension mismatch: models={self.models.embedding_dimension}, " 286 | f"vector_store={self.vector_store.embedding_dimension}. " 287 | f"Using models.embedding_dimension={self.models.embedding_dimension}" 288 | ) 289 | self.vector_store.embedding_dimension = self.models.embedding_dimension 290 | 291 | # Ensure ollama_host is consistent 292 | if self.models.ollama_host != self.vector_store.ollama_host: 293 | logger.warning( 294 | f"Ollama host mismatch: models={self.models.ollama_host}, " 295 | f"vector_store={self.vector_store.ollama_host}. " 296 | f"Using models.ollama_host={self.models.ollama_host}" 297 | ) 298 | self.vector_store.ollama_host = self.models.ollama_host 299 | 300 | return self 301 | 302 | def to_dict(self) -> Dict[str, Any]: 303 | """Convert to dictionary (for backward compatibility)""" 304 | return { 305 | 'system': self.system.model_dump(), 306 | 'models': self.models.model_dump(), 307 | 'vector_store': { 308 | **self.vector_store.model_dump(), 309 | 'persist_directory': str(self.vector_store.persist_directory) 310 | }, 311 | 'processing': self.processing.model_dump(), 312 | 'retrieval': self.retrieval.model_dump(), 313 | 'generation': self.generation.model_dump(), 314 | 'ui': self.ui.model_dump(), 315 | 'storage': { 316 | k: str(v) for k, v in self.storage.model_dump().items() 317 | }, 318 | 'supported_formats': self.supported_formats.model_dump() 319 | } 320 | 321 | 322 | # ============================================================================ 323 | # Configuration Loader (Single Source of Truth) 324 | # ============================================================================ 325 | 326 | class ConfigLoader: 327 | 328 | DEFAULT_CONFIG_PATH = Path("config.yaml") 329 | ENV_PREFIX = "SMARTRAG_" 330 | 331 | @classmethod 332 | def load( 333 | cls, 334 | config_path: Optional[Union[str, Path]] = None, 335 | config_dict: Optional[Dict[str, Any]] = None, 336 | override_params: Optional[Dict[str, Any]] = None, 337 | validate: bool = True 338 | ) -> SmartRAGConfig: 339 | """ 340 | Load configuration with priority chain. 341 | 342 | Args: 343 | config_path: Path to YAML config file 344 | config_dict: Dictionary with config values 345 | override_params: Explicit overrides (highest priority) 346 | validate: Whether to validate with Pydantic schema 347 | 348 | Returns: 349 | SmartRAGConfig: Validated configuration object 350 | 351 | Raises: 352 | ValueError: If configuration is invalid 353 | """ 354 | try: 355 | # Step 1: Start with defaults 356 | config_data = cls._get_defaults() 357 | 358 | # Step 2: Load from YAML file 359 | yaml_config = cls._load_from_yaml(config_path) 360 | config_data = cls._deep_merge(config_data, yaml_config) 361 | 362 | # Step 3: Override with config_dict if provided 363 | if config_dict: 364 | config_data = cls._deep_merge(config_data, config_dict) 365 | 366 | # Step 4: Override with environment variables 367 | env_overrides = cls._load_from_env() 368 | config_data = cls._deep_merge(config_data, env_overrides) 369 | 370 | # Step 5: Apply explicit overrides (highest priority) 371 | if override_params: 372 | config_data = cls._deep_merge(config_data, override_params) 373 | 374 | # Step 6: Validate and return 375 | if validate: 376 | config = SmartRAGConfig(**config_data) 377 | logger.info("✅ Configuration loaded and validated successfully") 378 | return config 379 | else: 380 | # Return unvalidated dict (not recommended) 381 | logger.warning("⚠️ Configuration loaded without validation") 382 | return config_data 383 | 384 | except Exception as e: 385 | logger.error(f"❌ Configuration loading failed: {e}") 386 | raise ValueError(f"Failed to load configuration: {e}") from e 387 | 388 | @staticmethod 389 | def _get_defaults() -> Dict[str, Any]: 390 | """Get hardcoded defaults""" 391 | return SmartRAGConfig().to_dict() 392 | 393 | @classmethod 394 | def _load_from_yaml(cls, config_path: Optional[Union[str, Path]] = None) -> Dict[str, Any]: 395 | """Load configuration from YAML file""" 396 | if config_path is None: 397 | config_path = cls.DEFAULT_CONFIG_PATH 398 | 399 | config_path = Path(config_path) 400 | 401 | if not config_path.exists(): 402 | logger.info(f"Config file not found: {config_path}, using defaults") 403 | return {} 404 | 405 | try: 406 | with open(config_path, 'r', encoding='utf-8') as f: 407 | yaml_data = yaml.safe_load(f) 408 | logger.info(f"📄 Loaded configuration from {config_path}") 409 | return yaml_data or {} 410 | except Exception as e: 411 | logger.warning(f"Failed to load YAML from {config_path}: {e}") 412 | return {} 413 | 414 | @classmethod 415 | def _load_from_env(cls) -> Dict[str, Any]: 416 | """Load configuration overrides from environment variables""" 417 | env_config = {} 418 | 419 | # Map environment variables to config structure 420 | # Format: SMARTRAG_MODELS_LLM_MODEL=llama3.1:8b -> models.llm_model 421 | env_mappings = { 422 | f"{cls.ENV_PREFIX}DEBUG": "system.debug", 423 | f"{cls.ENV_PREFIX}LOG_LEVEL": "system.log_level", 424 | f"{cls.ENV_PREFIX}LLM_TYPE": "models.llm_type", 425 | f"{cls.ENV_PREFIX}LLM_MODEL": "models.llm_model", 426 | f"{cls.ENV_PREFIX}OLLAMA_HOST": "models.ollama_host", 427 | f"{cls.ENV_PREFIX}EMBEDDING_MODEL": "models.embedding_model", 428 | f"{cls.ENV_PREFIX}VISION_MODEL": "models.vision_model", 429 | f"{cls.ENV_PREFIX}VECTOR_STORE_TYPE": "vector_store.type", 430 | f"{cls.ENV_PREFIX}PERSIST_DIR": "vector_store.persist_directory", 431 | f"{cls.ENV_PREFIX}COLLECTION_NAME": "vector_store.collection_name", 432 | f"{cls.ENV_PREFIX}CHUNK_SIZE": "processing.chunk_size", 433 | f"{cls.ENV_PREFIX}TOP_K": "retrieval.top_k", 434 | f"{cls.ENV_PREFIX}TEMPERATURE": "generation.temperature", 435 | f"{cls.ENV_PREFIX}MAX_TOKENS": "generation.max_tokens", 436 | } 437 | 438 | for env_var, config_path in env_mappings.items(): 439 | value = os.environ.get(env_var) 440 | if value is not None: 441 | # Convert value to appropriate type 442 | value = cls._convert_env_value(value) 443 | cls._set_nested_value(env_config, config_path, value) 444 | logger.debug(f"🌍 Environment override: {env_var} -> {config_path} = {value}") 445 | 446 | return env_config 447 | 448 | @staticmethod 449 | def _convert_env_value(value: str) -> Any: 450 | """Convert string environment variable to appropriate type""" 451 | # Boolean 452 | if value.lower() in ('true', 'yes', '1'): 453 | return True 454 | if value.lower() in ('false', 'no', '0'): 455 | return False 456 | 457 | # Number 458 | try: 459 | if '.' in value: 460 | return float(value) 461 | return int(value) 462 | except ValueError: 463 | pass 464 | 465 | # String 466 | return value 467 | 468 | @staticmethod 469 | def _set_nested_value(data: Dict, path: str, value: Any): 470 | """Set nested dictionary value using dot notation path""" 471 | keys = path.split('.') 472 | current = data 473 | for key in keys[:-1]: 474 | if key not in current: 475 | current[key] = {} 476 | current = current[key] 477 | current[keys[-1]] = value 478 | 479 | @staticmethod 480 | def _deep_merge(base: Dict, override: Dict) -> Dict: 481 | """Deep merge two dictionaries, override takes precedence""" 482 | result = base.copy() 483 | for key, value in override.items(): 484 | if key in result and isinstance(result[key], dict) and isinstance(value, dict): 485 | result[key] = ConfigLoader._deep_merge(result[key], value) 486 | else: 487 | result[key] = value 488 | return result 489 | 490 | 491 | # ============================================================================ 492 | # Convenience Functions 493 | # ============================================================================ 494 | 495 | def load_config( 496 | config_path: Optional[Union[str, Path]] = None, 497 | **overrides 498 | ) -> SmartRAGConfig: 499 | """ 500 | Convenience function to load configuration. 501 | 502 | Args: 503 | config_path: Path to YAML config file 504 | **overrides: Explicit parameter overrides 505 | 506 | Returns: 507 | SmartRAGConfig: Validated configuration 508 | 509 | Example: 510 | >>> config = load_config("config.yaml", models__llm_model="llama3.1:8b") 511 | """ 512 | # Convert double underscore kwargs to nested dict 513 | override_dict = {} 514 | for key, value in overrides.items(): 515 | path = key.replace('__', '.') 516 | ConfigLoader._set_nested_value(override_dict, path, value) 517 | 518 | return ConfigLoader.load( 519 | config_path=config_path, 520 | override_params=override_dict 521 | ) 522 | 523 | 524 | def save_config(config: SmartRAGConfig, output_path: Union[str, Path]): 525 | """Save configuration to YAML file""" 526 | output_path = Path(output_path) 527 | output_path.parent.mkdir(parents=True, exist_ok=True) 528 | 529 | with open(output_path, 'w', encoding='utf-8') as f: 530 | yaml.dump(config.to_dict(), f, default_flow_style=False, sort_keys=False) 531 | 532 | logger.info(f"💾 Configuration saved to {output_path}") 533 | 534 | 535 | if __name__ == "__main__": 536 | # Example usage and testing 537 | logging.basicConfig(level=logging.INFO) 538 | 539 | print("=" * 80) 540 | print("SmartRAG Configuration System - Testing") 541 | print("=" * 80) 542 | 543 | # Test 1: Load with defaults 544 | print("\n1. Loading with defaults...") 545 | config = load_config() 546 | print(f"✓ System: {config.system.name} v{config.system.version}") 547 | print(f"✓ LLM: {config.models.llm_type.value} - {config.models.llm_model}") 548 | print(f"✓ Vector Store: {config.vector_store.type.value}") 549 | 550 | # Test 2: Load from file 551 | print("\n2. Loading from config.yaml...") 552 | try: 553 | config = load_config("config.yaml") 554 | print(f"✓ Loaded from file successfully") 555 | except Exception as e: 556 | print(f"✗ Failed: {e}") 557 | 558 | # Test 3: Test overrides 559 | print("\n3. Testing overrides...") 560 | config = load_config( 561 | models__llm_model="llama2:7b", 562 | generation__temperature=0.5, 563 | retrieval__top_k=10 564 | ) 565 | print(f"✓ LLM Model: {config.models.llm_model}") 566 | print(f"✓ Temperature: {config.generation.temperature}") 567 | print(f"✓ Top K: {config.retrieval.top_k}") 568 | 569 | # Test 4: Validation 570 | print("\n4. Testing validation...") 571 | try: 572 | bad_config = SmartRAGConfig( 573 | processing=ProcessingConfig(chunk_overlap=1500, chunk_size=1000) 574 | ) 575 | print("✗ Should have failed validation!") 576 | except Exception as e: 577 | print(f"✓ Validation works: {e}") 578 | 579 | print("\n" + "=" * 80) 580 | print("All tests completed!") 581 | --------------------------------------------------------------------------------