├── app
├── api
│ ├── __init__.py
│ └── chat.py
├── models
│ ├── __init__.py
│ └── chat.py
├── services
│ ├── __init__.py
│ ├── image_service.py
│ ├── search_service.py
│ ├── message_service.py
│ └── document_service.py
├── __init__.py
├── config
│ ├── __init__.py
│ └── settings.py
├── core
│ ├── __init__.py
│ ├── logging.py
│ └── auth.py
└── utils
│ └── file_utils.py
├── .dockerignore
├── requirements.txt
├── .gitignore
├── main.py
├── Dockerfile
├── .env.example
├── docker-compose.yml
└── README.md
/app/api/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | API routes package
3 | """
--------------------------------------------------------------------------------
/app/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Data models package
3 | """
--------------------------------------------------------------------------------
/app/services/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Services package
3 | """
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | ExtendAI application package
3 | """
--------------------------------------------------------------------------------
/app/config/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration package
3 | """
--------------------------------------------------------------------------------
/app/core/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Core functionality package
3 | """
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .env
2 | .env.*
3 | .git
4 | .gitignore
5 | __pycache__
6 | *.pyc
7 | *.pyo
8 | *.pyd
9 | .Python
10 | env/
11 | venv/
12 | .venv/
13 | pip-log.txt
14 | pip-delete-this-directory.txt
15 | .tox/
16 | .coverage
17 | .coverage.*
18 | .cache
19 | nosetests.xml
20 | coverage.xml
21 | *.cover
22 | *.log
23 | .pytest_cache/
24 | .idea/
25 | .vscode/
26 | *.swp
27 | *.swo
28 | *~
29 | cache/
--------------------------------------------------------------------------------
/app/core/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 | def setup_logging():
5 | """Configure logging with a specific format and level"""
6 | logging.basicConfig(
7 | level=logging.INFO,
8 | format='%(asctime)s - %(levelname)s - %(message)s',
9 | stream=sys.stdout,
10 | force=True
11 | )
12 |
13 | # Disable other loggers that might be noisy
14 | logging.getLogger("httpx").setLevel(logging.WARNING)
15 | logging.getLogger("httpcore").setLevel(logging.WARNING)
16 | logging.getLogger("uvicorn").setLevel(logging.WARNING)
17 | logging.getLogger("asyncio").setLevel(logging.WARNING)
18 |
19 | return logging.getLogger(__name__)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi>=0.109.0
2 | uvicorn>=0.27.0
3 | httpx>=0.26.0
4 | python-dotenv>=1.0.0
5 | langchain>=0.1.0
6 | langchain-community>=0.0.16
7 | langchain-core>=0.1.18
8 | langchain-openai>=0.0.5
9 | langchain-text-splitters>=0.0.1
10 | langchain-pinecone>=0.0.2
11 | langchain-postgres>=0.0.1
12 | faiss-cpu>=1.7.4
13 | pinecone-client>=3.0.0
14 | psycopg[binary]>=3.1.18
15 | psycopg2-binary>=2.9.9
16 | ftfy>=6.1.3
17 | python-multipart>=0.0.6
18 | beautifulsoup4>=4.12.0
19 | docx2txt>=0.8
20 | pypdf>=4.0.0
21 | openpyxl>=3.1.0
22 | python-pptx>=0.6.0
23 | unstructured>=0.11.0
24 | markdown>=3.5.0
25 | docstring-parser>=0.15
26 | tqdm>=4.66.0
27 | googlesearch-python
28 | cachetools
29 | networkx
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # Virtual Environment
24 | venv/
25 | ENV/
26 | env/
27 | .env
28 |
29 | # IDE
30 | .idea/
31 | .vscode/
32 | *.swp
33 | *.swo
34 | .DS_Store
35 |
36 | # Project specific
37 | cache/
38 | *.log
39 | *.db
40 | *.sqlite3
41 |
42 | # Environment variables
43 | .env
44 | .env.local
45 | .env.*.local
46 |
47 | # Testing
48 | .coverage
49 | htmlcov/
50 | .pytest_cache/
51 | .tox/
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | .hypothesis/
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from fastapi.middleware.cors import CORSMiddleware
3 | import os
4 |
5 | from app.api.chat import router as chat_router
6 | from app.core.logging import setup_logging
7 | from app.config.settings import PORT, HOST
8 |
9 | # Setup logging
10 | logger = setup_logging()
11 |
12 | # Create FastAPI app
13 | app = FastAPI()
14 |
15 | # Add CORS middleware
16 | app.add_middleware(
17 | CORSMiddleware,
18 | allow_origins=["*"],
19 | allow_credentials=True,
20 | allow_methods=["*"],
21 | allow_headers=["*"],
22 | )
23 |
24 | # Include routers
25 | app.include_router(chat_router)
26 |
27 | if __name__ == "__main__":
28 | import uvicorn
29 | port = int(os.getenv("PORT", PORT))
30 | host = os.getenv("HOST", HOST)
31 | logger.info(f"Starting server on {host}:{port}")
32 | uvicorn.run(app, host=host, port=port)
--------------------------------------------------------------------------------
/app/core/auth.py:
--------------------------------------------------------------------------------
1 | from fastapi import HTTPException, Request
2 | from app.config.settings import MY_API_KEY
3 |
4 | async def verify_api_key(request: Request) -> str:
5 | """Verify the API key from the request header"""
6 | if not MY_API_KEY:
7 | raise HTTPException(status_code=500, detail="Server API key not configured")
8 |
9 | auth_header = request.headers.get("Authorization")
10 | if not auth_header:
11 | raise HTTPException(status_code=401, detail="Authorization header missing")
12 |
13 | try:
14 | scheme, token = auth_header.split()
15 | if scheme.lower() != "bearer":
16 | raise HTTPException(status_code=401, detail="Invalid authentication scheme")
17 | if token != MY_API_KEY:
18 | raise HTTPException(status_code=401, detail="Invalid API key")
19 | return token
20 | except ValueError:
21 | raise HTTPException(status_code=401, detail="Invalid authorization header format")
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use Python 3.11 as base image
2 | FROM python:3.11-slim
3 |
4 | # Set working directory
5 | WORKDIR /app
6 |
7 | # Set environment variables
8 | ENV PYTHONUNBUFFERED=1 \
9 | PYTHONDONTWRITEBYTECODE=1 \
10 | PIP_NO_CACHE_DIR=1
11 |
12 | # Install system dependencies
13 | RUN apt-get update && apt-get install -y --no-install-recommends \
14 | build-essential \
15 | curl \
16 | git \
17 | libpq-dev \
18 | postgresql-client \
19 | gcc \
20 | python3-dev \
21 | && rm -rf /var/lib/apt/lists/*
22 |
23 | # Create cache directory
24 | RUN mkdir -p /app/cache && chmod 777 /app/cache
25 |
26 | # Install Python dependencies
27 | COPY requirements.txt .
28 | RUN pip install --no-cache-dir -r requirements.txt
29 |
30 | # Copy application code
31 | COPY . .
32 |
33 | # Create a non-root user
34 | RUN useradd -m -u 1000 extendai && \
35 | chown -R extendai:extendai /app
36 | USER extendai
37 |
38 | # Expose port
39 | EXPOSE 8096
40 |
41 | # Set default command
42 | CMD ["python", "main.py"]
--------------------------------------------------------------------------------
/app/models/chat.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, HttpUrl
2 | from typing import List, Optional, Union, Literal, Dict, Any
3 | from enum import Enum
4 |
5 | class Role(str, Enum):
6 | SYSTEM = "system"
7 | USER = "user"
8 | ASSISTANT = "assistant"
9 |
10 | class ContentType(str, Enum):
11 | TEXT = "text"
12 | IMAGE_URL = "image_url"
13 | DOCUMENT = "document"
14 |
15 | class ImageUrl(BaseModel):
16 | url: str
17 |
18 | class Document(BaseModel):
19 | """Document model for storing document content and metadata"""
20 | page_content: str
21 | metadata: Dict[str, Any]
22 | id: Optional[str] = None
23 |
24 | class MessageContent(BaseModel):
25 | type: ContentType
26 | text: Optional[str] = None
27 | image_url: Optional[ImageUrl] = None
28 | document: Optional[Document] = None
29 |
30 | class Message(BaseModel):
31 | role: Role
32 | content: Union[str, List[MessageContent]]
33 |
34 | class ChatRequest(BaseModel):
35 | messages: List[Message]
36 | model: str
37 | stream: bool = False
38 |
39 | class SearchAnalysis(BaseModel):
40 | """Schema for analyzing whether search context is needed"""
41 | needs_search: bool
42 | search_keywords: Optional[List[str]] = None
43 |
44 | class SearchResult(BaseModel):
45 | """Schema for search results"""
46 | url: str
47 | title: str
48 | content: str
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | # Document Processing Settings
2 | CHUNK_SIZE=1000
3 | CHUNK_OVERLAP=200
4 | MAX_CHUNKS_PER_DOC=5
5 | EMBEDDING_BATCH_SIZE=50
6 | MIN_CHUNK_LENGTH=10
7 | MAX_FILE_SIZE=10485760
8 | MAX_IMAGE_SIZE=5242880
9 |
10 | # Vector Store Settings
11 | VECTOR_STORE_TYPE=postgres
12 | POSTGRES_CONNECTION_STRING=postgresql://extendai:extendai@pgvector:5432/extendai
13 | POSTGRES_COLLECTION_NAME=embeddings
14 | # if using Pinecone
15 | PINECONE_API_KEY=your-pinecone-api-key
16 | PINECONE_INDEX_NAME=extendai
17 |
18 | # Cache Settings
19 | VECTOR_CACHE_DIR=cache/vectors
20 | VECTOR_CACHE_TTL=7200
21 |
22 | API_TIMEOUT=600
23 | MY_API_KEY=sk-planetzero-api-key
24 |
25 | # Model API Settings
26 | TARGET_MODEL_BASE_URL=
27 | TARGET_MODEL_API_KEY=
28 | OPENAI_BASE_URL=
29 | OPENAI_API_KEY=
30 | OPENAI_ENHANCE_MODEL=
31 |
32 | # Embedding API Settings
33 | EMBEDDING_BASE_URL=
34 | EMBEDDING_API_KEY=
35 | EMBEDDING_MODEL=text-embedding-3-small
36 | EMBEDDING_DIMENSIONS=1536
37 |
38 | # Search Settings
39 | SEARXNG_URL=
40 | DEFAULT_MODEL=deepseek-r1
41 | SEARCH_ENGINE=google
42 | SEARCH_RESULT_LIMIT=5
43 | SEARCH_RESULT_MULTIPLIER=2
44 | WEB_CONTENT_CHUNK_SIZE=512
45 | WEB_CONTENT_CHUNK_OVERLAP=50
46 | WEB_CONTENT_MAX_CHUNKS=5
47 |
48 | # Proxy Settings (Optional)
49 | PROXY_ENABLED=false
50 | PROXY_HOST=proxy.example.com
51 | PROXY_PORT=8080
52 | PROXY_USERNAME=your-proxy-username
53 | PROXY_PASSWORD=your-proxy-password
54 |
55 | # Feature Switches
56 | ENABLE_PROGRESS_MESSAGES=false
57 | ENABLE_IMAGE_ANALYSIS=true
58 | ENABLE_WEB_SEARCH=true
59 | ENABLE_DOCUMENT_ANALYSIS=true
60 |
61 | # Progress Messages (Customizable)
62 | PROGRESS_MSG_IMAGE="Analyzing image content..."
63 | PROGRESS_MSG_DOC="Analyzing document..."
64 | PROGRESS_MSG_DOC_SEARCH="Searching document..."
65 | PROGRESS_MSG_WEB_SEARCH="Searching web content..."
66 | # Server Settings
67 | PORT=8096
68 | HOST=0.0.0.0
69 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3"
2 | name: extendai
3 |
4 | services:
5 | app:
6 | build: .
7 | ports:
8 | - "8096:8096"
9 | env_file:
10 | - .env
11 | environment:
12 | # Default settings that will be overridden by .env file if present
13 | VECTOR_STORE_TYPE: ${VECTOR_STORE_TYPE:-postgres}
14 | POSTGRES_CONNECTION_STRING: ${POSTGRES_CONNECTION_STRING:-postgresql://extendai:extendai@pgvector:5432/extendai}
15 | POSTGRES_COLLECTION_NAME: ${POSTGRES_COLLECTION_NAME:-embeddings}
16 | ENABLE_IMAGE_ANALYSIS: ${ENABLE_IMAGE_ANALYSIS:-true}
17 | ENABLE_WEB_SEARCH: ${ENABLE_WEB_SEARCH:-true}
18 | ENABLE_DOCUMENT_ANALYSIS: ${ENABLE_DOCUMENT_ANALYSIS:-true}
19 | depends_on:
20 | pgvector:
21 | condition: service_healthy
22 | volumes:
23 | - ./cache:/app/cache
24 | - ./.env:/app/.env:ro # Mount .env file as read-only
25 |
26 | postgres:
27 | image: postgres:16
28 | environment:
29 | POSTGRES_DB: extendai
30 | POSTGRES_USER: extendai
31 | POSTGRES_PASSWORD: extendai
32 | ports:
33 | - "6023:5432"
34 | command: postgres -c log_statement=all
35 | healthcheck:
36 | test: ["CMD-SHELL", "psql postgresql://extendai:extendai@localhost/extendai --command 'SELECT 1;' || exit 1"]
37 | interval: 5s
38 | retries: 60
39 | volumes:
40 | - postgres_data:/var/lib/postgresql/data
41 |
42 | pgvector:
43 | image: ankane/pgvector
44 | environment:
45 | POSTGRES_DB: extendai
46 | POSTGRES_USER: extendai
47 | POSTGRES_PASSWORD: extendai
48 | ports:
49 | - "6024:5432"
50 | command: postgres -c log_statement=all
51 | healthcheck:
52 | test: ["CMD-SHELL", "psql postgresql://extendai:extendai@localhost/extendai --command 'SELECT 1;' || exit 1"]
53 | interval: 5s
54 | retries: 60
55 | volumes:
56 | - postgres_data_pgvector:/var/lib/postgresql/data
57 |
58 | volumes:
59 | postgres_data:
60 | postgres_data_pgvector:
61 | cache:
--------------------------------------------------------------------------------
/app/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import base64
4 | from typing import Tuple
5 | from urllib.parse import urlparse, parse_qs, unquote
6 | from app.config.settings import (
7 | SUPPORTED_IMAGE_FORMATS,
8 | SUPPORTED_DOCUMENT_FORMATS
9 | )
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | # Known source file extensions
14 | KNOWN_SOURCE_EXT = [
15 | "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", "ts", "css",
16 | "cpp", "hpp", "h", "c", "cs", "sql", "log", "ini", "pl", "pm",
17 | "r", "dart", "dockerfile", "env", "php", "hs", "hsc", "lua",
18 | "nginxconf", "conf", "m", "mm", "plsql", "perl", "rb", "rs",
19 | "db2", "scala", "bash", "swift", "vue", "svelte", "msg", "ex",
20 | "exs", "erl", "tsx", "jsx", "hs", "lhs"
21 | ]
22 |
23 | def extract_filename_from_url(url: str) -> str:
24 | """Extract filename from URL using various methods"""
25 | # Try to get filename from URL parameters
26 | parsed_url = urlparse(url)
27 | query_params = parse_qs(parsed_url.query)
28 |
29 | # Check for filename in query parameters
30 | for param in ['filename', 'rscd']:
31 | if param in query_params:
32 | value = query_params[param][0]
33 | if 'filename=' in value:
34 | # Extract filename from rscd parameter
35 | filename = value.split('filename=')[-1].strip()
36 | # Remove any quotes or additional parameters
37 | filename = filename.split(';')[0].strip('"\'')
38 | return unquote(filename)
39 |
40 | # Fallback to path
41 | path = parsed_url.path
42 | if path:
43 | return os.path.basename(path)
44 |
45 | return ""
46 |
47 | def is_base64(s: str) -> bool:
48 | """Check if a string is base64 encoded"""
49 | try:
50 | # Check if string starts with data URI scheme
51 | if s.startswith('data:'):
52 | return True
53 |
54 | # Skip URLs that look like normal web URLs
55 | if s.startswith(('http://', 'https://')):
56 | return False
57 |
58 | # Try to decode the string if it looks like base64
59 | if len(s) % 4 == 0 and not s.startswith(('/', '\\')):
60 | try:
61 | # Check if string contains valid base64 characters
62 | if not all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=' for c in s):
63 | return False
64 |
65 | decoded = base64.b64decode(s)
66 | # Check if decoded data looks like binary data
67 | return len(decoded) > 0 and not all(32 <= byte <= 126 for byte in decoded)
68 | except:
69 | pass
70 | return False
71 | except:
72 | return False
73 |
74 | def detect_file_type(url: str, content_type: str, content_disposition: str = "") -> Tuple[str, bool]:
75 | """
76 | Detect if a file is an image or document based on URL, content type and content disposition
77 | Returns: (file_type, is_supported)
78 | where file_type is either 'image' or 'document'
79 | """
80 | # Handle base64 image data
81 | if is_base64(url):
82 | return 'image', True
83 |
84 | # Try to get filename from Content-Disposition
85 | filename = ""
86 | if content_disposition and "filename=" in content_disposition:
87 | filename = content_disposition.split("filename=")[-1].strip('"\'')
88 | filename = unquote(filename)
89 |
90 | # If no filename in Content-Disposition, try URL
91 | if not filename:
92 | filename = extract_filename_from_url(url)
93 |
94 | # Get file extension
95 | ext = os.path.splitext(filename.lower())[1] if filename else ""
96 |
97 | # Check content type and extension
98 | if (content_type in SUPPORTED_IMAGE_FORMATS or
99 | ext in SUPPORTED_IMAGE_FORMATS):
100 | return 'image', True
101 | elif (content_type in SUPPORTED_DOCUMENT_FORMATS or
102 | ext in SUPPORTED_DOCUMENT_FORMATS):
103 | return 'document', True
104 | elif ext.lstrip('.') in KNOWN_SOURCE_EXT:
105 | return 'document', True
106 |
107 | # If content type starts with image/, treat as image
108 | if content_type.startswith('image/'):
109 | return 'image', False
110 |
111 | # Default to document for all other types
112 | return 'document', False
--------------------------------------------------------------------------------
/app/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 | from enum import Enum
4 |
5 | # Load environment variables from .env
6 | load_dotenv()
7 |
8 | class SearchEngine(Enum):
9 | GOOGLE = "google"
10 | SEARXNG = "searxng"
11 |
12 | # Server Settings
13 | PORT = int(os.getenv("PORT", "8096"))
14 | HOST = os.getenv("HOST", "0.0.0.0")
15 |
16 | # Document processing settings
17 | CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "1000"))
18 | CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "200"))
19 | MAX_CHUNKS_PER_DOC = int(os.getenv("MAX_CHUNKS_PER_DOC", "5"))
20 | EMBEDDING_BATCH_SIZE = int(os.getenv("EMBEDDING_BATCH_SIZE", "50"))
21 | MIN_CHUNK_LENGTH = int(os.getenv("MIN_CHUNK_LENGTH", "10"))
22 |
23 | # Vector store settings
24 | VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "faiss") # "faiss", "postgres", or "pinecone"
25 | POSTGRES_CONNECTION_STRING = os.getenv("POSTGRES_CONNECTION_STRING", "")
26 | POSTGRES_COLLECTION_NAME = os.getenv("POSTGRES_COLLECTION_NAME", "document_vectors")
27 |
28 | # Pinecone settings
29 | PINECONE_API_KEY = os.getenv("PINECONE_API_KEY", "pcsk_xxx")
30 | PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "extendai")
31 | # Cache settings
32 | VECTOR_CACHE_DIR = os.getenv("VECTOR_CACHE_DIR", "cache/vectors")
33 | VECTOR_CACHE_TTL = int(os.getenv("VECTOR_CACHE_TTL", str(2 * 60 * 60))) # 2 hours in seconds
34 |
35 | # File size limits (in bytes)
36 | MAX_FILE_SIZE = int(os.getenv("MAX_FILE_SIZE", str(10 * 1024 * 1024))) # 10MB default
37 | MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", str(10 * 1024 * 1024))) # 10MB default
38 |
39 | # Supported file formats
40 | SUPPORTED_IMAGE_FORMATS = {
41 | # MIME types
42 | "image/jpeg", "image/png", "image/gif", "image/webp", "image/tiff", "image/bmp",
43 | # Extensions
44 | ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"
45 | }
46 |
47 | SUPPORTED_DOCUMENT_FORMATS = {
48 | # MIME types
49 | "application/pdf",
50 | "application/msword",
51 | "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
52 | "application/vnd.ms-excel",
53 | "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
54 | "application/vnd.ms-powerpoint",
55 | "application/vnd.openxmlformats-officedocument.presentationml.presentation",
56 | "text/plain",
57 | "text/csv",
58 | "text/html",
59 | "text/xml",
60 | "text/markdown",
61 | "application/epub+zip",
62 | # Extensions
63 | ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
64 | ".txt", ".csv", ".html", ".htm", ".xml", ".md", ".rst", ".epub"
65 | }
66 |
67 | # API URLs and Keys
68 | TARGET_MODEL_BASE_URL = os.getenv("TARGET_MODEL_BASE_URL", "https://api.openai.com")
69 | TARGET_MODEL_API_URL = f"{TARGET_MODEL_BASE_URL}/v1/chat/completions"
70 | TARGET_MODEL_API_KEY = os.getenv("TARGET_MODEL_API_KEY")
71 | OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com")
72 | OPENAI_API_URL = f"{OPENAI_BASE_URL}/v1/chat/completions"
73 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
74 | OPENAI_ENHANCE_MODEL = os.getenv("OPENAI_ENHANCE_MODEL") # Model for image/document enhancement
75 |
76 | # Embedding Settings
77 | EMBEDDING_BASE_URL = os.getenv("EMBEDDING_BASE_URL", "https://api.openai.com/v1")
78 | EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", OPENAI_API_KEY) # Default to OpenAI key if not set
79 | EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
80 | EMBEDDING_DIMENSIONS = int(os.getenv("EMBEDDING_DIMENSIONS", "1536"))
81 |
82 | MY_API_KEY = os.getenv("MY_API_KEY")
83 |
84 | # Search Settings
85 | SEARXNG_URL = os.getenv("SEARXNG_URL", "https://searxng.example.com")
86 | DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "deepseek-r1")
87 | SEARCH_ENGINE = os.getenv("SEARCH_ENGINE", "google").lower()
88 |
89 | # Search Result Settings
90 | SEARCH_RESULT_LIMIT = int(os.getenv("SEARCH_RESULT_LIMIT", "5")) # Number of results to return
91 | SEARCH_RESULT_MULTIPLIER = int(os.getenv("SEARCH_RESULT_MULTIPLIER", "2")) # Multiplier for raw results to fetch
92 | WEB_CONTENT_CHUNK_SIZE = int(os.getenv("WEB_CONTENT_CHUNK_SIZE", "512")) # Size of each content chunk
93 | WEB_CONTENT_CHUNK_OVERLAP = int(os.getenv("WEB_CONTENT_CHUNK_OVERLAP", "50")) # Overlap between chunks
94 | WEB_CONTENT_MAX_CHUNKS = int(os.getenv("WEB_CONTENT_MAX_CHUNKS", "5")) # Number of most relevant chunks to return
95 |
96 | # Feature Switches
97 | ENABLE_PROGRESS_MESSAGES = os.getenv("ENABLE_PROGRESS_MESSAGES", "true").lower() == "true"
98 | ENABLE_IMAGE_ANALYSIS = os.getenv("ENABLE_IMAGE_ANALYSIS", "true").lower() == "true"
99 | ENABLE_WEB_SEARCH = os.getenv("ENABLE_WEB_SEARCH", "true").lower() == "true"
100 | ENABLE_DOCUMENT_ANALYSIS = os.getenv("ENABLE_DOCUMENT_ANALYSIS", "true").lower() == "true"
101 |
102 | # Progress Messages
103 | PROGRESS_MESSAGES = {
104 | "image_analysis": os.getenv("PROGRESS_MSG_IMAGE", "Analyzing image content..."),
105 | "document_analysis": os.getenv("PROGRESS_MSG_DOC", "Analyzing document content..."),
106 | "document_search": os.getenv("PROGRESS_MSG_DOC_SEARCH", "Searching document for relevant content..."),
107 | "web_search": os.getenv("PROGRESS_MSG_WEB_SEARCH", "Doing web search for relevant information...")
108 | }
109 |
110 | # Proxy Settings
111 | PROXY_ENABLED = os.getenv("PROXY_ENABLED", "false").lower() == "true"
112 | PROXY_HOST = os.getenv("PROXY_HOST", "")
113 | PROXY_PORT = int(os.getenv("PROXY_PORT", ""))
114 | PROXY_USERNAME = os.getenv("PROXY_USERNAME", "") # Using -res-any for any region
115 | PROXY_PASSWORD = os.getenv("PROXY_PASSWORD", "")
116 |
117 | # Timeouts
118 | API_TIMEOUT = int(os.getenv("API_TIMEOUT", "600"))
119 | SEARCH_TIMEOUT = 10
120 | REQUEST_TIMEOUT = {
121 | "connect": 3,
122 | "read": 5,
123 | "write": 3,
124 | "pool": 3
125 | }
126 |
127 | # Headers
128 | DEFAULT_HEADERS = {
129 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
130 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
131 | "Accept-Language": "en-US,en;q=0.5",
132 | "Accept-Encoding": "gzip, deflate",
133 | "DNT": "1"
134 | }
135 |
136 | # Search Engine Weights
137 | SEARCH_ENGINE_WEIGHTS = {
138 | "google": 3,
139 | "bing": 2,
140 | "duckduckgo": 2,
141 | "brave": 1,
142 | "qwant": 1
143 | }
--------------------------------------------------------------------------------
/app/services/image_service.py:
--------------------------------------------------------------------------------
1 | import httpx
2 | import logging
3 | import base64
4 | from fastapi import HTTPException
5 | from app.utils.file_utils import is_base64
6 |
7 | from app.config.settings import (
8 | OPENAI_API_URL, OPENAI_API_KEY,
9 | API_TIMEOUT, MAX_IMAGE_SIZE,
10 | OPENAI_ENHANCE_MODEL
11 | )
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | def get_headers(api_key: str) -> dict:
16 | return {
17 | "Content-Type": "application/json",
18 | "Authorization": f"Bearer {api_key}"
19 | }
20 |
21 | async def process_image(image_url: str) -> str:
22 | """Process image using OpenAI and return description"""
23 | if not OPENAI_API_KEY:
24 | raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured")
25 |
26 | try:
27 | # Handle base64 image data
28 | if is_base64(image_url):
29 | # If it's a data URI, get the actual base64 data
30 | if image_url.startswith('data:'):
31 | base64_data = image_url.split(',')[1]
32 | else:
33 | base64_data = image_url
34 |
35 | # Check size of decoded data
36 | try:
37 | decoded_data = base64.b64decode(base64_data)
38 | if len(decoded_data) > MAX_IMAGE_SIZE:
39 | raise HTTPException(
40 | status_code=413,
41 | detail=f"Image size ({len(decoded_data)} bytes) exceeds maximum allowed size ({MAX_IMAGE_SIZE} bytes)"
42 | )
43 | except Exception as e:
44 | logger.error(f"Error decoding base64 data: {str(e)}")
45 | raise HTTPException(status_code=400, detail="Invalid base64 image data")
46 |
47 | image_data = image_url
48 | else:
49 | # Handle regular URL
50 | # Check if this is an S3 presigned URL
51 | is_s3_url = "X-Amz-Algorithm=AWS4-HMAC-SHA256" in image_url and "X-Amz-Credential" in image_url
52 |
53 | # Set headers based on URL type
54 | headers = {
55 | "Accept": "*/*",
56 | "Accept-Encoding": "gzip, deflate",
57 | "Connection": "keep-alive",
58 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
59 | }
60 |
61 | async with httpx.AsyncClient() as client:
62 | if is_s3_url:
63 | # For S3 presigned URLs, skip HEAD request and directly download
64 | response = await client.get(
65 | image_url,
66 | headers=headers,
67 | follow_redirects=True
68 | )
69 | response.raise_for_status()
70 |
71 | # Check content length after download
72 | image_bytes = response.content
73 | if len(image_bytes) > MAX_IMAGE_SIZE:
74 | raise HTTPException(
75 | status_code=413,
76 | detail=f"Image size ({len(image_bytes)} bytes) exceeds maximum allowed size ({MAX_IMAGE_SIZE} bytes)"
77 | )
78 |
79 | # Detect content type
80 | content_type = response.headers.get("content-type", "image/jpeg")
81 | if not content_type.startswith("image/"):
82 | content_type = "image/jpeg" # Default to jpeg if no valid content type
83 |
84 | # Create data URI
85 | image_data = f"data:{content_type};base64,{base64.b64encode(image_bytes).decode()}"
86 | else:
87 | # For regular URLs, do HEAD request first
88 | head_response = await client.head(
89 | image_url,
90 | headers=headers,
91 | follow_redirects=True
92 | )
93 | head_response.raise_for_status()
94 |
95 | # Check content length if available
96 | content_length = head_response.headers.get("content-length")
97 | if content_length:
98 | file_size = int(content_length)
99 | if file_size > MAX_IMAGE_SIZE:
100 | raise HTTPException(
101 | status_code=413,
102 | detail=f"Image size ({file_size} bytes) exceeds maximum allowed size ({MAX_IMAGE_SIZE} bytes)"
103 | )
104 |
105 | # For regular URLs, pass the URL directly
106 | image_data = image_url
107 |
108 | # Process image with OpenAI
109 | async with httpx.AsyncClient() as client:
110 | openai_request = {
111 | "model": OPENAI_ENHANCE_MODEL,
112 | "messages": [{
113 | "role": "user",
114 | "content": [{
115 | "type": "text",
116 | "text": "Only output the ocr result of user's uploaded image. If the image contains data structures like a graph or table, convert the info into text. If this is a image of an object, describe it as detailed as possible."
117 | }, {
118 | "type": "image_url",
119 | "image_url": {"url": image_data}
120 | }]
121 | }],
122 | "stream": False
123 | }
124 |
125 | # Add headers for non-base64 and non-S3 URLs
126 | if not is_base64(image_data) and not is_s3_url:
127 | openai_request["messages"][0]["content"][1]["image_url"]["headers"] = headers
128 |
129 | response = await client.post(
130 | OPENAI_API_URL,
131 | json=openai_request,
132 | headers=get_headers(OPENAI_API_KEY),
133 | timeout=API_TIMEOUT
134 | )
135 | response.raise_for_status()
136 | result = response.json()
137 | return result.get("choices", [{}])[0].get("message", {}).get("content", "")
138 |
139 | except Exception as e:
140 | logger.error(f"Image processing error: {str(e)}")
141 | raise HTTPException(status_code=500, detail=f"Image processing failed: {str(e)}")
--------------------------------------------------------------------------------
/app/api/chat.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import asyncio
4 | from fastapi import APIRouter, HTTPException, Request, Depends
5 | from fastapi.responses import StreamingResponse
6 | from openai import AsyncOpenAI
7 |
8 | from app.models.chat import ChatRequest
9 | from app.services.message_service import process_messages
10 | from app.config.settings import (
11 | TARGET_MODEL_BASE_URL, TARGET_MODEL_API_KEY,
12 | DEFAULT_MODEL
13 | )
14 | from app.core.auth import verify_api_key
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | router = APIRouter()
19 |
20 | # Initialize OpenAI client
21 | client = AsyncOpenAI(
22 | base_url=f"{TARGET_MODEL_BASE_URL}/v1",
23 | api_key=TARGET_MODEL_API_KEY
24 | )
25 |
26 | async def stream_progress_message(message: str):
27 | """Stream a progress message"""
28 | chunk_data = {
29 | "id": "progress_msg",
30 | "object": "chat.completion.chunk",
31 | "created": 0,
32 | "model": DEFAULT_MODEL,
33 | "choices": [{
34 | "index": 0,
35 | "delta": {"role": "assistant", "content": message},
36 | "finish_reason": None
37 | }]
38 | }
39 | yield f"data: {json.dumps(chunk_data)}\n\n"
40 |
41 | async def stream_response(response):
42 | """Stream the response from the upstream API"""
43 | try:
44 | buffer = ""
45 | async for line in response.aiter_lines():
46 | if not line.strip():
47 | continue
48 |
49 | buffer += line
50 | if buffer.strip() == 'data: [DONE]':
51 | yield 'data: [DONE]\n\n'
52 | break
53 |
54 | try:
55 | if buffer.startswith('data: '):
56 | json.loads(buffer[6:])
57 | yield f"{buffer}\n\n"
58 | buffer = ""
59 | except json.JSONDecodeError:
60 | continue
61 | except Exception as e:
62 | logger.error(f"Error parsing stream chunk: {str(e)}")
63 | continue
64 | except Exception as e:
65 | logger.error(f"Error in stream_response: {str(e)}")
66 | yield f"data: {json.dumps({'error': str(e)})}\n\n"
67 |
68 | @router.post("/v1/chat/completions")
69 | async def chat(
70 | request: ChatRequest,
71 | raw_request: Request,
72 | api_key: str = Depends(verify_api_key)
73 | ):
74 | """Handle chat requests with support for streaming and image processing"""
75 | if not TARGET_MODEL_API_KEY:
76 | raise HTTPException(status_code=500, detail="TARGET_MODEL_API_KEY not configured")
77 |
78 | try:
79 | logger.info("Processing chat request...")
80 |
81 | if request.stream:
82 | # Create a queue for progress messages
83 | progress_queue = asyncio.Queue()
84 |
85 | # Process messages in background task
86 | async def process_messages_task():
87 | try:
88 | processed_messages, _ = await process_messages(
89 | request.messages,
90 | stream=True,
91 | progress_queue=progress_queue
92 | )
93 | # Convert messages for API request
94 | api_messages = [{"role": msg.role, "content": msg.content} for msg in processed_messages]
95 |
96 | # Create streaming completion
97 | stream = await client.chat.completions.create(
98 | messages=api_messages,
99 | model=DEFAULT_MODEL,
100 | stream=True
101 | )
102 | return stream
103 | except Exception as e:
104 | logger.error(f"Error in process_messages_task: {str(e)}")
105 | raise
106 |
107 | # Start processing task
108 | processing_task = asyncio.create_task(process_messages_task())
109 |
110 | async def generate():
111 | progress_task = None
112 | current_progress_message = None
113 | progress_generator = None
114 | had_progress_message = False
115 |
116 | try:
117 | while True:
118 | # Check for new progress message
119 | try:
120 | progress_message = progress_queue.get_nowait()
121 | if had_progress_message:
122 | chunk_data = {
123 | "id": "progress_msg",
124 | "object": "chat.completion.chunk",
125 | "created": 0,
126 | "model": DEFAULT_MODEL,
127 | "choices": [{
128 | "index": 0,
129 | "delta": {"role": "assistant", "content": "\n"},
130 | "finish_reason": None
131 | }]
132 | }
133 | yield f"data: {json.dumps(chunk_data)}\n\n"
134 | progress_generator = stream_progress_message(progress_message)
135 | current_progress_message = progress_message
136 | yield await anext(progress_generator)
137 | had_progress_message = True
138 | except asyncio.QueueEmpty:
139 | pass
140 | except StopAsyncIteration:
141 | pass
142 |
143 | if processing_task.done():
144 | break
145 |
146 | await asyncio.sleep(0.1)
147 |
148 | stream = await processing_task
149 |
150 | if had_progress_message:
151 | chunk_data = {
152 | "id": "progress_msg",
153 | "object": "chat.completion.chunk",
154 | "created": 0,
155 | "model": DEFAULT_MODEL,
156 | "choices": [{
157 | "index": 0,
158 | "delta": {"role": "assistant", "content": "\n"},
159 | "finish_reason": None
160 | }]
161 | }
162 | yield f"data: {json.dumps(chunk_data)}\n\n"
163 |
164 | last_content_chunk = None
165 | async for chunk in stream:
166 | chunk_dict = chunk.model_dump()
167 |
168 | if chunk.choices[0].delta.content is not None:
169 | if last_content_chunk:
170 | yield f"data: {json.dumps(last_content_chunk)}\n\n"
171 | last_content_chunk = chunk_dict
172 | else:
173 | if last_content_chunk:
174 | last_content_chunk["choices"][0]["finish_reason"] = "stop"
175 | yield f"data: {json.dumps(last_content_chunk)}\n\n"
176 | last_content_chunk = None
177 | if not (chunk.choices[0].finish_reason == "stop" and not chunk.choices[0].delta.content):
178 | yield f"data: {json.dumps(chunk_dict)}\n\n"
179 |
180 | if last_content_chunk:
181 | last_content_chunk["choices"][0]["finish_reason"] = "stop"
182 | yield f"data: {json.dumps(last_content_chunk)}\n\n"
183 |
184 | yield "data: [DONE]\n\n"
185 |
186 | except Exception as e:
187 | logger.error(f"Error in generate: {str(e)}")
188 | if progress_task and not progress_task.done():
189 | progress_task.cancel()
190 | raise
191 |
192 | return StreamingResponse(
193 | generate(),
194 | media_type="text/event-stream",
195 | headers={
196 | "Cache-Control": "no-cache, no-transform",
197 | "Connection": "keep-alive",
198 | "Content-Type": "text/event-stream",
199 | "Transfer-Encoding": "chunked",
200 | "X-Accel-Buffering": "no"
201 | }
202 | )
203 | else:
204 | processed_messages, _ = await process_messages(request.messages)
205 |
206 | response = await client.chat.completions.create(
207 | messages=[{"role": msg.role, "content": msg.content} for msg in processed_messages],
208 | model=DEFAULT_MODEL,
209 | stream=False
210 | )
211 | return response.model_dump()
212 |
213 | except Exception as e:
214 | logger.error(f"Chat error: {str(e)}")
215 | raise HTTPException(status_code=500, detail=str(e))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ExtendAI
2 |
3 | ExtendAI is a universal framework that can extend any AI model to achieve multi-modal ability, network search ability, and document analysis ability.
4 |
5 | ## Features
6 |
7 | - **Multi-format Document Support**: Process various document formats including:
8 | - PDF files
9 | - Microsoft Office documents (Word, Excel, PowerPoint)
10 | - Markdown files
11 | - Text files
12 | - Source code files
13 | - HTML/XML files
14 | - CSV files
15 | - RST files
16 | - Outlook MSG files
17 | - EPub files
18 |
19 | - **Advanced Document Processing**:
20 | - Automatic text encoding detection and fixing
21 | - Mutiple Documents Batch Processing
22 | - Smart document chunking with configurable size and overlap
23 | - Ultra fast processing of large documents
24 | - File type detection based on content and extensions
25 |
26 | - **Intelligent Web Search**:
27 | - Integration with multiple search engines (Google, Bing)
28 | - SearXNG support for privacy-focused searches
29 | - Smart result filtering and ranking
30 | - Configurable search depth and limits
31 | - Automatic content extraction and processing
32 |
33 | - **Advanced Image Comprehension**:
34 | - Deep image analysis and understanding
35 | - Mutiple Images Batch Processing
36 | - Text extraction from images (OCR)
37 | - Support for multiple image formats
38 | - Context-aware image interpretation
39 |
40 | - **Flexible Vector Store Integration**:
41 | - PostgreSQL vector store support
42 | - Pinecone vector store support
43 | - Local FAISS vector store with caching
44 | - Automatic fallback mechanisms
45 |
46 | - **Performance Optimizations**:
47 | - Asynchronous document processing
48 | - Efficient caching system for FAISS vectors
49 | - Batched embedding generation
50 | - Configurable chunk sizes and overlap
51 | - File size limits and safety checks
52 |
53 | - **Robust Error Handling**:
54 | - Graceful fallbacks for failed operations
55 | - Comprehensive logging system
56 | - Automatic cleanup of temporary files
57 | - Cache invalidation and management
58 |
59 | ## Screenshots
60 |
61 |

62 |
63 |

64 |
65 |

66 |
67 |
68 |
69 | ## Configuration
70 |
71 | The system is highly configurable through environment variables:
72 |
73 | ```env
74 | # Document Processing Settings
75 | CHUNK_SIZE=1000 # Size of text chunks for processing (in characters)
76 | CHUNK_OVERLAP=200 # Overlap between chunks to maintain context
77 | MAX_CHUNKS_PER_DOC=5 # Maximum number of chunks to return per document
78 | EMBEDDING_BATCH_SIZE=50 # Number of chunks to process in parallel for embeddings
79 | MIN_CHUNK_LENGTH=10 # Minimum length of a chunk to be processed
80 | MAX_FILE_SIZE=10485760 # Maximum file size in bytes (10MB)
81 | MAX_IMAGE_SIZE=5242880 # Maximum image size in bytes (5MB)
82 |
83 | # Vector Store Settings
84 | VECTOR_STORE_TYPE=faiss # Vector store backend: 'faiss', 'postgres', or 'pinecone'
85 | POSTGRES_CONNECTION_STRING= # PostgreSQL connection URL for vector storage
86 | POSTGRES_COLLECTION_NAME= # Collection name for storing vectors in PostgreSQL
87 |
88 | # Pinecone Settings
89 | PINECONE_API_KEY= # API key for Pinecone vector database
90 | PINECONE_INDEX_NAME= # Name of the Pinecone index to use
91 |
92 | # Cache Settings
93 | VECTOR_CACHE_DIR=cache/vectors # Directory for storing FAISS vector cache
94 | VECTOR_CACHE_TTL=7200 # Cache time-to-live in seconds (2 hours)
95 |
96 | API_TIMEOUT=600 # API Timeout (in seconds)
97 | MY_API_KEY=sk-planetzero-api-key # Authorization key
98 |
99 | # Model API Settings
100 | TARGET_MODEL_BASE_URL= # Base URL for the target AI model API
101 | TARGET_MODEL_API_KEY= # API key for the target model
102 | OPENAI_BASE_URL= # Base URL for OpenAI API (or compatible endpoint)
103 | OPENAI_API_KEY= # OpenAI API key (or compatible key)
104 | OPENAI_ENHANCE_MODEL= # The model used to facilitate image processing and search analysis
105 |
106 | # Embedding API Settings
107 | EMBEDDING_BASE_URL= # Base URL for embedding API
108 | EMBEDDING_API_KEY= # API key for embedding service
109 | EMBEDDING_MODEL=text-embedding-3-small # Model to use for text embeddings
110 | EMBEDDING_DIMENSIONS=1536 # Dimension of the embedding vectors
111 |
112 | # Search Settings
113 | SEARXNG_URL= # URL for SearXNG instance (for web search)
114 | DEFAULT_MODEL=deepseek-r1 # Default AI model to use
115 | SEARCH_ENGINE=google # Search engine to use (google, bing, etc.)
116 | SEARCH_RESULT_LIMIT=5 # Number of search results to return
117 | SEARCH_RESULT_MULTIPLIER=2 # Multiplier for raw results to fetch
118 | WEB_CONTENT_CHUNK_SIZE=512 # Size of chunks for web content
119 | WEB_CONTENT_CHUNK_OVERLAP=50 # Overlap for web content chunks
120 | WEB_CONTENT_MAX_CHUNKS=5 # Maximum chunks to process from web content
121 |
122 | # Proxy Settings (Optional)
123 | PROXY_ENABLED=false # Whether to use proxy for requests
124 | PROXY_HOST= # Proxy server hostname
125 | PROXY_PORT= # Proxy server port
126 | PROXY_USERNAME= # Proxy authentication username
127 | PROXY_PASSWORD= # Proxy authentication password
128 | PROXY_COUNTRY= # Preferred proxy server country
129 | PROXY_SESSION_ID= # Session ID for proxy (if required)
130 |
131 | # Feature Switches
132 | ENABLE_PROGRESS_MESSAGES=false # Enable/disable progress message updates
133 | ENABLE_IMAGE_ANALYSIS=true # Enable/disable image analysis capability
134 | ENABLE_WEB_SEARCH=true # Enable/disable web search capability
135 | ENABLE_DOCUMENT_ANALYSIS=true # Enable/disable document analysis capability
136 |
137 | # Progress Messages (Customizable)
138 | PROGRESS_MSG_IMAGE="Analyzing image content..." # Message shown during image analysis
139 | PROGRESS_MSG_DOC="Analyzing document..." # Message shown during document processing
140 | PROGRESS_MSG_DOC_SEARCH="Searching document..." # Message shown during document search
141 | PROGRESS_MSG_WEB_SEARCH="Searching web content..." # Message shown during web search
142 | ```
143 |
144 | ## Advanced Configuration Details
145 |
146 | ### Document Processing
147 | - OpenAI format, just need to pass the document url through image_url parameter
148 | - The chunking system breaks down documents into manageable pieces while maintaining context through overlap
149 | - Batch processing helps optimize embedding generation and API usage
150 | - File size limits protect against resource exhaustion and API limitations
151 |
152 | ### Vector Store Options
153 | 1. **FAISS** (Local):
154 | - Fast, efficient local vector storage
155 | - Good for development and smaller deployments
156 | - Includes local caching system for performance
157 |
158 | 2. **PostgreSQL**:
159 | - Persistent vector storage in PostgreSQL database
160 | - Suitable for production deployments
161 | - Supports concurrent access and backups
162 |
163 | 3. **Pinecone**:
164 | - Cloud-based vector database
165 | - Excellent for large-scale deployments
166 | - Provides automatic scaling and management
167 |
168 | ### API Integration
169 | - Supports multiple model endpoints (OpenAI-compatible)
170 | - Configurable embedding services
171 | - Automatic fallbacks and error handling
172 |
173 | ### Search Capabilities
174 | - Integrated web search through SearXNG and Google
175 | - Configurable search engines and result limits
176 | - Smart content chunking for web results
177 |
178 | ### Progress Tracking
179 | - Customizable progress messages
180 | - Feature toggles for different capabilities
181 | - Localization support for messages
182 |
183 | ## Installation & Deployment
184 |
185 | ### Option 1: Docker Compose (Recommended for Production)
186 |
187 | 1. Clone the repository:
188 | ```bash
189 | git clone https://github.com/realnoob007/ExtendAI.git
190 | cd ExtendAI
191 | ```
192 |
193 | 2. Set up environment variables:
194 | ```bash
195 | cp .env.example .env
196 | # Edit .env with your configuration (API keys, model endpoints, etc.)
197 | ```
198 |
199 | 3. Start the services using Docker Compose:
200 | ```bash
201 | docker compose up -d
202 | ```
203 |
204 | This will start:
205 | - ExtendAI application on port 8096
206 | - PostgreSQL database on port 6023
207 | - PostgreSQL with pgvector extension on port 6024
208 |
209 | To view logs:
210 | ```bash
211 | docker compose logs -f
212 | ```
213 |
214 | To stop the services:
215 | ```bash
216 | docker compose down
217 | ```
218 |
219 | ### Option 2: Docker (Single Container)
220 |
221 | If you want to run only the ExtendAI application container and use external databases:
222 |
223 | 1. Clone and configure:
224 | ```bash
225 | git clone https://github.com/realnoob007/ExtendAI.git
226 | cd ExtendAI
227 | cp .env.example .env
228 | # Edit .env with your configuration
229 | ```
230 |
231 | 2. Run the container:
232 | ```bash
233 | docker run -d \
234 | --name extendai \
235 | -p 8096:8096 \
236 | --env-file .env \
237 | -v $(pwd)/cache:/app/cache \
238 | -v $(pwd)/.env:/app/.env:ro \
239 | chasney/extendai:latest
240 | ```
241 |
242 | Useful Docker commands:
243 | ```bash
244 | # View logs
245 | docker logs -f extendai
246 |
247 | # Stop container
248 | docker stop extendai
249 |
250 | # Remove container
251 | docker rm extendai
252 |
253 | # Rebuild image (if you made changes)
254 | docker build --no-cache -t extendai .
255 | ```
256 |
257 | ### Option 3: Local Development
258 |
259 | For development purposes, you can run the application directly:
260 |
261 | 1. Create a virtual environment (Python 3.11+ recommended):
262 | ```bash
263 | python -m venv venv
264 | source venv/bin/activate # On Windows: venv\Scripts\activate
265 | ```
266 |
267 | 2. Install dependencies:
268 | ```bash
269 | pip install -r requirements.txt
270 | ```
271 |
272 | 3. Set up environment variables:
273 | ```bash
274 | cp .env.example .env
275 | # Edit .env with your configuration
276 | ```
277 |
278 | 4. Run the application:
279 | ```bash
280 | python main.py
281 | ```
282 |
283 | ## Architecture
284 |
285 | (not fully refactored yet):
286 |
287 | - `main.py`: Application entry point and FastAPI setup
288 | - `app/api/`: API route definitions
289 | - `app/services/`: Core business logic services
290 | - `app/models/`: Data models and schemas
291 | - `app/utils/`: Utility functions and helpers
292 | - `app/core/`: Core system components
293 | - `app/config/`: Configuration management
294 |
295 | #### Example Request Payload (Post Request to http://0.0.0.0:8096/v1/chat/completions)
296 | ```json
297 | {
298 | "messages": [
299 | {
300 | "role": "user",
301 | "content": [
302 | {
303 | "type": "text",
304 | "text": "what is this document about?"
305 | },
306 | {
307 | "type": "image_url",
308 | "image_url": {
309 | "url": "https://xxx.com/document.pdf"
310 | }
311 | }
312 | ]
313 | }
314 | ],
315 | "model": "deepseek-r1-all",
316 | "stream": false
317 | }
318 | ```
319 |
320 | #### Content Types Support
321 | The API supports multiple content types in messages:
322 |
323 | 1. **Text Content**
324 | ```json
325 | {
326 | "type": "text",
327 | "text": "your question or prompt here"
328 | }
329 | ```
330 |
331 | 2. **Image URL**
332 | ```json
333 | {
334 | "type": "image_url",
335 | "image_url": {
336 | "url": "https://example.com/image.jpg"
337 | }
338 | }
339 | ```
340 |
341 | 3. **Document URL**
342 | ```json
343 | {
344 | "type": "image_url",
345 | "image_url": {
346 | "url": "https://example.com/document.pdf"
347 | }
348 | }
349 | ```
350 |
351 | #### Parameters
352 |
353 | | Parameter | Type | Required | Description |
354 | |-----------|------|----------|-------------|
355 | | messages | array | Yes | Array of message objects |
356 | | model | string | No | Model to use (default: deepseek-r1) |
357 | | stream | boolean | No | Whether to stream the response (default: false) |
358 |
359 | #### Non-Streaming Response
360 | ```json
361 | {
362 | "id": "chatcmpl-123",
363 | "object": "chat.completion",
364 | "created": 1677858242,
365 | "model": "deepseek-r1-all",
366 | "usage": {
367 | "prompt_tokens": 56,
368 | "completion_tokens": 31,
369 | "total_tokens": 87
370 | },
371 | "choices": [
372 | {
373 | "message": {
374 | "role": "assistant",
375 | "content": "The document appears to be about..."
376 | },
377 | "finish_reason": "stop",
378 | "index": 0
379 | }
380 | ]
381 | }
382 | ```
383 |
384 | #### Streaming Response
385 | When `stream` is set to `true`, the response will be sent as Server-Sent Events (SSE):
386 | ```http
387 | data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677858242,"model":"deepseek-r1-all","choices":[{"delta":{"role":"assistant","content":"The"},"index":0,"finish_reason":null}]}
388 |
389 | data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677858242,"model":"deepseek-r1-all","choices":[{"delta":{"content":" document"},"index":0,"finish_reason":null}]}
390 |
391 | data: [DONE]
392 | ```
393 |
394 | ## Contributing
395 |
396 | 1. Fork the repository
397 | 2. Create a feature branch
398 | 3. Commit your changes
399 | 4. Push to the branch
400 | 5. Create a Pull Request
401 |
402 | ## Sponsorship (You can run this project with this api platform to save 20%)
403 | [PlanetZero API](https://api.planetzeroapi.com/)
404 |
405 | ## License
406 |
407 | MIT License
408 |
--------------------------------------------------------------------------------
/app/services/search_service.py:
--------------------------------------------------------------------------------
1 | import httpx
2 | import asyncio
3 | import logging
4 | from bs4 import BeautifulSoup
5 | from urllib.parse import urlparse
6 | from googlesearch import search
7 | from typing import List, Optional, Dict, Any
8 | from langchain_text_splitters import RecursiveCharacterTextSplitter
9 | from langchain_openai import OpenAIEmbeddings
10 | from langchain_core.documents import Document
11 | from langchain_community.vectorstores import FAISS
12 | from langchain_postgres import PGVector
13 | from langchain_pinecone import PineconeVectorStore
14 | from pinecone import Pinecone
15 |
16 | from app.config.settings import (
17 | PROXY_ENABLED, PROXY_HOST, PROXY_PORT, PROXY_USERNAME,
18 | PROXY_PASSWORD, SEARCH_RESULT_LIMIT, SEARCH_RESULT_MULTIPLIER,
19 | WEB_CONTENT_CHUNK_SIZE, WEB_CONTENT_CHUNK_OVERLAP,
20 | EMBEDDING_MODEL, EMBEDDING_BASE_URL, EMBEDDING_API_KEY, EMBEDDING_DIMENSIONS,
21 | VECTOR_STORE_TYPE, POSTGRES_CONNECTION_STRING, POSTGRES_COLLECTION_NAME,
22 | PINECONE_API_KEY, PINECONE_INDEX_NAME
23 | )
24 | from app.models.chat import SearchResult
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | def get_proxy_url() -> str:
29 | """Generate basic proxy URL"""
30 | return f"http://{PROXY_USERNAME}:{PROXY_PASSWORD}@{PROXY_HOST}:{PROXY_PORT}"
31 |
32 | class SearchService:
33 | def __init__(self):
34 | self.text_splitter = RecursiveCharacterTextSplitter(
35 | chunk_size=WEB_CONTENT_CHUNK_SIZE,
36 | chunk_overlap=WEB_CONTENT_CHUNK_OVERLAP,
37 | add_start_index=True,
38 | separators=["\n\n", "\n", "。", ". ", " ", ""],
39 | length_function=len,
40 | is_separator_regex=False
41 | )
42 | self.embeddings = OpenAIEmbeddings(
43 | model=EMBEDDING_MODEL,
44 | openai_api_key=EMBEDDING_API_KEY,
45 | openai_api_base=EMBEDDING_BASE_URL,
46 | dimensions=EMBEDDING_DIMENSIONS,
47 | request_timeout=60.0,
48 | show_progress_bar=True,
49 | retry_min_seconds=1,
50 | retry_max_seconds=60,
51 | max_retries=3,
52 | skip_empty=True,
53 | )
54 | # Initialize vector store
55 | self.vector_store = None
56 | self._init_vector_store()
57 |
58 | def _init_vector_store(self):
59 | """Initialize vector store based on configuration"""
60 | try:
61 | if VECTOR_STORE_TYPE == "postgres" and POSTGRES_CONNECTION_STRING:
62 | logger.info("Initializing PostgreSQL vector store")
63 | self.vector_store = PGVector(
64 | connection=POSTGRES_CONNECTION_STRING,
65 | collection_name=f"{POSTGRES_COLLECTION_NAME}_web_search",
66 | embeddings=self.embeddings
67 | )
68 | elif VECTOR_STORE_TYPE == "pinecone":
69 | logger.info("Initializing Pinecone vector store")
70 | pc = Pinecone(api_key=PINECONE_API_KEY)
71 | index = pc.Index(PINECONE_INDEX_NAME)
72 | self.vector_store = PineconeVectorStore(
73 | embedding=self.embeddings,
74 | index=index,
75 | namespace="web_search"
76 | )
77 | else:
78 | logger.info("Initializing FAISS vector store")
79 | self.vector_store = FAISS.from_texts(
80 | ["placeholder"], # Need at least one document to initialize
81 | self.embeddings,
82 | metadatas=[{"placeholder": True}]
83 | )
84 | except Exception as e:
85 | logger.error(f"Failed to initialize vector store: {str(e)}")
86 | logger.warning("Falling back to FAISS vector store")
87 | self.vector_store = FAISS.from_texts(
88 | ["placeholder"],
89 | self.embeddings,
90 | metadatas=[{"placeholder": True}]
91 | )
92 |
93 | async def _get_proxy_config(self) -> Optional[str]:
94 | """Get proxy configuration if enabled"""
95 | if not PROXY_ENABLED or not all([PROXY_HOST, PROXY_PORT, PROXY_USERNAME, PROXY_PASSWORD]):
96 | return None
97 |
98 | return f"http://{PROXY_USERNAME}:{PROXY_PASSWORD}@{PROXY_HOST}:{PROXY_PORT}"
99 |
100 | async def _fetch_page_content(self, url: str) -> str:
101 | """Fetch and extract text content from a webpage"""
102 | try:
103 | proxy = await self._get_proxy_config()
104 | async with httpx.AsyncClient(proxy=proxy, timeout=30.0) as client:
105 | response = await client.get(url, follow_redirects=True)
106 | response.raise_for_status()
107 |
108 | soup = BeautifulSoup(response.text, 'html.parser')
109 |
110 | # Remove unwanted elements
111 | for element in soup.find_all(['script', 'style', 'nav', 'header', 'footer', 'iframe', 'aside', 'form']):
112 | element.decompose()
113 |
114 | # First try to find main content area
115 | main_content = ""
116 | content_tags = soup.find_all(['article', 'main', 'div'],
117 | class_=lambda x: x and any(c in str(x).lower() for c in ['content', 'article', 'post', 'entry', 'main', 'text']))
118 |
119 | if content_tags:
120 | # Get the tag with most text content
121 | main_tag = max(content_tags, key=lambda x: len(x.get_text().strip()))
122 |
123 | # Extract paragraphs from main content area
124 | paragraphs = []
125 | for p in main_tag.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
126 | text = p.get_text().strip()
127 | if len(text) > 20: # Skip very short paragraphs
128 | paragraphs.append(text)
129 |
130 | main_content = '\n'.join(paragraphs)
131 |
132 | # Fallback to all paragraphs if no main content found
133 | if not main_content:
134 | paragraphs = []
135 | for p in soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
136 | text = p.get_text().strip()
137 | if len(text) > 20:
138 | paragraphs.append(text)
139 | main_content = '\n'.join(paragraphs)
140 |
141 | return main_content.strip()
142 | except Exception as e:
143 | logger.error(f"Error fetching content from {url}: {str(e)}")
144 | return ""
145 |
146 | async def search(self, query: str) -> List[Dict[str, Any]]:
147 | """Perform web search and return relevant content chunks"""
148 | try:
149 | # Get more initial results to account for failures
150 | raw_limit = SEARCH_RESULT_LIMIT * SEARCH_RESULT_MULTIPLIER
151 |
152 | # Configure proxy for Google search
153 | proxy_url = get_proxy_url() if PROXY_ENABLED else None
154 |
155 | # Get search results using googlesearch
156 | search_urls = []
157 | try:
158 | results = search(
159 | query,
160 | num_results=raw_limit,
161 | proxy=proxy_url,
162 | ssl_verify=False # Skip SSL verification for proxy
163 | )
164 | search_urls = list(results) # Convert generator to list
165 | logger.info(f"Found {len(search_urls)} results from Google search")
166 | except Exception as e:
167 | logger.error(f"Google search error: {str(e)}")
168 | return []
169 |
170 | if not search_urls:
171 | return []
172 |
173 | # Deduplicate URLs while keeping order
174 | seen_urls = set()
175 | unique_results = []
176 | for url in search_urls:
177 | if not url or url in seen_urls:
178 | continue
179 | try:
180 | parsed = urlparse(url)
181 | if not all([parsed.scheme, parsed.netloc]):
182 | continue
183 | if any(parsed.path.lower().endswith(ext) for ext in [
184 | '.pdf', '.doc', '.docx', '.ppt', '.pptx', '.xls', '.xlsx',
185 | '.zip', '.rar', '.7z', '.tar', '.gz', '.mp3', '.mp4', '.avi'
186 | ]):
187 | continue
188 | seen_urls.add(url)
189 | unique_results.append({
190 | "url": url,
191 | "title": url.split("/")[-1] or parsed.netloc, # Simple title extraction
192 | })
193 | except:
194 | continue
195 |
196 | logger.info(f"Found {len(unique_results)} unique URLs to process")
197 |
198 | # Process URLs in batches until we have enough valid results
199 | valid_chunks = []
200 | batch_size = SEARCH_RESULT_LIMIT
201 | current_index = 0
202 |
203 | while current_index < len(unique_results) and len(valid_chunks) < SEARCH_RESULT_LIMIT:
204 | batch = unique_results[current_index:current_index + batch_size]
205 | current_index += batch_size
206 |
207 | # Fetch content from URLs in parallel
208 | proxies = await self._get_proxy_config()
209 | async with httpx.AsyncClient(proxy=proxies, timeout=30.0, verify=False) as client:
210 | tasks = []
211 | for result in batch:
212 | if "url" in result:
213 | tasks.append(self._fetch_page_content(result["url"]))
214 |
215 | # Wait for all tasks with timeout
216 | try:
217 | contents = await asyncio.gather(*tasks, return_exceptions=True)
218 |
219 | # Process successful results
220 | for result, content in zip(batch, contents):
221 | if isinstance(content, Exception):
222 | logger.error(f"Failed to fetch {result.get('url')}: {str(content)}")
223 | continue
224 |
225 | if not content:
226 | continue
227 |
228 | # Split content into chunks
229 | chunks = self.text_splitter.split_text(content)
230 | logger.info(f"Split content from {result.get('url')} into {len(chunks)} chunks")
231 |
232 | # Create documents with metadata
233 | for chunk in chunks:
234 | chunk = chunk.strip()
235 | if not chunk or len(chunk) < 50: # Skip very short chunks
236 | continue
237 | doc = Document(
238 | page_content=chunk,
239 | metadata={
240 | "url": result.get("url", ""),
241 | "title": result.get("title", ""),
242 | "source": "web_search"
243 | }
244 | )
245 | valid_chunks.append(doc)
246 |
247 | # Break if we have enough chunks
248 | if len(valid_chunks) >= SEARCH_RESULT_LIMIT * 2: # Get extra for diversity
249 | break
250 |
251 | if len(valid_chunks) >= SEARCH_RESULT_LIMIT * 2:
252 | break
253 |
254 | except asyncio.TimeoutError:
255 | logger.warning(f"Timeout processing batch starting at index {current_index}")
256 | continue
257 | except Exception as e:
258 | logger.error(f"Error processing batch: {str(e)}")
259 | continue
260 |
261 | if not valid_chunks:
262 | return []
263 |
264 | logger.info(f"Total valid chunks collected: {len(valid_chunks)}")
265 |
266 | # Create new FAISS store for similarity search
267 | vector_store = FAISS.from_documents(valid_chunks, self.embeddings)
268 |
269 | # Search for similar chunks using MMR for diversity
270 | similar_chunks = vector_store.max_marginal_relevance_search(
271 | query,
272 | k=min(SEARCH_RESULT_LIMIT, len(valid_chunks)),
273 | fetch_k=min(SEARCH_RESULT_LIMIT * 2, len(valid_chunks)),
274 | lambda_mult=0.7
275 | )
276 |
277 | # Return relevant chunks with metadata
278 | relevant_chunks = []
279 | logger.info("\nMatched chunks content:")
280 | logger.info("="*80)
281 |
282 | # Get embeddings for scoring
283 | query_embedding = self.embeddings.embed_query(query)
284 | chunk_embeddings = self.embeddings.embed_documents([doc.page_content for doc in similar_chunks])
285 |
286 | from numpy import dot
287 | from numpy.linalg import norm
288 |
289 | for i, (doc, chunk_embedding) in enumerate(zip(similar_chunks, chunk_embeddings), 1):
290 | # Calculate cosine similarity
291 | similarity = dot(query_embedding, chunk_embedding)/(norm(query_embedding)*norm(chunk_embedding))
292 |
293 | logger.info(f"\nChunk [{i}] (similarity: {similarity:.3f})")
294 | logger.info(f"Source: {doc.metadata['title']} ({doc.metadata['url']})")
295 | logger.info(f"Content:\n{doc.page_content}")
296 | logger.info("-"*80)
297 |
298 | relevant_chunks.append({
299 | "content": doc.page_content,
300 | "url": doc.metadata["url"],
301 | "title": doc.metadata["title"],
302 | "similarity": float(similarity)
303 | })
304 |
305 | logger.info(f"\nSelected {len(relevant_chunks)} most relevant chunks")
306 | return relevant_chunks
307 |
308 | except Exception as e:
309 | logger.error(f"Search error: {str(e)}")
310 | return []
--------------------------------------------------------------------------------
/app/services/message_service.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import datetime
3 | import httpx
4 | from typing import List, Tuple, Dict
5 | from fastapi import HTTPException
6 | import asyncio
7 | import time
8 | from cachetools import TTLCache
9 |
10 | from app.models.chat import Message, Role, SearchAnalysis, ContentType
11 | from app.config.settings import (
12 | OPENAI_API_URL, OPENAI_API_KEY,
13 | API_TIMEOUT, PROGRESS_MESSAGES,
14 | ENABLE_PROGRESS_MESSAGES, ENABLE_IMAGE_ANALYSIS,
15 | ENABLE_WEB_SEARCH, ENABLE_DOCUMENT_ANALYSIS,
16 | VECTOR_CACHE_TTL, OPENAI_ENHANCE_MODEL
17 | )
18 | from app.services.search_service import SearchService
19 | from app.services.image_service import process_image
20 | from app.services.document_service import DocumentService
21 | from app.utils.file_utils import detect_file_type
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | # Initialize services
26 | document_service = DocumentService()
27 | search_service = SearchService()
28 |
29 | # Global cache for processed files
30 | # Cache structure: {file_url: {'type': 'image'|'document', 'content': str, 'timestamp': float}}
31 | file_cache = TTLCache(maxsize=1000, ttl=VECTOR_CACHE_TTL)
32 |
33 | def get_headers(api_key: str) -> dict:
34 | return {
35 | "Content-Type": "application/json",
36 | "Authorization": f"Bearer {api_key}"
37 | }
38 |
39 | async def analyze_search_need(query: str) -> SearchAnalysis:
40 | """Analyze if the query needs search context"""
41 | if not OPENAI_API_KEY:
42 | raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured")
43 |
44 | async with httpx.AsyncClient() as client:
45 | try:
46 | response = await client.post(
47 | OPENAI_API_URL,
48 | json={
49 | "model": OPENAI_ENHANCE_MODEL,
50 | "messages": [{
51 | "role": "system",
52 | "content": "You are an expert at analyzing whether queries need real-time information. Return true only for queries about current events, real-time info, or those needing factual verification."
53 | }, {
54 | "role": "user",
55 | "content": query
56 | }],
57 | "response_format": {
58 | "type": "json_schema",
59 | "json_schema": {
60 | "name": "search_analysis",
61 | "schema": {
62 | "type": "object",
63 | "properties": {
64 | "needs_search": {"type": "boolean"},
65 | "search_keywords": {
66 | "type": "array",
67 | "items": {"type": "string"}
68 | }
69 | },
70 | "required": ["needs_search", "search_keywords"],
71 | "additionalProperties": False
72 | },
73 | "strict": True
74 | }
75 | }
76 | },
77 | headers=get_headers(OPENAI_API_KEY),
78 | timeout=API_TIMEOUT
79 | )
80 | response.raise_for_status()
81 | result = response.json()
82 |
83 | content = result.get("choices", [{}])[0].get("message", {}).get("content", {})
84 | if isinstance(content, str):
85 | import json
86 | try:
87 | content = json.loads(content)
88 | except json.JSONDecodeError:
89 | logger.error(f"Failed to parse response content as JSON: {content}")
90 | content = {"needs_search": False, "search_keywords": []}
91 |
92 | analysis = SearchAnalysis.model_validate(content)
93 |
94 | logger.info(f"Search analysis for query '{query}':")
95 | logger.info(f"Needs search: {analysis.needs_search}")
96 | if analysis.search_keywords:
97 | logger.info(f"Search keywords: {analysis.search_keywords}")
98 |
99 | return analysis
100 | except Exception as e:
101 | logger.error(f"Search analysis error: {str(e)}")
102 | raise HTTPException(status_code=500, detail=f"Search analysis failed: {str(e)}")
103 |
104 | async def process_messages(messages: List[Message], stream: bool = False, progress_queue: asyncio.Queue = None) -> Tuple[List[Message], List[str]]:
105 | """Process messages and handle any images or documents in them"""
106 |
107 | async def send_progress(message_type: str):
108 | if ENABLE_PROGRESS_MESSAGES and stream and progress_queue:
109 | await progress_queue.put(PROGRESS_MESSAGES[message_type])
110 |
111 | processed_messages = []
112 | system_content_parts = []
113 | current_search_parts = []
114 | current_search_urls = []
115 |
116 | current_date = datetime.datetime.now().strftime("%Y-%m-%d")
117 | system_content_parts.append(f"Current date: {current_date}")
118 |
119 | # Track processed files in current request to avoid duplicates
120 | request_processed_files = set()
121 |
122 | for message in messages:
123 | if isinstance(message.content, list):
124 | text_parts = []
125 | query_text = ""
126 |
127 | for content in message.content:
128 | if content.type == ContentType.TEXT:
129 | query_text = content.text
130 | text_parts.append(query_text)
131 | # Truncate long text in logs
132 | log_text = query_text[:50] + "..." if len(query_text) > 50 else query_text
133 | logger.info(f"Added text content: {log_text}")
134 | elif content.type == ContentType.IMAGE_URL and content.image_url:
135 | file_url = content.image_url.url
136 |
137 | # Skip if already processed in this request
138 | if file_url in request_processed_files:
139 | continue
140 | request_processed_files.add(file_url)
141 |
142 | # For base64 URLs, show truncated version in logs
143 | log_url = file_url
144 | if file_url.startswith('data:'):
145 | log_url = file_url[:30] + "..." + file_url[-10:]
146 | elif len(file_url) > 100:
147 | log_url = file_url[:50] + "..." + file_url[-50:]
148 |
149 | # Check cache first
150 | cache_hit = file_cache.get(file_url)
151 | if cache_hit:
152 | logger.info(f"Cache hit for file: {log_url}")
153 | if cache_hit['type'] == 'image':
154 | text_parts.append(f"[Image Content: {cache_hit['content']}]")
155 | logger.info("Added cached image content")
156 | elif cache_hit['type'] == 'document' and query_text:
157 | # For documents, we still need to search with the current query
158 | logger.info("Using cached document for search")
159 | await send_progress("document_search")
160 | relevant_chunks = await document_service.search_similar(query_text, source_url=file_url)
161 | if relevant_chunks:
162 | chunks_text = "\n\n".join(
163 | f"Relevant section {i+1} from {file_url}:\n{chunk.page_content}"
164 | for i, chunk in enumerate(relevant_chunks)
165 | )
166 | text_parts.append(f"[Document Content:\n{chunks_text}]")
167 | logger.info(f"Added {len(relevant_chunks)} relevant document sections from cache")
168 | else:
169 | text_parts.append(f"[No relevant content found in cached document: {log_url}]")
170 | continue
171 |
172 | logger.info(f"\nProcessing file URL: {log_url}")
173 |
174 | # Check if it's base64 data
175 | file_type, is_supported = detect_file_type(file_url, "")
176 |
177 | # Skip processing based on feature flags
178 | if file_type == 'image':
179 | if not ENABLE_IMAGE_ANALYSIS:
180 | continue
181 | await send_progress("image_analysis")
182 | else: # document type
183 | if not ENABLE_DOCUMENT_ANALYSIS:
184 | continue
185 | await send_progress("document_analysis")
186 |
187 | try:
188 | if file_type == 'image' and is_supported and ENABLE_IMAGE_ANALYSIS:
189 | logger.info("Processing as IMAGE")
190 | image_content = await process_image(file_url)
191 | if image_content:
192 | # Cache the result
193 | file_cache[file_url] = {
194 | 'type': 'image',
195 | 'content': image_content,
196 | 'timestamp': time.time()
197 | }
198 | text_parts.append(f"[Image Content: {image_content}]")
199 | logger.info("Image processing completed and cached")
200 | else:
201 | if ENABLE_DOCUMENT_ANALYSIS:
202 | logger.info("Processing as DOCUMENT")
203 | documents = await document_service.process_document(file_url)
204 | if documents:
205 | # Cache the document processing status
206 | file_cache[file_url] = {
207 | 'type': 'document',
208 | 'content': 'processed',
209 | 'timestamp': time.time()
210 | }
211 | if query_text:
212 | logger.info("Searching document for relevant content...")
213 | await send_progress("document_search")
214 | relevant_chunks = await document_service.search_similar(query_text, source_url=file_url)
215 | if relevant_chunks:
216 | chunks_text = "\n\n".join(
217 | f"Relevant section {i+1} from {file_url}:\n{chunk.page_content}"
218 | for i, chunk in enumerate(relevant_chunks)
219 | )
220 | text_parts.append(f"[Document Content:\n{chunks_text}]")
221 | logger.info(f"Added {len(relevant_chunks)} relevant document sections")
222 | else:
223 | text_parts.append(f"[No relevant content found in document: {file_url}]")
224 | except Exception as e:
225 | logger.error(f"Error processing file {file_url}: {str(e)}")
226 | text_parts.append(f"[Error processing file {file_url}: {str(e)}]")
227 |
228 | # Only perform web search if this is the last user message
229 | if query_text and message == messages[-1] and message.role == Role.USER:
230 | await send_progress("web_search")
231 |
232 | if ENABLE_WEB_SEARCH:
233 | search_analysis = await analyze_search_need(query_text)
234 |
235 | if search_analysis.needs_search and search_analysis.search_keywords:
236 | logger.info("Performing web search...")
237 | search_results = await search_service.search(query_text)
238 |
239 | search_context_parts = []
240 | current_search_urls = []
241 |
242 | for idx, result in enumerate(search_results, 1):
243 | url = result["url"]
244 | title = result["title"]
245 | content = result["content"]
246 | similarity = result["similarity"]
247 |
248 | current_search_urls.append(f"> [{idx}] {url}")
249 | search_context_parts.append(
250 | f"Source [{idx}] (Similarity: {similarity:.2f}): {title}\n{content}"
251 | )
252 |
253 | if search_context_parts:
254 | search_context = "\n\n".join(search_context_parts)
255 | current_search_parts = [
256 | "When using information from the search results, cite the sources using [n] format where n is the source number.",
257 | f"Search Results:\n{search_context}",
258 | f"Include these references at the end of your response:\n{chr(10).join(current_search_urls)}"
259 | ]
260 | logger.info(f"Added {len(search_context_parts)} search results to context")
261 | else:
262 | logger.info("Web search is disabled")
263 |
264 | if text_parts:
265 | processed_messages.append(Message(
266 | role=message.role,
267 | content=" ".join(text_parts)
268 | ))
269 | else:
270 | # Handle string content (direct text input)
271 | query_text = message.content
272 | processed_messages.append(message)
273 |
274 | # Only perform web search if this is the last user message and it's a text query
275 | if query_text and message == messages[-1] and message.role == Role.USER:
276 | await send_progress("web_search")
277 |
278 | if ENABLE_WEB_SEARCH:
279 | search_analysis = await analyze_search_need(query_text)
280 |
281 | if search_analysis.needs_search and search_analysis.search_keywords:
282 | logger.info("Performing web search...")
283 | search_results = await search_service.search(query_text)
284 |
285 | search_context_parts = []
286 | current_search_urls = []
287 |
288 | for idx, result in enumerate(search_results, 1):
289 | url = result["url"]
290 | title = result["title"]
291 | content = result["content"]
292 | similarity = result["similarity"]
293 |
294 | current_search_urls.append(f"> [{idx}] {url}")
295 | search_context_parts.append(
296 | f"Source [{idx}] (Similarity: {similarity:.2f}): {title}\n{content}"
297 | )
298 |
299 | if search_context_parts:
300 | search_context = "\n\n".join(search_context_parts)
301 | current_search_parts = [
302 | "When using information from the search results, cite the sources using [n] format where n is the source number.",
303 | f"Search Results:\n{search_context}",
304 | f"Include these references at the end of your response:\n{chr(10).join(current_search_urls)}"
305 | ]
306 | logger.info(f"Added {len(search_context_parts)} search results to context")
307 | else:
308 | logger.info("Web search is disabled")
309 |
310 | final_system_parts = system_content_parts + current_search_parts
311 |
312 | if final_system_parts:
313 | final_system_content = "\n\n".join(final_system_parts)
314 | processed_messages.insert(0, Message(
315 | role=Role.SYSTEM,
316 | content=final_system_content
317 | ))
318 |
319 | return processed_messages, current_search_urls
--------------------------------------------------------------------------------
/app/services/document_service.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import tempfile
4 | import httpx
5 | import ftfy
6 | import json
7 | import time
8 | import hashlib
9 | from pathlib import Path
10 | from typing import List, Optional, Tuple
11 | import faiss
12 | from langchain_community.document_loaders import (
13 | BSHTMLLoader,
14 | CSVLoader,
15 | Docx2txtLoader,
16 | OutlookMessageLoader,
17 | PyPDFLoader,
18 | TextLoader,
19 | UnstructuredEPubLoader,
20 | UnstructuredExcelLoader,
21 | UnstructuredMarkdownLoader,
22 | UnstructuredPowerPointLoader,
23 | UnstructuredRSTLoader,
24 | UnstructuredXMLLoader
25 | )
26 | from langchain_text_splitters import RecursiveCharacterTextSplitter
27 | from langchain_openai import OpenAIEmbeddings
28 | from langchain_core.vectorstores import VectorStore
29 | from langchain_community.vectorstores import FAISS
30 | from langchain_postgres import PGVector
31 | from langchain_pinecone import PineconeVectorStore
32 | from pinecone import Pinecone
33 | from langchain_core.documents import Document as LangchainDocument
34 | from fastapi import HTTPException
35 | import asyncio
36 |
37 | from app.models.chat import Document
38 | from app.config.settings import (
39 | EMBEDDING_BASE_URL, EMBEDDING_API_KEY,
40 | EMBEDDING_MODEL, EMBEDDING_DIMENSIONS,
41 | CHUNK_SIZE, CHUNK_OVERLAP, MAX_CHUNKS_PER_DOC,
42 | EMBEDDING_BATCH_SIZE, MIN_CHUNK_LENGTH,
43 | MAX_FILE_SIZE, VECTOR_CACHE_DIR, VECTOR_CACHE_TTL,
44 | VECTOR_STORE_TYPE, POSTGRES_CONNECTION_STRING,
45 | POSTGRES_COLLECTION_NAME, PINECONE_API_KEY,
46 | PINECONE_INDEX_NAME
47 | )
48 |
49 | logger = logging.getLogger(__name__)
50 |
51 | # Suppress FAISS GPU warning since we're explicitly using CPU
52 | faiss.get_num_gpus = lambda: 0
53 |
54 | # Known source file extensions
55 | KNOWN_SOURCE_EXT = [
56 | "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", "ts", "css",
57 | "cpp", "hpp", "h", "c", "cs", "sql", "log", "ini", "pl", "pm",
58 | "r", "dart", "dockerfile", "env", "php", "hs", "hsc", "lua",
59 | "nginxconf", "conf", "m", "mm", "plsql", "perl", "rb", "rs",
60 | "db2", "scala", "bash", "swift", "vue", "svelte", "msg", "ex",
61 | "exs", "erl", "tsx", "jsx", "hs", "lhs"
62 | ]
63 |
64 | class DocumentService:
65 | def __init__(self):
66 | self.text_splitter = RecursiveCharacterTextSplitter(
67 | chunk_size=CHUNK_SIZE,
68 | chunk_overlap=CHUNK_OVERLAP,
69 | add_start_index=True,
70 | separators=["\n\n", "\n", ". ", " ", ""],
71 | length_function=len,
72 | is_separator_regex=False
73 | )
74 | self.embeddings = OpenAIEmbeddings(
75 | model=EMBEDDING_MODEL,
76 | openai_api_key=EMBEDDING_API_KEY,
77 | openai_api_base=EMBEDDING_BASE_URL,
78 | dimensions=EMBEDDING_DIMENSIONS,
79 | request_timeout=60.0,
80 | show_progress_bar=True,
81 | retry_min_seconds=1,
82 | retry_max_seconds=60,
83 | max_retries=3,
84 | skip_empty=True,
85 | )
86 | self.vector_store: Optional[VectorStore] = None
87 |
88 | # Initialize vector store based on configuration
89 | logger.info(f"Initializing vector store with type: {VECTOR_STORE_TYPE}")
90 |
91 | if VECTOR_STORE_TYPE == "postgres":
92 | if not POSTGRES_CONNECTION_STRING:
93 | logger.error("PostgreSQL connection string is not configured")
94 | logger.warning("Falling back to FAISS vector store")
95 | self._init_faiss_store()
96 | else:
97 | try:
98 | logger.info(f"Attempting to connect to PostgreSQL with collection: {POSTGRES_COLLECTION_NAME}")
99 | self.vector_store = PGVector(
100 | connection=POSTGRES_CONNECTION_STRING,
101 | collection_name=POSTGRES_COLLECTION_NAME,
102 | embeddings=self.embeddings,
103 | )
104 | logger.info(f"Successfully initialized PostgreSQL vector store with collection: {POSTGRES_COLLECTION_NAME}")
105 | except Exception as e:
106 | logger.error(f"Failed to initialize PostgreSQL vector store: {str(e)}")
107 | logger.warning("Falling back to FAISS vector store")
108 | self._init_faiss_store()
109 | elif VECTOR_STORE_TYPE == "pinecone":
110 | try:
111 | pc = Pinecone(api_key=PINECONE_API_KEY)
112 | index = pc.Index(PINECONE_INDEX_NAME)
113 | self.vector_store = PineconeVectorStore(
114 | embedding=self.embeddings,
115 | index=index,
116 | namespace="default"
117 | )
118 | logger.info(f"Initialized Pinecone vector store with index: {PINECONE_INDEX_NAME}")
119 | except Exception as e:
120 | logger.error(f"Failed to initialize Pinecone vector store: {str(e)}")
121 | logger.warning("Falling back to FAISS vector store")
122 | self._init_faiss_store()
123 | else:
124 | logger.info("Using FAISS vector store as configured")
125 | self._init_faiss_store()
126 |
127 | def _init_faiss_store(self):
128 | """Initialize FAISS vector store and ensure cache directory exists"""
129 | if VECTOR_STORE_TYPE == "faiss":
130 | os.makedirs(VECTOR_CACHE_DIR, exist_ok=True)
131 | logger.info("Initialized FAISS vector store (CPU mode)")
132 |
133 | def _create_vector_store(self, documents: List[LangchainDocument]) -> VectorStore:
134 | """Create a new vector store instance"""
135 | try:
136 | if VECTOR_STORE_TYPE == "postgres" and POSTGRES_CONNECTION_STRING:
137 | store = PGVector.from_documents(
138 | documents=documents,
139 | embedding=self.embeddings,
140 | collection_name=POSTGRES_COLLECTION_NAME,
141 | connection=POSTGRES_CONNECTION_STRING,
142 | )
143 | logger.info("Created new PostgreSQL vector store")
144 | return store
145 | elif VECTOR_STORE_TYPE == "pinecone":
146 | pc = Pinecone(api_key=PINECONE_API_KEY)
147 | index = pc.Index(PINECONE_INDEX_NAME)
148 | store = PineconeVectorStore.from_documents(
149 | documents=documents,
150 | embedding=self.embeddings,
151 | index=index,
152 | namespace="default"
153 | )
154 | logger.info("Created new Pinecone vector store")
155 | return store
156 | except Exception as e:
157 | logger.error(f"Failed to create vector store: {str(e)}")
158 | logger.warning("Falling back to FAISS vector store")
159 |
160 | # Default or fallback to FAISS
161 | store = FAISS.from_documents(
162 | documents,
163 | self.embeddings,
164 | distance_strategy="COSINE"
165 | )
166 | logger.info("Created new FAISS vector store")
167 | return store
168 |
169 | def _add_to_vector_store(self, store: VectorStore, documents: List[LangchainDocument]):
170 | """Add documents to existing vector store"""
171 | try:
172 | if isinstance(store, PGVector):
173 | store.add_documents(documents)
174 | logger.info(f"Added {len(documents)} documents to PostgreSQL vector store")
175 | elif isinstance(store, PineconeVectorStore):
176 | store.add_documents(documents)
177 | logger.info(f"Added {len(documents)} documents to Pinecone vector store")
178 | elif isinstance(store, FAISS):
179 | store.add_documents(documents)
180 | logger.info(f"Added {len(documents)} documents to FAISS vector store")
181 | else:
182 | raise ValueError(f"Unsupported vector store type: {type(store)}")
183 | except Exception as e:
184 | logger.error(f"Failed to add documents to vector store: {str(e)}")
185 | raise
186 |
187 | def _get_cache_key(self, filename: str, file_size: int) -> str:
188 | """Generate cache key from filename and size"""
189 | return hashlib.md5(f"{filename}_{file_size}".encode()).hexdigest()
190 |
191 | def _get_cache_path(self, cache_key: str) -> Tuple[Path, Path]:
192 | """Get cache file paths for documents and vectors (FAISS only)"""
193 | docs_path = Path(VECTOR_CACHE_DIR) / f"{cache_key}_docs.json"
194 | vectors_path = Path(VECTOR_CACHE_DIR) / f"{cache_key}_vectors.faiss"
195 | return docs_path, vectors_path
196 |
197 | def _save_to_cache(self, cache_key: str, documents: List[Document], vector_store: VectorStore):
198 | """Save documents and vector store to cache (FAISS only)"""
199 | if not isinstance(vector_store, FAISS):
200 | return
201 |
202 | try:
203 | docs_path, vectors_path = self._get_cache_path(cache_key)
204 |
205 | # Save documents and timestamp
206 | docs_data = {
207 | "timestamp": time.time(),
208 | "documents": [doc.model_dump() for doc in documents]
209 | }
210 | with open(docs_path, 'w', encoding='utf-8') as f:
211 | json.dump(docs_data, f)
212 | logger.info(f"Saved document cache to {docs_path}")
213 |
214 | # Save FAISS index
215 | vector_store.save_local(str(vectors_path))
216 | logger.info(f"Saved FAISS vectors to {vectors_path}")
217 |
218 | except Exception as e:
219 | logger.error(f"Failed to save cache: {str(e)}")
220 |
221 | def _load_from_cache(self, cache_key: str) -> Tuple[Optional[List[Document]], Optional[VectorStore]]:
222 | """Load documents and vector store from cache if valid (FAISS only)"""
223 | try:
224 | docs_path, vectors_path = self._get_cache_path(cache_key)
225 |
226 | # Check if both cache files exist
227 | if not docs_path.exists() or not vectors_path.exists():
228 | return None, None
229 |
230 | # Load and validate documents cache
231 | try:
232 | with open(docs_path, 'r', encoding='utf-8') as f:
233 | docs_data = json.load(f)
234 | except (PermissionError, json.JSONDecodeError) as e:
235 | logger.error(f"Failed to read document cache: {str(e)}")
236 | return None, None
237 |
238 | # Check if cache is expired
239 | if time.time() - docs_data["timestamp"] > VECTOR_CACHE_TTL:
240 | logger.info("Cache expired, will reprocess document")
241 | # Clean up expired cache files
242 | try:
243 | if docs_path.exists():
244 | os.unlink(docs_path)
245 | if vectors_path.exists():
246 | os.unlink(vectors_path)
247 | except PermissionError as e:
248 | logger.warning(f"Failed to clean up expired cache: {str(e)}")
249 | return None, None
250 |
251 | # Restore documents
252 | try:
253 | documents = [Document.model_validate(doc) for doc in docs_data["documents"]]
254 | except Exception as e:
255 | logger.error(f"Failed to validate documents: {str(e)}")
256 | return None, None
257 |
258 | # Restore FAISS vector store
259 | try:
260 | vector_store = FAISS.load_local(
261 | str(vectors_path),
262 | self.embeddings,
263 | allow_dangerous_deserialization=True # Allow since we control the cache
264 | )
265 | logger.info(f"Loaded FAISS vectors from {vectors_path}")
266 | return documents, vector_store
267 | except Exception as e:
268 | logger.error(f"Failed to load vector store: {str(e)}")
269 | # Clean up invalid cache
270 | try:
271 | if docs_path.exists():
272 | os.unlink(docs_path)
273 | if vectors_path.exists():
274 | os.unlink(vectors_path)
275 | except PermissionError as e:
276 | logger.warning(f"Failed to clean up invalid cache: {str(e)}")
277 | return None, None
278 |
279 | except Exception as e:
280 | logger.error(f"Failed to load cache: {str(e)}")
281 | return None, None
282 |
283 | async def download_file(self, url: str) -> tuple[str, str, str]:
284 | """Download file from URL and return file path, name and content type"""
285 | async with httpx.AsyncClient() as client:
286 | try:
287 | # First do a HEAD request to check content length
288 | head_response = await client.head(url, follow_redirects=True)
289 | head_response.raise_for_status()
290 |
291 | # Check content length if available
292 | content_length = head_response.headers.get("content-length")
293 | if content_length:
294 | file_size = int(content_length)
295 | if file_size > MAX_FILE_SIZE:
296 | raise HTTPException(
297 | status_code=413,
298 | detail=f"File size ({file_size} bytes) exceeds maximum allowed size ({MAX_FILE_SIZE} bytes)"
299 | )
300 |
301 | # Proceed with download using streaming to enforce size limit
302 | async with client.stream("GET", url, follow_redirects=True) as response:
303 | response.raise_for_status()
304 |
305 | # Get filename from URL or Content-Disposition
306 | content_disposition = response.headers.get("content-disposition")
307 | if content_disposition and "filename=" in content_disposition:
308 | filename = content_disposition.split("filename=")[-1].strip('"')
309 | else:
310 | filename = url.split("/")[-1]
311 |
312 | content_type = response.headers.get("content-type", "")
313 |
314 | # Create temporary file with proper extension
315 | ext = os.path.splitext(filename)[1]
316 | with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
317 | total_size = 0
318 | async for chunk in response.aiter_bytes(chunk_size=8192):
319 | total_size += len(chunk)
320 | if total_size > MAX_FILE_SIZE:
321 | # Clean up and raise error
322 | temp_file.close()
323 | os.unlink(temp_file.name)
324 | raise HTTPException(
325 | status_code=413,
326 | detail=f"File size exceeds maximum allowed size ({MAX_FILE_SIZE} bytes)"
327 | )
328 | temp_file.write(chunk)
329 |
330 | return temp_file.name, filename, content_type
331 | except HTTPException:
332 | raise
333 | except Exception as e:
334 | logger.error(f"Error downloading file from {url}: {str(e)}")
335 | raise
336 |
337 | def _get_loader(self, filename: str, content_type: str, file_path: str):
338 | """Get appropriate document loader based on file type"""
339 | file_ext = filename.split(".")[-1].lower() if "." in filename else ""
340 |
341 | if file_ext == "pdf":
342 | return PyPDFLoader(file_path)
343 | elif file_ext == "csv":
344 | return CSVLoader(file_path)
345 | elif file_ext == "rst":
346 | return UnstructuredRSTLoader(file_path, mode="elements")
347 | elif file_ext == "xml":
348 | return UnstructuredXMLLoader(file_path)
349 | elif file_ext in ["htm", "html"]:
350 | return BSHTMLLoader(file_path, open_encoding="unicode_escape")
351 | elif file_ext == "md":
352 | return UnstructuredMarkdownLoader(file_path)
353 | elif content_type == "application/epub+zip":
354 | return UnstructuredEPubLoader(file_path)
355 | elif content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" or file_ext == "docx":
356 | return Docx2txtLoader(file_path)
357 | elif content_type in ["application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"] or file_ext in ["xls", "xlsx"]:
358 | return UnstructuredExcelLoader(file_path)
359 | elif content_type in ["application/vnd.ms-powerpoint", "application/vnd.openxmlformats-officedocument.presentationml.presentation"] or file_ext in ["ppt", "pptx"]:
360 | return UnstructuredPowerPointLoader(file_path)
361 | elif file_ext == "msg":
362 | return OutlookMessageLoader(file_path)
363 | elif file_ext in KNOWN_SOURCE_EXT or (content_type and content_type.find("text/") >= 0):
364 | return TextLoader(file_path, autodetect_encoding=True)
365 | else:
366 | return TextLoader(file_path, autodetect_encoding=True)
367 |
368 | async def _process_chunk_batch(self, batch: List[LangchainDocument]) -> None:
369 | """Process a batch of document chunks asynchronously"""
370 | try:
371 | if not self.vector_store:
372 | logger.info("Initializing new vector store")
373 | self.vector_store = self._create_vector_store(batch)
374 | else:
375 | logger.info("Adding documents to existing vector store")
376 | self._add_to_vector_store(self.vector_store, batch)
377 | except Exception as e:
378 | logger.error(f"Failed to process batch: {str(e)}")
379 | raise
380 |
381 | async def process_document(self, url: str) -> List[Document]:
382 | """Process a document from URL and store it in vector store"""
383 | temp_file_path = None
384 | try:
385 | # Download file
386 | temp_file_path, filename, content_type = await self.download_file(url)
387 | file_size = os.path.getsize(temp_file_path)
388 |
389 | # Log file information
390 | logger.info("="*50)
391 | logger.info("File Type Detection:")
392 | logger.info(f"URL: {url}")
393 | logger.info(f"Content-Type: {content_type}")
394 | logger.info(f"Filename: {filename}")
395 | logger.info(f"File Size: {file_size} bytes")
396 | logger.info(f"File Extension: {os.path.splitext(filename)[1]}")
397 | logger.info("="*50)
398 |
399 | # For PostgreSQL, check if document already exists
400 | if VECTOR_STORE_TYPE == "postgres" and isinstance(self.vector_store, PGVector):
401 | try:
402 | # Query by source URL and file size
403 | existing_docs = self.vector_store.similarity_search(
404 | "",
405 | k=1000,
406 | filter={
407 | "source": url,
408 | "file_size": file_size
409 | }
410 | )
411 | if existing_docs:
412 | logger.info(f"Found existing document in PostgreSQL: {filename}")
413 | return [
414 | Document(
415 | page_content=doc.page_content,
416 | metadata=doc.metadata,
417 | id=f"{url}_{i}"
418 | ) for i, doc in enumerate(existing_docs)
419 | ]
420 | except Exception as e:
421 | logger.error(f"Failed to query PostgreSQL: {str(e)}")
422 | # For FAISS, check file cache
423 | elif VECTOR_STORE_TYPE == "faiss":
424 | cache_key = self._get_cache_key(filename, file_size)
425 | cached_docs, cached_store = self._load_from_cache(cache_key)
426 | if cached_docs and cached_store:
427 | logger.info("Using cached vectors from FAISS")
428 | self.vector_store = cached_store
429 | return cached_docs
430 |
431 | logger.info(f"Processing document: {filename} ({file_size} bytes)")
432 |
433 | # Get appropriate loader
434 | loader = self._get_loader(filename, content_type, temp_file_path)
435 | logger.info(f"Using loader: {loader.__class__.__name__}")
436 |
437 | # Load document
438 | docs = loader.load()
439 | logger.info(f"Loaded {len(docs)} document sections")
440 |
441 | # Fix text encoding and create metadata
442 | fixed_docs = []
443 | for doc in docs:
444 | if isinstance(doc.page_content, (str, bytes)):
445 | content = ftfy.fix_text(str(doc.page_content))
446 | else:
447 | content = str(doc.page_content)
448 |
449 | fixed_doc = LangchainDocument(
450 | page_content=content,
451 | metadata={
452 | **doc.metadata,
453 | "source": url,
454 | "filename": filename,
455 | "content_type": content_type,
456 | "file_size": file_size # Add file size to metadata for future lookups
457 | }
458 | )
459 | fixed_docs.append(fixed_doc)
460 |
461 | # Split into chunks
462 | splits = self.text_splitter.split_documents(fixed_docs)
463 | logger.info(f"Split into {len(splits)} chunks")
464 |
465 | # Convert to our Document model first
466 | documents = []
467 | for i, split in enumerate(splits):
468 | doc = Document(
469 | page_content=split.page_content,
470 | metadata=split.metadata,
471 | id=f"{url}_{i}"
472 | )
473 | documents.append(doc)
474 |
475 | # Try to update vector store if needed
476 | try:
477 | # Verify splits have content
478 | valid_splits = []
479 | for split in splits:
480 | if not isinstance(split.page_content, str):
481 | continue
482 | content = split.page_content.strip()
483 | if not content:
484 | continue
485 | if len(content) < MIN_CHUNK_LENGTH:
486 | continue
487 | valid_splits.append(split)
488 |
489 | logger.info(f"Found {len(valid_splits)} valid chunks for embedding")
490 | logger.info(f"Average chunk size: {sum(len(split.page_content) for split in valid_splits) / len(valid_splits) if valid_splits else 0:.0f} characters")
491 |
492 | # Process valid splits in parallel batches
493 | if valid_splits:
494 | tasks = []
495 | for i in range(0, len(valid_splits), EMBEDDING_BATCH_SIZE):
496 | batch = valid_splits[i:i + EMBEDDING_BATCH_SIZE]
497 | logger.info(f"Processing batch {i//EMBEDDING_BATCH_SIZE + 1} of {(len(valid_splits) + EMBEDDING_BATCH_SIZE - 1)//EMBEDDING_BATCH_SIZE} ({len(batch)} chunks)")
498 | tasks.append(self._process_chunk_batch(batch))
499 |
500 | # Process all batches concurrently
501 | await asyncio.gather(*tasks)
502 | logger.info(f"Processed all {len(valid_splits)} chunks")
503 | else:
504 | logger.warning("No valid content found for vector store update")
505 | except Exception as ve:
506 | logger.error(f"Vector store operation failed: {str(ve)}")
507 | # Continue without vector store update
508 | pass
509 |
510 | if documents:
511 | logger.info(f"Document processed into {len(documents)} sections")
512 | # Don't add unnecessary description
513 | last_file_description = None
514 | else:
515 | logger.warning("No valid content found for vector store update")
516 |
517 | # Only save to cache if using FAISS
518 | if VECTOR_STORE_TYPE == "faiss":
519 | cache_key = self._get_cache_key(filename, file_size)
520 | self._save_to_cache(cache_key, documents, self.vector_store)
521 |
522 | return documents
523 |
524 | except Exception as e:
525 | logger.error(f"Error processing document from {url}: {str(e)}")
526 | raise
527 | finally:
528 | # Clean up temporary file
529 | if temp_file_path:
530 | try:
531 | os.unlink(temp_file_path)
532 | except Exception as e:
533 | logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}")
534 | pass
535 |
536 | async def search_similar(self, query: str, source_url: str = None, k: int = None) -> List[Document]:
537 | """Search for similar documents"""
538 | try:
539 | if not self.vector_store:
540 | logger.info("Vector store not initialized")
541 | return []
542 |
543 | # Use configured max chunks if k is not specified
544 | if k is None:
545 | k = MAX_CHUNKS_PER_DOC
546 |
547 | try:
548 | # Add source URL filter if provided
549 | filter_dict = {"source": source_url} if source_url else None
550 |
551 | results = self.vector_store.similarity_search(
552 | query,
553 | k=k,
554 | filter=filter_dict
555 | )
556 | except Exception as e:
557 | logger.error(f"Vector search failed: {str(e)}")
558 | return []
559 |
560 | return [
561 | Document(
562 | page_content=doc.page_content,
563 | metadata=doc.metadata,
564 | id=f"{doc.metadata.get('source', '')}_{i}"
565 | ) for i, doc in enumerate(results)
566 | ]
567 | except Exception as e:
568 | logger.error(f"Error: {str(e)}")
569 | return []
--------------------------------------------------------------------------------