├── app ├── api │ ├── __init__.py │ └── chat.py ├── models │ ├── __init__.py │ └── chat.py ├── services │ ├── __init__.py │ ├── image_service.py │ ├── search_service.py │ ├── message_service.py │ └── document_service.py ├── __init__.py ├── config │ ├── __init__.py │ └── settings.py ├── core │ ├── __init__.py │ ├── logging.py │ └── auth.py └── utils │ └── file_utils.py ├── .dockerignore ├── requirements.txt ├── .gitignore ├── main.py ├── Dockerfile ├── .env.example ├── docker-compose.yml └── README.md /app/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | API routes package 3 | """ -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data models package 3 | """ -------------------------------------------------------------------------------- /app/services/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Services package 3 | """ -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ExtendAI application package 3 | """ -------------------------------------------------------------------------------- /app/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration package 3 | """ -------------------------------------------------------------------------------- /app/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core functionality package 3 | """ -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .env 2 | .env.* 3 | .git 4 | .gitignore 5 | __pycache__ 6 | *.pyc 7 | *.pyo 8 | *.pyd 9 | .Python 10 | env/ 11 | venv/ 12 | .venv/ 13 | pip-log.txt 14 | pip-delete-this-directory.txt 15 | .tox/ 16 | .coverage 17 | .coverage.* 18 | .cache 19 | nosetests.xml 20 | coverage.xml 21 | *.cover 22 | *.log 23 | .pytest_cache/ 24 | .idea/ 25 | .vscode/ 26 | *.swp 27 | *.swo 28 | *~ 29 | cache/ -------------------------------------------------------------------------------- /app/core/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | def setup_logging(): 5 | """Configure logging with a specific format and level""" 6 | logging.basicConfig( 7 | level=logging.INFO, 8 | format='%(asctime)s - %(levelname)s - %(message)s', 9 | stream=sys.stdout, 10 | force=True 11 | ) 12 | 13 | # Disable other loggers that might be noisy 14 | logging.getLogger("httpx").setLevel(logging.WARNING) 15 | logging.getLogger("httpcore").setLevel(logging.WARNING) 16 | logging.getLogger("uvicorn").setLevel(logging.WARNING) 17 | logging.getLogger("asyncio").setLevel(logging.WARNING) 18 | 19 | return logging.getLogger(__name__) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi>=0.109.0 2 | uvicorn>=0.27.0 3 | httpx>=0.26.0 4 | python-dotenv>=1.0.0 5 | langchain>=0.1.0 6 | langchain-community>=0.0.16 7 | langchain-core>=0.1.18 8 | langchain-openai>=0.0.5 9 | langchain-text-splitters>=0.0.1 10 | langchain-pinecone>=0.0.2 11 | langchain-postgres>=0.0.1 12 | faiss-cpu>=1.7.4 13 | pinecone-client>=3.0.0 14 | psycopg[binary]>=3.1.18 15 | psycopg2-binary>=2.9.9 16 | ftfy>=6.1.3 17 | python-multipart>=0.0.6 18 | beautifulsoup4>=4.12.0 19 | docx2txt>=0.8 20 | pypdf>=4.0.0 21 | openpyxl>=3.1.0 22 | python-pptx>=0.6.0 23 | unstructured>=0.11.0 24 | markdown>=3.5.0 25 | docstring-parser>=0.15 26 | tqdm>=4.66.0 27 | googlesearch-python 28 | cachetools 29 | networkx -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | venv/ 25 | ENV/ 26 | env/ 27 | .env 28 | 29 | # IDE 30 | .idea/ 31 | .vscode/ 32 | *.swp 33 | *.swo 34 | .DS_Store 35 | 36 | # Project specific 37 | cache/ 38 | *.log 39 | *.db 40 | *.sqlite3 41 | 42 | # Environment variables 43 | .env 44 | .env.local 45 | .env.*.local 46 | 47 | # Testing 48 | .coverage 49 | htmlcov/ 50 | .pytest_cache/ 51 | .tox/ 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from fastapi.middleware.cors import CORSMiddleware 3 | import os 4 | 5 | from app.api.chat import router as chat_router 6 | from app.core.logging import setup_logging 7 | from app.config.settings import PORT, HOST 8 | 9 | # Setup logging 10 | logger = setup_logging() 11 | 12 | # Create FastAPI app 13 | app = FastAPI() 14 | 15 | # Add CORS middleware 16 | app.add_middleware( 17 | CORSMiddleware, 18 | allow_origins=["*"], 19 | allow_credentials=True, 20 | allow_methods=["*"], 21 | allow_headers=["*"], 22 | ) 23 | 24 | # Include routers 25 | app.include_router(chat_router) 26 | 27 | if __name__ == "__main__": 28 | import uvicorn 29 | port = int(os.getenv("PORT", PORT)) 30 | host = os.getenv("HOST", HOST) 31 | logger.info(f"Starting server on {host}:{port}") 32 | uvicorn.run(app, host=host, port=port) -------------------------------------------------------------------------------- /app/core/auth.py: -------------------------------------------------------------------------------- 1 | from fastapi import HTTPException, Request 2 | from app.config.settings import MY_API_KEY 3 | 4 | async def verify_api_key(request: Request) -> str: 5 | """Verify the API key from the request header""" 6 | if not MY_API_KEY: 7 | raise HTTPException(status_code=500, detail="Server API key not configured") 8 | 9 | auth_header = request.headers.get("Authorization") 10 | if not auth_header: 11 | raise HTTPException(status_code=401, detail="Authorization header missing") 12 | 13 | try: 14 | scheme, token = auth_header.split() 15 | if scheme.lower() != "bearer": 16 | raise HTTPException(status_code=401, detail="Invalid authentication scheme") 17 | if token != MY_API_KEY: 18 | raise HTTPException(status_code=401, detail="Invalid API key") 19 | return token 20 | except ValueError: 21 | raise HTTPException(status_code=401, detail="Invalid authorization header format") -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use Python 3.11 as base image 2 | FROM python:3.11-slim 3 | 4 | # Set working directory 5 | WORKDIR /app 6 | 7 | # Set environment variables 8 | ENV PYTHONUNBUFFERED=1 \ 9 | PYTHONDONTWRITEBYTECODE=1 \ 10 | PIP_NO_CACHE_DIR=1 11 | 12 | # Install system dependencies 13 | RUN apt-get update && apt-get install -y --no-install-recommends \ 14 | build-essential \ 15 | curl \ 16 | git \ 17 | libpq-dev \ 18 | postgresql-client \ 19 | gcc \ 20 | python3-dev \ 21 | && rm -rf /var/lib/apt/lists/* 22 | 23 | # Create cache directory 24 | RUN mkdir -p /app/cache && chmod 777 /app/cache 25 | 26 | # Install Python dependencies 27 | COPY requirements.txt . 28 | RUN pip install --no-cache-dir -r requirements.txt 29 | 30 | # Copy application code 31 | COPY . . 32 | 33 | # Create a non-root user 34 | RUN useradd -m -u 1000 extendai && \ 35 | chown -R extendai:extendai /app 36 | USER extendai 37 | 38 | # Expose port 39 | EXPOSE 8096 40 | 41 | # Set default command 42 | CMD ["python", "main.py"] -------------------------------------------------------------------------------- /app/models/chat.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, HttpUrl 2 | from typing import List, Optional, Union, Literal, Dict, Any 3 | from enum import Enum 4 | 5 | class Role(str, Enum): 6 | SYSTEM = "system" 7 | USER = "user" 8 | ASSISTANT = "assistant" 9 | 10 | class ContentType(str, Enum): 11 | TEXT = "text" 12 | IMAGE_URL = "image_url" 13 | DOCUMENT = "document" 14 | 15 | class ImageUrl(BaseModel): 16 | url: str 17 | 18 | class Document(BaseModel): 19 | """Document model for storing document content and metadata""" 20 | page_content: str 21 | metadata: Dict[str, Any] 22 | id: Optional[str] = None 23 | 24 | class MessageContent(BaseModel): 25 | type: ContentType 26 | text: Optional[str] = None 27 | image_url: Optional[ImageUrl] = None 28 | document: Optional[Document] = None 29 | 30 | class Message(BaseModel): 31 | role: Role 32 | content: Union[str, List[MessageContent]] 33 | 34 | class ChatRequest(BaseModel): 35 | messages: List[Message] 36 | model: str 37 | stream: bool = False 38 | 39 | class SearchAnalysis(BaseModel): 40 | """Schema for analyzing whether search context is needed""" 41 | needs_search: bool 42 | search_keywords: Optional[List[str]] = None 43 | 44 | class SearchResult(BaseModel): 45 | """Schema for search results""" 46 | url: str 47 | title: str 48 | content: str -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Document Processing Settings 2 | CHUNK_SIZE=1000 3 | CHUNK_OVERLAP=200 4 | MAX_CHUNKS_PER_DOC=5 5 | EMBEDDING_BATCH_SIZE=50 6 | MIN_CHUNK_LENGTH=10 7 | MAX_FILE_SIZE=10485760 8 | MAX_IMAGE_SIZE=5242880 9 | 10 | # Vector Store Settings 11 | VECTOR_STORE_TYPE=postgres 12 | POSTGRES_CONNECTION_STRING=postgresql://extendai:extendai@pgvector:5432/extendai 13 | POSTGRES_COLLECTION_NAME=embeddings 14 | # if using Pinecone 15 | PINECONE_API_KEY=your-pinecone-api-key 16 | PINECONE_INDEX_NAME=extendai 17 | 18 | # Cache Settings 19 | VECTOR_CACHE_DIR=cache/vectors 20 | VECTOR_CACHE_TTL=7200 21 | 22 | API_TIMEOUT=600 23 | MY_API_KEY=sk-planetzero-api-key 24 | 25 | # Model API Settings 26 | TARGET_MODEL_BASE_URL= 27 | TARGET_MODEL_API_KEY= 28 | OPENAI_BASE_URL= 29 | OPENAI_API_KEY= 30 | OPENAI_ENHANCE_MODEL= 31 | 32 | # Embedding API Settings 33 | EMBEDDING_BASE_URL= 34 | EMBEDDING_API_KEY= 35 | EMBEDDING_MODEL=text-embedding-3-small 36 | EMBEDDING_DIMENSIONS=1536 37 | 38 | # Search Settings 39 | SEARXNG_URL= 40 | DEFAULT_MODEL=deepseek-r1 41 | SEARCH_ENGINE=google 42 | SEARCH_RESULT_LIMIT=5 43 | SEARCH_RESULT_MULTIPLIER=2 44 | WEB_CONTENT_CHUNK_SIZE=512 45 | WEB_CONTENT_CHUNK_OVERLAP=50 46 | WEB_CONTENT_MAX_CHUNKS=5 47 | 48 | # Proxy Settings (Optional) 49 | PROXY_ENABLED=false 50 | PROXY_HOST=proxy.example.com 51 | PROXY_PORT=8080 52 | PROXY_USERNAME=your-proxy-username 53 | PROXY_PASSWORD=your-proxy-password 54 | 55 | # Feature Switches 56 | ENABLE_PROGRESS_MESSAGES=false 57 | ENABLE_IMAGE_ANALYSIS=true 58 | ENABLE_WEB_SEARCH=true 59 | ENABLE_DOCUMENT_ANALYSIS=true 60 | 61 | # Progress Messages (Customizable) 62 | PROGRESS_MSG_IMAGE="Analyzing image content..." 63 | PROGRESS_MSG_DOC="Analyzing document..." 64 | PROGRESS_MSG_DOC_SEARCH="Searching document..." 65 | PROGRESS_MSG_WEB_SEARCH="Searching web content..." 66 | # Server Settings 67 | PORT=8096 68 | HOST=0.0.0.0 69 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | name: extendai 3 | 4 | services: 5 | app: 6 | build: . 7 | ports: 8 | - "8096:8096" 9 | env_file: 10 | - .env 11 | environment: 12 | # Default settings that will be overridden by .env file if present 13 | VECTOR_STORE_TYPE: ${VECTOR_STORE_TYPE:-postgres} 14 | POSTGRES_CONNECTION_STRING: ${POSTGRES_CONNECTION_STRING:-postgresql://extendai:extendai@pgvector:5432/extendai} 15 | POSTGRES_COLLECTION_NAME: ${POSTGRES_COLLECTION_NAME:-embeddings} 16 | ENABLE_IMAGE_ANALYSIS: ${ENABLE_IMAGE_ANALYSIS:-true} 17 | ENABLE_WEB_SEARCH: ${ENABLE_WEB_SEARCH:-true} 18 | ENABLE_DOCUMENT_ANALYSIS: ${ENABLE_DOCUMENT_ANALYSIS:-true} 19 | depends_on: 20 | pgvector: 21 | condition: service_healthy 22 | volumes: 23 | - ./cache:/app/cache 24 | - ./.env:/app/.env:ro # Mount .env file as read-only 25 | 26 | postgres: 27 | image: postgres:16 28 | environment: 29 | POSTGRES_DB: extendai 30 | POSTGRES_USER: extendai 31 | POSTGRES_PASSWORD: extendai 32 | ports: 33 | - "6023:5432" 34 | command: postgres -c log_statement=all 35 | healthcheck: 36 | test: ["CMD-SHELL", "psql postgresql://extendai:extendai@localhost/extendai --command 'SELECT 1;' || exit 1"] 37 | interval: 5s 38 | retries: 60 39 | volumes: 40 | - postgres_data:/var/lib/postgresql/data 41 | 42 | pgvector: 43 | image: ankane/pgvector 44 | environment: 45 | POSTGRES_DB: extendai 46 | POSTGRES_USER: extendai 47 | POSTGRES_PASSWORD: extendai 48 | ports: 49 | - "6024:5432" 50 | command: postgres -c log_statement=all 51 | healthcheck: 52 | test: ["CMD-SHELL", "psql postgresql://extendai:extendai@localhost/extendai --command 'SELECT 1;' || exit 1"] 53 | interval: 5s 54 | retries: 60 55 | volumes: 56 | - postgres_data_pgvector:/var/lib/postgresql/data 57 | 58 | volumes: 59 | postgres_data: 60 | postgres_data_pgvector: 61 | cache: -------------------------------------------------------------------------------- /app/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import base64 4 | from typing import Tuple 5 | from urllib.parse import urlparse, parse_qs, unquote 6 | from app.config.settings import ( 7 | SUPPORTED_IMAGE_FORMATS, 8 | SUPPORTED_DOCUMENT_FORMATS 9 | ) 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | # Known source file extensions 14 | KNOWN_SOURCE_EXT = [ 15 | "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", "ts", "css", 16 | "cpp", "hpp", "h", "c", "cs", "sql", "log", "ini", "pl", "pm", 17 | "r", "dart", "dockerfile", "env", "php", "hs", "hsc", "lua", 18 | "nginxconf", "conf", "m", "mm", "plsql", "perl", "rb", "rs", 19 | "db2", "scala", "bash", "swift", "vue", "svelte", "msg", "ex", 20 | "exs", "erl", "tsx", "jsx", "hs", "lhs" 21 | ] 22 | 23 | def extract_filename_from_url(url: str) -> str: 24 | """Extract filename from URL using various methods""" 25 | # Try to get filename from URL parameters 26 | parsed_url = urlparse(url) 27 | query_params = parse_qs(parsed_url.query) 28 | 29 | # Check for filename in query parameters 30 | for param in ['filename', 'rscd']: 31 | if param in query_params: 32 | value = query_params[param][0] 33 | if 'filename=' in value: 34 | # Extract filename from rscd parameter 35 | filename = value.split('filename=')[-1].strip() 36 | # Remove any quotes or additional parameters 37 | filename = filename.split(';')[0].strip('"\'') 38 | return unquote(filename) 39 | 40 | # Fallback to path 41 | path = parsed_url.path 42 | if path: 43 | return os.path.basename(path) 44 | 45 | return "" 46 | 47 | def is_base64(s: str) -> bool: 48 | """Check if a string is base64 encoded""" 49 | try: 50 | # Check if string starts with data URI scheme 51 | if s.startswith('data:'): 52 | return True 53 | 54 | # Skip URLs that look like normal web URLs 55 | if s.startswith(('http://', 'https://')): 56 | return False 57 | 58 | # Try to decode the string if it looks like base64 59 | if len(s) % 4 == 0 and not s.startswith(('/', '\\')): 60 | try: 61 | # Check if string contains valid base64 characters 62 | if not all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=' for c in s): 63 | return False 64 | 65 | decoded = base64.b64decode(s) 66 | # Check if decoded data looks like binary data 67 | return len(decoded) > 0 and not all(32 <= byte <= 126 for byte in decoded) 68 | except: 69 | pass 70 | return False 71 | except: 72 | return False 73 | 74 | def detect_file_type(url: str, content_type: str, content_disposition: str = "") -> Tuple[str, bool]: 75 | """ 76 | Detect if a file is an image or document based on URL, content type and content disposition 77 | Returns: (file_type, is_supported) 78 | where file_type is either 'image' or 'document' 79 | """ 80 | # Handle base64 image data 81 | if is_base64(url): 82 | return 'image', True 83 | 84 | # Try to get filename from Content-Disposition 85 | filename = "" 86 | if content_disposition and "filename=" in content_disposition: 87 | filename = content_disposition.split("filename=")[-1].strip('"\'') 88 | filename = unquote(filename) 89 | 90 | # If no filename in Content-Disposition, try URL 91 | if not filename: 92 | filename = extract_filename_from_url(url) 93 | 94 | # Get file extension 95 | ext = os.path.splitext(filename.lower())[1] if filename else "" 96 | 97 | # Check content type and extension 98 | if (content_type in SUPPORTED_IMAGE_FORMATS or 99 | ext in SUPPORTED_IMAGE_FORMATS): 100 | return 'image', True 101 | elif (content_type in SUPPORTED_DOCUMENT_FORMATS or 102 | ext in SUPPORTED_DOCUMENT_FORMATS): 103 | return 'document', True 104 | elif ext.lstrip('.') in KNOWN_SOURCE_EXT: 105 | return 'document', True 106 | 107 | # If content type starts with image/, treat as image 108 | if content_type.startswith('image/'): 109 | return 'image', False 110 | 111 | # Default to document for all other types 112 | return 'document', False -------------------------------------------------------------------------------- /app/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from enum import Enum 4 | 5 | # Load environment variables from .env 6 | load_dotenv() 7 | 8 | class SearchEngine(Enum): 9 | GOOGLE = "google" 10 | SEARXNG = "searxng" 11 | 12 | # Server Settings 13 | PORT = int(os.getenv("PORT", "8096")) 14 | HOST = os.getenv("HOST", "0.0.0.0") 15 | 16 | # Document processing settings 17 | CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "1000")) 18 | CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "200")) 19 | MAX_CHUNKS_PER_DOC = int(os.getenv("MAX_CHUNKS_PER_DOC", "5")) 20 | EMBEDDING_BATCH_SIZE = int(os.getenv("EMBEDDING_BATCH_SIZE", "50")) 21 | MIN_CHUNK_LENGTH = int(os.getenv("MIN_CHUNK_LENGTH", "10")) 22 | 23 | # Vector store settings 24 | VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "faiss") # "faiss", "postgres", or "pinecone" 25 | POSTGRES_CONNECTION_STRING = os.getenv("POSTGRES_CONNECTION_STRING", "") 26 | POSTGRES_COLLECTION_NAME = os.getenv("POSTGRES_COLLECTION_NAME", "document_vectors") 27 | 28 | # Pinecone settings 29 | PINECONE_API_KEY = os.getenv("PINECONE_API_KEY", "pcsk_xxx") 30 | PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "extendai") 31 | # Cache settings 32 | VECTOR_CACHE_DIR = os.getenv("VECTOR_CACHE_DIR", "cache/vectors") 33 | VECTOR_CACHE_TTL = int(os.getenv("VECTOR_CACHE_TTL", str(2 * 60 * 60))) # 2 hours in seconds 34 | 35 | # File size limits (in bytes) 36 | MAX_FILE_SIZE = int(os.getenv("MAX_FILE_SIZE", str(10 * 1024 * 1024))) # 10MB default 37 | MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", str(10 * 1024 * 1024))) # 10MB default 38 | 39 | # Supported file formats 40 | SUPPORTED_IMAGE_FORMATS = { 41 | # MIME types 42 | "image/jpeg", "image/png", "image/gif", "image/webp", "image/tiff", "image/bmp", 43 | # Extensions 44 | ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp" 45 | } 46 | 47 | SUPPORTED_DOCUMENT_FORMATS = { 48 | # MIME types 49 | "application/pdf", 50 | "application/msword", 51 | "application/vnd.openxmlformats-officedocument.wordprocessingml.document", 52 | "application/vnd.ms-excel", 53 | "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", 54 | "application/vnd.ms-powerpoint", 55 | "application/vnd.openxmlformats-officedocument.presentationml.presentation", 56 | "text/plain", 57 | "text/csv", 58 | "text/html", 59 | "text/xml", 60 | "text/markdown", 61 | "application/epub+zip", 62 | # Extensions 63 | ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", 64 | ".txt", ".csv", ".html", ".htm", ".xml", ".md", ".rst", ".epub" 65 | } 66 | 67 | # API URLs and Keys 68 | TARGET_MODEL_BASE_URL = os.getenv("TARGET_MODEL_BASE_URL", "https://api.openai.com") 69 | TARGET_MODEL_API_URL = f"{TARGET_MODEL_BASE_URL}/v1/chat/completions" 70 | TARGET_MODEL_API_KEY = os.getenv("TARGET_MODEL_API_KEY") 71 | OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com") 72 | OPENAI_API_URL = f"{OPENAI_BASE_URL}/v1/chat/completions" 73 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 74 | OPENAI_ENHANCE_MODEL = os.getenv("OPENAI_ENHANCE_MODEL") # Model for image/document enhancement 75 | 76 | # Embedding Settings 77 | EMBEDDING_BASE_URL = os.getenv("EMBEDDING_BASE_URL", "https://api.openai.com/v1") 78 | EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", OPENAI_API_KEY) # Default to OpenAI key if not set 79 | EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") 80 | EMBEDDING_DIMENSIONS = int(os.getenv("EMBEDDING_DIMENSIONS", "1536")) 81 | 82 | MY_API_KEY = os.getenv("MY_API_KEY") 83 | 84 | # Search Settings 85 | SEARXNG_URL = os.getenv("SEARXNG_URL", "https://searxng.example.com") 86 | DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "deepseek-r1") 87 | SEARCH_ENGINE = os.getenv("SEARCH_ENGINE", "google").lower() 88 | 89 | # Search Result Settings 90 | SEARCH_RESULT_LIMIT = int(os.getenv("SEARCH_RESULT_LIMIT", "5")) # Number of results to return 91 | SEARCH_RESULT_MULTIPLIER = int(os.getenv("SEARCH_RESULT_MULTIPLIER", "2")) # Multiplier for raw results to fetch 92 | WEB_CONTENT_CHUNK_SIZE = int(os.getenv("WEB_CONTENT_CHUNK_SIZE", "512")) # Size of each content chunk 93 | WEB_CONTENT_CHUNK_OVERLAP = int(os.getenv("WEB_CONTENT_CHUNK_OVERLAP", "50")) # Overlap between chunks 94 | WEB_CONTENT_MAX_CHUNKS = int(os.getenv("WEB_CONTENT_MAX_CHUNKS", "5")) # Number of most relevant chunks to return 95 | 96 | # Feature Switches 97 | ENABLE_PROGRESS_MESSAGES = os.getenv("ENABLE_PROGRESS_MESSAGES", "true").lower() == "true" 98 | ENABLE_IMAGE_ANALYSIS = os.getenv("ENABLE_IMAGE_ANALYSIS", "true").lower() == "true" 99 | ENABLE_WEB_SEARCH = os.getenv("ENABLE_WEB_SEARCH", "true").lower() == "true" 100 | ENABLE_DOCUMENT_ANALYSIS = os.getenv("ENABLE_DOCUMENT_ANALYSIS", "true").lower() == "true" 101 | 102 | # Progress Messages 103 | PROGRESS_MESSAGES = { 104 | "image_analysis": os.getenv("PROGRESS_MSG_IMAGE", "Analyzing image content..."), 105 | "document_analysis": os.getenv("PROGRESS_MSG_DOC", "Analyzing document content..."), 106 | "document_search": os.getenv("PROGRESS_MSG_DOC_SEARCH", "Searching document for relevant content..."), 107 | "web_search": os.getenv("PROGRESS_MSG_WEB_SEARCH", "Doing web search for relevant information...") 108 | } 109 | 110 | # Proxy Settings 111 | PROXY_ENABLED = os.getenv("PROXY_ENABLED", "false").lower() == "true" 112 | PROXY_HOST = os.getenv("PROXY_HOST", "") 113 | PROXY_PORT = int(os.getenv("PROXY_PORT", "")) 114 | PROXY_USERNAME = os.getenv("PROXY_USERNAME", "") # Using -res-any for any region 115 | PROXY_PASSWORD = os.getenv("PROXY_PASSWORD", "") 116 | 117 | # Timeouts 118 | API_TIMEOUT = int(os.getenv("API_TIMEOUT", "600")) 119 | SEARCH_TIMEOUT = 10 120 | REQUEST_TIMEOUT = { 121 | "connect": 3, 122 | "read": 5, 123 | "write": 3, 124 | "pool": 3 125 | } 126 | 127 | # Headers 128 | DEFAULT_HEADERS = { 129 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", 130 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", 131 | "Accept-Language": "en-US,en;q=0.5", 132 | "Accept-Encoding": "gzip, deflate", 133 | "DNT": "1" 134 | } 135 | 136 | # Search Engine Weights 137 | SEARCH_ENGINE_WEIGHTS = { 138 | "google": 3, 139 | "bing": 2, 140 | "duckduckgo": 2, 141 | "brave": 1, 142 | "qwant": 1 143 | } -------------------------------------------------------------------------------- /app/services/image_service.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import logging 3 | import base64 4 | from fastapi import HTTPException 5 | from app.utils.file_utils import is_base64 6 | 7 | from app.config.settings import ( 8 | OPENAI_API_URL, OPENAI_API_KEY, 9 | API_TIMEOUT, MAX_IMAGE_SIZE, 10 | OPENAI_ENHANCE_MODEL 11 | ) 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | def get_headers(api_key: str) -> dict: 16 | return { 17 | "Content-Type": "application/json", 18 | "Authorization": f"Bearer {api_key}" 19 | } 20 | 21 | async def process_image(image_url: str) -> str: 22 | """Process image using OpenAI and return description""" 23 | if not OPENAI_API_KEY: 24 | raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured") 25 | 26 | try: 27 | # Handle base64 image data 28 | if is_base64(image_url): 29 | # If it's a data URI, get the actual base64 data 30 | if image_url.startswith('data:'): 31 | base64_data = image_url.split(',')[1] 32 | else: 33 | base64_data = image_url 34 | 35 | # Check size of decoded data 36 | try: 37 | decoded_data = base64.b64decode(base64_data) 38 | if len(decoded_data) > MAX_IMAGE_SIZE: 39 | raise HTTPException( 40 | status_code=413, 41 | detail=f"Image size ({len(decoded_data)} bytes) exceeds maximum allowed size ({MAX_IMAGE_SIZE} bytes)" 42 | ) 43 | except Exception as e: 44 | logger.error(f"Error decoding base64 data: {str(e)}") 45 | raise HTTPException(status_code=400, detail="Invalid base64 image data") 46 | 47 | image_data = image_url 48 | else: 49 | # Handle regular URL 50 | # Check if this is an S3 presigned URL 51 | is_s3_url = "X-Amz-Algorithm=AWS4-HMAC-SHA256" in image_url and "X-Amz-Credential" in image_url 52 | 53 | # Set headers based on URL type 54 | headers = { 55 | "Accept": "*/*", 56 | "Accept-Encoding": "gzip, deflate", 57 | "Connection": "keep-alive", 58 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" 59 | } 60 | 61 | async with httpx.AsyncClient() as client: 62 | if is_s3_url: 63 | # For S3 presigned URLs, skip HEAD request and directly download 64 | response = await client.get( 65 | image_url, 66 | headers=headers, 67 | follow_redirects=True 68 | ) 69 | response.raise_for_status() 70 | 71 | # Check content length after download 72 | image_bytes = response.content 73 | if len(image_bytes) > MAX_IMAGE_SIZE: 74 | raise HTTPException( 75 | status_code=413, 76 | detail=f"Image size ({len(image_bytes)} bytes) exceeds maximum allowed size ({MAX_IMAGE_SIZE} bytes)" 77 | ) 78 | 79 | # Detect content type 80 | content_type = response.headers.get("content-type", "image/jpeg") 81 | if not content_type.startswith("image/"): 82 | content_type = "image/jpeg" # Default to jpeg if no valid content type 83 | 84 | # Create data URI 85 | image_data = f"data:{content_type};base64,{base64.b64encode(image_bytes).decode()}" 86 | else: 87 | # For regular URLs, do HEAD request first 88 | head_response = await client.head( 89 | image_url, 90 | headers=headers, 91 | follow_redirects=True 92 | ) 93 | head_response.raise_for_status() 94 | 95 | # Check content length if available 96 | content_length = head_response.headers.get("content-length") 97 | if content_length: 98 | file_size = int(content_length) 99 | if file_size > MAX_IMAGE_SIZE: 100 | raise HTTPException( 101 | status_code=413, 102 | detail=f"Image size ({file_size} bytes) exceeds maximum allowed size ({MAX_IMAGE_SIZE} bytes)" 103 | ) 104 | 105 | # For regular URLs, pass the URL directly 106 | image_data = image_url 107 | 108 | # Process image with OpenAI 109 | async with httpx.AsyncClient() as client: 110 | openai_request = { 111 | "model": OPENAI_ENHANCE_MODEL, 112 | "messages": [{ 113 | "role": "user", 114 | "content": [{ 115 | "type": "text", 116 | "text": "Only output the ocr result of user's uploaded image. If the image contains data structures like a graph or table, convert the info into text. If this is a image of an object, describe it as detailed as possible." 117 | }, { 118 | "type": "image_url", 119 | "image_url": {"url": image_data} 120 | }] 121 | }], 122 | "stream": False 123 | } 124 | 125 | # Add headers for non-base64 and non-S3 URLs 126 | if not is_base64(image_data) and not is_s3_url: 127 | openai_request["messages"][0]["content"][1]["image_url"]["headers"] = headers 128 | 129 | response = await client.post( 130 | OPENAI_API_URL, 131 | json=openai_request, 132 | headers=get_headers(OPENAI_API_KEY), 133 | timeout=API_TIMEOUT 134 | ) 135 | response.raise_for_status() 136 | result = response.json() 137 | return result.get("choices", [{}])[0].get("message", {}).get("content", "") 138 | 139 | except Exception as e: 140 | logger.error(f"Image processing error: {str(e)}") 141 | raise HTTPException(status_code=500, detail=f"Image processing failed: {str(e)}") -------------------------------------------------------------------------------- /app/api/chat.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import asyncio 4 | from fastapi import APIRouter, HTTPException, Request, Depends 5 | from fastapi.responses import StreamingResponse 6 | from openai import AsyncOpenAI 7 | 8 | from app.models.chat import ChatRequest 9 | from app.services.message_service import process_messages 10 | from app.config.settings import ( 11 | TARGET_MODEL_BASE_URL, TARGET_MODEL_API_KEY, 12 | DEFAULT_MODEL 13 | ) 14 | from app.core.auth import verify_api_key 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | router = APIRouter() 19 | 20 | # Initialize OpenAI client 21 | client = AsyncOpenAI( 22 | base_url=f"{TARGET_MODEL_BASE_URL}/v1", 23 | api_key=TARGET_MODEL_API_KEY 24 | ) 25 | 26 | async def stream_progress_message(message: str): 27 | """Stream a progress message""" 28 | chunk_data = { 29 | "id": "progress_msg", 30 | "object": "chat.completion.chunk", 31 | "created": 0, 32 | "model": DEFAULT_MODEL, 33 | "choices": [{ 34 | "index": 0, 35 | "delta": {"role": "assistant", "content": message}, 36 | "finish_reason": None 37 | }] 38 | } 39 | yield f"data: {json.dumps(chunk_data)}\n\n" 40 | 41 | async def stream_response(response): 42 | """Stream the response from the upstream API""" 43 | try: 44 | buffer = "" 45 | async for line in response.aiter_lines(): 46 | if not line.strip(): 47 | continue 48 | 49 | buffer += line 50 | if buffer.strip() == 'data: [DONE]': 51 | yield 'data: [DONE]\n\n' 52 | break 53 | 54 | try: 55 | if buffer.startswith('data: '): 56 | json.loads(buffer[6:]) 57 | yield f"{buffer}\n\n" 58 | buffer = "" 59 | except json.JSONDecodeError: 60 | continue 61 | except Exception as e: 62 | logger.error(f"Error parsing stream chunk: {str(e)}") 63 | continue 64 | except Exception as e: 65 | logger.error(f"Error in stream_response: {str(e)}") 66 | yield f"data: {json.dumps({'error': str(e)})}\n\n" 67 | 68 | @router.post("/v1/chat/completions") 69 | async def chat( 70 | request: ChatRequest, 71 | raw_request: Request, 72 | api_key: str = Depends(verify_api_key) 73 | ): 74 | """Handle chat requests with support for streaming and image processing""" 75 | if not TARGET_MODEL_API_KEY: 76 | raise HTTPException(status_code=500, detail="TARGET_MODEL_API_KEY not configured") 77 | 78 | try: 79 | logger.info("Processing chat request...") 80 | 81 | if request.stream: 82 | # Create a queue for progress messages 83 | progress_queue = asyncio.Queue() 84 | 85 | # Process messages in background task 86 | async def process_messages_task(): 87 | try: 88 | processed_messages, _ = await process_messages( 89 | request.messages, 90 | stream=True, 91 | progress_queue=progress_queue 92 | ) 93 | # Convert messages for API request 94 | api_messages = [{"role": msg.role, "content": msg.content} for msg in processed_messages] 95 | 96 | # Create streaming completion 97 | stream = await client.chat.completions.create( 98 | messages=api_messages, 99 | model=DEFAULT_MODEL, 100 | stream=True 101 | ) 102 | return stream 103 | except Exception as e: 104 | logger.error(f"Error in process_messages_task: {str(e)}") 105 | raise 106 | 107 | # Start processing task 108 | processing_task = asyncio.create_task(process_messages_task()) 109 | 110 | async def generate(): 111 | progress_task = None 112 | current_progress_message = None 113 | progress_generator = None 114 | had_progress_message = False 115 | 116 | try: 117 | while True: 118 | # Check for new progress message 119 | try: 120 | progress_message = progress_queue.get_nowait() 121 | if had_progress_message: 122 | chunk_data = { 123 | "id": "progress_msg", 124 | "object": "chat.completion.chunk", 125 | "created": 0, 126 | "model": DEFAULT_MODEL, 127 | "choices": [{ 128 | "index": 0, 129 | "delta": {"role": "assistant", "content": "\n"}, 130 | "finish_reason": None 131 | }] 132 | } 133 | yield f"data: {json.dumps(chunk_data)}\n\n" 134 | progress_generator = stream_progress_message(progress_message) 135 | current_progress_message = progress_message 136 | yield await anext(progress_generator) 137 | had_progress_message = True 138 | except asyncio.QueueEmpty: 139 | pass 140 | except StopAsyncIteration: 141 | pass 142 | 143 | if processing_task.done(): 144 | break 145 | 146 | await asyncio.sleep(0.1) 147 | 148 | stream = await processing_task 149 | 150 | if had_progress_message: 151 | chunk_data = { 152 | "id": "progress_msg", 153 | "object": "chat.completion.chunk", 154 | "created": 0, 155 | "model": DEFAULT_MODEL, 156 | "choices": [{ 157 | "index": 0, 158 | "delta": {"role": "assistant", "content": "\n"}, 159 | "finish_reason": None 160 | }] 161 | } 162 | yield f"data: {json.dumps(chunk_data)}\n\n" 163 | 164 | last_content_chunk = None 165 | async for chunk in stream: 166 | chunk_dict = chunk.model_dump() 167 | 168 | if chunk.choices[0].delta.content is not None: 169 | if last_content_chunk: 170 | yield f"data: {json.dumps(last_content_chunk)}\n\n" 171 | last_content_chunk = chunk_dict 172 | else: 173 | if last_content_chunk: 174 | last_content_chunk["choices"][0]["finish_reason"] = "stop" 175 | yield f"data: {json.dumps(last_content_chunk)}\n\n" 176 | last_content_chunk = None 177 | if not (chunk.choices[0].finish_reason == "stop" and not chunk.choices[0].delta.content): 178 | yield f"data: {json.dumps(chunk_dict)}\n\n" 179 | 180 | if last_content_chunk: 181 | last_content_chunk["choices"][0]["finish_reason"] = "stop" 182 | yield f"data: {json.dumps(last_content_chunk)}\n\n" 183 | 184 | yield "data: [DONE]\n\n" 185 | 186 | except Exception as e: 187 | logger.error(f"Error in generate: {str(e)}") 188 | if progress_task and not progress_task.done(): 189 | progress_task.cancel() 190 | raise 191 | 192 | return StreamingResponse( 193 | generate(), 194 | media_type="text/event-stream", 195 | headers={ 196 | "Cache-Control": "no-cache, no-transform", 197 | "Connection": "keep-alive", 198 | "Content-Type": "text/event-stream", 199 | "Transfer-Encoding": "chunked", 200 | "X-Accel-Buffering": "no" 201 | } 202 | ) 203 | else: 204 | processed_messages, _ = await process_messages(request.messages) 205 | 206 | response = await client.chat.completions.create( 207 | messages=[{"role": msg.role, "content": msg.content} for msg in processed_messages], 208 | model=DEFAULT_MODEL, 209 | stream=False 210 | ) 211 | return response.model_dump() 212 | 213 | except Exception as e: 214 | logger.error(f"Chat error: {str(e)}") 215 | raise HTTPException(status_code=500, detail=str(e)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ExtendAI 2 | 3 | ExtendAI is a universal framework that can extend any AI model to achieve multi-modal ability, network search ability, and document analysis ability. 4 | 5 | ## Features 6 | 7 | - **Multi-format Document Support**: Process various document formats including: 8 | - PDF files 9 | - Microsoft Office documents (Word, Excel, PowerPoint) 10 | - Markdown files 11 | - Text files 12 | - Source code files 13 | - HTML/XML files 14 | - CSV files 15 | - RST files 16 | - Outlook MSG files 17 | - EPub files 18 | 19 | - **Advanced Document Processing**: 20 | - Automatic text encoding detection and fixing 21 | - Mutiple Documents Batch Processing 22 | - Smart document chunking with configurable size and overlap 23 | - Ultra fast processing of large documents 24 | - File type detection based on content and extensions 25 | 26 | - **Intelligent Web Search**: 27 | - Integration with multiple search engines (Google, Bing) 28 | - SearXNG support for privacy-focused searches 29 | - Smart result filtering and ranking 30 | - Configurable search depth and limits 31 | - Automatic content extraction and processing 32 | 33 | - **Advanced Image Comprehension**: 34 | - Deep image analysis and understanding 35 | - Mutiple Images Batch Processing 36 | - Text extraction from images (OCR) 37 | - Support for multiple image formats 38 | - Context-aware image interpretation 39 | 40 | - **Flexible Vector Store Integration**: 41 | - PostgreSQL vector store support 42 | - Pinecone vector store support 43 | - Local FAISS vector store with caching 44 | - Automatic fallback mechanisms 45 | 46 | - **Performance Optimizations**: 47 | - Asynchronous document processing 48 | - Efficient caching system for FAISS vectors 49 | - Batched embedding generation 50 | - Configurable chunk sizes and overlap 51 | - File size limits and safety checks 52 | 53 | - **Robust Error Handling**: 54 | - Graceful fallbacks for failed operations 55 | - Comprehensive logging system 56 | - Automatic cleanup of temporary files 57 | - Cache invalidation and management 58 | 59 | ## Screenshots 60 |
61 | ExtendAI Screenshot 1 62 |

63 | ExtendAI Screenshot 2 64 |

65 | ExtendAI Screenshot 3 66 | 67 |
68 | 69 | ## Configuration 70 | 71 | The system is highly configurable through environment variables: 72 | 73 | ```env 74 | # Document Processing Settings 75 | CHUNK_SIZE=1000 # Size of text chunks for processing (in characters) 76 | CHUNK_OVERLAP=200 # Overlap between chunks to maintain context 77 | MAX_CHUNKS_PER_DOC=5 # Maximum number of chunks to return per document 78 | EMBEDDING_BATCH_SIZE=50 # Number of chunks to process in parallel for embeddings 79 | MIN_CHUNK_LENGTH=10 # Minimum length of a chunk to be processed 80 | MAX_FILE_SIZE=10485760 # Maximum file size in bytes (10MB) 81 | MAX_IMAGE_SIZE=5242880 # Maximum image size in bytes (5MB) 82 | 83 | # Vector Store Settings 84 | VECTOR_STORE_TYPE=faiss # Vector store backend: 'faiss', 'postgres', or 'pinecone' 85 | POSTGRES_CONNECTION_STRING= # PostgreSQL connection URL for vector storage 86 | POSTGRES_COLLECTION_NAME= # Collection name for storing vectors in PostgreSQL 87 | 88 | # Pinecone Settings 89 | PINECONE_API_KEY= # API key for Pinecone vector database 90 | PINECONE_INDEX_NAME= # Name of the Pinecone index to use 91 | 92 | # Cache Settings 93 | VECTOR_CACHE_DIR=cache/vectors # Directory for storing FAISS vector cache 94 | VECTOR_CACHE_TTL=7200 # Cache time-to-live in seconds (2 hours) 95 | 96 | API_TIMEOUT=600 # API Timeout (in seconds) 97 | MY_API_KEY=sk-planetzero-api-key # Authorization key 98 | 99 | # Model API Settings 100 | TARGET_MODEL_BASE_URL= # Base URL for the target AI model API 101 | TARGET_MODEL_API_KEY= # API key for the target model 102 | OPENAI_BASE_URL= # Base URL for OpenAI API (or compatible endpoint) 103 | OPENAI_API_KEY= # OpenAI API key (or compatible key) 104 | OPENAI_ENHANCE_MODEL= # The model used to facilitate image processing and search analysis 105 | 106 | # Embedding API Settings 107 | EMBEDDING_BASE_URL= # Base URL for embedding API 108 | EMBEDDING_API_KEY= # API key for embedding service 109 | EMBEDDING_MODEL=text-embedding-3-small # Model to use for text embeddings 110 | EMBEDDING_DIMENSIONS=1536 # Dimension of the embedding vectors 111 | 112 | # Search Settings 113 | SEARXNG_URL= # URL for SearXNG instance (for web search) 114 | DEFAULT_MODEL=deepseek-r1 # Default AI model to use 115 | SEARCH_ENGINE=google # Search engine to use (google, bing, etc.) 116 | SEARCH_RESULT_LIMIT=5 # Number of search results to return 117 | SEARCH_RESULT_MULTIPLIER=2 # Multiplier for raw results to fetch 118 | WEB_CONTENT_CHUNK_SIZE=512 # Size of chunks for web content 119 | WEB_CONTENT_CHUNK_OVERLAP=50 # Overlap for web content chunks 120 | WEB_CONTENT_MAX_CHUNKS=5 # Maximum chunks to process from web content 121 | 122 | # Proxy Settings (Optional) 123 | PROXY_ENABLED=false # Whether to use proxy for requests 124 | PROXY_HOST= # Proxy server hostname 125 | PROXY_PORT= # Proxy server port 126 | PROXY_USERNAME= # Proxy authentication username 127 | PROXY_PASSWORD= # Proxy authentication password 128 | PROXY_COUNTRY= # Preferred proxy server country 129 | PROXY_SESSION_ID= # Session ID for proxy (if required) 130 | 131 | # Feature Switches 132 | ENABLE_PROGRESS_MESSAGES=false # Enable/disable progress message updates 133 | ENABLE_IMAGE_ANALYSIS=true # Enable/disable image analysis capability 134 | ENABLE_WEB_SEARCH=true # Enable/disable web search capability 135 | ENABLE_DOCUMENT_ANALYSIS=true # Enable/disable document analysis capability 136 | 137 | # Progress Messages (Customizable) 138 | PROGRESS_MSG_IMAGE="Analyzing image content..." # Message shown during image analysis 139 | PROGRESS_MSG_DOC="Analyzing document..." # Message shown during document processing 140 | PROGRESS_MSG_DOC_SEARCH="Searching document..." # Message shown during document search 141 | PROGRESS_MSG_WEB_SEARCH="Searching web content..." # Message shown during web search 142 | ``` 143 | 144 | ## Advanced Configuration Details 145 | 146 | ### Document Processing 147 | - OpenAI format, just need to pass the document url through image_url parameter 148 | - The chunking system breaks down documents into manageable pieces while maintaining context through overlap 149 | - Batch processing helps optimize embedding generation and API usage 150 | - File size limits protect against resource exhaustion and API limitations 151 | 152 | ### Vector Store Options 153 | 1. **FAISS** (Local): 154 | - Fast, efficient local vector storage 155 | - Good for development and smaller deployments 156 | - Includes local caching system for performance 157 | 158 | 2. **PostgreSQL**: 159 | - Persistent vector storage in PostgreSQL database 160 | - Suitable for production deployments 161 | - Supports concurrent access and backups 162 | 163 | 3. **Pinecone**: 164 | - Cloud-based vector database 165 | - Excellent for large-scale deployments 166 | - Provides automatic scaling and management 167 | 168 | ### API Integration 169 | - Supports multiple model endpoints (OpenAI-compatible) 170 | - Configurable embedding services 171 | - Automatic fallbacks and error handling 172 | 173 | ### Search Capabilities 174 | - Integrated web search through SearXNG and Google 175 | - Configurable search engines and result limits 176 | - Smart content chunking for web results 177 | 178 | ### Progress Tracking 179 | - Customizable progress messages 180 | - Feature toggles for different capabilities 181 | - Localization support for messages 182 | 183 | ## Installation & Deployment 184 | 185 | ### Option 1: Docker Compose (Recommended for Production) 186 | 187 | 1. Clone the repository: 188 | ```bash 189 | git clone https://github.com/realnoob007/ExtendAI.git 190 | cd ExtendAI 191 | ``` 192 | 193 | 2. Set up environment variables: 194 | ```bash 195 | cp .env.example .env 196 | # Edit .env with your configuration (API keys, model endpoints, etc.) 197 | ``` 198 | 199 | 3. Start the services using Docker Compose: 200 | ```bash 201 | docker compose up -d 202 | ``` 203 | 204 | This will start: 205 | - ExtendAI application on port 8096 206 | - PostgreSQL database on port 6023 207 | - PostgreSQL with pgvector extension on port 6024 208 | 209 | To view logs: 210 | ```bash 211 | docker compose logs -f 212 | ``` 213 | 214 | To stop the services: 215 | ```bash 216 | docker compose down 217 | ``` 218 | 219 | ### Option 2: Docker (Single Container) 220 | 221 | If you want to run only the ExtendAI application container and use external databases: 222 | 223 | 1. Clone and configure: 224 | ```bash 225 | git clone https://github.com/realnoob007/ExtendAI.git 226 | cd ExtendAI 227 | cp .env.example .env 228 | # Edit .env with your configuration 229 | ``` 230 | 231 | 2. Run the container: 232 | ```bash 233 | docker run -d \ 234 | --name extendai \ 235 | -p 8096:8096 \ 236 | --env-file .env \ 237 | -v $(pwd)/cache:/app/cache \ 238 | -v $(pwd)/.env:/app/.env:ro \ 239 | chasney/extendai:latest 240 | ``` 241 | 242 | Useful Docker commands: 243 | ```bash 244 | # View logs 245 | docker logs -f extendai 246 | 247 | # Stop container 248 | docker stop extendai 249 | 250 | # Remove container 251 | docker rm extendai 252 | 253 | # Rebuild image (if you made changes) 254 | docker build --no-cache -t extendai . 255 | ``` 256 | 257 | ### Option 3: Local Development 258 | 259 | For development purposes, you can run the application directly: 260 | 261 | 1. Create a virtual environment (Python 3.11+ recommended): 262 | ```bash 263 | python -m venv venv 264 | source venv/bin/activate # On Windows: venv\Scripts\activate 265 | ``` 266 | 267 | 2. Install dependencies: 268 | ```bash 269 | pip install -r requirements.txt 270 | ``` 271 | 272 | 3. Set up environment variables: 273 | ```bash 274 | cp .env.example .env 275 | # Edit .env with your configuration 276 | ``` 277 | 278 | 4. Run the application: 279 | ```bash 280 | python main.py 281 | ``` 282 | 283 | ## Architecture 284 | 285 | (not fully refactored yet): 286 | 287 | - `main.py`: Application entry point and FastAPI setup 288 | - `app/api/`: API route definitions 289 | - `app/services/`: Core business logic services 290 | - `app/models/`: Data models and schemas 291 | - `app/utils/`: Utility functions and helpers 292 | - `app/core/`: Core system components 293 | - `app/config/`: Configuration management 294 | 295 | #### Example Request Payload (Post Request to http://0.0.0.0:8096/v1/chat/completions) 296 | ```json 297 | { 298 | "messages": [ 299 | { 300 | "role": "user", 301 | "content": [ 302 | { 303 | "type": "text", 304 | "text": "what is this document about?" 305 | }, 306 | { 307 | "type": "image_url", 308 | "image_url": { 309 | "url": "https://xxx.com/document.pdf" 310 | } 311 | } 312 | ] 313 | } 314 | ], 315 | "model": "deepseek-r1-all", 316 | "stream": false 317 | } 318 | ``` 319 | 320 | #### Content Types Support 321 | The API supports multiple content types in messages: 322 | 323 | 1. **Text Content** 324 | ```json 325 | { 326 | "type": "text", 327 | "text": "your question or prompt here" 328 | } 329 | ``` 330 | 331 | 2. **Image URL** 332 | ```json 333 | { 334 | "type": "image_url", 335 | "image_url": { 336 | "url": "https://example.com/image.jpg" 337 | } 338 | } 339 | ``` 340 | 341 | 3. **Document URL** 342 | ```json 343 | { 344 | "type": "image_url", 345 | "image_url": { 346 | "url": "https://example.com/document.pdf" 347 | } 348 | } 349 | ``` 350 | 351 | #### Parameters 352 | 353 | | Parameter | Type | Required | Description | 354 | |-----------|------|----------|-------------| 355 | | messages | array | Yes | Array of message objects | 356 | | model | string | No | Model to use (default: deepseek-r1) | 357 | | stream | boolean | No | Whether to stream the response (default: false) | 358 | 359 | #### Non-Streaming Response 360 | ```json 361 | { 362 | "id": "chatcmpl-123", 363 | "object": "chat.completion", 364 | "created": 1677858242, 365 | "model": "deepseek-r1-all", 366 | "usage": { 367 | "prompt_tokens": 56, 368 | "completion_tokens": 31, 369 | "total_tokens": 87 370 | }, 371 | "choices": [ 372 | { 373 | "message": { 374 | "role": "assistant", 375 | "content": "The document appears to be about..." 376 | }, 377 | "finish_reason": "stop", 378 | "index": 0 379 | } 380 | ] 381 | } 382 | ``` 383 | 384 | #### Streaming Response 385 | When `stream` is set to `true`, the response will be sent as Server-Sent Events (SSE): 386 | ```http 387 | data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677858242,"model":"deepseek-r1-all","choices":[{"delta":{"role":"assistant","content":"The"},"index":0,"finish_reason":null}]} 388 | 389 | data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677858242,"model":"deepseek-r1-all","choices":[{"delta":{"content":" document"},"index":0,"finish_reason":null}]} 390 | 391 | data: [DONE] 392 | ``` 393 | 394 | ## Contributing 395 | 396 | 1. Fork the repository 397 | 2. Create a feature branch 398 | 3. Commit your changes 399 | 4. Push to the branch 400 | 5. Create a Pull Request 401 | 402 | ## Sponsorship (You can run this project with this api platform to save 20%) 403 | [PlanetZero API](https://api.planetzeroapi.com/) 404 | 405 | ## License 406 | 407 | MIT License 408 | -------------------------------------------------------------------------------- /app/services/search_service.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import asyncio 3 | import logging 4 | from bs4 import BeautifulSoup 5 | from urllib.parse import urlparse 6 | from googlesearch import search 7 | from typing import List, Optional, Dict, Any 8 | from langchain_text_splitters import RecursiveCharacterTextSplitter 9 | from langchain_openai import OpenAIEmbeddings 10 | from langchain_core.documents import Document 11 | from langchain_community.vectorstores import FAISS 12 | from langchain_postgres import PGVector 13 | from langchain_pinecone import PineconeVectorStore 14 | from pinecone import Pinecone 15 | 16 | from app.config.settings import ( 17 | PROXY_ENABLED, PROXY_HOST, PROXY_PORT, PROXY_USERNAME, 18 | PROXY_PASSWORD, SEARCH_RESULT_LIMIT, SEARCH_RESULT_MULTIPLIER, 19 | WEB_CONTENT_CHUNK_SIZE, WEB_CONTENT_CHUNK_OVERLAP, 20 | EMBEDDING_MODEL, EMBEDDING_BASE_URL, EMBEDDING_API_KEY, EMBEDDING_DIMENSIONS, 21 | VECTOR_STORE_TYPE, POSTGRES_CONNECTION_STRING, POSTGRES_COLLECTION_NAME, 22 | PINECONE_API_KEY, PINECONE_INDEX_NAME 23 | ) 24 | from app.models.chat import SearchResult 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def get_proxy_url() -> str: 29 | """Generate basic proxy URL""" 30 | return f"http://{PROXY_USERNAME}:{PROXY_PASSWORD}@{PROXY_HOST}:{PROXY_PORT}" 31 | 32 | class SearchService: 33 | def __init__(self): 34 | self.text_splitter = RecursiveCharacterTextSplitter( 35 | chunk_size=WEB_CONTENT_CHUNK_SIZE, 36 | chunk_overlap=WEB_CONTENT_CHUNK_OVERLAP, 37 | add_start_index=True, 38 | separators=["\n\n", "\n", "。", ". ", " ", ""], 39 | length_function=len, 40 | is_separator_regex=False 41 | ) 42 | self.embeddings = OpenAIEmbeddings( 43 | model=EMBEDDING_MODEL, 44 | openai_api_key=EMBEDDING_API_KEY, 45 | openai_api_base=EMBEDDING_BASE_URL, 46 | dimensions=EMBEDDING_DIMENSIONS, 47 | request_timeout=60.0, 48 | show_progress_bar=True, 49 | retry_min_seconds=1, 50 | retry_max_seconds=60, 51 | max_retries=3, 52 | skip_empty=True, 53 | ) 54 | # Initialize vector store 55 | self.vector_store = None 56 | self._init_vector_store() 57 | 58 | def _init_vector_store(self): 59 | """Initialize vector store based on configuration""" 60 | try: 61 | if VECTOR_STORE_TYPE == "postgres" and POSTGRES_CONNECTION_STRING: 62 | logger.info("Initializing PostgreSQL vector store") 63 | self.vector_store = PGVector( 64 | connection=POSTGRES_CONNECTION_STRING, 65 | collection_name=f"{POSTGRES_COLLECTION_NAME}_web_search", 66 | embeddings=self.embeddings 67 | ) 68 | elif VECTOR_STORE_TYPE == "pinecone": 69 | logger.info("Initializing Pinecone vector store") 70 | pc = Pinecone(api_key=PINECONE_API_KEY) 71 | index = pc.Index(PINECONE_INDEX_NAME) 72 | self.vector_store = PineconeVectorStore( 73 | embedding=self.embeddings, 74 | index=index, 75 | namespace="web_search" 76 | ) 77 | else: 78 | logger.info("Initializing FAISS vector store") 79 | self.vector_store = FAISS.from_texts( 80 | ["placeholder"], # Need at least one document to initialize 81 | self.embeddings, 82 | metadatas=[{"placeholder": True}] 83 | ) 84 | except Exception as e: 85 | logger.error(f"Failed to initialize vector store: {str(e)}") 86 | logger.warning("Falling back to FAISS vector store") 87 | self.vector_store = FAISS.from_texts( 88 | ["placeholder"], 89 | self.embeddings, 90 | metadatas=[{"placeholder": True}] 91 | ) 92 | 93 | async def _get_proxy_config(self) -> Optional[str]: 94 | """Get proxy configuration if enabled""" 95 | if not PROXY_ENABLED or not all([PROXY_HOST, PROXY_PORT, PROXY_USERNAME, PROXY_PASSWORD]): 96 | return None 97 | 98 | return f"http://{PROXY_USERNAME}:{PROXY_PASSWORD}@{PROXY_HOST}:{PROXY_PORT}" 99 | 100 | async def _fetch_page_content(self, url: str) -> str: 101 | """Fetch and extract text content from a webpage""" 102 | try: 103 | proxy = await self._get_proxy_config() 104 | async with httpx.AsyncClient(proxy=proxy, timeout=30.0) as client: 105 | response = await client.get(url, follow_redirects=True) 106 | response.raise_for_status() 107 | 108 | soup = BeautifulSoup(response.text, 'html.parser') 109 | 110 | # Remove unwanted elements 111 | for element in soup.find_all(['script', 'style', 'nav', 'header', 'footer', 'iframe', 'aside', 'form']): 112 | element.decompose() 113 | 114 | # First try to find main content area 115 | main_content = "" 116 | content_tags = soup.find_all(['article', 'main', 'div'], 117 | class_=lambda x: x and any(c in str(x).lower() for c in ['content', 'article', 'post', 'entry', 'main', 'text'])) 118 | 119 | if content_tags: 120 | # Get the tag with most text content 121 | main_tag = max(content_tags, key=lambda x: len(x.get_text().strip())) 122 | 123 | # Extract paragraphs from main content area 124 | paragraphs = [] 125 | for p in main_tag.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']): 126 | text = p.get_text().strip() 127 | if len(text) > 20: # Skip very short paragraphs 128 | paragraphs.append(text) 129 | 130 | main_content = '\n'.join(paragraphs) 131 | 132 | # Fallback to all paragraphs if no main content found 133 | if not main_content: 134 | paragraphs = [] 135 | for p in soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']): 136 | text = p.get_text().strip() 137 | if len(text) > 20: 138 | paragraphs.append(text) 139 | main_content = '\n'.join(paragraphs) 140 | 141 | return main_content.strip() 142 | except Exception as e: 143 | logger.error(f"Error fetching content from {url}: {str(e)}") 144 | return "" 145 | 146 | async def search(self, query: str) -> List[Dict[str, Any]]: 147 | """Perform web search and return relevant content chunks""" 148 | try: 149 | # Get more initial results to account for failures 150 | raw_limit = SEARCH_RESULT_LIMIT * SEARCH_RESULT_MULTIPLIER 151 | 152 | # Configure proxy for Google search 153 | proxy_url = get_proxy_url() if PROXY_ENABLED else None 154 | 155 | # Get search results using googlesearch 156 | search_urls = [] 157 | try: 158 | results = search( 159 | query, 160 | num_results=raw_limit, 161 | proxy=proxy_url, 162 | ssl_verify=False # Skip SSL verification for proxy 163 | ) 164 | search_urls = list(results) # Convert generator to list 165 | logger.info(f"Found {len(search_urls)} results from Google search") 166 | except Exception as e: 167 | logger.error(f"Google search error: {str(e)}") 168 | return [] 169 | 170 | if not search_urls: 171 | return [] 172 | 173 | # Deduplicate URLs while keeping order 174 | seen_urls = set() 175 | unique_results = [] 176 | for url in search_urls: 177 | if not url or url in seen_urls: 178 | continue 179 | try: 180 | parsed = urlparse(url) 181 | if not all([parsed.scheme, parsed.netloc]): 182 | continue 183 | if any(parsed.path.lower().endswith(ext) for ext in [ 184 | '.pdf', '.doc', '.docx', '.ppt', '.pptx', '.xls', '.xlsx', 185 | '.zip', '.rar', '.7z', '.tar', '.gz', '.mp3', '.mp4', '.avi' 186 | ]): 187 | continue 188 | seen_urls.add(url) 189 | unique_results.append({ 190 | "url": url, 191 | "title": url.split("/")[-1] or parsed.netloc, # Simple title extraction 192 | }) 193 | except: 194 | continue 195 | 196 | logger.info(f"Found {len(unique_results)} unique URLs to process") 197 | 198 | # Process URLs in batches until we have enough valid results 199 | valid_chunks = [] 200 | batch_size = SEARCH_RESULT_LIMIT 201 | current_index = 0 202 | 203 | while current_index < len(unique_results) and len(valid_chunks) < SEARCH_RESULT_LIMIT: 204 | batch = unique_results[current_index:current_index + batch_size] 205 | current_index += batch_size 206 | 207 | # Fetch content from URLs in parallel 208 | proxies = await self._get_proxy_config() 209 | async with httpx.AsyncClient(proxy=proxies, timeout=30.0, verify=False) as client: 210 | tasks = [] 211 | for result in batch: 212 | if "url" in result: 213 | tasks.append(self._fetch_page_content(result["url"])) 214 | 215 | # Wait for all tasks with timeout 216 | try: 217 | contents = await asyncio.gather(*tasks, return_exceptions=True) 218 | 219 | # Process successful results 220 | for result, content in zip(batch, contents): 221 | if isinstance(content, Exception): 222 | logger.error(f"Failed to fetch {result.get('url')}: {str(content)}") 223 | continue 224 | 225 | if not content: 226 | continue 227 | 228 | # Split content into chunks 229 | chunks = self.text_splitter.split_text(content) 230 | logger.info(f"Split content from {result.get('url')} into {len(chunks)} chunks") 231 | 232 | # Create documents with metadata 233 | for chunk in chunks: 234 | chunk = chunk.strip() 235 | if not chunk or len(chunk) < 50: # Skip very short chunks 236 | continue 237 | doc = Document( 238 | page_content=chunk, 239 | metadata={ 240 | "url": result.get("url", ""), 241 | "title": result.get("title", ""), 242 | "source": "web_search" 243 | } 244 | ) 245 | valid_chunks.append(doc) 246 | 247 | # Break if we have enough chunks 248 | if len(valid_chunks) >= SEARCH_RESULT_LIMIT * 2: # Get extra for diversity 249 | break 250 | 251 | if len(valid_chunks) >= SEARCH_RESULT_LIMIT * 2: 252 | break 253 | 254 | except asyncio.TimeoutError: 255 | logger.warning(f"Timeout processing batch starting at index {current_index}") 256 | continue 257 | except Exception as e: 258 | logger.error(f"Error processing batch: {str(e)}") 259 | continue 260 | 261 | if not valid_chunks: 262 | return [] 263 | 264 | logger.info(f"Total valid chunks collected: {len(valid_chunks)}") 265 | 266 | # Create new FAISS store for similarity search 267 | vector_store = FAISS.from_documents(valid_chunks, self.embeddings) 268 | 269 | # Search for similar chunks using MMR for diversity 270 | similar_chunks = vector_store.max_marginal_relevance_search( 271 | query, 272 | k=min(SEARCH_RESULT_LIMIT, len(valid_chunks)), 273 | fetch_k=min(SEARCH_RESULT_LIMIT * 2, len(valid_chunks)), 274 | lambda_mult=0.7 275 | ) 276 | 277 | # Return relevant chunks with metadata 278 | relevant_chunks = [] 279 | logger.info("\nMatched chunks content:") 280 | logger.info("="*80) 281 | 282 | # Get embeddings for scoring 283 | query_embedding = self.embeddings.embed_query(query) 284 | chunk_embeddings = self.embeddings.embed_documents([doc.page_content for doc in similar_chunks]) 285 | 286 | from numpy import dot 287 | from numpy.linalg import norm 288 | 289 | for i, (doc, chunk_embedding) in enumerate(zip(similar_chunks, chunk_embeddings), 1): 290 | # Calculate cosine similarity 291 | similarity = dot(query_embedding, chunk_embedding)/(norm(query_embedding)*norm(chunk_embedding)) 292 | 293 | logger.info(f"\nChunk [{i}] (similarity: {similarity:.3f})") 294 | logger.info(f"Source: {doc.metadata['title']} ({doc.metadata['url']})") 295 | logger.info(f"Content:\n{doc.page_content}") 296 | logger.info("-"*80) 297 | 298 | relevant_chunks.append({ 299 | "content": doc.page_content, 300 | "url": doc.metadata["url"], 301 | "title": doc.metadata["title"], 302 | "similarity": float(similarity) 303 | }) 304 | 305 | logger.info(f"\nSelected {len(relevant_chunks)} most relevant chunks") 306 | return relevant_chunks 307 | 308 | except Exception as e: 309 | logger.error(f"Search error: {str(e)}") 310 | return [] -------------------------------------------------------------------------------- /app/services/message_service.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import datetime 3 | import httpx 4 | from typing import List, Tuple, Dict 5 | from fastapi import HTTPException 6 | import asyncio 7 | import time 8 | from cachetools import TTLCache 9 | 10 | from app.models.chat import Message, Role, SearchAnalysis, ContentType 11 | from app.config.settings import ( 12 | OPENAI_API_URL, OPENAI_API_KEY, 13 | API_TIMEOUT, PROGRESS_MESSAGES, 14 | ENABLE_PROGRESS_MESSAGES, ENABLE_IMAGE_ANALYSIS, 15 | ENABLE_WEB_SEARCH, ENABLE_DOCUMENT_ANALYSIS, 16 | VECTOR_CACHE_TTL, OPENAI_ENHANCE_MODEL 17 | ) 18 | from app.services.search_service import SearchService 19 | from app.services.image_service import process_image 20 | from app.services.document_service import DocumentService 21 | from app.utils.file_utils import detect_file_type 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | # Initialize services 26 | document_service = DocumentService() 27 | search_service = SearchService() 28 | 29 | # Global cache for processed files 30 | # Cache structure: {file_url: {'type': 'image'|'document', 'content': str, 'timestamp': float}} 31 | file_cache = TTLCache(maxsize=1000, ttl=VECTOR_CACHE_TTL) 32 | 33 | def get_headers(api_key: str) -> dict: 34 | return { 35 | "Content-Type": "application/json", 36 | "Authorization": f"Bearer {api_key}" 37 | } 38 | 39 | async def analyze_search_need(query: str) -> SearchAnalysis: 40 | """Analyze if the query needs search context""" 41 | if not OPENAI_API_KEY: 42 | raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured") 43 | 44 | async with httpx.AsyncClient() as client: 45 | try: 46 | response = await client.post( 47 | OPENAI_API_URL, 48 | json={ 49 | "model": OPENAI_ENHANCE_MODEL, 50 | "messages": [{ 51 | "role": "system", 52 | "content": "You are an expert at analyzing whether queries need real-time information. Return true only for queries about current events, real-time info, or those needing factual verification." 53 | }, { 54 | "role": "user", 55 | "content": query 56 | }], 57 | "response_format": { 58 | "type": "json_schema", 59 | "json_schema": { 60 | "name": "search_analysis", 61 | "schema": { 62 | "type": "object", 63 | "properties": { 64 | "needs_search": {"type": "boolean"}, 65 | "search_keywords": { 66 | "type": "array", 67 | "items": {"type": "string"} 68 | } 69 | }, 70 | "required": ["needs_search", "search_keywords"], 71 | "additionalProperties": False 72 | }, 73 | "strict": True 74 | } 75 | } 76 | }, 77 | headers=get_headers(OPENAI_API_KEY), 78 | timeout=API_TIMEOUT 79 | ) 80 | response.raise_for_status() 81 | result = response.json() 82 | 83 | content = result.get("choices", [{}])[0].get("message", {}).get("content", {}) 84 | if isinstance(content, str): 85 | import json 86 | try: 87 | content = json.loads(content) 88 | except json.JSONDecodeError: 89 | logger.error(f"Failed to parse response content as JSON: {content}") 90 | content = {"needs_search": False, "search_keywords": []} 91 | 92 | analysis = SearchAnalysis.model_validate(content) 93 | 94 | logger.info(f"Search analysis for query '{query}':") 95 | logger.info(f"Needs search: {analysis.needs_search}") 96 | if analysis.search_keywords: 97 | logger.info(f"Search keywords: {analysis.search_keywords}") 98 | 99 | return analysis 100 | except Exception as e: 101 | logger.error(f"Search analysis error: {str(e)}") 102 | raise HTTPException(status_code=500, detail=f"Search analysis failed: {str(e)}") 103 | 104 | async def process_messages(messages: List[Message], stream: bool = False, progress_queue: asyncio.Queue = None) -> Tuple[List[Message], List[str]]: 105 | """Process messages and handle any images or documents in them""" 106 | 107 | async def send_progress(message_type: str): 108 | if ENABLE_PROGRESS_MESSAGES and stream and progress_queue: 109 | await progress_queue.put(PROGRESS_MESSAGES[message_type]) 110 | 111 | processed_messages = [] 112 | system_content_parts = [] 113 | current_search_parts = [] 114 | current_search_urls = [] 115 | 116 | current_date = datetime.datetime.now().strftime("%Y-%m-%d") 117 | system_content_parts.append(f"Current date: {current_date}") 118 | 119 | # Track processed files in current request to avoid duplicates 120 | request_processed_files = set() 121 | 122 | for message in messages: 123 | if isinstance(message.content, list): 124 | text_parts = [] 125 | query_text = "" 126 | 127 | for content in message.content: 128 | if content.type == ContentType.TEXT: 129 | query_text = content.text 130 | text_parts.append(query_text) 131 | # Truncate long text in logs 132 | log_text = query_text[:50] + "..." if len(query_text) > 50 else query_text 133 | logger.info(f"Added text content: {log_text}") 134 | elif content.type == ContentType.IMAGE_URL and content.image_url: 135 | file_url = content.image_url.url 136 | 137 | # Skip if already processed in this request 138 | if file_url in request_processed_files: 139 | continue 140 | request_processed_files.add(file_url) 141 | 142 | # For base64 URLs, show truncated version in logs 143 | log_url = file_url 144 | if file_url.startswith('data:'): 145 | log_url = file_url[:30] + "..." + file_url[-10:] 146 | elif len(file_url) > 100: 147 | log_url = file_url[:50] + "..." + file_url[-50:] 148 | 149 | # Check cache first 150 | cache_hit = file_cache.get(file_url) 151 | if cache_hit: 152 | logger.info(f"Cache hit for file: {log_url}") 153 | if cache_hit['type'] == 'image': 154 | text_parts.append(f"[Image Content: {cache_hit['content']}]") 155 | logger.info("Added cached image content") 156 | elif cache_hit['type'] == 'document' and query_text: 157 | # For documents, we still need to search with the current query 158 | logger.info("Using cached document for search") 159 | await send_progress("document_search") 160 | relevant_chunks = await document_service.search_similar(query_text, source_url=file_url) 161 | if relevant_chunks: 162 | chunks_text = "\n\n".join( 163 | f"Relevant section {i+1} from {file_url}:\n{chunk.page_content}" 164 | for i, chunk in enumerate(relevant_chunks) 165 | ) 166 | text_parts.append(f"[Document Content:\n{chunks_text}]") 167 | logger.info(f"Added {len(relevant_chunks)} relevant document sections from cache") 168 | else: 169 | text_parts.append(f"[No relevant content found in cached document: {log_url}]") 170 | continue 171 | 172 | logger.info(f"\nProcessing file URL: {log_url}") 173 | 174 | # Check if it's base64 data 175 | file_type, is_supported = detect_file_type(file_url, "") 176 | 177 | # Skip processing based on feature flags 178 | if file_type == 'image': 179 | if not ENABLE_IMAGE_ANALYSIS: 180 | continue 181 | await send_progress("image_analysis") 182 | else: # document type 183 | if not ENABLE_DOCUMENT_ANALYSIS: 184 | continue 185 | await send_progress("document_analysis") 186 | 187 | try: 188 | if file_type == 'image' and is_supported and ENABLE_IMAGE_ANALYSIS: 189 | logger.info("Processing as IMAGE") 190 | image_content = await process_image(file_url) 191 | if image_content: 192 | # Cache the result 193 | file_cache[file_url] = { 194 | 'type': 'image', 195 | 'content': image_content, 196 | 'timestamp': time.time() 197 | } 198 | text_parts.append(f"[Image Content: {image_content}]") 199 | logger.info("Image processing completed and cached") 200 | else: 201 | if ENABLE_DOCUMENT_ANALYSIS: 202 | logger.info("Processing as DOCUMENT") 203 | documents = await document_service.process_document(file_url) 204 | if documents: 205 | # Cache the document processing status 206 | file_cache[file_url] = { 207 | 'type': 'document', 208 | 'content': 'processed', 209 | 'timestamp': time.time() 210 | } 211 | if query_text: 212 | logger.info("Searching document for relevant content...") 213 | await send_progress("document_search") 214 | relevant_chunks = await document_service.search_similar(query_text, source_url=file_url) 215 | if relevant_chunks: 216 | chunks_text = "\n\n".join( 217 | f"Relevant section {i+1} from {file_url}:\n{chunk.page_content}" 218 | for i, chunk in enumerate(relevant_chunks) 219 | ) 220 | text_parts.append(f"[Document Content:\n{chunks_text}]") 221 | logger.info(f"Added {len(relevant_chunks)} relevant document sections") 222 | else: 223 | text_parts.append(f"[No relevant content found in document: {file_url}]") 224 | except Exception as e: 225 | logger.error(f"Error processing file {file_url}: {str(e)}") 226 | text_parts.append(f"[Error processing file {file_url}: {str(e)}]") 227 | 228 | # Only perform web search if this is the last user message 229 | if query_text and message == messages[-1] and message.role == Role.USER: 230 | await send_progress("web_search") 231 | 232 | if ENABLE_WEB_SEARCH: 233 | search_analysis = await analyze_search_need(query_text) 234 | 235 | if search_analysis.needs_search and search_analysis.search_keywords: 236 | logger.info("Performing web search...") 237 | search_results = await search_service.search(query_text) 238 | 239 | search_context_parts = [] 240 | current_search_urls = [] 241 | 242 | for idx, result in enumerate(search_results, 1): 243 | url = result["url"] 244 | title = result["title"] 245 | content = result["content"] 246 | similarity = result["similarity"] 247 | 248 | current_search_urls.append(f"> [{idx}] {url}") 249 | search_context_parts.append( 250 | f"Source [{idx}] (Similarity: {similarity:.2f}): {title}\n{content}" 251 | ) 252 | 253 | if search_context_parts: 254 | search_context = "\n\n".join(search_context_parts) 255 | current_search_parts = [ 256 | "When using information from the search results, cite the sources using [n] format where n is the source number.", 257 | f"Search Results:\n{search_context}", 258 | f"Include these references at the end of your response:\n{chr(10).join(current_search_urls)}" 259 | ] 260 | logger.info(f"Added {len(search_context_parts)} search results to context") 261 | else: 262 | logger.info("Web search is disabled") 263 | 264 | if text_parts: 265 | processed_messages.append(Message( 266 | role=message.role, 267 | content=" ".join(text_parts) 268 | )) 269 | else: 270 | # Handle string content (direct text input) 271 | query_text = message.content 272 | processed_messages.append(message) 273 | 274 | # Only perform web search if this is the last user message and it's a text query 275 | if query_text and message == messages[-1] and message.role == Role.USER: 276 | await send_progress("web_search") 277 | 278 | if ENABLE_WEB_SEARCH: 279 | search_analysis = await analyze_search_need(query_text) 280 | 281 | if search_analysis.needs_search and search_analysis.search_keywords: 282 | logger.info("Performing web search...") 283 | search_results = await search_service.search(query_text) 284 | 285 | search_context_parts = [] 286 | current_search_urls = [] 287 | 288 | for idx, result in enumerate(search_results, 1): 289 | url = result["url"] 290 | title = result["title"] 291 | content = result["content"] 292 | similarity = result["similarity"] 293 | 294 | current_search_urls.append(f"> [{idx}] {url}") 295 | search_context_parts.append( 296 | f"Source [{idx}] (Similarity: {similarity:.2f}): {title}\n{content}" 297 | ) 298 | 299 | if search_context_parts: 300 | search_context = "\n\n".join(search_context_parts) 301 | current_search_parts = [ 302 | "When using information from the search results, cite the sources using [n] format where n is the source number.", 303 | f"Search Results:\n{search_context}", 304 | f"Include these references at the end of your response:\n{chr(10).join(current_search_urls)}" 305 | ] 306 | logger.info(f"Added {len(search_context_parts)} search results to context") 307 | else: 308 | logger.info("Web search is disabled") 309 | 310 | final_system_parts = system_content_parts + current_search_parts 311 | 312 | if final_system_parts: 313 | final_system_content = "\n\n".join(final_system_parts) 314 | processed_messages.insert(0, Message( 315 | role=Role.SYSTEM, 316 | content=final_system_content 317 | )) 318 | 319 | return processed_messages, current_search_urls -------------------------------------------------------------------------------- /app/services/document_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import tempfile 4 | import httpx 5 | import ftfy 6 | import json 7 | import time 8 | import hashlib 9 | from pathlib import Path 10 | from typing import List, Optional, Tuple 11 | import faiss 12 | from langchain_community.document_loaders import ( 13 | BSHTMLLoader, 14 | CSVLoader, 15 | Docx2txtLoader, 16 | OutlookMessageLoader, 17 | PyPDFLoader, 18 | TextLoader, 19 | UnstructuredEPubLoader, 20 | UnstructuredExcelLoader, 21 | UnstructuredMarkdownLoader, 22 | UnstructuredPowerPointLoader, 23 | UnstructuredRSTLoader, 24 | UnstructuredXMLLoader 25 | ) 26 | from langchain_text_splitters import RecursiveCharacterTextSplitter 27 | from langchain_openai import OpenAIEmbeddings 28 | from langchain_core.vectorstores import VectorStore 29 | from langchain_community.vectorstores import FAISS 30 | from langchain_postgres import PGVector 31 | from langchain_pinecone import PineconeVectorStore 32 | from pinecone import Pinecone 33 | from langchain_core.documents import Document as LangchainDocument 34 | from fastapi import HTTPException 35 | import asyncio 36 | 37 | from app.models.chat import Document 38 | from app.config.settings import ( 39 | EMBEDDING_BASE_URL, EMBEDDING_API_KEY, 40 | EMBEDDING_MODEL, EMBEDDING_DIMENSIONS, 41 | CHUNK_SIZE, CHUNK_OVERLAP, MAX_CHUNKS_PER_DOC, 42 | EMBEDDING_BATCH_SIZE, MIN_CHUNK_LENGTH, 43 | MAX_FILE_SIZE, VECTOR_CACHE_DIR, VECTOR_CACHE_TTL, 44 | VECTOR_STORE_TYPE, POSTGRES_CONNECTION_STRING, 45 | POSTGRES_COLLECTION_NAME, PINECONE_API_KEY, 46 | PINECONE_INDEX_NAME 47 | ) 48 | 49 | logger = logging.getLogger(__name__) 50 | 51 | # Suppress FAISS GPU warning since we're explicitly using CPU 52 | faiss.get_num_gpus = lambda: 0 53 | 54 | # Known source file extensions 55 | KNOWN_SOURCE_EXT = [ 56 | "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", "ts", "css", 57 | "cpp", "hpp", "h", "c", "cs", "sql", "log", "ini", "pl", "pm", 58 | "r", "dart", "dockerfile", "env", "php", "hs", "hsc", "lua", 59 | "nginxconf", "conf", "m", "mm", "plsql", "perl", "rb", "rs", 60 | "db2", "scala", "bash", "swift", "vue", "svelte", "msg", "ex", 61 | "exs", "erl", "tsx", "jsx", "hs", "lhs" 62 | ] 63 | 64 | class DocumentService: 65 | def __init__(self): 66 | self.text_splitter = RecursiveCharacterTextSplitter( 67 | chunk_size=CHUNK_SIZE, 68 | chunk_overlap=CHUNK_OVERLAP, 69 | add_start_index=True, 70 | separators=["\n\n", "\n", ". ", " ", ""], 71 | length_function=len, 72 | is_separator_regex=False 73 | ) 74 | self.embeddings = OpenAIEmbeddings( 75 | model=EMBEDDING_MODEL, 76 | openai_api_key=EMBEDDING_API_KEY, 77 | openai_api_base=EMBEDDING_BASE_URL, 78 | dimensions=EMBEDDING_DIMENSIONS, 79 | request_timeout=60.0, 80 | show_progress_bar=True, 81 | retry_min_seconds=1, 82 | retry_max_seconds=60, 83 | max_retries=3, 84 | skip_empty=True, 85 | ) 86 | self.vector_store: Optional[VectorStore] = None 87 | 88 | # Initialize vector store based on configuration 89 | logger.info(f"Initializing vector store with type: {VECTOR_STORE_TYPE}") 90 | 91 | if VECTOR_STORE_TYPE == "postgres": 92 | if not POSTGRES_CONNECTION_STRING: 93 | logger.error("PostgreSQL connection string is not configured") 94 | logger.warning("Falling back to FAISS vector store") 95 | self._init_faiss_store() 96 | else: 97 | try: 98 | logger.info(f"Attempting to connect to PostgreSQL with collection: {POSTGRES_COLLECTION_NAME}") 99 | self.vector_store = PGVector( 100 | connection=POSTGRES_CONNECTION_STRING, 101 | collection_name=POSTGRES_COLLECTION_NAME, 102 | embeddings=self.embeddings, 103 | ) 104 | logger.info(f"Successfully initialized PostgreSQL vector store with collection: {POSTGRES_COLLECTION_NAME}") 105 | except Exception as e: 106 | logger.error(f"Failed to initialize PostgreSQL vector store: {str(e)}") 107 | logger.warning("Falling back to FAISS vector store") 108 | self._init_faiss_store() 109 | elif VECTOR_STORE_TYPE == "pinecone": 110 | try: 111 | pc = Pinecone(api_key=PINECONE_API_KEY) 112 | index = pc.Index(PINECONE_INDEX_NAME) 113 | self.vector_store = PineconeVectorStore( 114 | embedding=self.embeddings, 115 | index=index, 116 | namespace="default" 117 | ) 118 | logger.info(f"Initialized Pinecone vector store with index: {PINECONE_INDEX_NAME}") 119 | except Exception as e: 120 | logger.error(f"Failed to initialize Pinecone vector store: {str(e)}") 121 | logger.warning("Falling back to FAISS vector store") 122 | self._init_faiss_store() 123 | else: 124 | logger.info("Using FAISS vector store as configured") 125 | self._init_faiss_store() 126 | 127 | def _init_faiss_store(self): 128 | """Initialize FAISS vector store and ensure cache directory exists""" 129 | if VECTOR_STORE_TYPE == "faiss": 130 | os.makedirs(VECTOR_CACHE_DIR, exist_ok=True) 131 | logger.info("Initialized FAISS vector store (CPU mode)") 132 | 133 | def _create_vector_store(self, documents: List[LangchainDocument]) -> VectorStore: 134 | """Create a new vector store instance""" 135 | try: 136 | if VECTOR_STORE_TYPE == "postgres" and POSTGRES_CONNECTION_STRING: 137 | store = PGVector.from_documents( 138 | documents=documents, 139 | embedding=self.embeddings, 140 | collection_name=POSTGRES_COLLECTION_NAME, 141 | connection=POSTGRES_CONNECTION_STRING, 142 | ) 143 | logger.info("Created new PostgreSQL vector store") 144 | return store 145 | elif VECTOR_STORE_TYPE == "pinecone": 146 | pc = Pinecone(api_key=PINECONE_API_KEY) 147 | index = pc.Index(PINECONE_INDEX_NAME) 148 | store = PineconeVectorStore.from_documents( 149 | documents=documents, 150 | embedding=self.embeddings, 151 | index=index, 152 | namespace="default" 153 | ) 154 | logger.info("Created new Pinecone vector store") 155 | return store 156 | except Exception as e: 157 | logger.error(f"Failed to create vector store: {str(e)}") 158 | logger.warning("Falling back to FAISS vector store") 159 | 160 | # Default or fallback to FAISS 161 | store = FAISS.from_documents( 162 | documents, 163 | self.embeddings, 164 | distance_strategy="COSINE" 165 | ) 166 | logger.info("Created new FAISS vector store") 167 | return store 168 | 169 | def _add_to_vector_store(self, store: VectorStore, documents: List[LangchainDocument]): 170 | """Add documents to existing vector store""" 171 | try: 172 | if isinstance(store, PGVector): 173 | store.add_documents(documents) 174 | logger.info(f"Added {len(documents)} documents to PostgreSQL vector store") 175 | elif isinstance(store, PineconeVectorStore): 176 | store.add_documents(documents) 177 | logger.info(f"Added {len(documents)} documents to Pinecone vector store") 178 | elif isinstance(store, FAISS): 179 | store.add_documents(documents) 180 | logger.info(f"Added {len(documents)} documents to FAISS vector store") 181 | else: 182 | raise ValueError(f"Unsupported vector store type: {type(store)}") 183 | except Exception as e: 184 | logger.error(f"Failed to add documents to vector store: {str(e)}") 185 | raise 186 | 187 | def _get_cache_key(self, filename: str, file_size: int) -> str: 188 | """Generate cache key from filename and size""" 189 | return hashlib.md5(f"{filename}_{file_size}".encode()).hexdigest() 190 | 191 | def _get_cache_path(self, cache_key: str) -> Tuple[Path, Path]: 192 | """Get cache file paths for documents and vectors (FAISS only)""" 193 | docs_path = Path(VECTOR_CACHE_DIR) / f"{cache_key}_docs.json" 194 | vectors_path = Path(VECTOR_CACHE_DIR) / f"{cache_key}_vectors.faiss" 195 | return docs_path, vectors_path 196 | 197 | def _save_to_cache(self, cache_key: str, documents: List[Document], vector_store: VectorStore): 198 | """Save documents and vector store to cache (FAISS only)""" 199 | if not isinstance(vector_store, FAISS): 200 | return 201 | 202 | try: 203 | docs_path, vectors_path = self._get_cache_path(cache_key) 204 | 205 | # Save documents and timestamp 206 | docs_data = { 207 | "timestamp": time.time(), 208 | "documents": [doc.model_dump() for doc in documents] 209 | } 210 | with open(docs_path, 'w', encoding='utf-8') as f: 211 | json.dump(docs_data, f) 212 | logger.info(f"Saved document cache to {docs_path}") 213 | 214 | # Save FAISS index 215 | vector_store.save_local(str(vectors_path)) 216 | logger.info(f"Saved FAISS vectors to {vectors_path}") 217 | 218 | except Exception as e: 219 | logger.error(f"Failed to save cache: {str(e)}") 220 | 221 | def _load_from_cache(self, cache_key: str) -> Tuple[Optional[List[Document]], Optional[VectorStore]]: 222 | """Load documents and vector store from cache if valid (FAISS only)""" 223 | try: 224 | docs_path, vectors_path = self._get_cache_path(cache_key) 225 | 226 | # Check if both cache files exist 227 | if not docs_path.exists() or not vectors_path.exists(): 228 | return None, None 229 | 230 | # Load and validate documents cache 231 | try: 232 | with open(docs_path, 'r', encoding='utf-8') as f: 233 | docs_data = json.load(f) 234 | except (PermissionError, json.JSONDecodeError) as e: 235 | logger.error(f"Failed to read document cache: {str(e)}") 236 | return None, None 237 | 238 | # Check if cache is expired 239 | if time.time() - docs_data["timestamp"] > VECTOR_CACHE_TTL: 240 | logger.info("Cache expired, will reprocess document") 241 | # Clean up expired cache files 242 | try: 243 | if docs_path.exists(): 244 | os.unlink(docs_path) 245 | if vectors_path.exists(): 246 | os.unlink(vectors_path) 247 | except PermissionError as e: 248 | logger.warning(f"Failed to clean up expired cache: {str(e)}") 249 | return None, None 250 | 251 | # Restore documents 252 | try: 253 | documents = [Document.model_validate(doc) for doc in docs_data["documents"]] 254 | except Exception as e: 255 | logger.error(f"Failed to validate documents: {str(e)}") 256 | return None, None 257 | 258 | # Restore FAISS vector store 259 | try: 260 | vector_store = FAISS.load_local( 261 | str(vectors_path), 262 | self.embeddings, 263 | allow_dangerous_deserialization=True # Allow since we control the cache 264 | ) 265 | logger.info(f"Loaded FAISS vectors from {vectors_path}") 266 | return documents, vector_store 267 | except Exception as e: 268 | logger.error(f"Failed to load vector store: {str(e)}") 269 | # Clean up invalid cache 270 | try: 271 | if docs_path.exists(): 272 | os.unlink(docs_path) 273 | if vectors_path.exists(): 274 | os.unlink(vectors_path) 275 | except PermissionError as e: 276 | logger.warning(f"Failed to clean up invalid cache: {str(e)}") 277 | return None, None 278 | 279 | except Exception as e: 280 | logger.error(f"Failed to load cache: {str(e)}") 281 | return None, None 282 | 283 | async def download_file(self, url: str) -> tuple[str, str, str]: 284 | """Download file from URL and return file path, name and content type""" 285 | async with httpx.AsyncClient() as client: 286 | try: 287 | # First do a HEAD request to check content length 288 | head_response = await client.head(url, follow_redirects=True) 289 | head_response.raise_for_status() 290 | 291 | # Check content length if available 292 | content_length = head_response.headers.get("content-length") 293 | if content_length: 294 | file_size = int(content_length) 295 | if file_size > MAX_FILE_SIZE: 296 | raise HTTPException( 297 | status_code=413, 298 | detail=f"File size ({file_size} bytes) exceeds maximum allowed size ({MAX_FILE_SIZE} bytes)" 299 | ) 300 | 301 | # Proceed with download using streaming to enforce size limit 302 | async with client.stream("GET", url, follow_redirects=True) as response: 303 | response.raise_for_status() 304 | 305 | # Get filename from URL or Content-Disposition 306 | content_disposition = response.headers.get("content-disposition") 307 | if content_disposition and "filename=" in content_disposition: 308 | filename = content_disposition.split("filename=")[-1].strip('"') 309 | else: 310 | filename = url.split("/")[-1] 311 | 312 | content_type = response.headers.get("content-type", "") 313 | 314 | # Create temporary file with proper extension 315 | ext = os.path.splitext(filename)[1] 316 | with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: 317 | total_size = 0 318 | async for chunk in response.aiter_bytes(chunk_size=8192): 319 | total_size += len(chunk) 320 | if total_size > MAX_FILE_SIZE: 321 | # Clean up and raise error 322 | temp_file.close() 323 | os.unlink(temp_file.name) 324 | raise HTTPException( 325 | status_code=413, 326 | detail=f"File size exceeds maximum allowed size ({MAX_FILE_SIZE} bytes)" 327 | ) 328 | temp_file.write(chunk) 329 | 330 | return temp_file.name, filename, content_type 331 | except HTTPException: 332 | raise 333 | except Exception as e: 334 | logger.error(f"Error downloading file from {url}: {str(e)}") 335 | raise 336 | 337 | def _get_loader(self, filename: str, content_type: str, file_path: str): 338 | """Get appropriate document loader based on file type""" 339 | file_ext = filename.split(".")[-1].lower() if "." in filename else "" 340 | 341 | if file_ext == "pdf": 342 | return PyPDFLoader(file_path) 343 | elif file_ext == "csv": 344 | return CSVLoader(file_path) 345 | elif file_ext == "rst": 346 | return UnstructuredRSTLoader(file_path, mode="elements") 347 | elif file_ext == "xml": 348 | return UnstructuredXMLLoader(file_path) 349 | elif file_ext in ["htm", "html"]: 350 | return BSHTMLLoader(file_path, open_encoding="unicode_escape") 351 | elif file_ext == "md": 352 | return UnstructuredMarkdownLoader(file_path) 353 | elif content_type == "application/epub+zip": 354 | return UnstructuredEPubLoader(file_path) 355 | elif content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" or file_ext == "docx": 356 | return Docx2txtLoader(file_path) 357 | elif content_type in ["application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"] or file_ext in ["xls", "xlsx"]: 358 | return UnstructuredExcelLoader(file_path) 359 | elif content_type in ["application/vnd.ms-powerpoint", "application/vnd.openxmlformats-officedocument.presentationml.presentation"] or file_ext in ["ppt", "pptx"]: 360 | return UnstructuredPowerPointLoader(file_path) 361 | elif file_ext == "msg": 362 | return OutlookMessageLoader(file_path) 363 | elif file_ext in KNOWN_SOURCE_EXT or (content_type and content_type.find("text/") >= 0): 364 | return TextLoader(file_path, autodetect_encoding=True) 365 | else: 366 | return TextLoader(file_path, autodetect_encoding=True) 367 | 368 | async def _process_chunk_batch(self, batch: List[LangchainDocument]) -> None: 369 | """Process a batch of document chunks asynchronously""" 370 | try: 371 | if not self.vector_store: 372 | logger.info("Initializing new vector store") 373 | self.vector_store = self._create_vector_store(batch) 374 | else: 375 | logger.info("Adding documents to existing vector store") 376 | self._add_to_vector_store(self.vector_store, batch) 377 | except Exception as e: 378 | logger.error(f"Failed to process batch: {str(e)}") 379 | raise 380 | 381 | async def process_document(self, url: str) -> List[Document]: 382 | """Process a document from URL and store it in vector store""" 383 | temp_file_path = None 384 | try: 385 | # Download file 386 | temp_file_path, filename, content_type = await self.download_file(url) 387 | file_size = os.path.getsize(temp_file_path) 388 | 389 | # Log file information 390 | logger.info("="*50) 391 | logger.info("File Type Detection:") 392 | logger.info(f"URL: {url}") 393 | logger.info(f"Content-Type: {content_type}") 394 | logger.info(f"Filename: {filename}") 395 | logger.info(f"File Size: {file_size} bytes") 396 | logger.info(f"File Extension: {os.path.splitext(filename)[1]}") 397 | logger.info("="*50) 398 | 399 | # For PostgreSQL, check if document already exists 400 | if VECTOR_STORE_TYPE == "postgres" and isinstance(self.vector_store, PGVector): 401 | try: 402 | # Query by source URL and file size 403 | existing_docs = self.vector_store.similarity_search( 404 | "", 405 | k=1000, 406 | filter={ 407 | "source": url, 408 | "file_size": file_size 409 | } 410 | ) 411 | if existing_docs: 412 | logger.info(f"Found existing document in PostgreSQL: {filename}") 413 | return [ 414 | Document( 415 | page_content=doc.page_content, 416 | metadata=doc.metadata, 417 | id=f"{url}_{i}" 418 | ) for i, doc in enumerate(existing_docs) 419 | ] 420 | except Exception as e: 421 | logger.error(f"Failed to query PostgreSQL: {str(e)}") 422 | # For FAISS, check file cache 423 | elif VECTOR_STORE_TYPE == "faiss": 424 | cache_key = self._get_cache_key(filename, file_size) 425 | cached_docs, cached_store = self._load_from_cache(cache_key) 426 | if cached_docs and cached_store: 427 | logger.info("Using cached vectors from FAISS") 428 | self.vector_store = cached_store 429 | return cached_docs 430 | 431 | logger.info(f"Processing document: {filename} ({file_size} bytes)") 432 | 433 | # Get appropriate loader 434 | loader = self._get_loader(filename, content_type, temp_file_path) 435 | logger.info(f"Using loader: {loader.__class__.__name__}") 436 | 437 | # Load document 438 | docs = loader.load() 439 | logger.info(f"Loaded {len(docs)} document sections") 440 | 441 | # Fix text encoding and create metadata 442 | fixed_docs = [] 443 | for doc in docs: 444 | if isinstance(doc.page_content, (str, bytes)): 445 | content = ftfy.fix_text(str(doc.page_content)) 446 | else: 447 | content = str(doc.page_content) 448 | 449 | fixed_doc = LangchainDocument( 450 | page_content=content, 451 | metadata={ 452 | **doc.metadata, 453 | "source": url, 454 | "filename": filename, 455 | "content_type": content_type, 456 | "file_size": file_size # Add file size to metadata for future lookups 457 | } 458 | ) 459 | fixed_docs.append(fixed_doc) 460 | 461 | # Split into chunks 462 | splits = self.text_splitter.split_documents(fixed_docs) 463 | logger.info(f"Split into {len(splits)} chunks") 464 | 465 | # Convert to our Document model first 466 | documents = [] 467 | for i, split in enumerate(splits): 468 | doc = Document( 469 | page_content=split.page_content, 470 | metadata=split.metadata, 471 | id=f"{url}_{i}" 472 | ) 473 | documents.append(doc) 474 | 475 | # Try to update vector store if needed 476 | try: 477 | # Verify splits have content 478 | valid_splits = [] 479 | for split in splits: 480 | if not isinstance(split.page_content, str): 481 | continue 482 | content = split.page_content.strip() 483 | if not content: 484 | continue 485 | if len(content) < MIN_CHUNK_LENGTH: 486 | continue 487 | valid_splits.append(split) 488 | 489 | logger.info(f"Found {len(valid_splits)} valid chunks for embedding") 490 | logger.info(f"Average chunk size: {sum(len(split.page_content) for split in valid_splits) / len(valid_splits) if valid_splits else 0:.0f} characters") 491 | 492 | # Process valid splits in parallel batches 493 | if valid_splits: 494 | tasks = [] 495 | for i in range(0, len(valid_splits), EMBEDDING_BATCH_SIZE): 496 | batch = valid_splits[i:i + EMBEDDING_BATCH_SIZE] 497 | logger.info(f"Processing batch {i//EMBEDDING_BATCH_SIZE + 1} of {(len(valid_splits) + EMBEDDING_BATCH_SIZE - 1)//EMBEDDING_BATCH_SIZE} ({len(batch)} chunks)") 498 | tasks.append(self._process_chunk_batch(batch)) 499 | 500 | # Process all batches concurrently 501 | await asyncio.gather(*tasks) 502 | logger.info(f"Processed all {len(valid_splits)} chunks") 503 | else: 504 | logger.warning("No valid content found for vector store update") 505 | except Exception as ve: 506 | logger.error(f"Vector store operation failed: {str(ve)}") 507 | # Continue without vector store update 508 | pass 509 | 510 | if documents: 511 | logger.info(f"Document processed into {len(documents)} sections") 512 | # Don't add unnecessary description 513 | last_file_description = None 514 | else: 515 | logger.warning("No valid content found for vector store update") 516 | 517 | # Only save to cache if using FAISS 518 | if VECTOR_STORE_TYPE == "faiss": 519 | cache_key = self._get_cache_key(filename, file_size) 520 | self._save_to_cache(cache_key, documents, self.vector_store) 521 | 522 | return documents 523 | 524 | except Exception as e: 525 | logger.error(f"Error processing document from {url}: {str(e)}") 526 | raise 527 | finally: 528 | # Clean up temporary file 529 | if temp_file_path: 530 | try: 531 | os.unlink(temp_file_path) 532 | except Exception as e: 533 | logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}") 534 | pass 535 | 536 | async def search_similar(self, query: str, source_url: str = None, k: int = None) -> List[Document]: 537 | """Search for similar documents""" 538 | try: 539 | if not self.vector_store: 540 | logger.info("Vector store not initialized") 541 | return [] 542 | 543 | # Use configured max chunks if k is not specified 544 | if k is None: 545 | k = MAX_CHUNKS_PER_DOC 546 | 547 | try: 548 | # Add source URL filter if provided 549 | filter_dict = {"source": source_url} if source_url else None 550 | 551 | results = self.vector_store.similarity_search( 552 | query, 553 | k=k, 554 | filter=filter_dict 555 | ) 556 | except Exception as e: 557 | logger.error(f"Vector search failed: {str(e)}") 558 | return [] 559 | 560 | return [ 561 | Document( 562 | page_content=doc.page_content, 563 | metadata=doc.metadata, 564 | id=f"{doc.metadata.get('source', '')}_{i}" 565 | ) for i, doc in enumerate(results) 566 | ] 567 | except Exception as e: 568 | logger.error(f"Error: {str(e)}") 569 | return [] --------------------------------------------------------------------------------