├── .python-version ├── tests ├── __init__.py ├── lib │ ├── __init__.py │ ├── test_embedder.py │ ├── test_crawler.py │ ├── test_chunker.py │ └── test_crawler_enhanced.py ├── integration │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ └── test_vector_indexer.py │ ├── services │ │ ├── __init__.py │ │ └── test_document_service.py │ ├── conftest.py │ └── test_processor_enhanced.py ├── services │ ├── __init__.py │ ├── test_admin_service.py │ ├── test_map_service_legacy.py │ ├── test_job_service.py │ └── test_map_service.py ├── README.md ├── conftest.py ├── api │ └── test_map_api.py └── common │ └── test_processor.py ├── doctor.png ├── src ├── common │ ├── __init__.py │ ├── logger.py │ ├── config.py │ ├── processor.py │ ├── models.py │ └── processor_enhanced.py ├── lib │ ├── __init__.py │ ├── database │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── migrations │ │ │ └── 001_add_hierarchy.sql │ │ ├── migrations.py │ │ └── schema.py │ ├── embedder.py │ ├── chunker.py │ ├── crawler.py │ └── crawler_enhanced.py ├── web_service │ ├── __init__.py │ ├── services │ │ ├── __init__.py │ │ ├── admin_service.py │ │ ├── debug_bm25.py │ │ └── job_service.py │ ├── Dockerfile │ ├── api │ │ ├── __init__.py │ │ ├── admin.py │ │ ├── map.py │ │ ├── jobs.py │ │ ├── documents.py │ │ └── diagnostics.py │ └── main.py └── crawl_worker │ ├── __init__.py │ ├── Dockerfile │ └── main.py ├── .cursor └── rules │ ├── code-style.mdc │ ├── testing.mdc │ └── refactors.mdc ├── pytest.ini ├── Dockerfile.base ├── .pre-commit-config.yaml ├── .gitignore ├── LICENSE.md ├── .github └── workflows │ └── pytest.yml ├── docker-compose.yml ├── pyproject.toml ├── llms.txt ├── docs └── maps_api_examples.md └── README.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the Doctor application.""" 2 | -------------------------------------------------------------------------------- /tests/lib/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the lib module.""" 2 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | """Integration tests package.""" 2 | -------------------------------------------------------------------------------- /doctor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sisig-ai/doctor/HEAD/doctor.png -------------------------------------------------------------------------------- /src/common/__init__.py: -------------------------------------------------------------------------------- 1 | """Common utilities for the Doctor project.""" 2 | -------------------------------------------------------------------------------- /tests/services/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the web service's services.""" 2 | -------------------------------------------------------------------------------- /src/lib/__init__.py: -------------------------------------------------------------------------------- 1 | """Shared library components for the Doctor application.""" 2 | -------------------------------------------------------------------------------- /src/web_service/__init__.py: -------------------------------------------------------------------------------- 1 | """Web Service module for the Doctor project.""" 2 | -------------------------------------------------------------------------------- /src/crawl_worker/__init__.py: -------------------------------------------------------------------------------- 1 | """Crawl worker package for the Doctor application.""" 2 | -------------------------------------------------------------------------------- /src/web_service/services/__init__.py: -------------------------------------------------------------------------------- 1 | """Services package for the web service.""" 2 | -------------------------------------------------------------------------------- /tests/integration/common/__init__.py: -------------------------------------------------------------------------------- 1 | """Integration tests for common modules.""" 2 | -------------------------------------------------------------------------------- /tests/integration/services/__init__.py: -------------------------------------------------------------------------------- 1 | """Integration tests for service modules.""" 2 | -------------------------------------------------------------------------------- /.cursor/rules/code-style.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: 3 | globs: 4 | alwaysApply: true 5 | --- 6 | - Use Google-style docstrings 7 | - Always type-annotate variables and methods 8 | - After making all changes, run `ruff format` and `ruff check --fix` 9 | -------------------------------------------------------------------------------- /src/crawl_worker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM doctor-base:latest AS crawl_worker 2 | 3 | # The WORKDIR, ENV, and code are inherited from the base image. 4 | # If specific overrides or additional ENV vars are needed for crawl_worker, set them here. 5 | 6 | RUN crawl4ai-setup 7 | 8 | # Run the worker 9 | CMD ["python", "-m", "src.crawl_worker.main"] 10 | -------------------------------------------------------------------------------- /src/web_service/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM doctor-base:latest AS web_service 2 | 3 | # The WORKDIR, ENV, and code are inherited from the base image. 4 | # If specific overrides or additional ENV vars are needed for web_service, set them here. 5 | 6 | # Expose the port 7 | EXPOSE 9111 8 | 9 | # Run the web service 10 | CMD ["python", "src/web_service/main.py"] 11 | -------------------------------------------------------------------------------- /.cursor/rules/testing.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Writing tests or implementing new logic 3 | globs: 4 | alwaysApply: false 5 | --- 6 | - Always write tests using functional pytest 7 | - Tests should always maintain or increase code coverage 8 | - Use parametrized testing whenever possible, instead of multiple test functions 9 | - Use snapshot tests to check for known/constant values 10 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | python_files = test_*.py 4 | python_classes = Test* 5 | python_functions = test_* 6 | addopts = -v --cov=src --cov-report=term-missing 7 | asyncio_mode = auto 8 | asyncio_default_fixture_loop_scope = function 9 | markers = 10 | unit: marks tests as unit tests 11 | integration: marks tests as integration tests 12 | async_test: marks tests as asynchronous 13 | -------------------------------------------------------------------------------- /.cursor/rules/refactors.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Refactoring and implementing new logic 3 | globs: 4 | alwaysApply: false 5 | --- 6 | - Run tests after refactor/implementation is completed, ensuring they pass, and add news tests as necessary 7 | - To run the tests, do `pytest` (making sure that you are using `.venv/bin/python`) 8 | - When tests fail, adjust them according to the changes in code OR fix the code if it's a logic issue 9 | -------------------------------------------------------------------------------- /Dockerfile.base: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim 2 | 3 | WORKDIR /app 4 | 5 | # Install uv 6 | RUN pip install --no-cache-dir uv 7 | 8 | # Copy project configuration and readme 9 | COPY pyproject.toml README.md /app/ 10 | 11 | # Install dependencies 12 | RUN uv pip install --system -e . 13 | 14 | # Copy source code 15 | COPY src /app/src 16 | 17 | # Create data directory 18 | RUN mkdir -p /app/data 19 | 20 | # Set environment variables 21 | ENV PYTHONPATH=/app 22 | ENV PYTHONUNBUFFERED=1 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.11.8 4 | hooks: 5 | - id: ruff 6 | name: ruff check 7 | args: [--fix] 8 | - id: ruff-format 9 | name: ruff format 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.5.0 12 | hooks: 13 | - id: trailing-whitespace 14 | - id: end-of-file-fixer 15 | - id: check-yaml 16 | - repo: local 17 | hooks: 18 | - id: mypy-warn-only 19 | name: mypy (warning only) 20 | entry: bash -c "mypy --ignore-missing-imports || true" 21 | language: system 22 | types: [python] 23 | require_serial: true 24 | verbose: true 25 | -------------------------------------------------------------------------------- /src/web_service/api/__init__.py: -------------------------------------------------------------------------------- 1 | """API routes package for the web service.""" 2 | 3 | from fastapi import APIRouter 4 | 5 | from src.web_service.api.admin import router as admin_router 6 | from src.web_service.api.diagnostics import router as diagnostics_router 7 | from src.web_service.api.documents import router as documents_router 8 | from src.web_service.api.jobs import router as jobs_router 9 | from src.web_service.api.map import router as map_router 10 | 11 | # Create a main router that includes all the other routers 12 | api_router = APIRouter() 13 | 14 | # Include the routers 15 | api_router.include_router(documents_router) 16 | api_router.include_router(jobs_router) 17 | api_router.include_router(admin_router) 18 | api_router.include_router(diagnostics_router) 19 | api_router.include_router(map_router) 20 | -------------------------------------------------------------------------------- /src/lib/database/__init__.py: -------------------------------------------------------------------------------- 1 | """Initializes the database package, exposing key components for database interaction. 2 | 3 | This package provides a modular approach to database management for the Doctor project, 4 | separating connection handling, high-level operations, schema definitions, and utilities. 5 | 6 | The main interface for database operations is the `DatabaseOperations` class. 7 | """ 8 | 9 | from .connection import DuckDBConnectionManager 10 | from .operations import DatabaseOperations 11 | from .schema import ( 12 | CREATE_DOCUMENT_EMBEDDINGS_TABLE_SQL, 13 | CREATE_JOBS_TABLE_SQL, 14 | CREATE_PAGES_TABLE_SQL, 15 | ) 16 | from .utils import deserialize_tags, serialize_tags 17 | 18 | __all__ = [ 19 | "CREATE_DOCUMENT_EMBEDDINGS_TABLE_SQL", 20 | "CREATE_JOBS_TABLE_SQL", 21 | "CREATE_PAGES_TABLE_SQL", 22 | "DatabaseOperations", 23 | "DuckDBConnectionManager", 24 | "deserialize_tags", 25 | "serialize_tags", 26 | ] 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | parts/ 14 | sdist/ 15 | var/ 16 | wheels/ 17 | *.egg-info/ 18 | .installed.cfg 19 | *.egg 20 | .ropeproject/ 21 | .ruff_cache/ 22 | .pytest_cache/ 23 | data/ 24 | 25 | # Virtual Environment 26 | .env 27 | .venv 28 | env/ 29 | venv/ 30 | ENV/ 31 | env.bak/ 32 | venv.bak/ 33 | 34 | # DuckDB 35 | *.duckdb 36 | *.duckdb.wal 37 | 38 | # IDE 39 | .idea/ 40 | .vscode/ 41 | *.swp 42 | *.swo 43 | .zed 44 | 45 | # Logs 46 | *.log 47 | logs/ 48 | 49 | # Local configuration 50 | .env.local 51 | .env.development.local 52 | .env.test.local 53 | .env.production.local 54 | 55 | # Docker 56 | .docker/ 57 | docker-compose.override.yml 58 | 59 | # OS specific 60 | .DS_Store 61 | Thumbs.db 62 | 63 | .roo 64 | .coverage 65 | coverage.xml 66 | .aider* 67 | 68 | **/.claude/settings.local.json 69 | CLAUDE.md 70 | .claude/ 71 | -------------------------------------------------------------------------------- /tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures for integration tests.""" 2 | 3 | import duckdb 4 | import pytest 5 | 6 | 7 | @pytest.fixture 8 | def in_memory_duckdb_connection(): 9 | """Create an in-memory DuckDB connection for integration testing. 10 | 11 | This connection has the proper setup for vector search using the same 12 | setup logic as the main application: 13 | - VSS extension loaded 14 | - document_embeddings table created with the proper schema 15 | - HNSW index created 16 | - pages table created (for document service tests) 17 | """ 18 | from src.lib.database import DatabaseOperations 19 | 20 | conn = duckdb.connect(":memory:") 21 | 22 | # Use the same setup functions as the main application 23 | try: 24 | # Create a Database instance and initialize with our in-memory connection 25 | db = DatabaseOperations() 26 | db.db.conn = conn 27 | 28 | # Initialize all tables and extensions 29 | db.db.initialize() 30 | 31 | yield conn 32 | finally: 33 | conn.close() 34 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 sisig-ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests for Doctor Application 2 | 3 | This directory contains tests for the Doctor application. 4 | 5 | ## Structure 6 | 7 | - `conftest.py`: Common fixtures for all tests 8 | - `lib/`: Tests for the library components 9 | - `test_crawler.py`: Tests for the crawler module 10 | - `test_chunker.py`: Tests for the chunker module 11 | - `test_embedder.py`: Tests for the embedder module 12 | - `test_indexer.py`: Tests for the indexer module 13 | - `test_database.py`: Tests for the database module 14 | - `test_processor.py`: Tests for the processor module 15 | 16 | ## Running Tests 17 | 18 | To run all tests: 19 | 20 | ```bash 21 | pytest 22 | ``` 23 | 24 | To run tests with coverage: 25 | 26 | ```bash 27 | pytest --cov=src 28 | ``` 29 | 30 | To run specific test categories: 31 | 32 | ```bash 33 | # Run all unit tests 34 | pytest -m unit 35 | 36 | # Run all async tests 37 | pytest -m async_test 38 | 39 | # Run tests for a specific module 40 | pytest tests/lib/test_crawler.py 41 | ``` 42 | 43 | ## Test Markers 44 | 45 | - `unit`: Unit tests 46 | - `integration`: Integration tests (not implemented yet) 47 | - `async_test`: Tests that use asyncio 48 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Python Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.12' 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install uv 25 | uv venv .venv 26 | uv sync 27 | 28 | - name: Debug environment 29 | run: | 30 | echo "Current directory: $(pwd)" 31 | echo "Python path: $(which python)" 32 | echo "Python version: $(python --version)" 33 | echo "Pytest version: $(.venv/bin/pytest --version)" 34 | echo "Listing test directory:" 35 | ls -la tests/ 36 | ls -la tests/lib/ 37 | 38 | - name: Test with pytest 39 | run: | 40 | .venv/bin/pytest tests/ -v --cov=src --cov-report=xml 41 | 42 | - name: Upload coverage to GitHub 43 | env: 44 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 45 | uses: codecov/codecov-action@v3 46 | with: 47 | file: ./coverage.xml 48 | fail_ci_if_error: false 49 | -------------------------------------------------------------------------------- /src/web_service/services/admin_service.py: -------------------------------------------------------------------------------- 1 | """Admin service for the web service.""" 2 | 3 | import uuid 4 | 5 | from rq import Queue 6 | 7 | from src.common.logger import get_logger 8 | 9 | # Get logger for this module 10 | logger = get_logger(__name__) 11 | 12 | 13 | async def delete_docs( 14 | queue: Queue, 15 | tags: list[str] | None = None, 16 | domain: str | None = None, 17 | page_ids: list[str] | None = None, 18 | ) -> str: 19 | """Delete documents from the database based on filters. 20 | 21 | Args: 22 | queue: Redis queue for job processing 23 | tags: Optional list of tags to filter by 24 | domain: Optional domain substring to filter by 25 | page_ids: Optional list of specific page IDs to delete 26 | 27 | Returns: 28 | str: The task ID for tracking 29 | 30 | """ 31 | logger.info( 32 | f"Enqueueing delete task with filters: tags={tags}, domain={domain}, page_ids={page_ids}", 33 | ) 34 | 35 | # Generate a task ID for tracking logs 36 | task_id = str(uuid.uuid4()) 37 | 38 | # Enqueue the delete task 39 | queue.enqueue( 40 | "src.crawl_worker.tasks.delete_docs", 41 | task_id, 42 | tags, 43 | domain, 44 | page_ids, 45 | ) 46 | 47 | logger.info(f"Enqueued delete task with ID: {task_id}") 48 | 49 | return task_id 50 | -------------------------------------------------------------------------------- /src/lib/database/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the database module. 2 | 3 | This module contains helper functions for common data transformations or 4 | operations used within the database package, such as serializing and 5 | deserializing tags. 6 | """ 7 | 8 | import json 9 | 10 | from src.common.logger import get_logger 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def serialize_tags(tags: list[str] | None) -> str: 16 | """Serialize a list of tags to a JSON string. 17 | 18 | Args: 19 | tags: A list of tag strings, or None. 20 | Returns: 21 | str: A JSON string representation of the tags list. Returns an empty list '[]' if tags is None. 22 | """ 23 | if tags is None: 24 | return json.dumps([]) 25 | return json.dumps(tags) 26 | 27 | 28 | def deserialize_tags(tags_json: str | None) -> list[str]: 29 | """Deserialize a tags JSON string to a list of strings. 30 | 31 | Args: 32 | tags_json: The JSON string containing the tags, or None. 33 | Returns: 34 | list[str]: A list of tag strings. Returns an empty list if input is None, empty, or invalid JSON. 35 | """ 36 | if not tags_json: 37 | return [] 38 | try: 39 | return json.loads(tags_json) 40 | except json.JSONDecodeError: 41 | logger.warning(f"Could not decode tags JSON: {tags_json!r}") # Added !r for better logging 42 | return [] 43 | -------------------------------------------------------------------------------- /src/lib/database/migrations/001_add_hierarchy.sql: -------------------------------------------------------------------------------- 1 | -- Migration: Add hierarchy columns to pages table 2 | -- This migration adds columns to track page relationships for the maps feature 3 | 4 | -- Add hierarchy columns to pages table 5 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS parent_page_id VARCHAR; 6 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS root_page_id VARCHAR; 7 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS depth INTEGER DEFAULT 0; 8 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS path TEXT; 9 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS title TEXT; 10 | 11 | -- Add foreign key constraints (DuckDB doesn't enforce these, but good for documentation) 12 | -- ALTER TABLE pages ADD CONSTRAINT fk_parent_page FOREIGN KEY (parent_page_id) REFERENCES pages(id); 13 | -- ALTER TABLE pages ADD CONSTRAINT fk_root_page FOREIGN KEY (root_page_id) REFERENCES pages(id); 14 | 15 | -- Create indexes for better performance 16 | CREATE INDEX IF NOT EXISTS idx_pages_parent_page_id ON pages(parent_page_id); 17 | CREATE INDEX IF NOT EXISTS idx_pages_root_page_id ON pages(root_page_id); 18 | CREATE INDEX IF NOT EXISTS idx_pages_depth ON pages(depth); 19 | 20 | -- Update existing pages to have sensible defaults 21 | -- Set root_page_id to self for existing pages (they become roots) 22 | UPDATE pages 23 | SET root_page_id = id 24 | WHERE root_page_id IS NULL; 25 | 26 | -- Set depth to 0 for existing pages 27 | UPDATE pages 28 | SET depth = 0 29 | WHERE depth IS NULL; 30 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | base: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile.base 8 | image: doctor-base:latest 9 | 10 | 11 | redis: 12 | image: redis:alpine 13 | ports: 14 | - "6379:6379" 15 | networks: 16 | - doctor-net 17 | healthcheck: 18 | test: ["CMD", "redis-cli", "ping"] 19 | interval: 10s 20 | timeout: 5s 21 | retries: 5 22 | 23 | crawl_worker: 24 | build: 25 | context: . 26 | dockerfile: src/crawl_worker/Dockerfile 27 | volumes: 28 | - ./src:/app/src 29 | - ./data:/app/data 30 | environment: 31 | - REDIS_URI=redis://redis:6379 32 | - OPENAI_API_KEY=${OPENAI_API_KEY} 33 | - PYTHONPATH=/app 34 | depends_on: 35 | redis: 36 | condition: service_healthy 37 | base: 38 | condition: service_completed_successfully 39 | 40 | networks: 41 | - doctor-net 42 | 43 | web_service: 44 | build: 45 | context: . 46 | dockerfile: src/web_service/Dockerfile 47 | ports: 48 | - "9111:9111" 49 | volumes: 50 | - ./src:/app/src 51 | - ./data:/app/data 52 | environment: 53 | - REDIS_URI=redis://redis:6379 54 | - PYTHONPATH=/app 55 | - OPENAI_API_KEY=${OPENAI_API_KEY} 56 | 57 | depends_on: 58 | redis: 59 | condition: service_healthy 60 | base: 61 | condition: service_completed_successfully 62 | 63 | networks: 64 | - doctor-net 65 | 66 | networks: 67 | doctor-net: 68 | driver: bridge 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "doctor" 3 | version = "0.2.0" 4 | description = "LLM agent system for discovering, crawling, and indexing web sites" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "fastapi", 9 | "uvicorn[standard]", 10 | "redis", 11 | "rq", 12 | "crawl4ai==0.6.0", 13 | "langchain-text-splitters", 14 | "litellm", 15 | "openai", 16 | "duckdb", 17 | "pydantic>=2.0.0", 18 | "httpx", 19 | "python-multipart", 20 | "uuid", 21 | "fastapi-mcp==0.3.2", 22 | "mcp==1.7.1", 23 | "nest-asyncio>=1.6.0", 24 | "markdown>=3.8", 25 | ] 26 | 27 | [build-system] 28 | requires = ["hatchling"] 29 | build-backend = "hatchling.build" 30 | 31 | [tool.ruff] 32 | line-length = 100 33 | target-version = "py312" 34 | 35 | [tool.black] 36 | line-length = 100 37 | target-version = ["py312"] 38 | 39 | [tool.hatch.build.targets.wheel] 40 | # Tell hatch where to find the packages for the 'doctor' project 41 | packages = ["src/common", "src/crawl_worker", "src/web_service", "src/lib"] 42 | 43 | [dependency-groups] 44 | dev = [ 45 | "ruff>=0.11.8", 46 | "pytest-asyncio>=0.21.0", 47 | "pytest-cov>=4.1.0", 48 | "pytest>=8.3.5", 49 | "pre-commit>=4.2.0", 50 | "coverage>=7.8.0", 51 | ] 52 | 53 | [tool.pytest.ini_options] 54 | testpaths = ["tests"] 55 | python_files = "test_*.py" 56 | python_classes = "Test*" 57 | python_functions = "test_*" 58 | markers = [ 59 | "unit: marks tests as unit tests", 60 | "integration: marks tests as integration tests", 61 | "async_test: marks tests as asynchronous tests", 62 | ] 63 | -------------------------------------------------------------------------------- /src/lib/embedder.py: -------------------------------------------------------------------------------- 1 | """Text embedding functionality using LiteLLM.""" 2 | 3 | from typing import Literal 4 | 5 | import litellm 6 | 7 | from src.common.config import DOC_EMBEDDING_MODEL, QUERY_EMBEDDING_MODEL 8 | from src.common.logger import get_logger 9 | 10 | # Configure logging 11 | logger = get_logger(__name__) 12 | 13 | 14 | async def generate_embedding( 15 | text: str, 16 | model: str = None, 17 | timeout: int = 30, 18 | text_type: Literal["doc", "query"] = "doc", 19 | ) -> list[float]: 20 | """Generate an embedding for a text chunk. 21 | 22 | Args: 23 | text: The text to embed 24 | model: The embedding model to use (defaults to config value) 25 | timeout: Timeout in seconds for the embedding API call 26 | text_type: The type of text to embed (defaults to "doc") 27 | 28 | Returns: 29 | The generated embedding as a list of floats 30 | 31 | """ 32 | if not text.strip(): 33 | logger.warning("Received empty text for embedding, cannot proceed") 34 | raise ValueError("Cannot generate embedding for empty text") 35 | 36 | model_name = model or (DOC_EMBEDDING_MODEL if text_type == "doc" else QUERY_EMBEDDING_MODEL) 37 | logger.debug(f"Generating embedding for text of length {len(text)} using model {model_name}") 38 | 39 | try: 40 | embedding_response = await litellm.aembedding( 41 | model=model_name, 42 | input=[text], 43 | timeout=timeout, 44 | ) 45 | 46 | # Extract the embedding vector from the response 47 | embedding = embedding_response["data"][0]["embedding"] 48 | logger.debug(f"Successfully generated embedding of dimension {len(embedding)}") 49 | 50 | return embedding 51 | 52 | except Exception as e: 53 | logger.error(f"Error generating embedding: {e!s}") 54 | raise 55 | -------------------------------------------------------------------------------- /src/web_service/api/admin.py: -------------------------------------------------------------------------------- 1 | """Admin API routes for the web service.""" 2 | 3 | import redis 4 | from fastapi import APIRouter, Depends, status 5 | from rq import Queue 6 | 7 | from src.common.config import REDIS_URI 8 | from src.common.logger import get_logger 9 | from src.common.models import ( 10 | DeleteDocsRequest, 11 | ) 12 | from src.web_service.services.admin_service import ( 13 | delete_docs, 14 | ) 15 | 16 | # Get logger for this module 17 | logger = get_logger(__name__) 18 | 19 | # Create router 20 | router = APIRouter(tags=["admin"]) 21 | 22 | 23 | @router.post("/delete_docs", status_code=status.HTTP_204_NO_CONTENT, operation_id="delete_docs") 24 | async def delete_docs_endpoint( 25 | request: DeleteDocsRequest, 26 | queue: Queue = Depends(lambda: Queue("worker", connection=redis.from_url(REDIS_URI))), 27 | ) -> None: 28 | """Deletes documents from the database based on filters. 29 | 30 | Args: 31 | request: The delete request with optional filters. 32 | queue: The RQ queue for enqueueing the delete task. 33 | 34 | Returns: 35 | None: Returns a 204 No Content response upon successful enqueueing. 36 | 37 | """ 38 | logger.info( 39 | f"API: Deleting docs with filters: tags={request.tags}, domain={request.domain}, page_ids={request.page_ids}", 40 | ) 41 | 42 | try: 43 | # Call the service function 44 | await delete_docs( 45 | queue=queue, 46 | tags=request.tags, 47 | domain=request.domain, 48 | page_ids=request.page_ids, 49 | ) 50 | 51 | # Return 204 No Content 52 | return 53 | except Exception as e: 54 | logger.error(f"Error deleting documents: {e!s}") 55 | # Since this is an asynchronous operation, we still return 204 56 | # The actual deletion happens in the background 57 | return 58 | -------------------------------------------------------------------------------- /src/lib/chunker.py: -------------------------------------------------------------------------------- 1 | """Text chunking functionality using LangChain.""" 2 | 3 | from langchain_text_splitters import RecursiveCharacterTextSplitter 4 | 5 | from src.common.config import CHUNK_OVERLAP, CHUNK_SIZE 6 | from src.common.logger import get_logger 7 | 8 | # Configure logging 9 | logger = get_logger(__name__) 10 | 11 | 12 | class TextChunker: 13 | """Class for splitting text into semantic chunks.""" 14 | 15 | def __init__(self, chunk_size: int = None, chunk_overlap: int = None): 16 | """Initialize the text chunker. 17 | 18 | Args: 19 | chunk_size: Size of each text chunk (defaults to config value) 20 | chunk_overlap: Overlap between chunks (defaults to config value) 21 | 22 | """ 23 | self.chunk_size = chunk_size or CHUNK_SIZE 24 | self.chunk_overlap = chunk_overlap or CHUNK_OVERLAP 25 | 26 | self.text_splitter = RecursiveCharacterTextSplitter( 27 | chunk_size=self.chunk_size, 28 | chunk_overlap=self.chunk_overlap, 29 | length_function=len, 30 | ) 31 | 32 | logger.debug( 33 | f"Initialized TextChunker with chunk_size={self.chunk_size}, chunk_overlap={self.chunk_overlap}", 34 | ) 35 | 36 | def split_text(self, text: str) -> list[str]: 37 | """Split the text into chunks. 38 | 39 | Args: 40 | text: The text to split 41 | 42 | Returns: 43 | List of text chunks 44 | 45 | """ 46 | if not text or not text.strip(): 47 | logger.warning("Received empty text for chunking, returning empty list") 48 | return [] 49 | 50 | chunks = self.text_splitter.split_text(text) 51 | logger.debug(f"Split text of length {len(text)} into {len(chunks)} chunks") 52 | 53 | # Filter out empty chunks 54 | non_empty_chunks = [chunk for chunk in chunks if chunk.strip()] 55 | if len(non_empty_chunks) < len(chunks): 56 | logger.debug(f"Filtered out {len(chunks) - len(non_empty_chunks)} empty chunks") 57 | 58 | return non_empty_chunks 59 | -------------------------------------------------------------------------------- /src/common/logger.py: -------------------------------------------------------------------------------- 1 | """Centralized logging configuration for the Doctor project.""" 2 | 3 | import logging 4 | import os 5 | 6 | 7 | def configure_logging( 8 | name: str = None, 9 | level: str = None, 10 | log_file: str | None = None, 11 | ) -> logging.Logger: 12 | """Configure and return a logger with consistent formatting. 13 | 14 | Args: 15 | name: Logger name (typically __name__ from the calling module) 16 | level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) 17 | Falls back to DOCTOR_LOG_LEVEL env var, defaults to INFO 18 | log_file: Optional path to log file 19 | 20 | Returns: 21 | Configured logger instance 22 | 23 | """ 24 | # Get log level from params, env var, or default to INFO 25 | if level is None: 26 | level = os.getenv("DOCTOR_LOG_LEVEL", "INFO") 27 | 28 | level = getattr(logging, level.upper()) 29 | 30 | # Configure root logger if not already configured 31 | if not logging.getLogger().handlers: 32 | # Basic configuration with consistent format 33 | logging.basicConfig( 34 | level=level, 35 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 36 | datefmt="%Y-%m-%d %H:%M:%S", 37 | ) 38 | 39 | # Add file handler if specified 40 | if log_file: 41 | file_handler = logging.FileHandler(log_file) 42 | file_handler.setFormatter( 43 | logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"), 44 | ) 45 | logging.getLogger().addHandler(file_handler) 46 | 47 | # Get logger for the specific module 48 | logger = logging.getLogger(name) 49 | logger.setLevel(level) 50 | 51 | return logger 52 | 53 | 54 | def get_logger(name: str = None) -> logging.Logger: 55 | """Get a logger with the specified name. 56 | 57 | Args: 58 | name: Logger name (typically __name__ from the calling module) 59 | 60 | Returns: 61 | Configured logger instance 62 | 63 | """ 64 | return configure_logging(name) 65 | -------------------------------------------------------------------------------- /llms.txt: -------------------------------------------------------------------------------- 1 | # Doctor 2 | 3 | > Doctor is a system that lets LLM agents discover, crawl, and index web sites for better and more up-to-date reasoning and code generation. 4 | 5 | Doctor provides a complete stack for crawling, indexing, and searching web content to enhance LLM capabilities. It handles web crawling, text chunking, embedding generation, and semantic search, making this functionality available to large language models through a Model Context Protocol (MCP) server. 6 | 7 | ## Core Components 8 | 9 | - [Crawl Worker](/src/crawl_worker) - Processes crawl jobs, chunks text, and creates embeddings 10 | - [Web Service](/src/web_service) - FastAPI service exposing endpoints for fetching, searching, and viewing data, and exposing MCP server 11 | - [Common](/src/common) - Shared code, models, and database utilities 12 | 13 | ## Infrastructure 14 | 15 | - DuckDB - Database for storing document data and embeddings with vector search capabilities 16 | - Redis - Message broker for asynchronous task processing 17 | - Docker - Container orchestration for deploying the complete stack 18 | 19 | ## Models 20 | 21 | - OpenAI Text Embeddings - Used for generating vector embeddings of text chunks 22 | - Implementation: text-embedding-ada-002 (1536 dimensions) 23 | - Integration: Accessed via litellm library 24 | 25 | ## Libraries 26 | 27 | - crawl4ai - For web page crawling 28 | - langchain_text_splitters - For chunking text content 29 | - litellm - Wrapper for accessing embedding models 30 | - fastapi - Web service framework 31 | - fastapi-mcp - MCP server implementation 32 | - duckdb-vss - Vector similarity search extension for DuckDB 33 | 34 | ## Technical Requirements 35 | 36 | - Docker and Docker Compose - For running the complete stack 37 | - Python 3.10+ - Primary programming language 38 | - OpenAI API key - Required for embedding generation 39 | 40 | ## API Configuration 41 | 42 | - OpenAI API key must be provided via environment variable: OPENAI_API_KEY 43 | - Additional environment variables: 44 | - REDIS_URI - URI for Redis connection 45 | - DATA_DIR - Directory for storing DuckDB data files (default: "data") 46 | -------------------------------------------------------------------------------- /src/common/config.py: -------------------------------------------------------------------------------- 1 | """Configuration settings for the Doctor project.""" 2 | 3 | import os 4 | 5 | from src.common.logger import get_logger 6 | 7 | # Get logger for this module 8 | logger = get_logger(__name__) 9 | 10 | # Vector settings 11 | VECTOR_SIZE = 3072 # OpenAI text-embedding-3-large embedding size 12 | 13 | # Redis settings 14 | REDIS_URI = os.getenv("REDIS_URI", "redis://localhost:6379") 15 | 16 | # DuckDB settings 17 | DATA_DIR = os.getenv("DATA_DIR", "data") 18 | DUCKDB_PATH = os.path.join(DATA_DIR, "doctor.duckdb") 19 | DUCKDB_EMBEDDINGS_TABLE = "document_embeddings" 20 | DB_RETRY_ATTEMPTS = int(os.getenv("DB_RETRY_ATTEMPTS", "5")) 21 | DB_RETRY_DELAY_SEC = float(os.getenv("DB_RETRY_DELAY_SEC", "0.5")) 22 | 23 | # OpenAI settings 24 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 25 | DOC_EMBEDDING_MODEL = "openai/text-embedding-3-large" 26 | QUERY_EMBEDDING_MODEL = "openai/text-embedding-3-large" 27 | 28 | # Web service settings 29 | WEB_SERVICE_HOST = os.getenv("WEB_SERVICE_HOST", "0.0.0.0") 30 | WEB_SERVICE_PORT = int(os.getenv("WEB_SERVICE_PORT", "9111")) 31 | 32 | # Crawl settings 33 | DEFAULT_MAX_PAGES = 100 34 | CHUNK_SIZE = 1000 35 | CHUNK_OVERLAP = 200 36 | 37 | # Search settings 38 | RETURN_FULL_DOCUMENT_TEXT = True 39 | 40 | # MCP Server settings 41 | DOCTOR_BASE_URL = os.getenv("DOCTOR_BASE_URL", "http://localhost:9111") 42 | 43 | 44 | def validate_config() -> list[str]: 45 | """Validate the configuration and return a list of any issues.""" 46 | issues = [] 47 | 48 | # Only check for valid API key in production 49 | if not os.getenv("OPENAI_API_KEY"): 50 | issues.append("OPENAI_API_KEY environment variable is not set") 51 | 52 | # Ensure data directory exists 53 | os.makedirs(DATA_DIR, exist_ok=True) 54 | 55 | return issues 56 | 57 | 58 | def check_config() -> bool: 59 | """Check if the configuration is valid and log any issues.""" 60 | issues = validate_config() 61 | 62 | if issues: 63 | for issue in issues: 64 | logger.error(f"Configuration error: {issue}") 65 | return False 66 | 67 | return True 68 | 69 | 70 | if __name__ == "__main__": 71 | # When run directly, validate the configuration 72 | check_config() 73 | -------------------------------------------------------------------------------- /tests/lib/test_embedder.py: -------------------------------------------------------------------------------- 1 | """Tests for the embedder module.""" 2 | 3 | from unittest.mock import AsyncMock, patch 4 | 5 | import pytest 6 | 7 | from src.lib.embedder import generate_embedding 8 | 9 | 10 | @pytest.fixture 11 | def mock_embedding_response(): 12 | """Mock response from litellm.aembedding.""" 13 | return {"data": [{"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]}]} 14 | 15 | 16 | @pytest.mark.unit 17 | @pytest.mark.async_test 18 | async def test_generate_embedding(sample_text, mock_embedding_response): 19 | """Test generating an embedding for a text chunk.""" 20 | with patch("src.lib.embedder.litellm.aembedding", new_callable=AsyncMock) as mock_aembedding: 21 | mock_aembedding.return_value = mock_embedding_response 22 | 23 | # Test with default model 24 | with patch("src.lib.embedder.DOC_EMBEDDING_MODEL", "text-embedding-3-small"): 25 | embedding = await generate_embedding(sample_text) 26 | assert embedding == mock_embedding_response["data"][0]["embedding"] 27 | mock_aembedding.assert_called_once_with( 28 | model="text-embedding-3-small", 29 | input=[sample_text], 30 | timeout=30, 31 | ) 32 | 33 | # Test with custom model and timeout 34 | mock_aembedding.reset_mock() 35 | embedding = await generate_embedding(sample_text, model="custom-model", timeout=60) 36 | 37 | # Check that aembedding was called with the correct arguments 38 | mock_aembedding.assert_called_once_with( 39 | model="custom-model", 40 | input=[sample_text], 41 | timeout=60, 42 | ) 43 | 44 | 45 | @pytest.mark.unit 46 | @pytest.mark.async_test 47 | async def test_generate_embedding_with_empty_text(): 48 | """Test generating an embedding with empty text.""" 49 | with pytest.raises(ValueError, match="Cannot generate embedding for empty text"): 50 | await generate_embedding("") 51 | 52 | with pytest.raises(ValueError, match="Cannot generate embedding for empty text"): 53 | await generate_embedding(" ") 54 | 55 | 56 | @pytest.mark.unit 57 | @pytest.mark.async_test 58 | async def test_generate_embedding_error_handling(): 59 | """Test error handling when generating an embedding.""" 60 | with patch("src.lib.embedder.litellm.aembedding", new_callable=AsyncMock) as mock_aembedding: 61 | # Simulate an API error 62 | mock_aembedding.side_effect = Exception("API error") 63 | 64 | with pytest.raises(Exception, match="API error"): 65 | await generate_embedding("Some text") 66 | -------------------------------------------------------------------------------- /tests/integration/common/test_vector_indexer.py: -------------------------------------------------------------------------------- 1 | """Integration tests for the VectorIndexer class with real DuckDB.""" 2 | 3 | import pytest 4 | 5 | from src.common.config import VECTOR_SIZE 6 | from src.common.indexer import VectorIndexer 7 | 8 | 9 | @pytest.mark.integration 10 | @pytest.mark.async_test 11 | async def test_vectorindexer_with_real_duckdb(in_memory_duckdb_connection): 12 | """Test the DuckDB implementation of VectorIndexer with a real in-memory database.""" 13 | # Create VectorIndexer with the in-memory test connection 14 | indexer = VectorIndexer(connection=in_memory_duckdb_connection) 15 | 16 | # Test vector with random values (correct dimension) 17 | test_vector = [0.1] * VECTOR_SIZE 18 | test_payload = { 19 | "text": "Test chunk", 20 | "page_id": "test-page-id", 21 | "url": "https://example.com", 22 | "domain": "example.com", 23 | "tags": ["test", "example"], 24 | "job_id": "test-job", 25 | } 26 | 27 | # Test index_vector 28 | point_id = await indexer.index_vector(test_vector, test_payload) 29 | assert point_id is not None 30 | 31 | # Test search - should find the vector we just indexed 32 | results = await indexer.search(test_vector, limit=1) 33 | assert len(results) == 1 34 | assert results[0]["id"] == point_id 35 | assert results[0]["payload"]["text"] == "Test chunk" 36 | assert results[0]["payload"]["tags"] == ["test", "example"] 37 | 38 | # Test tag filtering 39 | filter_payload = {"must": [{"key": "tags", "match": {"any": ["test"]}}]} 40 | 41 | results = await indexer.search(test_vector, limit=10, filter_payload=filter_payload) 42 | assert len(results) == 1 43 | 44 | # Test with non-matching tag filter 45 | filter_payload = {"must": [{"key": "tags", "match": {"any": ["nonexistent"]}}]} 46 | 47 | results = await indexer.search(test_vector, limit=10, filter_payload=filter_payload) 48 | assert len(results) == 0 49 | 50 | # Add another vector with different tags 51 | test_vector2 = [0.2] * VECTOR_SIZE 52 | test_payload2 = { 53 | "text": "Another test chunk", 54 | "page_id": "test-page-id-2", 55 | "url": "https://example.org", 56 | "domain": "example.org", 57 | "tags": ["different", "example"], 58 | "job_id": "test-job", 59 | } 60 | 61 | await indexer.index_vector(test_vector2, test_payload2) 62 | 63 | # Test filtering by the "example" tag which both vectors have 64 | filter_payload = {"must": [{"key": "tags", "match": {"any": ["example"]}}]} 65 | 66 | results = await indexer.search(test_vector, limit=10, filter_payload=filter_payload) 67 | assert len(results) == 2 68 | -------------------------------------------------------------------------------- /src/web_service/main.py: -------------------------------------------------------------------------------- 1 | """Main module for the Doctor Web Service.""" 2 | 3 | from collections.abc import AsyncIterator 4 | from contextlib import asynccontextmanager 5 | 6 | import redis 7 | from fastapi import FastAPI 8 | from fastapi_mcp import FastApiMCP 9 | 10 | from src.common.config import ( 11 | REDIS_URI, 12 | WEB_SERVICE_HOST, 13 | WEB_SERVICE_PORT, 14 | check_config, 15 | ) 16 | from src.common.logger import get_logger 17 | from src.lib.database import DatabaseOperations 18 | from src.web_service.api import api_router 19 | 20 | # Get logger for this module 21 | logger = get_logger(__name__) 22 | 23 | 24 | @asynccontextmanager 25 | async def lifespan(app: FastAPI) -> AsyncIterator[None]: 26 | """Lifespan context manager for the FastAPI application. 27 | 28 | Handles startup and shutdown events. 29 | 30 | Args: 31 | app: The FastAPI application instance. 32 | 33 | Yields: 34 | None: Indicates the application is ready. 35 | 36 | """ 37 | # Initialize databases for the web service 38 | DatabaseOperations() 39 | # The db.db.initialize() is called within DatabaseOperations constructor using a context manager. 40 | # The db.db.close() is also handled by the context manager in initialize. 41 | logger.info("Database initialization complete") 42 | if not check_config(): 43 | logger.error("Invalid configuration. Exiting.") 44 | exit(1) 45 | yield 46 | logger.info("Shutting down application") 47 | 48 | 49 | def create_application() -> FastAPI: 50 | """Creates and configures the FastAPI application. 51 | 52 | Returns: 53 | FastAPI: Configured FastAPI application 54 | 55 | """ 56 | app = FastAPI( 57 | title="Doctor API", 58 | description="API for the Doctor web crawling and indexing system", 59 | version="0.2.0", 60 | lifespan=lifespan, 61 | ) 62 | 63 | # Include the API router 64 | app.include_router(api_router) 65 | 66 | # Set up MCP 67 | mcp_server_description = """ 68 | Search for documents using semantic search. 69 | 1. Use the `list_tags` endpoint to get a list of all available tags. 70 | 2. Use the `search_docs` endpoint to search for documents using semantic search, optionally filtered by tag. 71 | 3. Use the `get_doc_page` endpoint to get the full text of a document page. 72 | 4. You can also use the `list_doc_pages` endpoint to get a list of all available document pages. 73 | """ 74 | mcp = FastApiMCP( 75 | app, 76 | name="Doctor", 77 | description=mcp_server_description, 78 | exclude_operations=["fetch_url", "job_progress", "delete_docs"], 79 | ) 80 | 81 | mcp.mount() 82 | 83 | return app 84 | 85 | 86 | # Create the application 87 | app = create_application() 88 | 89 | # Create Redis connection 90 | redis_conn = redis.from_url(REDIS_URI) 91 | 92 | 93 | if __name__ == "__main__": 94 | import uvicorn 95 | 96 | uvicorn.run(app, host=WEB_SERVICE_HOST, port=WEB_SERVICE_PORT) 97 | -------------------------------------------------------------------------------- /src/crawl_worker/main.py: -------------------------------------------------------------------------------- 1 | """Main module for the Doctor Crawl Worker.""" 2 | 3 | import redis 4 | from rq import Worker 5 | 6 | from src.common.config import REDIS_URI, check_config 7 | from src.common.logger import get_logger 8 | from src.lib.database import DatabaseOperations 9 | 10 | # Get logger for this module 11 | logger = get_logger(__name__) 12 | 13 | 14 | def main() -> int: 15 | """Main entry point for the Crawl Worker. 16 | 17 | Returns: 18 | int: The exit code (0 for success, 1 for failure). 19 | 20 | """ 21 | # Validate configuration 22 | if not check_config(): 23 | logger.error("Invalid configuration. Exiting.") 24 | return 1 25 | 26 | # Initialize databases with write access 27 | try: 28 | logger.info("Initializing databases for the crawl worker...") 29 | # DatabaseOperations() constructor now calls initialize() internally using a context manager. 30 | db_ops = DatabaseOperations() 31 | logger.info( 32 | "Database initialization (via DatabaseOperations constructor) completed successfully." 33 | ) 34 | 35 | # Double-check that the document_embeddings table exists 36 | # Use a new context-managed connection for this check 37 | with db_ops.db as conn_manager: 38 | actual_conn = conn_manager.conn 39 | if not actual_conn: 40 | logger.error("Failed to get DuckDB connection for table verification.") 41 | return 1 42 | 43 | result = actual_conn.execute( 44 | "SELECT count(*) FROM information_schema.tables WHERE table_name = 'document_embeddings'", 45 | ).fetchone() 46 | 47 | if result is None: 48 | logger.error("Failed to execute query to check for document_embeddings table") 49 | return 1 # conn_manager will close connection 50 | 51 | table_count = result[0] 52 | 53 | if table_count == 0: 54 | logger.error("document_embeddings table is still missing after initialization!") 55 | return 1 # conn_manager will close connection 56 | logger.info("Verified document_embeddings table exists") 57 | # Connection is closed by conn_manager context exit 58 | 59 | except Exception as e: 60 | logger.error(f"Database initialization or verification failed: {e!s}") 61 | return 1 62 | 63 | # Connect to Redis 64 | try: 65 | logger.info(f"Connecting to Redis at {REDIS_URI}") 66 | redis_conn = redis.from_url(REDIS_URI) 67 | 68 | # Start worker 69 | logger.info("Starting worker, listening on queue: worker") 70 | worker = Worker(["worker"], connection=redis_conn) 71 | worker.work(with_scheduler=True) 72 | return 0 # Return success if worker completes normally 73 | except Exception as redis_error: 74 | logger.error(f"Redis worker error: {redis_error!s}") 75 | return 1 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /tests/lib/test_crawler.py: -------------------------------------------------------------------------------- 1 | """Tests for the crawler module.""" 2 | 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | import pytest 6 | 7 | from src.lib.crawler import crawl_url, extract_page_text 8 | 9 | 10 | @pytest.mark.unit 11 | @pytest.mark.async_test 12 | async def test_crawl_url(sample_url): 13 | """Test crawling a URL.""" 14 | # Create mock results 15 | mock_results = [ 16 | MagicMock(url=sample_url), 17 | MagicMock(url=f"{sample_url}/page1"), 18 | MagicMock(url=f"{sample_url}/page2"), 19 | ] 20 | 21 | # Create a mock AsyncWebCrawler class 22 | mock_crawler_instance = AsyncMock() 23 | mock_crawler_instance.arun = AsyncMock(return_value=mock_results) 24 | 25 | # Create a mock context manager 26 | mock_crawler_class = MagicMock() 27 | mock_crawler_class.return_value.__aenter__.return_value = mock_crawler_instance 28 | 29 | # Patch the AsyncWebCrawler 30 | with patch("src.lib.crawler.AsyncWebCrawler", mock_crawler_class): 31 | results = await crawl_url(sample_url, max_pages=3, max_depth=2) 32 | 33 | # Check that arun was called with the correct arguments 34 | mock_crawler_instance.arun.assert_called_once() 35 | args, kwargs = mock_crawler_instance.arun.call_args 36 | assert kwargs["url"] == sample_url 37 | assert "config" in kwargs 38 | 39 | # Check that we got the expected results 40 | assert results == mock_results 41 | assert len(results) == 3 42 | 43 | 44 | @pytest.mark.unit 45 | def test_extract_page_text_with_markdown(sample_crawl_result): 46 | """Test extracting text from a crawl result with markdown.""" 47 | text = extract_page_text(sample_crawl_result) 48 | assert text == "# Example Page\n\nThis is some example content." 49 | 50 | 51 | @pytest.mark.unit 52 | def test_extract_page_text_with_extracted_content(): 53 | """Test extracting text from a crawl result with extracted content.""" 54 | # Create a mock result with only extracted content 55 | mock_markdown = MagicMock() 56 | mock_markdown.fit_markdown = "Example Page. This is some example content." 57 | 58 | mock_result = MagicMock( 59 | url="https://example.com", 60 | markdown=mock_markdown, 61 | _markdown=None, 62 | extracted_content="Fallback content that should not be used", 63 | html="...", 64 | ) 65 | 66 | text = extract_page_text(mock_result) 67 | assert text == "Example Page. This is some example content." 68 | 69 | 70 | @pytest.mark.unit 71 | def test_extract_page_text_with_html_only(): 72 | """Test extracting text from a crawl result with only HTML.""" 73 | # Create a mock result with only HTML 74 | html_content = "

Example

" 75 | 76 | # Create mock with no markdown or extracted content, only HTML 77 | mock_result = MagicMock( 78 | url="https://example.com", 79 | markdown=None, 80 | _markdown=None, 81 | extracted_content=None, 82 | html=html_content, 83 | ) 84 | 85 | text = extract_page_text(mock_result) 86 | assert text == html_content 87 | -------------------------------------------------------------------------------- /tests/lib/test_chunker.py: -------------------------------------------------------------------------------- 1 | """Tests for the chunker module.""" 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | 7 | from src.lib.chunker import TextChunker 8 | 9 | 10 | @pytest.fixture 11 | def mock_text_splitter(): 12 | """Mock for the RecursiveCharacterTextSplitter.""" 13 | mock = MagicMock() 14 | mock.split_text.return_value = [ 15 | "This is chunk 1.", 16 | "This is chunk 2.", 17 | "This is chunk 3.", 18 | " ", # Empty chunk (with whitespace) that should be filtered 19 | ] 20 | return mock 21 | 22 | 23 | @pytest.mark.unit 24 | def test_text_chunker_initialization(): 25 | """Test TextChunker initialization with default values.""" 26 | with ( 27 | patch("src.lib.chunker.CHUNK_SIZE", 1000), 28 | patch("src.lib.chunker.CHUNK_OVERLAP", 100), 29 | patch("src.lib.chunker.RecursiveCharacterTextSplitter") as mock_splitter_class, 30 | ): 31 | chunker = TextChunker() 32 | 33 | # Check that the chunker was initialized with the correct values 34 | assert chunker.chunk_size == 1000 35 | assert chunker.chunk_overlap == 100 36 | 37 | # Check that the text splitter was initialized with the correct values 38 | mock_splitter_class.assert_called_once_with( 39 | chunk_size=1000, 40 | chunk_overlap=100, 41 | length_function=len, 42 | ) 43 | 44 | 45 | @pytest.mark.unit 46 | def test_text_chunker_initialization_with_custom_values(): 47 | """Test TextChunker initialization with custom values.""" 48 | with patch("src.lib.chunker.RecursiveCharacterTextSplitter") as mock_splitter_class: 49 | chunker = TextChunker(chunk_size=500, chunk_overlap=50) 50 | 51 | # Check that the chunker was initialized with the correct values 52 | assert chunker.chunk_size == 500 53 | assert chunker.chunk_overlap == 50 54 | 55 | # Check that the text splitter was initialized with the correct values 56 | mock_splitter_class.assert_called_once_with( 57 | chunk_size=500, 58 | chunk_overlap=50, 59 | length_function=len, 60 | ) 61 | 62 | 63 | @pytest.mark.unit 64 | def test_split_text(sample_text, mock_text_splitter): 65 | """Test splitting text into chunks.""" 66 | with patch("src.lib.chunker.RecursiveCharacterTextSplitter", return_value=mock_text_splitter): 67 | chunker = TextChunker() 68 | chunks = chunker.split_text(sample_text) 69 | 70 | # Check that the text splitter was called with the correct input 71 | mock_text_splitter.split_text.assert_called_once_with(sample_text) 72 | 73 | # Check that we got the expected chunks (empty chunks filtered out) 74 | assert chunks == ["This is chunk 1.", "This is chunk 2.", "This is chunk 3."] 75 | assert len(chunks) == 3 76 | 77 | 78 | @pytest.mark.unit 79 | def test_split_text_empty_input(): 80 | """Test splitting empty text.""" 81 | chunker = TextChunker() 82 | 83 | # Test with empty string 84 | chunks = chunker.split_text("") 85 | assert chunks == [] 86 | 87 | # Test with whitespace only 88 | chunks = chunker.split_text(" \n ") 89 | assert chunks == [] 90 | 91 | 92 | @pytest.mark.unit 93 | def test_split_text_all_empty_chunks(): 94 | """Test when all chunks are empty after filtering.""" 95 | with patch("src.lib.chunker.RecursiveCharacterTextSplitter") as mock_splitter_class: 96 | # Mock splitter to return only empty chunks 97 | mock_splitter = MagicMock() 98 | mock_splitter.split_text.return_value = [" ", "\n", " \t "] 99 | mock_splitter_class.return_value = mock_splitter 100 | 101 | chunker = TextChunker() 102 | chunks = chunker.split_text("Some text") 103 | 104 | # Check that we got an empty list 105 | assert chunks == [] 106 | -------------------------------------------------------------------------------- /src/web_service/services/debug_bm25.py: -------------------------------------------------------------------------------- 1 | """BM25 diagnostic function.""" 2 | 3 | import duckdb 4 | 5 | from src.common.logger import get_logger 6 | 7 | # Get logger for this module 8 | logger = get_logger(__name__) 9 | 10 | 11 | async def debug_bm25_search(conn: duckdb.DuckDBPyConnection, query: str) -> dict: 12 | """Diagnose BM25/FTS search issues by always attempting to query the FTS index using DuckDB's FTS extension. 13 | Reports the result of the FTS query attempt, any errors, and explains FTS index visibility. 14 | """ 15 | logger.info(f"Running BM25/FTS search diagnostics for query: '{query}' (DuckDB-specific)") 16 | results = {} 17 | 18 | try: 19 | # Check if FTS extension is loaded 20 | fts_loaded = conn.execute(""" 21 | SELECT * FROM duckdb_extensions() 22 | WHERE extension_name = 'fts' AND loaded = true 23 | """).fetchall() 24 | results["fts_extension_loaded"] = bool(fts_loaded) 25 | logger.info(f"FTS extension loaded: {results['fts_extension_loaded']}") 26 | 27 | # Count records in pages table 28 | try: 29 | pages_count = conn.execute("SELECT COUNT(*) FROM pages").fetchone()[0] 30 | results["pages_count"] = pages_count 31 | logger.info(f"Pages table record count: {pages_count}") 32 | except Exception as e: 33 | results["pages_count"] = None 34 | results["pages_count_error"] = str(e) 35 | logger.warning(f"Could not count pages: {e}") 36 | 37 | # Always attempt FTS/BM25 search using DuckDB FTS extension 38 | try: 39 | escaped_query = query.replace("'", "''") 40 | bm25_sql = f""" 41 | SELECT p.id, p.url, fts_main_pages.match_bm25(p.id, '{escaped_query}') AS score 42 | FROM pages p 43 | WHERE score IS NOT NULL 44 | ORDER BY score DESC 45 | LIMIT 5 46 | """ 47 | logger.info(f"Attempting FTS/BM25 search: {bm25_sql}") 48 | bm25_results = conn.execute(bm25_sql).fetchall() 49 | results["fts_bm25_search"] = { 50 | "success": True, 51 | "count": len(bm25_results), 52 | "samples": [ 53 | {"id": row[0], "url": row[1], "score": row[2]} for row in bm25_results[:3] 54 | ], 55 | "error": None, 56 | } 57 | logger.info(f"FTS/BM25 search returned {len(bm25_results)} results") 58 | except Exception as e_bm25: 59 | results["fts_bm25_search"] = { 60 | "success": False, 61 | "count": 0, 62 | "samples": [], 63 | "error": str(e_bm25), 64 | } 65 | logger.warning(f"FTS/BM25 search failed: {e_bm25}") 66 | 67 | # Add a note about FTS index visibility in DuckDB 68 | results["fts_index_visibility_note"] = ( 69 | "DuckDB FTS indexes are not visible in sqlite_master or information_schema.tables. " 70 | "If FTS index creation returned an 'already exists' error, the index is present and can be queried " 71 | "using fts_main_.match_bm25 or similar functions, even if not listed in system tables." 72 | ) 73 | 74 | # Add table/column info for clarity 75 | try: 76 | tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() 77 | table_names = [row[0] for row in tables] 78 | results["tables"] = table_names 79 | if "pages" in table_names: 80 | columns = conn.execute("PRAGMA table_info('pages')").fetchall() 81 | col_names = [row[1] for row in columns] 82 | results["pages_columns"] = col_names 83 | except Exception as e_schema: 84 | results["schema_check_error"] = str(e_schema) 85 | 86 | except Exception as e: 87 | logger.error(f"BM25/FTS diagnostics failed: {e}") 88 | results["error"] = str(e) 89 | 90 | return results 91 | -------------------------------------------------------------------------------- /src/web_service/api/map.py: -------------------------------------------------------------------------------- 1 | """API endpoints for the site map feature.""" 2 | 3 | from fastapi import APIRouter, HTTPException, Response 4 | from fastapi.responses import HTMLResponse 5 | 6 | from src.common.logger import get_logger 7 | from src.web_service.services.map_service import MapService 8 | 9 | logger = get_logger(__name__) 10 | 11 | router = APIRouter(tags=["map"]) 12 | 13 | 14 | @router.get("/map", response_class=HTMLResponse) 15 | async def get_site_index() -> str: 16 | """Get an index of all crawled sites. 17 | 18 | Returns: 19 | HTML page listing all root pages/sites. 20 | """ 21 | try: 22 | service = MapService() 23 | sites = await service.get_all_sites() 24 | # Log the sites data for debugging 25 | logger.debug(f"Retrieved {len(sites)} sites from database") 26 | if sites: 27 | logger.debug(f"First site data: {sites[0]}") 28 | return service.format_site_list(sites) 29 | except Exception as e: 30 | logger.error(f"Error getting site index: {e}", exc_info=True) 31 | raise HTTPException(status_code=500, detail=str(e)) 32 | 33 | 34 | @router.get("/map/site/{root_page_id}", response_class=HTMLResponse) 35 | async def get_site_tree(root_page_id: str) -> str: 36 | """Get the hierarchical tree view for a specific site. 37 | 38 | Args: 39 | root_page_id: The ID of the root page. 40 | 41 | Returns: 42 | HTML page showing the site's page hierarchy. 43 | """ 44 | try: 45 | service = MapService() 46 | tree = await service.build_page_tree(root_page_id) 47 | logger.debug(f"Built tree for {root_page_id}: {tree}") 48 | return service.format_site_tree(tree) 49 | except Exception as e: 50 | logger.error(f"Error getting site tree for {root_page_id}: {e}", exc_info=True) 51 | raise HTTPException(status_code=500, detail=str(e)) 52 | 53 | 54 | @router.get("/map/page/{page_id}", response_class=HTMLResponse) 55 | async def view_page(page_id: str) -> str: 56 | """View a specific page with navigation. 57 | 58 | Args: 59 | page_id: The ID of the page to view. 60 | 61 | Returns: 62 | HTML page with the page content and navigation. 63 | """ 64 | try: 65 | service = MapService() 66 | 67 | # Get the page 68 | page = await service.db_ops.get_page_by_id(page_id) 69 | if not page: 70 | raise HTTPException(status_code=404, detail="Page not found") 71 | 72 | # Get navigation context 73 | navigation = await service.get_navigation_context(page_id) 74 | 75 | # Render the page 76 | return service.render_page_html(page, navigation) 77 | except HTTPException: 78 | raise 79 | except Exception as e: 80 | logger.error(f"Error viewing page {page_id}: {e}") 81 | raise HTTPException(status_code=500, detail=str(e)) 82 | 83 | 84 | @router.get("/map/page/{page_id}/raw", response_class=Response) 85 | async def get_page_raw(page_id: str) -> Response: 86 | """Get the raw markdown content of a page. 87 | 88 | Args: 89 | page_id: The ID of the page. 90 | 91 | Returns: 92 | Raw markdown content. 93 | """ 94 | try: 95 | service = MapService() 96 | 97 | # Get the page 98 | page = await service.db_ops.get_page_by_id(page_id) 99 | if not page: 100 | raise HTTPException(status_code=404, detail="Page not found") 101 | 102 | # Return raw text as markdown 103 | return Response( 104 | content=page.get("raw_text", ""), 105 | media_type="text/markdown", 106 | headers={"Content-Disposition": f'inline; filename="{page.get("title", "page")}.md"'}, 107 | ) 108 | except HTTPException: 109 | raise 110 | except Exception as e: 111 | logger.error(f"Error getting raw content for page {page_id}: {e}") 112 | raise HTTPException(status_code=500, detail=str(e)) 113 | -------------------------------------------------------------------------------- /src/lib/database/migrations.py: -------------------------------------------------------------------------------- 1 | """Database migration management for the Doctor project. 2 | 3 | This module handles applying database migrations to ensure the schema is up-to-date. 4 | """ 5 | 6 | import pathlib 7 | 8 | import duckdb 9 | 10 | from src.common.logger import get_logger 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | class MigrationRunner: 16 | """Manages and executes database migrations.""" 17 | 18 | def __init__(self, conn: duckdb.DuckDBPyConnection) -> None: 19 | """Initialize the migration runner. 20 | 21 | Args: 22 | conn: Active DuckDB connection. 23 | 24 | Returns: 25 | None. 26 | """ 27 | self.conn = conn 28 | self.migrations_dir = pathlib.Path(__file__).parent / "migrations" 29 | 30 | def _ensure_migration_table(self) -> None: 31 | """Create the migrations tracking table if it doesn't exist. 32 | 33 | Args: 34 | None. 35 | 36 | Returns: 37 | None. 38 | """ 39 | self.conn.execute(""" 40 | CREATE TABLE IF NOT EXISTS _migrations ( 41 | migration_name VARCHAR PRIMARY KEY, 42 | applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 43 | ) 44 | """) 45 | 46 | def _get_applied_migrations(self) -> set[str]: 47 | """Get the set of already applied migrations. 48 | 49 | Args: 50 | None. 51 | 52 | Returns: 53 | set[str]: Set of migration names that have been applied. 54 | """ 55 | result = self.conn.execute("SELECT migration_name FROM _migrations").fetchall() 56 | return {row[0] for row in result} 57 | 58 | def _apply_migration(self, migration_path: pathlib.Path) -> None: 59 | """Apply a single migration file. 60 | 61 | Args: 62 | migration_path: Path to the migration SQL file. 63 | 64 | Returns: 65 | None. 66 | 67 | Raises: 68 | Exception: If migration fails. 69 | """ 70 | migration_name = migration_path.name 71 | logger.info(f"Applying migration: {migration_name}") 72 | 73 | try: 74 | # Read and execute the migration 75 | sql_content = migration_path.read_text() 76 | 77 | # Split by semicolons and execute each statement 78 | # Filter out empty statements 79 | statements = [s.strip() for s in sql_content.split(";") if s.strip()] 80 | 81 | for statement in statements: 82 | self.conn.execute(statement) 83 | 84 | # Record the migration as applied 85 | self.conn.execute( 86 | "INSERT INTO _migrations (migration_name) VALUES (?)", [migration_name] 87 | ) 88 | 89 | logger.info(f"Successfully applied migration: {migration_name}") 90 | 91 | except Exception as e: 92 | logger.error(f"Failed to apply migration {migration_name}: {e}") 93 | raise 94 | 95 | def run_migrations(self) -> None: 96 | """Run all pending migrations in order. 97 | 98 | Args: 99 | None. 100 | 101 | Returns: 102 | None. 103 | """ 104 | # Ensure migration tracking table exists 105 | self._ensure_migration_table() 106 | 107 | # Get already applied migrations 108 | applied = self._get_applied_migrations() 109 | 110 | # Find all migration files 111 | if not self.migrations_dir.exists(): 112 | logger.debug("No migrations directory found") 113 | return 114 | 115 | migration_files = sorted(self.migrations_dir.glob("*.sql")) 116 | 117 | # Apply pending migrations 118 | pending_count = 0 119 | for migration_file in migration_files: 120 | if migration_file.name not in applied: 121 | self._apply_migration(migration_file) 122 | pending_count += 1 123 | 124 | if pending_count == 0: 125 | logger.debug("No pending migrations to apply") 126 | else: 127 | logger.info(f"Applied {pending_count} migration(s)") 128 | -------------------------------------------------------------------------------- /src/lib/database/schema.py: -------------------------------------------------------------------------------- 1 | """Schema definitions for Doctor project database tables.""" 2 | # Table creation SQLs and schema helpers for Doctor DB 3 | 4 | from src.common.config import VECTOR_SIZE 5 | 6 | # Table creation SQL statements 7 | # These constants define the SQL for creating the main tables in the Doctor database. 8 | CREATE_JOBS_TABLE_SQL = """ 9 | CREATE TABLE IF NOT EXISTS jobs ( 10 | job_id VARCHAR PRIMARY KEY, 11 | start_url VARCHAR, 12 | status VARCHAR, 13 | pages_discovered INTEGER DEFAULT 0, 14 | pages_crawled INTEGER DEFAULT 0, 15 | max_pages INTEGER, 16 | tags VARCHAR, -- JSON string array 17 | created_at TIMESTAMP, 18 | updated_at TIMESTAMP, 19 | error_message VARCHAR 20 | ) 21 | """ 22 | 23 | CREATE_PAGES_TABLE_SQL = """ 24 | CREATE TABLE IF NOT EXISTS pages ( 25 | id VARCHAR PRIMARY KEY, 26 | url VARCHAR, 27 | domain VARCHAR, 28 | raw_text TEXT, 29 | crawl_date TIMESTAMP, 30 | tags VARCHAR, -- JSON string array 31 | job_id VARCHAR, -- Reference to the job that crawled this page 32 | parent_page_id VARCHAR, -- Reference to parent page for hierarchy 33 | root_page_id VARCHAR, -- Reference to root page of the site 34 | depth INTEGER DEFAULT 0, -- Distance from root page 35 | path TEXT, -- Relative path from root page 36 | title TEXT -- Extracted page title 37 | ) 38 | """ 39 | 40 | CREATE_DOCUMENT_EMBEDDINGS_TABLE_SQL = f""" 41 | CREATE TABLE IF NOT EXISTS document_embeddings ( 42 | id VARCHAR PRIMARY KEY, 43 | embedding FLOAT[{VECTOR_SIZE}] NOT NULL, 44 | text_chunk VARCHAR, 45 | page_id VARCHAR, 46 | url VARCHAR, 47 | domain VARCHAR, 48 | tags VARCHAR[], 49 | job_id VARCHAR 50 | ); 51 | """ 52 | 53 | # Extension management SQL 54 | # These constants are used to manage DuckDB extensions. 55 | CHECK_EXTENSION_LOADED_SQL = "SELECT * FROM duckdb_loaded_extensions() WHERE name = '{0}';" 56 | INSTALL_EXTENSION_SQL = "INSTALL {0};" 57 | LOAD_EXTENSION_SQL = "LOAD {0};" 58 | 59 | # FTS (Full-Text Search) related SQL 60 | # These constants are used for managing FTS indexes and tables. 61 | CREATE_FTS_INDEX_SQL = "PRAGMA create_fts_index('pages', 'id', 'raw_text', overwrite=1);" 62 | CHECK_FTS_INDEXES_SQL = ( 63 | "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'fts_idx_%'" 64 | ) 65 | CHECK_FTS_MAIN_PAGES_TABLE_SQL = ( 66 | "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'fts_main_pages'" 67 | ) 68 | DROP_FTS_MAIN_PAGES_TABLE_SQL = "DROP TABLE IF EXISTS fts_main_pages;" 69 | 70 | # HNSW (Vector Search) related SQL 71 | # These constants are used for managing HNSW vector search indexes. 72 | SET_HNSW_PERSISTENCE_SQL = "SET hnsw_enable_experimental_persistence = true;" 73 | CHECK_HNSW_INDEX_SQL = ( 74 | "SELECT count(*) FROM duckdb_indexes() WHERE index_name = 'hnsw_index_on_embeddings'" 75 | ) 76 | CREATE_HNSW_INDEX_SQL = """ 77 | CREATE INDEX hnsw_index_on_embeddings 78 | ON document_embeddings 79 | USING HNSW (embedding) 80 | WITH (metric = 'cosine'); 81 | """ 82 | 83 | # VSS (Vector Similarity Search) verification SQL 84 | # These constants are used to verify VSS extension functionality. 85 | VSS_ARRAY_TO_STRING_TEST_SQL = "SELECT array_to_string([0.1, 0.2]::FLOAT[], ', ');" 86 | VSS_COSINE_SIMILARITY_TEST_SQL = "SELECT list_cosine_similarity([0.1,0.2],[0.2,0.3]);" 87 | 88 | # Table existence check SQL 89 | # Used to check if a table exists in the database. 90 | CHECK_TABLE_EXISTS_SQL = "SELECT count(*) FROM information_schema.tables WHERE table_name = '{0}'" 91 | 92 | # Transaction management SQL 93 | # Used for managing transactions and checkpoints. 94 | BEGIN_TRANSACTION_SQL = "BEGIN TRANSACTION" 95 | CHECKPOINT_SQL = "CHECKPOINT" 96 | TEST_CONNECTION_SQL = "SELECT 1" 97 | 98 | # DML (Data Manipulation Language) SQL 99 | # Used for inserting and updating data in tables. 100 | INSERT_PAGE_SQL = """ 101 | INSERT INTO pages (id, url, domain, raw_text, crawl_date, tags, job_id, 102 | parent_page_id, root_page_id, depth, path, title) 103 | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 104 | """ 105 | 106 | # Base for dynamic UPDATE job query 107 | UPDATE_JOB_STATUS_BASE_SQL = "UPDATE jobs SET status = ?, updated_at = ?" 108 | -------------------------------------------------------------------------------- /src/lib/crawler.py: -------------------------------------------------------------------------------- 1 | """Web crawling functionality using crawl4ai.""" 2 | 3 | from typing import Any 4 | 5 | from crawl4ai import AsyncWebCrawler, CrawlerRunConfig 6 | from crawl4ai.content_filter_strategy import PruningContentFilter 7 | from crawl4ai.deep_crawling import BFSDeepCrawlStrategy 8 | from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator 9 | 10 | from src.common.logger import get_logger 11 | 12 | # Configure logging 13 | logger = get_logger(__name__) 14 | 15 | 16 | async def crawl_url( 17 | url: str, 18 | max_pages: int = 100, 19 | max_depth: int = 2, 20 | strip_urls: bool = True, 21 | ) -> list[Any]: 22 | """Crawl a URL and return the results. 23 | 24 | Args: 25 | url: The URL to start crawling from 26 | max_pages: Maximum number of pages to crawl 27 | max_depth: Maximum depth for the BFS crawl 28 | strip_urls: Whether to strip URLs from the returned markdown 29 | 30 | Returns: 31 | List of crawled page results 32 | 33 | """ 34 | logger.info(f"Starting crawl for URL: {url} with max_pages={max_pages}") 35 | 36 | # Create content filter to remove navigation elements and other non-essential content 37 | content_filter = PruningContentFilter(threshold=0.6, threshold_type="fixed") 38 | 39 | # Configure markdown generator to ignore links and navigation elements 40 | markdown_generator = DefaultMarkdownGenerator( 41 | content_filter=content_filter, 42 | options={ 43 | "ignore_links": strip_urls, 44 | "body_width": 0, 45 | "ignore_images": True, 46 | "single_line_break": True, 47 | }, 48 | ) 49 | 50 | config = CrawlerRunConfig( 51 | deep_crawl_strategy=BFSDeepCrawlStrategy( 52 | max_depth=max_depth, 53 | max_pages=max_pages, 54 | logger=get_logger("crawl4ai"), 55 | include_external=False, 56 | ), 57 | markdown_generator=markdown_generator, 58 | excluded_tags=["nav", "footer", "aside", "header"], 59 | remove_overlay_elements=True, 60 | verbose=True, 61 | ) 62 | 63 | # Initialize the crawler 64 | async with AsyncWebCrawler() as crawler: 65 | crawl_results = await crawler.arun(url=url, config=config) 66 | logger.info(f"Deep crawl discovered {len(crawl_results)} pages") 67 | return crawl_results 68 | 69 | 70 | def extract_page_text(page_result: Any) -> str: 71 | """Extract the text content from a crawl4ai page result. 72 | 73 | Args: 74 | page_result: The crawl result for the page 75 | 76 | Returns: 77 | The extracted text content 78 | 79 | """ 80 | # Use filtered markdown if available, otherwise use raw markdown, 81 | # extracted content, or HTML as fallbacks 82 | if hasattr(page_result, "markdown") and page_result.markdown: 83 | if hasattr(page_result.markdown, "fit_markdown") and page_result.markdown.fit_markdown: 84 | page_text = page_result.markdown.fit_markdown 85 | logger.debug(f"Using fit markdown text of length {len(page_text)}") 86 | elif hasattr(page_result.markdown, "raw_markdown"): 87 | page_text = page_result.markdown.raw_markdown 88 | logger.debug(f"Using raw markdown text of length {len(page_text)}") 89 | else: 90 | # Handle string-like markdown (backward compatibility) 91 | page_text = str(page_result.markdown) 92 | logger.debug(f"Using string markdown text of length {len(page_text)}") 93 | elif hasattr(page_result, "_markdown") and page_result._markdown: 94 | if hasattr(page_result._markdown, "fit_markdown") and page_result._markdown.fit_markdown: 95 | page_text = page_result._markdown.fit_markdown 96 | logger.debug(f"Using fit markdown text from _markdown of length {len(page_text)}") 97 | else: 98 | page_text = page_result._markdown.raw_markdown 99 | logger.debug(f"Using raw markdown text from _markdown of length {len(page_text)}") 100 | elif page_result.extracted_content: 101 | page_text = page_result.extracted_content 102 | logger.debug(f"Using extracted content of length {len(page_text)}") 103 | else: 104 | page_text = page_result.html 105 | logger.debug(f"Using HTML content of length {len(page_text)}") 106 | 107 | return page_text 108 | -------------------------------------------------------------------------------- /docs/maps_api_examples.md: -------------------------------------------------------------------------------- 1 | # Maps Feature API Examples 2 | 3 | This document provides examples of using the Maps feature API endpoints. 4 | 5 | ## Overview 6 | 7 | The Maps feature tracks page hierarchies during crawling, allowing you to navigate crawled sites as they were originally structured. This makes it easy for LLMs to understand site organization and find related content. 8 | 9 | ## API Endpoints 10 | 11 | ### 1. Get All Sites - `/map` 12 | 13 | Lists all crawled root pages (sites). 14 | 15 | **Request:** 16 | ```http 17 | GET /map 18 | ``` 19 | 20 | **Response:** 21 | HTML page listing all sites with links to their tree views. 22 | 23 | **Example Usage:** 24 | ```bash 25 | curl http://localhost:9111/map 26 | ``` 27 | 28 | ### 2. View Site Tree - `/map/site/{root_page_id}` 29 | 30 | Shows the hierarchical structure of a specific site. 31 | 32 | **Request:** 33 | ```http 34 | GET /map/site/550e8400-e29b-41d4-a716-446655440000 35 | ``` 36 | 37 | **Response:** 38 | HTML page with collapsible tree view of all pages in the site. 39 | 40 | **Example Usage:** 41 | ```bash 42 | curl http://localhost:9111/map/site/550e8400-e29b-41d4-a716-446655440000 43 | ``` 44 | 45 | ### 3. View Page - `/map/page/{page_id}` 46 | 47 | Displays a specific page with navigation links. 48 | 49 | **Request:** 50 | ```http 51 | GET /map/page/660e8400-e29b-41d4-a716-446655440001 52 | ``` 53 | 54 | **Response:** 55 | HTML page with: 56 | - Breadcrumb navigation 57 | - Rendered markdown content 58 | - Links to parent, siblings, and child pages 59 | - Link back to site map 60 | 61 | **Example Usage:** 62 | ```bash 63 | curl http://localhost:9111/map/page/660e8400-e29b-41d4-a716-446655440001 64 | ``` 65 | 66 | ### 4. Get Raw Content - `/map/page/{page_id}/raw` 67 | 68 | Returns the raw markdown content of a page. 69 | 70 | **Request:** 71 | ```http 72 | GET /map/page/660e8400-e29b-41d4-a716-446655440001/raw 73 | ``` 74 | 75 | **Response:** 76 | ```markdown 77 | # Page Title 78 | 79 | Raw markdown content of the page... 80 | ``` 81 | 82 | **Example Usage:** 83 | ```bash 84 | curl http://localhost:9111/map/page/660e8400-e29b-41d4-a716-446655440001/raw > page.md 85 | ``` 86 | 87 | ## Workflow Example 88 | 89 | ### 1. Crawl a Website 90 | 91 | First, crawl a website to populate the hierarchy: 92 | 93 | ```bash 94 | curl -X POST http://localhost:9111/fetch_url \ 95 | -H "Content-Type: application/json" \ 96 | -d '{ 97 | "url": "https://docs.example.com", 98 | "max_pages": 100, 99 | "tags": ["documentation"] 100 | }' 101 | ``` 102 | 103 | ### 2. View All Sites 104 | 105 | Visit the map index to see all crawled sites: 106 | 107 | ```bash 108 | curl http://localhost:9111/map 109 | ``` 110 | 111 | ### 3. Explore Site Structure 112 | 113 | Click on a site or use its ID to view the tree structure: 114 | 115 | ```bash 116 | curl http://localhost:9111/map/site/{root_page_id} 117 | ``` 118 | 119 | ### 4. Navigate Pages 120 | 121 | Click through pages to explore content with full navigation context. 122 | 123 | ## Database Schema 124 | 125 | The hierarchy is tracked using these fields in the `pages` table: 126 | 127 | - `parent_page_id`: ID of the parent page (NULL for root pages) 128 | - `root_page_id`: ID of the root page for this site 129 | - `depth`: Distance from the root (0 for root pages) 130 | - `path`: Relative path from the root 131 | - `title`: Extracted page title 132 | 133 | ## Integration with LLMs 134 | 135 | The Maps feature is particularly useful for LLMs because: 136 | 137 | 1. **Contextual Understanding**: LLMs can understand how pages relate to each other 138 | 2. **Efficient Navigation**: Direct links to related content without searching 139 | 3. **Site Structure**: Understanding the organization helps with better responses 140 | 4. **No JavaScript**: Clean HTML that's easy for LLMs to parse 141 | 142 | ## Python Client Example 143 | 144 | ```python 145 | import httpx 146 | import asyncio 147 | 148 | async def explore_site_map(): 149 | async with httpx.AsyncClient() as client: 150 | # Get all sites 151 | sites_response = await client.get("http://localhost:9111/map") 152 | 153 | # Parse the HTML to extract site IDs (simplified example) 154 | # In practice, use an HTML parser like BeautifulSoup 155 | 156 | # Get a specific site's tree 157 | tree_response = await client.get( 158 | "http://localhost:9111/map/site/550e8400-e29b-41d4-a716-446655440000" 159 | ) 160 | 161 | # Get raw content of a page 162 | raw_response = await client.get( 163 | "http://localhost:9111/map/page/660e8400-e29b-41d4-a716-446655440001/raw" 164 | ) 165 | 166 | return raw_response.text 167 | 168 | # Run the example 169 | content = asyncio.run(explore_site_map()) 170 | print(content) 171 | ``` 172 | 173 | ## Notes 174 | 175 | - The Maps feature automatically tracks hierarchy during crawling 176 | - No additional configuration is needed - it works with existing crawl jobs 177 | - Pages maintain their relationships even after crawling is complete 178 | - The UI requires no JavaScript, making it accessible and fast 179 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Common fixtures for tests.""" 2 | 3 | import asyncio 4 | import random 5 | 6 | import duckdb 7 | import pytest 8 | 9 | from src.common.config import VECTOR_SIZE 10 | 11 | 12 | @pytest.fixture 13 | def event_loop(): 14 | """Create an instance of the default event loop for each test.""" 15 | loop = asyncio.get_event_loop_policy().new_event_loop() 16 | yield loop 17 | loop.close() 18 | 19 | 20 | @pytest.fixture 21 | def sample_url(): 22 | """Sample URL for testing.""" 23 | return "https://example.com" 24 | 25 | 26 | @pytest.fixture 27 | def sample_text(): 28 | """Sample text content for testing.""" 29 | return """ 30 | Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam auctor, 31 | nisl nec ultricies lacinia, nisl nisl aliquet nisl, nec ultricies nisl 32 | nisl nec ultricies lacinia, nisl nisl aliquet nisl, nec ultricies nisl 33 | nisl nec ultricies lacinia, nisl nisl aliquet nisl, nec ultricies nisl. 34 | 35 | Pellentesque habitant morbi tristique senectus et netus et malesuada 36 | fames ac turpis egestas. Sed euismod, nisl nec ultricies lacinia, nisl 37 | nisl aliquet nisl, nec ultricies nisl nisl nec ultricies lacinia. 38 | """ 39 | 40 | 41 | @pytest.fixture 42 | def sample_embedding(): 43 | """Sample embedding vector for testing (legacy version: 384 dim).""" 44 | random.seed(42) # For reproducibility 45 | return [random.random() for _ in range(384)] 46 | 47 | 48 | @pytest.fixture 49 | def sample_embedding_full_size() -> list[float]: 50 | """Sample embedding vector with the full VECTOR_SIZE dimension for DuckDB tests.""" 51 | random.seed(42) # For reproducibility 52 | return [random.random() for _ in range(VECTOR_SIZE)] 53 | 54 | 55 | @pytest.fixture 56 | def sample_crawl_result(): 57 | """Sample crawl result for testing.""" 58 | 59 | class MockCrawlResult: 60 | def __init__(self, url, markdown=None, extracted_content=None, html=None): 61 | self.url = url 62 | self._markdown = markdown 63 | self.extracted_content = extracted_content 64 | self.html = html 65 | 66 | # Create a mock _markdown attribute if provided 67 | if markdown: 68 | 69 | class MockMarkdown: 70 | def __init__(self, raw_markdown): 71 | self.raw_markdown = raw_markdown 72 | 73 | self._markdown = MockMarkdown(markdown) 74 | 75 | return MockCrawlResult( 76 | url="https://example.com", 77 | markdown="# Example Page\n\nThis is some example content.", 78 | extracted_content="Example Page. This is some example content.", 79 | html="Example

Example Page

This is some example content.

", 80 | ) 81 | 82 | 83 | @pytest.fixture 84 | def job_id(): 85 | """Sample job ID for testing.""" 86 | return "test-job-123" 87 | 88 | 89 | @pytest.fixture 90 | def page_id(): 91 | """Sample page ID for testing.""" 92 | return "test-page-456" 93 | 94 | 95 | @pytest.fixture 96 | def sample_tags(): 97 | """Sample tags for testing.""" 98 | return ["test", "example", "documentation"] 99 | 100 | 101 | @pytest.fixture 102 | def in_memory_duckdb_connection(): 103 | """Create an in-memory DuckDB connection for testing. 104 | 105 | This connection has the proper setup for vector search using the same 106 | setup logic as the main application: 107 | - VSS extension loaded 108 | - document_embeddings table created with the proper schema 109 | - HNSW index created 110 | - pages table created (for document service tests) 111 | 112 | Usage: 113 | def test_something(in_memory_duckdb_connection): 114 | # Use the connection for testing 115 | ... 116 | """ 117 | from src.lib.database import DatabaseOperations 118 | 119 | # Create in-memory connection 120 | conn = duckdb.connect(":memory:") 121 | 122 | try: 123 | # Use the Database class to set up the connection 124 | db = DatabaseOperations() 125 | db.conn = conn 126 | 127 | # Create tables and set up VSS extension 128 | db.db.ensure_tables() 129 | db.db.ensure_vss_extension() 130 | 131 | yield conn 132 | finally: 133 | conn.close() 134 | 135 | 136 | @pytest.fixture(scope="session", autouse=True) 137 | def ensure_duckdb_database(): 138 | """Ensure the DuckDB database file exists before running tests. 139 | 140 | This fixture runs once per test session and initializes the database 141 | if it doesn't exist yet, which is especially important for CI environments. 142 | """ 143 | import os 144 | 145 | from src.common.config import DUCKDB_PATH 146 | from src.lib.database import DatabaseOperations 147 | 148 | # Only initialize if the file doesn't exist 149 | if not os.path.exists(DUCKDB_PATH): 150 | print(f"Database file {DUCKDB_PATH} does not exist. Creating it for tests...") 151 | try: 152 | db = DatabaseOperations(read_only=False) 153 | db.db.initialize() 154 | db.db.close() 155 | print(f"Successfully created database at {DUCKDB_PATH}") 156 | except Exception as e: 157 | print(f"Warning: Failed to create database: {e}") 158 | # Don't fail the tests if we can't create the DB - individual tests 159 | # that need it can handle the situation appropriately 160 | -------------------------------------------------------------------------------- /src/lib/crawler_enhanced.py: -------------------------------------------------------------------------------- 1 | """Enhanced web crawler with hierarchy tracking for the maps feature.""" 2 | 3 | from typing import Any, Optional 4 | from urllib.parse import urlparse 5 | import re 6 | 7 | from src.common.logger import get_logger 8 | from src.lib.crawler import crawl_url as base_crawl_url, extract_page_text 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class CrawlResultWithHierarchy: 14 | """Enhanced crawl result that includes hierarchy information.""" 15 | 16 | def __init__(self, base_result: Any) -> None: 17 | """Initialize with a base crawl4ai result. 18 | 19 | Args: 20 | base_result: The original crawl4ai result object. 21 | 22 | Returns: 23 | None. 24 | """ 25 | self.base_result = base_result 26 | self.url = base_result.url 27 | self.parent_url: Optional[str] = None 28 | self.root_url: Optional[str] = None 29 | self.depth: int = 0 30 | self.relative_path: str = "" 31 | self.title: Optional[str] = None 32 | 33 | # Extract hierarchy info from metadata if available 34 | if hasattr(base_result, "metadata"): 35 | self.parent_url = base_result.metadata.get("parent_url") 36 | self.depth = base_result.metadata.get("depth", 0) 37 | 38 | # Extract title from the page 39 | self._extract_title() 40 | 41 | def _extract_title(self) -> None: 42 | """Extract the page title from the HTML or markdown content. 43 | 44 | Args: 45 | None. 46 | 47 | Returns: 48 | None. 49 | """ 50 | try: 51 | # Try to extract from HTML title tag first 52 | if hasattr(self.base_result, "html") and self.base_result.html: 53 | title_match = re.search( 54 | r"]*>(.*?)", self.base_result.html, re.IGNORECASE | re.DOTALL 55 | ) 56 | if title_match: 57 | self.title = title_match.group(1).strip() 58 | # Clean up common HTML entities 59 | self.title = ( 60 | self.title.replace("&", "&").replace("<", "<").replace(">", ">") 61 | ) 62 | return 63 | 64 | # Fallback to first H1 in markdown 65 | text = extract_page_text(self.base_result) 66 | if text: 67 | # Look for first markdown heading 68 | h1_match = re.search(r"^#\s+(.+)$", text, re.MULTILINE) 69 | if h1_match: 70 | self.title = h1_match.group(1).strip() 71 | return 72 | 73 | # Look for first line of non-empty text as last resort 74 | lines = text.strip().split("\n") 75 | for line in lines[:5]: # Check first 5 lines 76 | if line.strip() and len(line.strip()) > 10: 77 | self.title = line.strip()[:100] # Limit to 100 chars 78 | break 79 | 80 | except Exception as e: 81 | logger.warning(f"Error extracting title from {self.url}: {e}") 82 | 83 | # Default to URL path if no title found 84 | if not self.title: 85 | path = urlparse(self.url).path 86 | self.title = path.strip("/").split("/")[-1] or "Home" 87 | 88 | 89 | async def crawl_url_with_hierarchy( 90 | url: str, 91 | max_pages: int = 100, 92 | max_depth: int = 2, 93 | strip_urls: bool = True, 94 | ) -> list[CrawlResultWithHierarchy]: 95 | """Crawl a URL and return results with hierarchy information. 96 | 97 | Args: 98 | url: The URL to start crawling from. 99 | max_pages: Maximum number of pages to crawl. 100 | max_depth: Maximum depth for the BFS crawl. 101 | strip_urls: Whether to strip URLs from the returned markdown. 102 | 103 | Returns: 104 | List of crawl results with hierarchy information. 105 | """ 106 | # Get base crawl results 107 | base_results = await base_crawl_url(url, max_pages, max_depth, strip_urls) 108 | 109 | # Enhance results with hierarchy information 110 | enhanced_results = [] 111 | root_url = url # The starting URL is the root 112 | 113 | # Build a map of URLs to results for quick lookup 114 | url_to_result = {} 115 | for result in base_results: 116 | enhanced = CrawlResultWithHierarchy(result) 117 | enhanced.root_url = root_url 118 | url_to_result[enhanced.url] = enhanced 119 | enhanced_results.append(enhanced) 120 | 121 | # Calculate relative paths 122 | for enhanced in enhanced_results: 123 | if enhanced.parent_url and enhanced.parent_url in url_to_result: 124 | parent = url_to_result[enhanced.parent_url] 125 | # Calculate relative path based on parent's path 126 | if parent.relative_path == "/": 127 | enhanced.relative_path = enhanced.title 128 | elif parent.relative_path: 129 | enhanced.relative_path = f"{parent.relative_path}/{enhanced.title}" 130 | else: 131 | enhanced.relative_path = enhanced.title 132 | elif enhanced.url == root_url: 133 | enhanced.relative_path = "/" 134 | else: 135 | # Fallback to URL-based path 136 | enhanced.relative_path = urlparse(enhanced.url).path 137 | 138 | logger.info(f"Enhanced {len(enhanced_results)} pages with hierarchy information") 139 | return enhanced_results 140 | -------------------------------------------------------------------------------- /tests/services/test_admin_service.py: -------------------------------------------------------------------------------- 1 | """Tests for the admin service.""" 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | from rq import Queue 7 | 8 | from src.web_service.services.admin_service import delete_docs 9 | 10 | 11 | @pytest.fixture 12 | def mock_queue(): 13 | """Mock Redis queue.""" 14 | mock = MagicMock(spec=Queue) 15 | mock.enqueue = MagicMock() 16 | return mock 17 | 18 | 19 | @pytest.mark.unit 20 | @pytest.mark.async_test 21 | async def test_delete_docs_basic(mock_queue): 22 | """Test basic document deletion without filters.""" 23 | # Set a fixed UUID for testing 24 | with patch("uuid.uuid4", return_value="delete-task-123"): 25 | # Call the function without filters 26 | task_id = await delete_docs(queue=mock_queue) 27 | 28 | # Verify queue.enqueue was called with the right parameters 29 | mock_queue.enqueue.assert_called_once() 30 | call_args = mock_queue.enqueue.call_args[0] 31 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs" 32 | assert call_args[1] == "delete-task-123" # task_id 33 | assert call_args[2] is None # tags 34 | assert call_args[3] is None # domain 35 | assert call_args[4] is None # page_ids 36 | 37 | # Check that the function returned the expected task ID 38 | assert task_id == "delete-task-123" 39 | 40 | 41 | @pytest.mark.unit 42 | @pytest.mark.async_test 43 | async def test_delete_docs_with_tags(mock_queue): 44 | """Test document deletion with tag filter.""" 45 | # Set a fixed UUID for testing 46 | with patch("uuid.uuid4", return_value="delete-task-456"): 47 | # Call the function with tags filter 48 | task_id = await delete_docs( 49 | queue=mock_queue, 50 | tags=["test", "example"], 51 | ) 52 | 53 | # Verify queue.enqueue was called with the right parameters 54 | mock_queue.enqueue.assert_called_once() 55 | call_args = mock_queue.enqueue.call_args[0] 56 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs" 57 | assert call_args[1] == "delete-task-456" # task_id 58 | assert call_args[2] == ["test", "example"] # tags 59 | assert call_args[3] is None # domain 60 | assert call_args[4] is None # page_ids 61 | 62 | # Check that the function returned the expected task ID 63 | assert task_id == "delete-task-456" 64 | 65 | 66 | @pytest.mark.unit 67 | @pytest.mark.async_test 68 | async def test_delete_docs_with_domain(mock_queue): 69 | """Test document deletion with domain filter.""" 70 | # Set a fixed UUID for testing 71 | with patch("uuid.uuid4", return_value="delete-task-789"): 72 | # Call the function with domain filter 73 | task_id = await delete_docs( 74 | queue=mock_queue, 75 | domain="example.com", 76 | ) 77 | 78 | # Verify queue.enqueue was called with the right parameters 79 | mock_queue.enqueue.assert_called_once() 80 | call_args = mock_queue.enqueue.call_args[0] 81 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs" 82 | assert call_args[1] == "delete-task-789" # task_id 83 | assert call_args[2] is None # tags 84 | assert call_args[3] == "example.com" # domain 85 | assert call_args[4] is None # page_ids 86 | 87 | # Check that the function returned the expected task ID 88 | assert task_id == "delete-task-789" 89 | 90 | 91 | @pytest.mark.unit 92 | @pytest.mark.async_test 93 | async def test_delete_docs_with_page_ids(mock_queue): 94 | """Test document deletion with specific page IDs.""" 95 | # Set a fixed UUID for testing 96 | with patch("uuid.uuid4", return_value="delete-task-abc"): 97 | # Call the function with page_ids filter 98 | page_ids = ["page-1", "page-2", "page-3"] 99 | task_id = await delete_docs( 100 | queue=mock_queue, 101 | page_ids=page_ids, 102 | ) 103 | 104 | # Verify queue.enqueue was called with the right parameters 105 | mock_queue.enqueue.assert_called_once() 106 | call_args = mock_queue.enqueue.call_args[0] 107 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs" 108 | assert call_args[1] == "delete-task-abc" # task_id 109 | assert call_args[2] is None # tags 110 | assert call_args[3] is None # domain 111 | assert call_args[4] == page_ids # page_ids 112 | 113 | # Check that the function returned the expected task ID 114 | assert task_id == "delete-task-abc" 115 | 116 | 117 | @pytest.mark.unit 118 | @pytest.mark.async_test 119 | async def test_delete_docs_with_all_filters(mock_queue): 120 | """Test document deletion with all filters.""" 121 | # Set a fixed UUID for testing 122 | with patch("uuid.uuid4", return_value="delete-task-all"): 123 | # Call the function with all filters 124 | page_ids = ["page-1", "page-2"] 125 | task_id = await delete_docs( 126 | queue=mock_queue, 127 | tags=["test"], 128 | domain="example.com", 129 | page_ids=page_ids, 130 | ) 131 | 132 | # Verify queue.enqueue was called with the right parameters 133 | mock_queue.enqueue.assert_called_once() 134 | call_args = mock_queue.enqueue.call_args[0] 135 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs" 136 | assert call_args[1] == "delete-task-all" # task_id 137 | assert call_args[2] == ["test"] # tags 138 | assert call_args[3] == "example.com" # domain 139 | assert call_args[4] == page_ids # page_ids 140 | 141 | # Check that the function returned the expected task ID 142 | assert task_id == "delete-task-all" 143 | -------------------------------------------------------------------------------- /src/web_service/services/job_service.py: -------------------------------------------------------------------------------- 1 | """Job service for the web service.""" 2 | 3 | import uuid 4 | 5 | import duckdb 6 | from rq import Queue 7 | 8 | from src.common.logger import get_logger 9 | from src.common.models import ( 10 | FetchUrlResponse, 11 | JobProgressResponse, 12 | ) 13 | 14 | # Get logger for this module 15 | logger = get_logger(__name__) 16 | 17 | 18 | async def fetch_url( 19 | queue: Queue, 20 | url: str, 21 | tags: list[str] | None = None, 22 | max_pages: int = 100, 23 | ) -> FetchUrlResponse: 24 | """Initiate a fetch job to crawl a website. 25 | 26 | Args: 27 | queue: Redis queue for job processing 28 | url: The URL to start indexing from 29 | tags: Optional tags to assign this website 30 | max_pages: How many pages to index 31 | 32 | Returns: 33 | FetchUrlResponse: The job ID 34 | 35 | """ 36 | # Generate a temporary job ID 37 | job_id = str(uuid.uuid4()) 38 | 39 | # Enqueue the job creation task 40 | # The crawl worker will handle both creating the job record and enqueueing the crawl task 41 | queue.enqueue( 42 | "src.crawl_worker.tasks.create_job", 43 | url, 44 | job_id, 45 | tags=tags, 46 | max_pages=max_pages, 47 | ) 48 | 49 | logger.info(f"Enqueued job creation for URL: {url}, job_id: {job_id}") 50 | 51 | return FetchUrlResponse(job_id=job_id) 52 | 53 | 54 | async def get_job_progress( 55 | conn: duckdb.DuckDBPyConnection, 56 | job_id: str, 57 | ) -> JobProgressResponse | None: 58 | """Check the progress of a job. 59 | 60 | Args: 61 | conn: Connected DuckDB connection 62 | job_id: The job ID to check progress for 63 | 64 | Returns: 65 | JobProgressResponse: Job progress information or None if job not found 66 | 67 | """ 68 | logger.info(f"Checking progress for job {job_id}") 69 | 70 | # First try exact job_id match 71 | logger.info(f"Looking for exact match for job ID: {job_id}") 72 | result = conn.execute( 73 | """ 74 | SELECT job_id, start_url, status, pages_discovered, pages_crawled, 75 | max_pages, tags, created_at, updated_at, error_message 76 | FROM jobs 77 | WHERE job_id = ? 78 | """, 79 | (job_id,), 80 | ).fetchone() 81 | 82 | # If not found, try partial match (useful if client received truncated UUID) 83 | if not result and len(job_id) >= 8: 84 | logger.info(f"Exact match not found, trying partial match for job ID: {job_id}") 85 | result = conn.execute( 86 | """ 87 | SELECT job_id, start_url, status, pages_discovered, pages_crawled, 88 | max_pages, tags, created_at, updated_at, error_message 89 | FROM jobs 90 | WHERE job_id LIKE ? 91 | ORDER BY created_at DESC 92 | LIMIT 1 93 | """, 94 | (f"{job_id}%",), 95 | ).fetchone() 96 | 97 | if not result: 98 | logger.warning(f"Job {job_id} not found") 99 | return None 100 | 101 | # --- Job found in database --- 102 | ( 103 | job_id, 104 | url, 105 | status, 106 | pages_discovered, 107 | pages_crawled, 108 | max_pages, 109 | tags_json, 110 | created_at, 111 | updated_at, 112 | error_message, 113 | ) = result 114 | 115 | logger.info( 116 | f"Found job with ID: {job_id}, status: {status}, discovered: {pages_discovered}, crawled: {pages_crawled}, updated: {updated_at}", 117 | ) 118 | 119 | # Determine if job is completed 120 | completed = status in ["completed", "failed"] 121 | 122 | # Calculate progress percentage 123 | progress_percent = 0 124 | if max_pages > 0 and pages_discovered > 0: 125 | progress_percent = min(100, int((pages_crawled / min(pages_discovered, max_pages)) * 100)) 126 | 127 | logger.info( 128 | f"Job {job_id} progress: {pages_crawled}/{pages_discovered} pages, {progress_percent}% complete, status: {status}", 129 | ) 130 | 131 | return JobProgressResponse( 132 | pages_crawled=pages_crawled, 133 | pages_total=pages_discovered, 134 | completed=completed, 135 | status=status, 136 | error_message=error_message, 137 | progress_percent=progress_percent, 138 | url=url, 139 | max_pages=max_pages, 140 | created_at=created_at, 141 | updated_at=updated_at, 142 | ) 143 | 144 | 145 | async def get_job_count() -> int: 146 | """Get the total number of jobs in the database. 147 | Used for debugging purposes. 148 | 149 | Returns: 150 | int: Total number of jobs 151 | 152 | """ 153 | from src.lib.database import DatabaseOperations 154 | 155 | db_ops = DatabaseOperations() 156 | try: 157 | # Use DuckDBConnectionManager as a context manager 158 | with db_ops.db as conn_manager: 159 | actual_conn = conn_manager.conn 160 | if not actual_conn: 161 | logger.error("Failed to obtain database connection for get_job_count.") 162 | return -1 # Or raise an error 163 | job_count_result = actual_conn.execute("SELECT COUNT(*) FROM jobs").fetchone() 164 | if job_count_result: 165 | job_count = job_count_result[0] 166 | logger.info(f"Database contains {job_count} total jobs.") 167 | return job_count 168 | logger.warning("Could not retrieve job count from database.") 169 | return -1 170 | except Exception as count_error: 171 | logger.warning(f"Failed to count jobs in database: {count_error!s}") 172 | return -1 173 | # No finally block needed to close connection, context manager handles it. 174 | -------------------------------------------------------------------------------- /src/common/processor.py: -------------------------------------------------------------------------------- 1 | """Page processing pipeline combining crawling, chunking, embedding, and indexing.""" 2 | 3 | import asyncio 4 | from itertools import islice 5 | from typing import Any 6 | from urllib.parse import urlparse 7 | 8 | from src.common.indexer import VectorIndexer 9 | from src.common.logger import get_logger 10 | from src.lib.chunker import TextChunker 11 | from src.lib.crawler import extract_page_text 12 | from src.lib.database import DatabaseOperations 13 | from src.lib.embedder import generate_embedding 14 | 15 | # Configure logging 16 | logger = get_logger(__name__) 17 | 18 | 19 | async def process_crawl_result( 20 | page_result: Any, 21 | job_id: str, 22 | tags: list[str] | None = None, 23 | max_concurrent_embeddings: int = 5, 24 | ) -> str: 25 | """Process a single crawled page result through the entire pipeline. 26 | 27 | Args: 28 | page_result: The crawl result for the page 29 | job_id: The ID of the crawl job 30 | tags: Optional tags to associate with the page 31 | max_concurrent_embeddings: Maximum number of concurrent embedding generations 32 | 33 | Returns: 34 | The ID of the processed page 35 | 36 | """ 37 | if tags is None: 38 | tags = [] # Initialize as empty list instead of None 39 | 40 | try: 41 | logger.info(f"Processing page: {page_result.url}") 42 | 43 | # Extract text from the crawl result 44 | page_text = extract_page_text(page_result) 45 | 46 | # Store the page in the database 47 | db_ops = DatabaseOperations() 48 | page_id = await db_ops.store_page( # store_page now handles its own connection 49 | url=page_result.url, 50 | text=page_text, 51 | job_id=job_id, 52 | tags=tags, 53 | ) 54 | # No explicit close needed as store_page uses context manager internally 55 | 56 | # Initialize components 57 | chunker = TextChunker() 58 | indexer = VectorIndexer() # VectorIndexer will now create and manage its own connection. 59 | 60 | # Split text into chunks 61 | chunks = chunker.split_text(page_text) 62 | logger.info(f"Split page into {len(chunks)} chunks") 63 | 64 | # Extract domain from URL 65 | domain = urlparse(page_result.url).netloc 66 | 67 | # Process chunks in parallel batches 68 | successful_chunks = 0 69 | 70 | async def process_chunk(chunk_text): 71 | try: 72 | # Generate embedding 73 | embedding = await generate_embedding(chunk_text) 74 | 75 | # Prepare payload 76 | payload = { 77 | "text": chunk_text, 78 | "page_id": page_id, 79 | "url": page_result.url, 80 | "domain": domain, 81 | "tags": tags, 82 | "job_id": job_id, 83 | } 84 | 85 | # Index the vector 86 | await indexer.index_vector(embedding, payload) 87 | return True 88 | except Exception as chunk_error: 89 | logger.error(f"Error processing chunk: {chunk_error!s}") 90 | return False 91 | 92 | # Process chunks in batches with limited concurrency 93 | i = 0 94 | while i < len(chunks): 95 | # Take up to max_concurrent_embeddings chunks 96 | batch_chunks = list(islice(chunks, i, i + max_concurrent_embeddings)) 97 | i += max_concurrent_embeddings 98 | 99 | # Process this batch in parallel 100 | results = await asyncio.gather( 101 | *[process_chunk(chunk) for chunk in batch_chunks], 102 | return_exceptions=False, 103 | ) 104 | 105 | successful_chunks += sum(1 for result in results if result) 106 | 107 | logger.info( 108 | f"Successfully indexed {successful_chunks}/{len(chunks)} chunks for page {page_id}", 109 | ) 110 | 111 | return page_id 112 | 113 | except Exception as e: 114 | logger.error(f"Error processing page {page_result.url}: {e!s}") 115 | raise 116 | finally: 117 | # VectorIndexer now manages its own connection if instantiated without one. 118 | # Its __del__ method will attempt to close its connection if self._own_connection is True. 119 | # No explicit action needed here for the indexer's connection. 120 | pass 121 | 122 | 123 | async def process_page_batch( 124 | page_results: list[Any], 125 | job_id: str, 126 | tags: list[str] | None = None, 127 | batch_size: int = 10, 128 | ) -> list[str]: 129 | """Process a batch of crawled pages. 130 | 131 | Args: 132 | page_results: List of crawl results 133 | job_id: The ID of the crawl job 134 | tags: Optional tags to associate with the pages 135 | batch_size: Size of batches for processing 136 | 137 | Returns: 138 | List of processed page IDs 139 | 140 | """ 141 | if tags is None: 142 | tags = [] 143 | 144 | processed_page_ids = [] 145 | 146 | # Process pages in smaller batches to avoid memory issues 147 | for i in range(0, len(page_results), batch_size): 148 | batch = page_results[i : i + batch_size] 149 | logger.info(f"Processing batch of {len(batch)} pages (batch {i // batch_size + 1})") 150 | 151 | for page_result in batch: 152 | try: 153 | page_id = await process_crawl_result(page_result, job_id, tags) 154 | processed_page_ids.append(page_id) 155 | 156 | # Update job progress 157 | db_ops_status = DatabaseOperations() 158 | await ( 159 | db_ops_status.update_job_status( # update_job_status handles its own connection 160 | job_id=job_id, 161 | status="running", 162 | pages_discovered=len(page_results), 163 | pages_crawled=len(processed_page_ids), 164 | ) 165 | ) 166 | 167 | except Exception as page_error: 168 | logger.error(f"Error in batch processing for {page_result.url}: {page_error!s}") 169 | continue 170 | 171 | return processed_page_ids 172 | -------------------------------------------------------------------------------- /src/web_service/api/jobs.py: -------------------------------------------------------------------------------- 1 | """Job API routes for the web service.""" 2 | 3 | import asyncio 4 | 5 | import redis 6 | from fastapi import APIRouter, Depends, HTTPException, Query 7 | from rq import Queue 8 | 9 | from src.common.config import REDIS_URI 10 | from src.common.logger import get_logger 11 | from src.common.models import ( 12 | FetchUrlRequest, 13 | FetchUrlResponse, 14 | JobProgressResponse, 15 | ) 16 | from src.lib.database import DatabaseOperations 17 | from src.web_service.services.job_service import ( 18 | fetch_url, 19 | get_job_count, 20 | get_job_progress, 21 | ) 22 | 23 | # Get logger for this module 24 | logger = get_logger(__name__) 25 | 26 | # Create router 27 | router = APIRouter(tags=["jobs"]) 28 | 29 | 30 | @router.post("/fetch_url", response_model=FetchUrlResponse, operation_id="fetch_url") 31 | async def fetch_url_endpoint( 32 | request: FetchUrlRequest, 33 | queue: Queue = Depends(lambda: Queue("worker", connection=redis.from_url(REDIS_URI))), 34 | ): 35 | """Initiate a fetch job to crawl a website. 36 | 37 | Args: 38 | request: The fetch URL request 39 | 40 | Returns: 41 | The job ID 42 | 43 | """ 44 | logger.info(f"API: Initiating fetch for URL: {request.url}") 45 | 46 | try: 47 | # Call the service function 48 | response = await fetch_url( 49 | queue=queue, 50 | url=request.url, 51 | tags=request.tags, 52 | max_pages=request.max_pages, 53 | ) 54 | return response 55 | except Exception as e: 56 | logger.error(f"Error initiating fetch: {e!s}") 57 | raise HTTPException(status_code=500, detail=f"Error initiating fetch: {e!s}") 58 | 59 | 60 | @router.get("/job_progress", response_model=JobProgressResponse, operation_id="job_progress") 61 | async def job_progress_endpoint( 62 | job_id: str = Query(..., description="The job ID to check progress for"), 63 | ): 64 | """Check the progress of a job. 65 | 66 | Args: 67 | job_id: The job ID to check progress for 68 | 69 | Returns: 70 | Job progress information 71 | 72 | """ 73 | logger.info(f"API: BEGIN job_progress for job {job_id}") 74 | 75 | # Create a fresh connection to get the latest data 76 | # This ensures we always get the most recent job state 77 | attempts = 0 78 | max_attempts = 3 79 | retry_delay = 0.1 # seconds 80 | 81 | logger.info(f"Starting check loop for job {job_id} (max_attempts={max_attempts})") 82 | while attempts < max_attempts: 83 | attempts += 1 84 | db_ops = DatabaseOperations() # Instantiate inside the loop 85 | 86 | try: 87 | with db_ops.db as conn_manager: # Use DuckDBConnectionManager as context manager 88 | actual_conn = conn_manager.conn 89 | if not actual_conn: 90 | # This case should ideally be handled by DuckDBConnectionManager.connect() raising an error 91 | logger.error( 92 | f"Attempt {attempts}: Failed to obtain DB connection from manager for job {job_id}." 93 | ) 94 | if attempts >= max_attempts: 95 | raise HTTPException( 96 | status_code=500, detail="Database connection error after retries." 97 | ) 98 | await asyncio.sleep(retry_delay) 99 | retry_delay *= 2 100 | continue # To next attempt 101 | 102 | logger.info( 103 | f"Established fresh connection to database (attempt {attempts}) for job {job_id}" 104 | ) 105 | 106 | # Call the service function 107 | result = await get_job_progress(actual_conn, job_id) 108 | 109 | if not result: 110 | # Job not found on this attempt 111 | logger.warning( 112 | f"Job {job_id} not found (attempt {attempts}/{max_attempts}). Retrying in {retry_delay}s...", 113 | ) 114 | # Connection is closed by conn_manager context exit 115 | if attempts >= max_attempts: 116 | logger.warning(f"Job {job_id} not found after {max_attempts} attempts.") 117 | break # Exit the while loop to raise 404 118 | await asyncio.sleep(retry_delay) 119 | retry_delay *= 2 120 | continue # Go to next iteration of the while loop 121 | 122 | # Job found, return the result 123 | return result 124 | 125 | except HTTPException: 126 | # Re-raise HTTP exceptions as-is 127 | logger.warning( 128 | f"HTTPException occurred during attempt {attempts} for job {job_id}, re-raising.", 129 | ) 130 | raise 131 | except Exception as e: # Includes duckdb.Error if connect fails in conn_manager 132 | logger.error( 133 | f"Error checking job progress (attempt {attempts}) for job {job_id}: {e!s}", 134 | ) 135 | if attempts < max_attempts: 136 | logger.info(f"Non-fatal error on attempt {attempts}, will retry after delay.") 137 | # Connection (if opened by conn_manager) is closed on context exit 138 | await asyncio.sleep(retry_delay) 139 | retry_delay *= 2 140 | continue 141 | logger.error(f"Error on final attempt ({attempts}) for job {job_id}. Raising 500.") 142 | raise HTTPException( 143 | status_code=500, 144 | detail=f"Database error after retries: {e!s}", 145 | ) 146 | # The 'finally' block for closing 'conn' is removed as conn_manager handles it. 147 | 148 | # If the loop finished without finding the job (i.e., break was hit after max attempts) 149 | # raise the 404 error. 150 | # Check job count for debugging before raising 404 151 | job_count = await get_job_count() 152 | logger.info(f"Database contains {job_count} total jobs during final 404 check.") 153 | 154 | logger.warning(f"Raising 404 for job {job_id} after all retries.") # Add log before raising 155 | raise HTTPException(status_code=404, detail=f"Job {job_id} not found after retries") 156 | -------------------------------------------------------------------------------- /src/common/models.py: -------------------------------------------------------------------------------- 1 | """Pydantic models for the Doctor project.""" 2 | 3 | from datetime import datetime 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class FetchUrlRequest(BaseModel): 9 | """Request model for the /fetch_url endpoint.""" 10 | 11 | url: str = Field(..., description="The URL to start indexing from") 12 | tags: list[str] | None = Field(default=None, description="Tags to assign this website") 13 | max_pages: int = Field(default=100, description="How many pages to index", ge=1, le=1000) 14 | 15 | 16 | class FetchUrlResponse(BaseModel): 17 | """Response model for the /fetch_url endpoint.""" 18 | 19 | job_id: str = Field(..., description="The job ID of the index job") 20 | 21 | 22 | class SearchDocsRequest(BaseModel): 23 | """Request model for the /search_docs endpoint.""" 24 | 25 | query: str = Field(..., description="The search string to query the database with") 26 | tags: list[str] | None = Field(default=None, description="Tags to limit the search with") 27 | max_results: int = Field( 28 | default=10, 29 | description="Maximum number of results to return", 30 | ge=1, 31 | le=100, 32 | ) 33 | 34 | 35 | class SearchResult(BaseModel): 36 | """A single search result.""" 37 | 38 | chunk_text: str = Field(..., description="The text of the chunk") 39 | page_id: str = Field(..., description="Reference to the original page") 40 | tags: list[str] = Field(default_factory=list, description="Tags associated with the chunk") 41 | score: float = Field(..., description="Similarity score") 42 | url: str = Field(..., description="Original URL of the page") 43 | 44 | 45 | class SearchDocsResponse(BaseModel): 46 | """Response model for the /search_docs endpoint.""" 47 | 48 | results: list[SearchResult] = Field(default_factory=list, description="Search results") 49 | 50 | 51 | class JobProgressRequest(BaseModel): 52 | """Request model for the /job_progress endpoint.""" 53 | 54 | job_id: str = Field(..., description="The job ID to check progress for") 55 | 56 | 57 | class JobProgressResponse(BaseModel): 58 | """Response model for the /job_progress endpoint.""" 59 | 60 | pages_crawled: int = Field(..., description="Number of pages crawled so far") 61 | pages_total: int = Field(..., description="Total number of pages discovered") 62 | completed: bool = Field(..., description="Whether the job is completed") 63 | status: str = Field(..., description="Current job status") 64 | error_message: str | None = Field(default=None, description="Error message if job failed") 65 | progress_percent: int | None = Field( 66 | default=None, 67 | description="Percentage of crawl completed", 68 | ) 69 | url: str | None = Field(default=None, description="URL being crawled") 70 | max_pages: int | None = Field(default=None, description="Maximum pages to crawl") 71 | created_at: datetime | None = Field(default=None, description="When the job was created") 72 | updated_at: datetime | None = Field( 73 | default=None, 74 | description="When the job was last updated", 75 | ) 76 | 77 | 78 | class ListDocPagesRequest(BaseModel): 79 | """Request model for the /list_doc_pages endpoint.""" 80 | 81 | page: int = Field(default=1, description="Page number", ge=1) 82 | tags: list[str] | None = Field(default=None, description="Tags to filter by") 83 | 84 | 85 | class DocPageSummary(BaseModel): 86 | """Summary information about a document page.""" 87 | 88 | page_id: str = Field(..., description="Unique page ID") 89 | domain: str = Field(..., description="Domain of the page") 90 | tags: list[str] = Field(default_factory=list, description="Tags associated with the page") 91 | crawl_date: datetime = Field(..., description="When the page was crawled") 92 | url: str = Field(..., description="URL of the page") 93 | 94 | 95 | class ListDocPagesResponse(BaseModel): 96 | """Response model for the /list_doc_pages endpoint.""" 97 | 98 | doc_pages: list[DocPageSummary] = Field( 99 | default_factory=list, 100 | description="List of document pages", 101 | ) 102 | total_pages: int = Field(..., description="Total number of pages matching the query") 103 | current_page: int = Field(..., description="Current page number") 104 | pages_per_page: int = Field(default=100, description="Number of items per page") 105 | 106 | 107 | class GetDocPageRequest(BaseModel): 108 | """Request model for the /get_doc_page endpoint.""" 109 | 110 | page_id: str = Field(..., description="The page ID to retrieve") 111 | starting_line: int = Field(default=1, description="Line to view from", ge=1) 112 | ending_line: int = Field(default=100, description="Line to view up to", ge=1) 113 | 114 | 115 | class GetDocPageResponse(BaseModel): 116 | """Response model for the /get_doc_page endpoint.""" 117 | 118 | text: str = Field(..., description="The document page text") 119 | total_lines: int = Field(..., description="Total number of lines in the document") 120 | 121 | 122 | class Job(BaseModel): 123 | """Internal model for a crawl job.""" 124 | 125 | job_id: str 126 | start_url: str 127 | status: str = "pending" # pending, running, completed, failed 128 | pages_discovered: int = 0 129 | pages_crawled: int = 0 130 | max_pages: int 131 | tags: list[str] = Field(default_factory=list) 132 | created_at: datetime = Field(default_factory=datetime.now) 133 | updated_at: datetime = Field(default_factory=datetime.now) 134 | error_message: str | None = None 135 | 136 | 137 | class Page(BaseModel): 138 | """Internal model for a crawled page.""" 139 | 140 | id: str 141 | url: str 142 | domain: str 143 | raw_text: str 144 | crawl_date: datetime = Field(default_factory=datetime.now) 145 | tags: list[str] = Field(default_factory=list) 146 | 147 | 148 | class Chunk(BaseModel): 149 | """Internal model for a text chunk with its embedding.""" 150 | 151 | id: str 152 | text: str 153 | page_id: str 154 | domain: str 155 | tags: list[str] = Field(default_factory=list) 156 | embedding: list[float] 157 | 158 | 159 | class DeleteDocsRequest(BaseModel): 160 | """Request model for the /delete_docs endpoint.""" 161 | 162 | tags: list[str] | None = Field(default=None, description="Tags to filter by") 163 | domain: str | None = Field(default=None, description="Domain substring to filter by") 164 | page_ids: list[str] | None = Field(default=None, description="Specific page IDs to delete") 165 | 166 | 167 | class ListTagsResponse(BaseModel): 168 | """Response model for the /list_tags endpoint.""" 169 | 170 | tags: list[str] = Field( 171 | default_factory=list, 172 | description="List of all unique tags in the database", 173 | ) 174 | -------------------------------------------------------------------------------- /tests/api/test_map_api.py: -------------------------------------------------------------------------------- 1 | """Tests for the map API endpoints.""" 2 | 3 | import pytest 4 | from fastapi.testclient import TestClient 5 | from unittest.mock import AsyncMock, patch 6 | 7 | from src.web_service.main import app 8 | 9 | 10 | @pytest.fixture 11 | def client(): 12 | """Create a test client for the FastAPI app. 13 | 14 | Args: 15 | None. 16 | 17 | Returns: 18 | TestClient: Test client instance. 19 | """ 20 | return TestClient(app) 21 | 22 | 23 | class TestMapAPI: 24 | """Test the map API endpoints.""" 25 | 26 | def test_get_site_index(self, client: TestClient) -> None: 27 | """Test the /map endpoint for site index. 28 | 29 | Args: 30 | client: The test client. 31 | 32 | Returns: 33 | None. 34 | """ 35 | # Mock the map service 36 | with patch("src.web_service.api.map.MapService") as MockService: 37 | mock_service = MockService.return_value 38 | mock_service.get_all_sites = AsyncMock( 39 | return_value=[ 40 | {"id": "site1", "title": "Site 1", "url": "https://example1.com"}, 41 | {"id": "site2", "title": "Site 2", "url": "https://example2.com"}, 42 | ] 43 | ) 44 | mock_service.format_site_list.return_value = "Site List" 45 | 46 | response = client.get("/map") 47 | 48 | assert response.status_code == 200 49 | assert response.headers["content-type"] == "text/html; charset=utf-8" 50 | assert response.text == "Site List" 51 | 52 | def test_get_site_index_error(self, client: TestClient) -> None: 53 | """Test /map endpoint error handling. 54 | 55 | Args: 56 | client: The test client. 57 | 58 | Returns: 59 | None. 60 | """ 61 | with patch("src.web_service.api.map.MapService") as MockService: 62 | mock_service = MockService.return_value 63 | mock_service.get_all_sites = AsyncMock(side_effect=Exception("Database error")) 64 | 65 | response = client.get("/map") 66 | 67 | assert response.status_code == 500 68 | assert "Database error" in response.json()["detail"] 69 | 70 | def test_get_site_tree(self, client: TestClient) -> None: 71 | """Test the /map/site/{root_page_id} endpoint. 72 | 73 | Args: 74 | client: The test client. 75 | 76 | Returns: 77 | None. 78 | """ 79 | root_id = "root-123" 80 | 81 | with patch("src.web_service.api.map.MapService") as MockService: 82 | mock_service = MockService.return_value 83 | mock_service.build_page_tree = AsyncMock( 84 | return_value={"id": root_id, "title": "Test Site", "children": []} 85 | ) 86 | mock_service.format_site_tree.return_value = "Site Tree" 87 | 88 | response = client.get(f"/map/site/{root_id}") 89 | 90 | assert response.status_code == 200 91 | assert response.headers["content-type"] == "text/html; charset=utf-8" 92 | assert response.text == "Site Tree" 93 | 94 | def test_view_page(self, client: TestClient) -> None: 95 | """Test the /map/page/{page_id} endpoint. 96 | 97 | Args: 98 | client: The test client. 99 | 100 | Returns: 101 | None. 102 | """ 103 | page_id = "page-123" 104 | 105 | with patch("src.web_service.api.map.MapService") as MockService: 106 | mock_service = MockService.return_value 107 | mock_service.db_ops.get_page_by_id = AsyncMock( 108 | return_value={ 109 | "id": page_id, 110 | "title": "Test Page", 111 | "raw_text": "# Test Content", 112 | } 113 | ) 114 | mock_service.get_navigation_context = AsyncMock( 115 | return_value={ 116 | "current_page": {"id": page_id, "title": "Test Page"}, 117 | "parent": None, 118 | "siblings": [], 119 | "children": [], 120 | "root": None, 121 | } 122 | ) 123 | mock_service.render_page_html.return_value = "Rendered Page" 124 | 125 | response = client.get(f"/map/page/{page_id}") 126 | 127 | assert response.status_code == 200 128 | assert response.headers["content-type"] == "text/html; charset=utf-8" 129 | assert response.text == "Rendered Page" 130 | 131 | def test_view_page_not_found(self, client: TestClient) -> None: 132 | """Test viewing a non-existent page. 133 | 134 | Args: 135 | client: The test client. 136 | 137 | Returns: 138 | None. 139 | """ 140 | page_id = "nonexistent" 141 | 142 | with patch("src.web_service.api.map.MapService") as MockService: 143 | mock_service = MockService.return_value 144 | mock_service.db_ops.get_page_by_id = AsyncMock(return_value=None) 145 | 146 | response = client.get(f"/map/page/{page_id}") 147 | 148 | assert response.status_code == 404 149 | assert response.json()["detail"] == "Page not found" 150 | 151 | def test_get_page_raw(self, client: TestClient) -> None: 152 | """Test the /map/page/{page_id}/raw endpoint. 153 | 154 | Args: 155 | client: The test client. 156 | 157 | Returns: 158 | None. 159 | """ 160 | page_id = "page-123" 161 | raw_content = "# Test Page\n\nThis is markdown content." 162 | 163 | with patch("src.web_service.api.map.MapService") as MockService: 164 | mock_service = MockService.return_value 165 | mock_service.db_ops.get_page_by_id = AsyncMock( 166 | return_value={ 167 | "id": page_id, 168 | "title": "Test Page", 169 | "raw_text": raw_content, 170 | } 171 | ) 172 | 173 | response = client.get(f"/map/page/{page_id}/raw") 174 | 175 | assert response.status_code == 200 176 | assert response.headers["content-type"] == "text/markdown; charset=utf-8" 177 | assert response.text == raw_content 178 | assert 'filename="Test Page.md"' in response.headers["content-disposition"] 179 | 180 | def test_get_page_raw_not_found(self, client: TestClient) -> None: 181 | """Test getting raw content for non-existent page. 182 | 183 | Args: 184 | client: The test client. 185 | 186 | Returns: 187 | None. 188 | """ 189 | page_id = "nonexistent" 190 | 191 | with patch("src.web_service.api.map.MapService") as MockService: 192 | mock_service = MockService.return_value 193 | mock_service.db_ops.get_page_by_id = AsyncMock(return_value=None) 194 | 195 | response = client.get(f"/map/page/{page_id}/raw") 196 | 197 | assert response.status_code == 404 198 | assert response.json()["detail"] == "Page not found" 199 | -------------------------------------------------------------------------------- /tests/services/test_map_service_legacy.py: -------------------------------------------------------------------------------- 1 | """Tests for the map service legacy page handling.""" 2 | 3 | import pytest 4 | from unittest.mock import AsyncMock, patch 5 | import datetime 6 | 7 | from src.web_service.services.map_service import MapService 8 | 9 | 10 | @pytest.mark.asyncio 11 | class TestMapServiceLegacy: 12 | """Test the MapService legacy page handling.""" 13 | 14 | async def test_get_all_sites_with_legacy(self) -> None: 15 | """Test getting all sites including legacy domain groups. 16 | 17 | Args: 18 | None. 19 | 20 | Returns: 21 | None. 22 | """ 23 | service = MapService() 24 | 25 | # Mock root pages (pages with hierarchy) 26 | mock_root_pages = [ 27 | { 28 | "id": "site1", 29 | "url": "https://docs.example.com", 30 | "title": "Documentation Site", 31 | "domain": "docs.example.com", 32 | "crawl_date": datetime.datetime(2024, 1, 2), 33 | } 34 | ] 35 | 36 | # Mock legacy pages 37 | mock_legacy_pages = [ 38 | { 39 | "id": "page1", 40 | "url": "https://blog.example.com/post1", 41 | "domain": "blog.example.com", 42 | "title": "Blog Post 1", 43 | "crawl_date": datetime.datetime(2024, 1, 1), 44 | "root_page_id": None, 45 | "parent_page_id": None, 46 | }, 47 | { 48 | "id": "page2", 49 | "url": "https://blog.example.com/post2", 50 | "domain": "blog.example.com", 51 | "title": "Blog Post 2", 52 | "crawl_date": datetime.datetime(2024, 1, 3), 53 | "root_page_id": None, 54 | "parent_page_id": None, 55 | }, 56 | { 57 | "id": "page3", 58 | "url": "https://shop.example.com/product1", 59 | "domain": "shop.example.com", 60 | "title": "Product 1", 61 | "crawl_date": datetime.datetime(2024, 1, 1), 62 | "root_page_id": None, 63 | "parent_page_id": None, 64 | }, 65 | ] 66 | 67 | with patch.object( 68 | service.db_ops, "get_root_pages", new_callable=AsyncMock 69 | ) as mock_get_root: 70 | mock_get_root.return_value = mock_root_pages 71 | 72 | with patch.object( 73 | service.db_ops, "get_legacy_pages", new_callable=AsyncMock 74 | ) as mock_get_legacy: 75 | mock_get_legacy.return_value = mock_legacy_pages 76 | 77 | sites = await service.get_all_sites() 78 | 79 | # Should have 1 root page + 2 domain groups 80 | assert len(sites) == 3 81 | 82 | # Check that we have the regular site 83 | regular_sites = [s for s in sites if not s.get("is_synthetic")] 84 | assert len(regular_sites) == 1 85 | assert regular_sites[0]["title"] == "Documentation Site" 86 | 87 | # Check that we have domain groups 88 | domain_groups = [s for s in sites if s.get("is_synthetic")] 89 | assert len(domain_groups) == 2 90 | 91 | # Check blog domain group 92 | blog_group = next(s for s in domain_groups if "blog.example.com" in s["title"]) 93 | assert blog_group["id"] == "legacy-domain-blog.example.com" 94 | assert blog_group["page_count"] == 2 95 | assert blog_group["is_synthetic"] is True 96 | 97 | # Check shop domain group 98 | shop_group = next(s for s in domain_groups if "shop.example.com" in s["title"]) 99 | assert shop_group["id"] == "legacy-domain-shop.example.com" 100 | assert shop_group["page_count"] == 1 101 | 102 | async def test_build_domain_tree(self) -> None: 103 | """Test building a tree for a domain group. 104 | 105 | Args: 106 | None. 107 | 108 | Returns: 109 | None. 110 | """ 111 | service = MapService() 112 | domain = "blog.example.com" 113 | 114 | # Mock pages for the domain 115 | mock_domain_pages = [ 116 | { 117 | "id": "page1", 118 | "url": "https://blog.example.com/post1", 119 | "title": "Post 1", 120 | "root_page_id": None, 121 | }, 122 | { 123 | "id": "page2", 124 | "url": "https://blog.example.com/post2", 125 | "title": "Post 2", 126 | "root_page_id": None, 127 | }, 128 | { 129 | "id": "page3", 130 | "url": "https://blog.example.com/post3", 131 | "title": "Post 3", 132 | "root_page_id": "page3", # This is a proper root, not legacy 133 | "parent_page_id": None, 134 | "depth": 0, 135 | }, 136 | ] 137 | 138 | with patch.object( 139 | service.db_ops, "get_pages_by_domain", new_callable=AsyncMock 140 | ) as mock_get: 141 | mock_get.return_value = mock_domain_pages 142 | 143 | tree = await service.build_page_tree(f"legacy-domain-{domain}") 144 | 145 | # Check the synthetic root 146 | assert tree["id"] == f"legacy-domain-{domain}" 147 | assert tree["is_synthetic"] is True 148 | assert tree["title"] == f"{domain} (3 Pages)" 149 | 150 | # Should only include legacy pages (not page3) 151 | assert len(tree["children"]) == 2 152 | assert all(child["id"] in ["page1", "page2"] for child in tree["children"]) 153 | 154 | # Children should be sorted by URL 155 | assert tree["children"][0]["url"] < tree["children"][1]["url"] 156 | 157 | def test_format_site_list_with_legacy(self) -> None: 158 | """Test formatting site list with legacy domain groups. 159 | 160 | Args: 161 | None. 162 | 163 | Returns: 164 | None. 165 | """ 166 | service = MapService() 167 | 168 | sites = [ 169 | { 170 | "id": "site1", 171 | "url": "https://docs.example.com", 172 | "title": "Documentation", 173 | "crawl_date": datetime.datetime(2024, 1, 1), 174 | "is_synthetic": False, 175 | }, 176 | { 177 | "id": "legacy-domain-blog.example.com", 178 | "url": "https://blog.example.com", 179 | "title": "blog.example.com (Legacy Pages)", 180 | "crawl_date": datetime.datetime(2024, 1, 1), 181 | "is_synthetic": True, 182 | "page_count": 5, 183 | }, 184 | ] 185 | 186 | html = service.format_site_list(sites) 187 | 188 | # Check regular site formatting 189 | assert "Documentation" in html 190 | assert "Crawled: 2024-01-01" in html 191 | 192 | # Check legacy domain group formatting 193 | assert "blog.example.com (Legacy Pages)" in html 194 | assert "Domain Group • 5 pages" in html 195 | assert "First crawled: 2024-01-01" in html 196 | -------------------------------------------------------------------------------- /src/common/processor_enhanced.py: -------------------------------------------------------------------------------- 1 | """Enhanced page processing pipeline with hierarchy tracking.""" 2 | 3 | import asyncio 4 | from itertools import islice 5 | from urllib.parse import urlparse 6 | 7 | from src.common.indexer import VectorIndexer 8 | from src.common.logger import get_logger 9 | from src.lib.chunker import TextChunker 10 | from src.lib.crawler import extract_page_text 11 | from src.lib.crawler_enhanced import CrawlResultWithHierarchy, crawl_url_with_hierarchy 12 | from src.lib.database import DatabaseOperations 13 | from src.lib.embedder import generate_embedding 14 | 15 | # Configure logging 16 | logger = get_logger(__name__) 17 | 18 | 19 | async def process_crawl_result_with_hierarchy( 20 | page_result: CrawlResultWithHierarchy, 21 | job_id: str, 22 | tags: list[str] | None = None, 23 | max_concurrent_embeddings: int = 5, 24 | url_to_page_id: dict[str, str] | None = None, 25 | ) -> str: 26 | """Process a single crawled page result with hierarchy information. 27 | 28 | Args: 29 | page_result: The enhanced crawl result with hierarchy info. 30 | job_id: The ID of the crawl job. 31 | tags: Optional tags to associate with the page. 32 | max_concurrent_embeddings: Maximum number of concurrent embedding generations. 33 | url_to_page_id: Optional mapping of URLs to page IDs for parent lookup. 34 | 35 | Returns: 36 | The ID of the processed page. 37 | """ 38 | if tags is None: 39 | tags = [] 40 | if url_to_page_id is None: 41 | url_to_page_id = {} 42 | 43 | try: 44 | logger.info(f"Processing page with hierarchy: {page_result.url}") 45 | 46 | # Extract text from the crawl result 47 | page_text = extract_page_text(page_result.base_result) 48 | 49 | # Look up parent page ID if we have a parent URL 50 | parent_page_id = None 51 | if page_result.parent_url and page_result.parent_url in url_to_page_id: 52 | parent_page_id = url_to_page_id[page_result.parent_url] 53 | 54 | # Look up root page ID 55 | root_page_id = None 56 | if page_result.root_url and page_result.root_url in url_to_page_id: 57 | root_page_id = url_to_page_id[page_result.root_url] 58 | 59 | # Store the page in the database with hierarchy info 60 | db_ops = DatabaseOperations() 61 | page_id = await db_ops.store_page( 62 | url=page_result.url, 63 | text=page_text, 64 | job_id=job_id, 65 | tags=tags, 66 | parent_page_id=parent_page_id, 67 | root_page_id=root_page_id, 68 | depth=page_result.depth, 69 | path=page_result.relative_path, 70 | title=page_result.title, 71 | ) 72 | 73 | # Store the mapping for future lookups 74 | url_to_page_id[page_result.url] = page_id 75 | 76 | # Initialize components 77 | chunker = TextChunker() 78 | indexer = VectorIndexer() 79 | 80 | # Split text into chunks 81 | chunks = chunker.split_text(page_text) 82 | logger.info(f"Split page into {len(chunks)} chunks") 83 | 84 | # Extract domain from URL 85 | domain = urlparse(page_result.url).netloc 86 | 87 | # Process chunks in parallel batches 88 | successful_chunks = 0 89 | 90 | async def process_chunk(chunk_text): 91 | try: 92 | # Generate embedding 93 | embedding = await generate_embedding(chunk_text) 94 | 95 | # Prepare payload 96 | payload = { 97 | "text": chunk_text, 98 | "page_id": page_id, 99 | "url": page_result.url, 100 | "domain": domain, 101 | "tags": tags, 102 | "job_id": job_id, 103 | } 104 | 105 | # Index the vector 106 | await indexer.index_vector(embedding, payload) 107 | return True 108 | except Exception as chunk_error: 109 | logger.error(f"Error processing chunk: {chunk_error!s}") 110 | return False 111 | 112 | # Process chunks in batches with limited concurrency 113 | i = 0 114 | while i < len(chunks): 115 | # Take up to max_concurrent_embeddings chunks 116 | batch_chunks = list(islice(chunks, i, i + max_concurrent_embeddings)) 117 | i += max_concurrent_embeddings 118 | 119 | # Process this batch in parallel 120 | results = await asyncio.gather( 121 | *[process_chunk(chunk) for chunk in batch_chunks], 122 | return_exceptions=False, 123 | ) 124 | 125 | successful_chunks += sum(1 for result in results if result) 126 | 127 | logger.info( 128 | f"Successfully indexed {successful_chunks}/{len(chunks)} chunks for page {page_id}", 129 | ) 130 | 131 | return page_id 132 | 133 | except Exception as e: 134 | logger.error(f"Error processing page {page_result.url}: {e!s}") 135 | raise 136 | 137 | 138 | async def process_crawl_with_hierarchy( 139 | url: str, 140 | job_id: str, 141 | tags: list[str] | None = None, 142 | max_pages: int = 100, 143 | max_depth: int = 2, 144 | strip_urls: bool = True, 145 | ) -> list[str]: 146 | """Crawl a URL and process all pages with hierarchy tracking. 147 | 148 | Args: 149 | url: The URL to start crawling from. 150 | job_id: The ID of the crawl job. 151 | tags: Optional tags to associate with the pages. 152 | max_pages: Maximum number of pages to crawl. 153 | max_depth: Maximum depth for the BFS crawl. 154 | strip_urls: Whether to strip URLs from the returned markdown. 155 | 156 | Returns: 157 | List of processed page IDs. 158 | """ 159 | if tags is None: 160 | tags = [] 161 | 162 | # Crawl with hierarchy tracking 163 | crawl_results = await crawl_url_with_hierarchy(url, max_pages, max_depth, strip_urls) 164 | 165 | # Sort results by depth to ensure parents are processed before children 166 | crawl_results.sort(key=lambda r: r.depth) 167 | 168 | # Map to track URL to page ID relationships 169 | url_to_page_id = {} 170 | processed_page_ids = [] 171 | 172 | # Process pages in order 173 | for i, page_result in enumerate(crawl_results): 174 | try: 175 | page_id = await process_crawl_result_with_hierarchy( 176 | page_result, job_id, tags, url_to_page_id=url_to_page_id 177 | ) 178 | processed_page_ids.append(page_id) 179 | 180 | # Update job progress 181 | db_ops_status = DatabaseOperations() 182 | await db_ops_status.update_job_status( 183 | job_id=job_id, 184 | status="running", 185 | pages_discovered=len(crawl_results), 186 | pages_crawled=len(processed_page_ids), 187 | ) 188 | 189 | except Exception as page_error: 190 | logger.error(f"Error processing page {page_result.url}: {page_error!s}") 191 | continue 192 | 193 | return processed_page_ids 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Doctor Logo 4 | 5 |
6 | 7 |

🩺 Doctor

8 | 9 | [![Python Version](https://img.shields.io/badge/python-%3E=3.12-3776ab?style=flat&labelColor=333333&logo=python&logoColor=white)](https://github.com/sisig-ai/doctor) 10 | [![License](https://img.shields.io/badge/license-MIT-00acc1?style=flat&labelColor=333333&logo=open-source-initiative&logoColor=white)](LICENSE.md) 11 | [![Python Tests](https://github.com/sisig-ai/doctor/actions/workflows/pytest.yml/badge.svg)](https://github.com/sisig-ai/doctor/actions/workflows/pytest.yml) 12 | [![codecov](https://codecov.io/gh/sisig-ai/doctor/branch/main/graph/badge.svg)](https://codecov.io/gh/sisig-ai/doctor) 13 | 14 | A tool for discovering, crawl, and indexing web sites to be exposed as an MCP server for LLM agents for better and more up-to-date reasoning and code generation. 15 | 16 |
17 | 18 | --- 19 | 20 | ### 🔍 Overview 21 | 22 | Doctor provides a complete stack for: 23 | - Crawling web pages using crawl4ai with hierarchy tracking 24 | - Chunking text with LangChain 25 | - Creating embeddings with OpenAI via litellm 26 | - Storing data in DuckDB with vector search support 27 | - Exposing search functionality via a FastAPI web service 28 | - Making these capabilities available to LLMs through an MCP server 29 | - Navigating crawled sites with hierarchical site maps 30 | 31 | --- 32 | 33 | ### 🏗️ Core Infrastructure 34 | 35 | #### 🗄️ DuckDB 36 | - Database for storing document data and embeddings with vector search capabilities 37 | - Managed by unified Database class 38 | 39 | #### 📨 Redis 40 | - Message broker for asynchronous task processing 41 | 42 | #### 🕸️ Crawl Worker 43 | - Processes crawl jobs 44 | - Chunks text 45 | - Creates embeddings 46 | 47 | #### 🌐 Web Server 48 | - FastAPI service exposing endpoints 49 | - Fetching, searching, and viewing data 50 | - Exposing the MCP server 51 | 52 | --- 53 | 54 | ### 💻 Setup 55 | 56 | #### ⚙️ Prerequisites 57 | - Docker and Docker Compose 58 | - Python 3.10+ 59 | - uv (Python package manager) 60 | - OpenAI API key 61 | 62 | #### 📦 Installation 63 | 1. Clone this repository 64 | 2. Set up environment variables: 65 | ``` 66 | export OPENAI_API_KEY=your-openai-key 67 | ``` 68 | 3. Run the stack: 69 | ``` 70 | docker compose up 71 | ``` 72 | 73 | --- 74 | 75 | ### 👁 Usage 76 | 1. Go to http://localhost:9111/docs to see the OpenAPI docs 77 | 2. Look for the `/fetch_url` endpoint and start a crawl job by providing a URL 78 | 3. Use `/job_progress` to see the current job status 79 | 4. Configure your editor to use `http://localhost:9111/mcp` as an MCP server 80 | 81 | --- 82 | 83 | ### ☁️ Web API 84 | 85 | #### Core Endpoints 86 | - `POST /fetch_url`: Start crawling a URL 87 | - `GET /search_docs`: Search indexed documents 88 | - `GET /job_progress`: Check crawl job progress 89 | - `GET /list_doc_pages`: List indexed pages 90 | - `GET /get_doc_page`: Get full text of a page 91 | 92 | #### Site Map Feature 93 | The Maps feature provides a hierarchical view of crawled websites, making it easy to navigate and explore the structure of indexed sites. 94 | 95 | **Endpoints:** 96 | - `GET /map`: View an index of all crawled sites 97 | - `GET /map/site/{root_page_id}`: View the hierarchical tree structure of a specific site 98 | - `GET /map/page/{page_id}`: View a specific page with navigation (parent, siblings, children) 99 | - `GET /map/page/{page_id}/raw`: Get the raw markdown content of a page 100 | 101 | **Features:** 102 | - **Hierarchical Navigation**: Pages maintain parent-child relationships, allowing you to navigate through the site structure 103 | - **Domain Grouping**: Pages from the same domain crawled individually are automatically grouped together 104 | - **Automatic Title Extraction**: Page titles are extracted from HTML or markdown content 105 | - **Breadcrumb Navigation**: Easy navigation with breadcrumbs showing the path from root to current page 106 | - **Sibling Navigation**: Quick access to pages at the same level in the hierarchy 107 | - **Legacy Page Support**: Pages crawled before hierarchy tracking are grouped by domain for easy access 108 | - **No JavaScript Required**: All navigation works with pure HTML and CSS for maximum compatibility 109 | 110 | **Usage Example:** 111 | 1. Crawl a website using the `/fetch_url` endpoint 112 | 2. Visit `/map` to see all crawled sites 113 | 3. Click on a site to view its hierarchical structure 114 | 4. Navigate through pages using the provided links 115 | 116 | --- 117 | 118 | ### 🔧 MCP Integration 119 | Ensure that your Docker Compose stack is up, and then add to your Cursor or VSCode MCP Servers configuration: 120 | 121 | ```json 122 | "doctor": { 123 | "type": "sse", 124 | "url": "http://localhost:9111/mcp" 125 | } 126 | ``` 127 | 128 | --- 129 | 130 | ### 🧪 Testing 131 | 132 | #### Running Tests 133 | To run all tests: 134 | ```bash 135 | # Run all tests with coverage report 136 | pytest 137 | ``` 138 | 139 | To run specific test categories: 140 | ```bash 141 | # Run only unit tests 142 | pytest -m unit 143 | 144 | # Run only async tests 145 | pytest -m async_test 146 | 147 | # Run tests for a specific component 148 | pytest tests/lib/test_crawler.py 149 | ``` 150 | 151 | #### Test Coverage 152 | The project is configured to generate coverage reports automatically: 153 | ```bash 154 | # Run tests with detailed coverage report 155 | pytest --cov=src --cov-report=term-missing 156 | ``` 157 | 158 | #### Test Structure 159 | - `tests/conftest.py`: Common fixtures for all tests 160 | - `tests/lib/`: Tests for library components 161 | - `test_crawler.py`: Tests for the crawler module 162 | - `test_crawler_enhanced.py`: Tests for enhanced crawler with hierarchy tracking 163 | - `test_chunker.py`: Tests for the chunker module 164 | - `test_embedder.py`: Tests for the embedder module 165 | - `test_database.py`: Tests for the unified Database class 166 | - `test_database_hierarchy.py`: Tests for database hierarchy operations 167 | - `tests/common/`: Tests for common modules 168 | - `tests/services/`: Tests for service layer 169 | - `test_map_service.py`: Tests for the map service 170 | - `tests/api/`: Tests for API endpoints 171 | - `test_map_api.py`: Tests for map API endpoints 172 | - `tests/integration/`: Integration tests 173 | - `test_processor_enhanced.py`: Tests for enhanced processor with hierarchy 174 | 175 | --- 176 | 177 | ### 🐞 Code Quality 178 | 179 | #### Pre-commit Hooks 180 | The project is configured with pre-commit hooks that run automatically before each commit: 181 | - `ruff check --fix`: Lints code and automatically fixes issues 182 | - `ruff format`: Formats code according to project style 183 | - Trailing whitespace removal 184 | - End-of-file fixing 185 | - YAML validation 186 | - Large file checks 187 | 188 | #### Setup Pre-commit 189 | To set up pre-commit hooks: 190 | ```bash 191 | # Install pre-commit 192 | uv pip install pre-commit 193 | 194 | # Install the git hooks 195 | pre-commit install 196 | ``` 197 | 198 | #### Running Pre-commit Manually 199 | You can run the pre-commit hooks manually on all files: 200 | ```bash 201 | # Run all pre-commit hooks 202 | pre-commit run --all-files 203 | ``` 204 | 205 | Or on staged files only: 206 | ```bash 207 | # Run on staged files 208 | pre-commit run 209 | ``` 210 | 211 | --- 212 | 213 | ### ⚖️ License 214 | This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details. 215 | -------------------------------------------------------------------------------- /src/web_service/api/documents.py: -------------------------------------------------------------------------------- 1 | """Document API routes for the web service.""" 2 | 3 | from fastapi import APIRouter, HTTPException, Query 4 | 5 | from src.common.config import RETURN_FULL_DOCUMENT_TEXT 6 | from src.common.logger import get_logger 7 | from src.common.models import ( 8 | GetDocPageResponse, 9 | ListDocPagesResponse, 10 | ListTagsResponse, 11 | SearchDocsResponse, 12 | ) 13 | from src.lib.database import DatabaseOperations 14 | from src.web_service.services.document_service import ( 15 | get_doc_page, 16 | list_doc_pages, 17 | list_tags, 18 | search_docs, 19 | ) 20 | 21 | # Get logger for this module 22 | logger = get_logger(__name__) 23 | 24 | # Create router 25 | router = APIRouter(tags=["documents"]) 26 | 27 | 28 | @router.get("/search_docs", response_model=SearchDocsResponse, operation_id="search_docs") 29 | async def search_docs_endpoint( 30 | query: str = Query(..., description="The search string to query the database with"), 31 | tags: list[str] | None = Query(None, description="Tags to limit the search with"), 32 | max_results: int = Query(10, description="Maximum number of results to return", ge=1, le=100), 33 | return_full_document_text: bool = Query( 34 | RETURN_FULL_DOCUMENT_TEXT, 35 | description="Whether to return the full document text instead of the matching chunks only", 36 | ), 37 | ): 38 | """Search for documents using semantic search. Use `get_doc_page` to get the full text of a document page. 39 | 40 | Args: 41 | query: The search query 42 | tags: Optional tags to filter by 43 | max_results: Maximum number of results to return 44 | 45 | Returns: 46 | Search results 47 | 48 | """ 49 | logger.info( 50 | f"API: Searching docs with query: '{query}', tags: {tags}, max_results: {max_results}, return_full_document_text: {return_full_document_text}", 51 | ) 52 | 53 | db_ops = DatabaseOperations() 54 | try: 55 | # Use DuckDBConnectionManager as a context manager 56 | with db_ops.db as conn_manager: 57 | actual_conn = conn_manager.conn 58 | if not actual_conn: 59 | logger.error("Failed to obtain database connection for search_docs.") 60 | raise HTTPException(status_code=500, detail="Database connection error.") 61 | # Call the service function 62 | response = await search_docs( 63 | actual_conn, query, tags, max_results, return_full_document_text 64 | ) 65 | return response 66 | except HTTPException: # Re-raise HTTP exceptions directly 67 | raise 68 | except ( 69 | Exception 70 | ) as e: # Catch other exceptions (like DB errors from service or connection issues) 71 | logger.error(f"Error searching documents: {e!s}") 72 | # Log the specific error, but return a generic server error to the client 73 | raise HTTPException(status_code=500, detail=f"Search error: {e!s}") 74 | # No finally block to close connection, context manager handles it. 75 | 76 | 77 | @router.get("/list_doc_pages", response_model=ListDocPagesResponse, operation_id="list_doc_pages") 78 | async def list_doc_pages_endpoint( 79 | page: int = Query(1, description="Page number", ge=1), 80 | tags: list[str] | None = Query(None, description="Tags to filter by"), 81 | ): 82 | """List all available indexed pages. 83 | 84 | Args: 85 | page: Page number (1-based) 86 | tags: Optional tags to filter by 87 | 88 | Returns: 89 | List of document pages 90 | 91 | """ 92 | logger.info(f"API: Listing document pages (page={page}, tags={tags})") 93 | 94 | db_ops = DatabaseOperations() 95 | try: 96 | with db_ops.db as conn_manager: 97 | actual_conn = conn_manager.conn 98 | if not actual_conn: 99 | logger.error("Failed to obtain database connection for list_doc_pages.") 100 | raise HTTPException(status_code=500, detail="Database connection error.") 101 | # Call the service function 102 | response = await list_doc_pages(actual_conn, page, tags) 103 | return response 104 | except HTTPException: 105 | raise 106 | except Exception as e: 107 | logger.error(f"Error listing document pages: {e!s}") 108 | raise HTTPException(status_code=500, detail=f"Database error: {e!s}") 109 | # No finally block to close connection 110 | 111 | 112 | @router.get("/get_doc_page", response_model=GetDocPageResponse, operation_id="get_doc_page") 113 | async def get_doc_page_endpoint( 114 | page_id: str = Query(..., description="The page ID to retrieve"), 115 | starting_line: int = Query(1, description="Line to view from", ge=1), 116 | ending_line: int = Query( 117 | -1, 118 | description="Line to view up to. Set to -1 to view the entire page.", 119 | ), 120 | ): 121 | """Get the full text of a document page. Use `search_docs` or `list_doc_pages` to get the page IDs. 122 | 123 | Args: 124 | page_id: The page ID 125 | starting_line: Line to view from (1-based) 126 | ending_line: Line to view up to (1-based). Set to -1 to view the entire page. 127 | 128 | Returns: 129 | Document page text 130 | 131 | """ 132 | logger.info(f"API: Retrieving document page {page_id} (lines {starting_line}-{ending_line})") 133 | 134 | db_ops = DatabaseOperations() 135 | try: 136 | with db_ops.db as conn_manager: 137 | actual_conn = conn_manager.conn 138 | if not actual_conn: 139 | logger.error("Failed to obtain database connection for get_doc_page.") 140 | raise HTTPException(status_code=500, detail="Database connection error.") 141 | # Call the service function 142 | response = await get_doc_page(actual_conn, page_id, starting_line, ending_line) 143 | if response is None: 144 | raise HTTPException(status_code=404, detail=f"Page {page_id} not found") 145 | return response 146 | except HTTPException: 147 | # Re-raise HTTP exceptions as-is 148 | raise 149 | except Exception as e: 150 | logger.error(f"Error retrieving document page {page_id}: {e!s}") 151 | raise HTTPException(status_code=500, detail=f"Database error: {e!s}") 152 | # No finally block to close connection 153 | 154 | 155 | @router.get("/list_tags", response_model=ListTagsResponse, operation_id="list_tags") 156 | async def list_tags_endpoint( 157 | search_substring: str | None = Query( 158 | None, 159 | description="Optional substring to filter tags (case-insensitive fuzzy matching)", 160 | ), 161 | ): 162 | """List all unique tags available in the document database. Use `search_docs` or `list_doc_pages` to get the page IDs using the tags. 163 | 164 | Args: 165 | search_substring: Optional substring to filter tags using case-insensitive fuzzy matching 166 | 167 | Returns: 168 | List of unique tags 169 | 170 | """ 171 | logger.info( 172 | f"API: Listing all unique document tags{' with filter: ' + search_substring if search_substring else ''}", 173 | ) 174 | 175 | db_ops = DatabaseOperations() 176 | try: 177 | with db_ops.db as conn_manager: 178 | actual_conn = conn_manager.conn 179 | if not actual_conn: 180 | logger.error("Failed to obtain database connection for list_tags.") 181 | raise HTTPException(status_code=500, detail="Database connection error.") 182 | # Call the service function 183 | response = await list_tags(actual_conn, search_substring) 184 | return response 185 | except HTTPException: 186 | raise 187 | except Exception as e: 188 | logger.error(f"Error listing document tags: {e!s}") 189 | raise HTTPException(status_code=500, detail=f"Database error: {e!s}") 190 | # No finally block to close connection 191 | -------------------------------------------------------------------------------- /tests/common/test_processor.py: -------------------------------------------------------------------------------- 1 | """Tests for the processor module.""" 2 | 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | import pytest 6 | 7 | from src.common.processor import process_crawl_result, process_page_batch 8 | 9 | 10 | @pytest.fixture 11 | def mock_processor_dependencies(): 12 | """Setup mock dependencies for processor tests.""" 13 | extract_page_text_mock = MagicMock() 14 | extract_page_text_mock.return_value = "Extracted text content" 15 | 16 | # Mock for DatabaseOperations class store_page method 17 | store_page_mock = AsyncMock() 18 | store_page_mock.return_value = "test-page-123" 19 | 20 | # Mock for DatabaseOperations class update_job_status method 21 | update_job_status_mock = MagicMock() 22 | 23 | # Create a DatabaseOperations class mock that can be both used directly 24 | # and as a context manager 25 | database_mock = MagicMock( 26 | __enter__=lambda self: self, 27 | __exit__=lambda *args: None, 28 | store_page=store_page_mock, 29 | update_job_status=update_job_status_mock, 30 | ) 31 | 32 | generate_embedding_mock = AsyncMock() 33 | generate_embedding_mock.return_value = [0.1, 0.2, 0.3, 0.4, 0.5] 34 | 35 | mocks = { 36 | "extract_page_text": extract_page_text_mock, 37 | "database": database_mock, 38 | "store_page": store_page_mock, 39 | "TextChunker": MagicMock(), 40 | "generate_embedding": generate_embedding_mock, 41 | "VectorIndexer": MagicMock(), 42 | "update_job_status": update_job_status_mock, 43 | } 44 | 45 | chunker_instance = MagicMock() 46 | chunker_instance.split_text.return_value = ["Chunk 1", "Chunk 2", "Chunk 3"] 47 | mocks["TextChunker"].return_value = chunker_instance 48 | 49 | indexer_instance = MagicMock() 50 | indexer_instance.index_vector = AsyncMock(return_value="vector-id-123") 51 | mocks["VectorIndexer"].return_value = indexer_instance 52 | 53 | return mocks 54 | 55 | 56 | @pytest.mark.unit 57 | @pytest.mark.async_test 58 | async def test_process_crawl_result( 59 | sample_crawl_result, 60 | job_id, 61 | sample_tags, 62 | mock_processor_dependencies, 63 | ): 64 | """Test processing a single crawl result.""" 65 | with ( 66 | patch( 67 | "src.common.processor.extract_page_text", 68 | mock_processor_dependencies["extract_page_text"], 69 | ), 70 | patch( 71 | "src.common.processor.DatabaseOperations", 72 | return_value=mock_processor_dependencies["database"], 73 | ), 74 | patch("src.common.processor.TextChunker", mock_processor_dependencies["TextChunker"]), 75 | patch( 76 | "src.common.processor.generate_embedding", 77 | mock_processor_dependencies["generate_embedding"], 78 | ), 79 | patch("src.common.processor.VectorIndexer", mock_processor_dependencies["VectorIndexer"]), 80 | ): 81 | page_id = await process_crawl_result( 82 | page_result=sample_crawl_result, 83 | job_id=job_id, 84 | tags=sample_tags, 85 | ) 86 | 87 | mock_processor_dependencies["extract_page_text"].assert_called_once_with( 88 | sample_crawl_result, 89 | ) 90 | mock_processor_dependencies["store_page"].assert_called_once_with( 91 | url=sample_crawl_result.url, 92 | text="Extracted text content", 93 | job_id=job_id, 94 | tags=sample_tags, 95 | ) 96 | chunker_instance = mock_processor_dependencies["TextChunker"].return_value 97 | chunker_instance.split_text.assert_called_once_with("Extracted text content") 98 | assert mock_processor_dependencies["generate_embedding"].call_count == 3 99 | indexer_instance = mock_processor_dependencies["VectorIndexer"].return_value 100 | assert indexer_instance.index_vector.call_count == 3 101 | assert page_id == "test-page-123" 102 | 103 | 104 | @pytest.mark.unit 105 | @pytest.mark.async_test 106 | async def test_process_crawl_result_with_errors( 107 | sample_crawl_result, 108 | job_id, 109 | sample_tags, 110 | mock_processor_dependencies, 111 | ): 112 | """Test processing a crawl result with errors during chunking/embedding.""" 113 | with ( 114 | patch( 115 | "src.common.processor.extract_page_text", 116 | mock_processor_dependencies["extract_page_text"], 117 | ), 118 | patch( 119 | "src.common.processor.DatabaseOperations", 120 | return_value=mock_processor_dependencies["database"], 121 | ), 122 | patch("src.common.processor.TextChunker", mock_processor_dependencies["TextChunker"]), 123 | patch("src.common.processor.generate_embedding", side_effect=Exception("Embedding error")), 124 | patch("src.common.processor.VectorIndexer", mock_processor_dependencies["VectorIndexer"]), 125 | ): 126 | page_id = await process_crawl_result( 127 | page_result=sample_crawl_result, 128 | job_id=job_id, 129 | tags=sample_tags, 130 | ) 131 | assert page_id == "test-page-123" 132 | mock_processor_dependencies["store_page"].assert_called_once() 133 | indexer_instance = mock_processor_dependencies["VectorIndexer"].return_value 134 | assert indexer_instance.index_vector.call_count == 0 135 | 136 | 137 | @pytest.mark.unit 138 | @pytest.mark.async_test 139 | async def test_process_page_batch(mock_processor_dependencies): 140 | """Test processing a batch of crawl results.""" 141 | mock_results = [ 142 | MagicMock(url="https://example.com/page1"), 143 | MagicMock(url="https://example.com/page2"), 144 | MagicMock(url="https://example.com/page3"), 145 | ] 146 | mock_process_result = AsyncMock(side_effect=["page-1", "page-2", "page-3"]) 147 | 148 | with ( 149 | patch("src.common.processor.process_crawl_result", mock_process_result), 150 | patch( 151 | "src.common.processor.DatabaseOperations", 152 | return_value=mock_processor_dependencies["database"], 153 | ), 154 | ): 155 | job_id = "test-job" 156 | tags = ["test", "batch"] 157 | page_ids = await process_page_batch( 158 | page_results=mock_results, 159 | job_id=job_id, 160 | tags=tags, 161 | batch_size=2, 162 | ) 163 | assert mock_process_result.call_count == 3 164 | mock_process_result.assert_any_call(mock_results[0], job_id, tags) 165 | mock_process_result.assert_any_call(mock_results[1], job_id, tags) 166 | mock_process_result.assert_any_call(mock_results[2], job_id, tags) 167 | assert mock_processor_dependencies["update_job_status"].call_count == 3 168 | assert page_ids == ["page-1", "page-2", "page-3"] 169 | 170 | 171 | @pytest.mark.unit 172 | @pytest.mark.async_test 173 | async def test_process_page_batch_with_errors(mock_processor_dependencies): 174 | """Test processing a batch with errors on some pages.""" 175 | mock_results = [ 176 | MagicMock(url="https://example.com/page1"), 177 | MagicMock(url="https://example.com/page2"), 178 | MagicMock(url="https://example.com/page3"), 179 | ] 180 | mock_process_result = AsyncMock( 181 | side_effect=["page-1", Exception("Error processing page 2"), "page-3"], 182 | ) 183 | 184 | with ( 185 | patch("src.common.processor.process_crawl_result", mock_process_result), 186 | patch( 187 | "src.common.processor.DatabaseOperations", 188 | return_value=mock_processor_dependencies["database"], 189 | ), 190 | ): 191 | job_id = "test-job" 192 | tags = ["test", "batch"] 193 | page_ids = await process_page_batch(page_results=mock_results, job_id=job_id, tags=tags) 194 | assert mock_process_result.call_count == 3 195 | assert mock_processor_dependencies["update_job_status"].call_count == 2 196 | assert page_ids == ["page-1", "page-3"] 197 | 198 | 199 | @pytest.mark.unit 200 | @pytest.mark.async_test 201 | async def test_process_page_batch_empty(): 202 | """Test processing an empty batch of pages.""" 203 | page_ids = await process_page_batch(page_results=[], job_id="test-job", tags=["test"]) 204 | assert page_ids == [] 205 | -------------------------------------------------------------------------------- /tests/lib/test_crawler_enhanced.py: -------------------------------------------------------------------------------- 1 | """Tests for the enhanced crawler with hierarchy tracking.""" 2 | 3 | import pytest 4 | from unittest.mock import Mock, AsyncMock, patch 5 | from src.lib.crawler_enhanced import CrawlResultWithHierarchy, crawl_url_with_hierarchy 6 | 7 | 8 | class TestCrawlResultWithHierarchy: 9 | """Test the CrawlResultWithHierarchy class.""" 10 | 11 | def test_init_with_parent_url(self) -> None: 12 | """Test initialization with parent URL in metadata. 13 | 14 | Args: 15 | None. 16 | 17 | Returns: 18 | None. 19 | """ 20 | # Create mock base result 21 | base_result = Mock() 22 | base_result.url = "https://example.com/page1" 23 | base_result.html = "Test Page" 24 | base_result.metadata = {"parent_url": "https://example.com", "depth": 1} 25 | 26 | # Create enhanced result 27 | enhanced = CrawlResultWithHierarchy(base_result) 28 | 29 | assert enhanced.url == "https://example.com/page1" 30 | assert enhanced.parent_url == "https://example.com" 31 | assert enhanced.depth == 1 32 | assert enhanced.title == "Test Page" 33 | 34 | def test_init_without_metadata(self) -> None: 35 | """Test initialization without metadata. 36 | 37 | Args: 38 | None. 39 | 40 | Returns: 41 | None. 42 | """ 43 | # Create mock base result without metadata 44 | base_result = Mock() 45 | base_result.url = "https://example.com" 46 | base_result.html = "Content" 47 | base_result.metadata = None 48 | 49 | # Mock hasattr to return False for metadata 50 | with patch("builtins.hasattr", side_effect=lambda obj, attr: attr != "metadata"): 51 | enhanced = CrawlResultWithHierarchy(base_result) 52 | 53 | assert enhanced.url == "https://example.com" 54 | assert enhanced.parent_url is None 55 | assert enhanced.depth == 0 56 | 57 | def test_extract_title_from_html(self) -> None: 58 | """Test extracting title from HTML. 59 | 60 | Args: 61 | None. 62 | 63 | Returns: 64 | None. 65 | """ 66 | base_result = Mock() 67 | base_result.url = "https://example.com" 68 | base_result.html = "My Test Page" 69 | base_result.metadata = {} 70 | 71 | enhanced = CrawlResultWithHierarchy(base_result) 72 | assert enhanced.title == "My Test Page" 73 | 74 | def test_extract_title_from_markdown(self) -> None: 75 | """Test extracting title from markdown when no HTML title. 76 | 77 | Args: 78 | None. 79 | 80 | Returns: 81 | None. 82 | """ 83 | base_result = Mock() 84 | base_result.url = "https://example.com" 85 | base_result.html = "Content" 86 | base_result.metadata = {} 87 | 88 | # Mock extract_page_text to return markdown 89 | with patch( 90 | "src.lib.crawler_enhanced.extract_page_text", 91 | return_value="# Main Heading\n\nSome content", 92 | ): 93 | enhanced = CrawlResultWithHierarchy(base_result) 94 | 95 | assert enhanced.title == "Main Heading" 96 | 97 | def test_extract_title_fallback_to_url(self) -> None: 98 | """Test title fallback to URL path when no title found. 99 | 100 | Args: 101 | None. 102 | 103 | Returns: 104 | None. 105 | """ 106 | base_result = Mock() 107 | base_result.url = "https://example.com/docs/api" 108 | base_result.html = "" 109 | base_result.metadata = {} 110 | 111 | with patch("src.lib.crawler_enhanced.extract_page_text", return_value=""): 112 | enhanced = CrawlResultWithHierarchy(base_result) 113 | 114 | assert enhanced.title == "api" 115 | 116 | def test_extract_title_home_for_root(self) -> None: 117 | """Test title defaults to 'Home' for root URLs. 118 | 119 | Args: 120 | None. 121 | 122 | Returns: 123 | None. 124 | """ 125 | base_result = Mock() 126 | base_result.url = "https://example.com/" 127 | base_result.html = "" 128 | base_result.metadata = {} 129 | 130 | with patch("src.lib.crawler_enhanced.extract_page_text", return_value=""): 131 | enhanced = CrawlResultWithHierarchy(base_result) 132 | 133 | assert enhanced.title == "Home" 134 | 135 | 136 | @pytest.mark.asyncio 137 | class TestCrawlUrlWithHierarchy: 138 | """Test the crawl_url_with_hierarchy function.""" 139 | 140 | async def test_crawl_with_hierarchy(self) -> None: 141 | """Test crawling with hierarchy enhancement. 142 | 143 | Args: 144 | None. 145 | 146 | Returns: 147 | None. 148 | """ 149 | # Mock base crawl results 150 | mock_results = [ 151 | Mock( 152 | url="https://example.com", 153 | metadata={"parent_url": None, "depth": 0}, 154 | html="Home", 155 | ), 156 | Mock( 157 | url="https://example.com/about", 158 | metadata={"parent_url": "https://example.com", "depth": 1}, 159 | html="About", 160 | ), 161 | Mock( 162 | url="https://example.com/contact", 163 | metadata={"parent_url": "https://example.com", "depth": 1}, 164 | html="Contact", 165 | ), 166 | ] 167 | 168 | with patch("src.lib.crawler_enhanced.base_crawl_url", new_callable=AsyncMock) as mock_crawl: 169 | mock_crawl.return_value = mock_results 170 | 171 | results = await crawl_url_with_hierarchy("https://example.com", max_pages=10) 172 | 173 | assert len(results) == 3 174 | assert all(isinstance(r, CrawlResultWithHierarchy) for r in results) 175 | 176 | # Check hierarchy information 177 | assert results[0].parent_url is None 178 | assert results[0].root_url == "https://example.com" 179 | assert results[0].depth == 0 180 | 181 | assert results[1].parent_url == "https://example.com" 182 | assert results[1].root_url == "https://example.com" 183 | assert results[1].depth == 1 184 | 185 | async def test_relative_path_calculation(self) -> None: 186 | """Test calculation of relative paths. 187 | 188 | Args: 189 | None. 190 | 191 | Returns: 192 | None. 193 | """ 194 | # Mock results with parent-child relationship 195 | mock_results = [ 196 | Mock( 197 | url="https://example.com", 198 | metadata={"parent_url": None, "depth": 0}, 199 | html="Home", 200 | ), 201 | Mock( 202 | url="https://example.com/docs", 203 | metadata={"parent_url": "https://example.com", "depth": 1}, 204 | html="Documentation", 205 | ), 206 | Mock( 207 | url="https://example.com/docs/api", 208 | metadata={"parent_url": "https://example.com/docs", "depth": 2}, 209 | html="API Reference", 210 | ), 211 | ] 212 | 213 | with patch("src.lib.crawler_enhanced.base_crawl_url", new_callable=AsyncMock) as mock_crawl: 214 | mock_crawl.return_value = mock_results 215 | 216 | results = await crawl_url_with_hierarchy("https://example.com") 217 | 218 | # Check relative paths 219 | assert results[0].relative_path == "/" 220 | assert results[1].relative_path == "Documentation" 221 | assert results[2].relative_path == "Documentation/API Reference" 222 | 223 | async def test_empty_crawl_results(self) -> None: 224 | """Test handling of empty crawl results. 225 | 226 | Args: 227 | None. 228 | 229 | Returns: 230 | None. 231 | """ 232 | with patch("src.lib.crawler_enhanced.base_crawl_url", new_callable=AsyncMock) as mock_crawl: 233 | mock_crawl.return_value = [] 234 | 235 | results = await crawl_url_with_hierarchy("https://example.com") 236 | 237 | assert results == [] 238 | -------------------------------------------------------------------------------- /tests/services/test_job_service.py: -------------------------------------------------------------------------------- 1 | """Tests for the job service.""" 2 | 3 | import datetime 4 | from unittest.mock import MagicMock, patch 5 | 6 | import duckdb 7 | import pytest 8 | from rq import Queue 9 | 10 | from src.common.models import ( 11 | FetchUrlResponse, 12 | JobProgressResponse, 13 | ) 14 | from src.web_service.services.job_service import ( 15 | fetch_url, 16 | get_job_count, 17 | get_job_progress, 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def mock_duckdb_connection(): 23 | """Mock DuckDB connection.""" 24 | mock = MagicMock(spec=duckdb.DuckDBPyConnection) 25 | return mock 26 | 27 | 28 | @pytest.fixture 29 | def mock_queue(): 30 | """Mock Redis queue.""" 31 | mock = MagicMock(spec=Queue) 32 | mock.enqueue = MagicMock() 33 | return mock 34 | 35 | 36 | @pytest.mark.unit 37 | @pytest.mark.async_test 38 | async def test_fetch_url_basic(mock_queue): 39 | """Test fetching a URL with basic parameters.""" 40 | # Set a fixed UUID for testing 41 | with patch("uuid.uuid4", return_value="test-job-123"): 42 | # Call the function with just URL 43 | result = await fetch_url( 44 | queue=mock_queue, 45 | url="https://example.com", 46 | ) 47 | 48 | # Verify queue.enqueue was called with the right parameters 49 | mock_queue.enqueue.assert_called_once() 50 | call_args = mock_queue.enqueue.call_args[0] 51 | assert call_args[0] == "src.crawl_worker.tasks.create_job" 52 | assert call_args[1] == "https://example.com" 53 | assert call_args[2] == "test-job-123" 54 | 55 | # Should receive keyword arguments for tags and max_pages 56 | kwargs = mock_queue.enqueue.call_args[1] 57 | assert kwargs["tags"] is None 58 | assert kwargs["max_pages"] == 100 59 | 60 | # Check that the function returned the expected response 61 | assert isinstance(result, FetchUrlResponse) 62 | assert result.job_id == "test-job-123" 63 | 64 | 65 | @pytest.mark.unit 66 | @pytest.mark.async_test 67 | async def test_fetch_url_with_options(mock_queue): 68 | """Test fetching a URL with all parameters.""" 69 | # Set a fixed UUID for testing 70 | with patch("uuid.uuid4", return_value="test-job-456"): 71 | # Call the function with all parameters 72 | result = await fetch_url( 73 | queue=mock_queue, 74 | url="https://example.com", 75 | tags=["test", "example"], 76 | max_pages=50, 77 | ) 78 | 79 | # Verify queue.enqueue was called with the right parameters 80 | mock_queue.enqueue.assert_called_once() 81 | call_args = mock_queue.enqueue.call_args[0] 82 | assert call_args[0] == "src.crawl_worker.tasks.create_job" 83 | assert call_args[1] == "https://example.com" 84 | assert call_args[2] == "test-job-456" 85 | 86 | # Should receive keyword arguments with the provided values 87 | kwargs = mock_queue.enqueue.call_args[1] 88 | assert kwargs["tags"] == ["test", "example"] 89 | assert kwargs["max_pages"] == 50 90 | 91 | # Check that the function returned the expected response 92 | assert isinstance(result, FetchUrlResponse) 93 | assert result.job_id == "test-job-456" 94 | 95 | 96 | @pytest.mark.unit 97 | @pytest.mark.async_test 98 | async def test_get_job_progress_exact_match(mock_duckdb_connection): 99 | """Test getting job progress with exact job ID match.""" 100 | # Mock database results for successful job 101 | created_at = datetime.datetime(2023, 1, 1, 10, 0, 0) 102 | updated_at = datetime.datetime(2023, 1, 1, 10, 5, 0) 103 | 104 | mock_result = ( 105 | "test-job-123", # job_id 106 | "https://example.com", # start_url 107 | "running", # status 108 | 200, # pages_discovered 109 | 50, # pages_crawled 110 | 100, # max_pages 111 | "test,example", # tags_json 112 | created_at, # created_at 113 | updated_at, # updated_at 114 | None, # error_message 115 | ) 116 | 117 | # Set up mock execution and result 118 | mock_cursor = MagicMock() 119 | mock_cursor.fetchone.return_value = mock_result 120 | mock_duckdb_connection.execute.return_value = mock_cursor 121 | 122 | # Call the function 123 | result = await get_job_progress( 124 | conn=mock_duckdb_connection, 125 | job_id="test-job-123", 126 | ) 127 | 128 | # Verify database was queried with exact job ID 129 | mock_duckdb_connection.execute.assert_called_once() 130 | query = mock_duckdb_connection.execute.call_args[0][0] 131 | assert "SELECT job_id, start_url, status" in query 132 | assert "WHERE job_id = ?" in query 133 | 134 | # Check that results are returned correctly 135 | assert isinstance(result, JobProgressResponse) 136 | assert result.status == "running" 137 | assert result.pages_crawled == 50 138 | assert result.pages_total == 200 139 | assert result.completed is False 140 | assert result.progress_percent == 50 # 50/100 = 50% 141 | assert result.url == "https://example.com" 142 | assert result.max_pages == 100 143 | assert result.created_at == created_at 144 | assert result.updated_at == updated_at 145 | assert result.error_message is None 146 | 147 | 148 | @pytest.mark.unit 149 | @pytest.mark.async_test 150 | async def test_get_job_progress_partial_match(mock_duckdb_connection): 151 | """Test getting job progress with partial job ID match.""" 152 | # First attempt should return None to simulate no exact match 153 | mock_cursor1 = MagicMock() 154 | mock_cursor1.fetchone.return_value = None 155 | 156 | # Second attempt (partial match) should succeed 157 | created_at = datetime.datetime(2023, 1, 1, 10, 0, 0) 158 | updated_at = datetime.datetime(2023, 1, 1, 10, 5, 0) 159 | 160 | mock_result = ( 161 | "test-job-123-full", # job_id 162 | "https://example.com", # start_url 163 | "completed", # status 164 | 100, # pages_discovered 165 | 100, # pages_crawled 166 | 100, # max_pages 167 | "test,example", # tags_json 168 | created_at, # created_at 169 | updated_at, # updated_at 170 | None, # error_message 171 | ) 172 | 173 | mock_cursor2 = MagicMock() 174 | mock_cursor2.fetchone.return_value = mock_result 175 | 176 | # Set up side effect for consecutive calls 177 | mock_duckdb_connection.execute.side_effect = [mock_cursor1, mock_cursor2] 178 | 179 | # Call the function with partial job ID 180 | result = await get_job_progress( 181 | conn=mock_duckdb_connection, 182 | job_id="test-job", 183 | ) 184 | 185 | # Should make two database calls 186 | assert mock_duckdb_connection.execute.call_count == 2 187 | 188 | # Second call should use LIKE with wildcard 189 | query2 = mock_duckdb_connection.execute.call_args_list[1][0][0] 190 | assert "WHERE job_id LIKE ?" in query2 191 | params2 = mock_duckdb_connection.execute.call_args_list[1][0][1] 192 | assert params2[0] == "test-job%" 193 | 194 | # Check that results are returned correctly 195 | assert isinstance(result, JobProgressResponse) 196 | assert result.status == "completed" 197 | assert result.pages_crawled == 100 198 | assert result.pages_total == 100 199 | assert result.completed is True 200 | assert result.progress_percent == 100 # 100/100 = 100% 201 | 202 | 203 | @pytest.mark.unit 204 | @pytest.mark.async_test 205 | async def test_get_job_progress_not_found(mock_duckdb_connection): 206 | """Test getting job progress for a non-existent job.""" 207 | # Both exact and partial matches return None 208 | mock_cursor1 = MagicMock() 209 | mock_cursor1.fetchone.return_value = None 210 | 211 | mock_cursor2 = MagicMock() 212 | mock_cursor2.fetchone.return_value = None 213 | 214 | # Set up side effect for consecutive calls 215 | mock_duckdb_connection.execute.side_effect = [mock_cursor1, mock_cursor2] 216 | 217 | # Call the function 218 | result = await get_job_progress( 219 | conn=mock_duckdb_connection, 220 | job_id="nonexistent", 221 | ) 222 | 223 | # Should return None for non-existent job 224 | assert result is None 225 | 226 | 227 | @pytest.mark.unit 228 | @pytest.mark.async_test 229 | async def test_get_job_count(): 230 | """Test getting job count.""" 231 | # Mock database connection 232 | mock_conn = MagicMock() 233 | mock_conn.execute.return_value.fetchone.return_value = (42,) 234 | 235 | with patch("src.web_service.services.job_service.duckdb.connect", return_value=mock_conn): 236 | # Call the function 237 | result = await get_job_count() 238 | 239 | # Verify connection was closed 240 | # In new implementation, we close the database, not the connection directly 241 | 242 | # Check result 243 | assert result == 42 244 | 245 | 246 | @pytest.mark.unit 247 | @pytest.mark.async_test 248 | async def test_get_job_count_error(): 249 | """Test getting job count when an error occurs.""" 250 | # Mock database connection with an error 251 | mock_conn = MagicMock() 252 | mock_conn.execute.side_effect = duckdb.Error("Database error") 253 | 254 | with patch("src.web_service.services.job_service.duckdb.connect", return_value=mock_conn): 255 | # Call the function and assert duckdb.Error is raised 256 | with pytest.raises(duckdb.Error, match="Database error"): 257 | await get_job_count() 258 | 259 | # Should still close the database even with an error 260 | # In new implementation, we close the database, not the connection directly 261 | -------------------------------------------------------------------------------- /tests/integration/test_processor_enhanced.py: -------------------------------------------------------------------------------- 1 | """Integration tests for the enhanced processor with hierarchy tracking.""" 2 | 3 | import pytest 4 | from unittest.mock import Mock, AsyncMock, patch 5 | 6 | from src.common.processor_enhanced import ( 7 | process_crawl_result_with_hierarchy, 8 | process_crawl_with_hierarchy, 9 | ) 10 | from src.lib.crawler_enhanced import CrawlResultWithHierarchy 11 | 12 | 13 | @pytest.mark.asyncio 14 | @pytest.mark.integration 15 | class TestProcessorEnhancedIntegration: 16 | """Integration tests for enhanced processor.""" 17 | 18 | async def test_process_crawl_result_with_hierarchy(self) -> None: 19 | """Test processing a single crawl result with hierarchy. 20 | 21 | Args: 22 | None. 23 | 24 | Returns: 25 | None. 26 | """ 27 | # Create a mock enhanced crawl result 28 | base_result = Mock() 29 | base_result.url = "https://example.com/docs" 30 | base_result.html = "Documentation" 31 | 32 | enhanced_result = CrawlResultWithHierarchy(base_result) 33 | enhanced_result.parent_url = "https://example.com" 34 | enhanced_result.root_url = "https://example.com" 35 | enhanced_result.depth = 1 36 | enhanced_result.relative_path = "/docs" 37 | enhanced_result.title = "Documentation" 38 | 39 | # Mock dependencies 40 | with patch("src.common.processor_enhanced.extract_page_text") as mock_extract: 41 | mock_extract.return_value = "# Documentation\n\nThis is the documentation." 42 | 43 | with patch("src.common.processor_enhanced.DatabaseOperations") as MockDB: 44 | mock_db = MockDB.return_value 45 | mock_db.store_page = AsyncMock(return_value="page-123") 46 | 47 | with patch("src.common.processor_enhanced.TextChunker") as MockChunker: 48 | mock_chunker = MockChunker.return_value 49 | mock_chunker.split_text.return_value = ["Chunk 1 content", "Chunk 2 content"] 50 | 51 | with patch("src.common.processor_enhanced.generate_embedding") as mock_embed: 52 | mock_embed.return_value = [0.1, 0.2, 0.3] 53 | 54 | with patch("src.common.processor_enhanced.VectorIndexer") as MockIndexer: 55 | mock_indexer = MockIndexer.return_value 56 | mock_indexer.index_vector = AsyncMock() 57 | 58 | # Process with hierarchy tracking 59 | url_to_page_id = {"https://example.com": "root-123"} 60 | page_id = await process_crawl_result_with_hierarchy( 61 | enhanced_result, 62 | job_id="test-job", 63 | tags=["test"], 64 | url_to_page_id=url_to_page_id, 65 | ) 66 | 67 | # Verify results 68 | assert page_id == "page-123" 69 | assert url_to_page_id[enhanced_result.url] == "page-123" 70 | 71 | # Verify store_page was called with hierarchy info 72 | mock_db.store_page.assert_called_once_with( 73 | url="https://example.com/docs", 74 | text="# Documentation\n\nThis is the documentation.", 75 | job_id="test-job", 76 | tags=["test"], 77 | parent_page_id="root-123", # Parent ID was looked up 78 | root_page_id="root-123", 79 | depth=1, 80 | path="/docs", 81 | title="Documentation", 82 | ) 83 | 84 | # Verify chunks were indexed 85 | assert mock_indexer.index_vector.call_count == 2 86 | 87 | async def test_process_crawl_with_hierarchy_full_flow(self) -> None: 88 | """Test the full crawl and process flow with hierarchy. 89 | 90 | Args: 91 | None. 92 | 93 | Returns: 94 | None. 95 | """ 96 | # Mock crawl results with hierarchy 97 | mock_crawl_results = [] 98 | for i, (url, parent_url, depth) in enumerate( 99 | [ 100 | ("https://example.com", None, 0), 101 | ("https://example.com/about", "https://example.com", 1), 102 | ("https://example.com/docs", "https://example.com", 1), 103 | ("https://example.com/docs/api", "https://example.com/docs", 2), 104 | ] 105 | ): 106 | base = Mock() 107 | base.url = url 108 | base.html = f"Page {i}" 109 | 110 | enhanced = CrawlResultWithHierarchy(base) 111 | enhanced.url = url 112 | enhanced.parent_url = parent_url 113 | enhanced.root_url = "https://example.com" 114 | enhanced.depth = depth 115 | enhanced.title = f"Page {i}" 116 | enhanced.relative_path = url.replace("https://example.com", "") or "/" 117 | 118 | mock_crawl_results.append(enhanced) 119 | 120 | # Mock the enhanced crawler 121 | with patch("src.common.processor_enhanced.crawl_url_with_hierarchy") as mock_crawl: 122 | mock_crawl.return_value = mock_crawl_results 123 | 124 | # Mock other dependencies 125 | with patch("src.common.processor_enhanced.extract_page_text") as mock_extract: 126 | mock_extract.return_value = "Page content" 127 | 128 | with patch("src.common.processor_enhanced.DatabaseOperations") as MockDB: 129 | mock_db = MockDB.return_value 130 | # Return unique IDs for each page 131 | mock_db.store_page = AsyncMock( 132 | side_effect=["page-0", "page-1", "page-2", "page-3"] 133 | ) 134 | mock_db.update_job_status = AsyncMock() 135 | 136 | with patch("src.common.processor_enhanced.TextChunker") as MockChunker: 137 | mock_chunker = MockChunker.return_value 138 | mock_chunker.split_text.return_value = ["chunk"] 139 | 140 | with patch( 141 | "src.common.processor_enhanced.generate_embedding" 142 | ) as mock_embed: 143 | mock_embed.return_value = [0.1, 0.2] 144 | 145 | with patch( 146 | "src.common.processor_enhanced.VectorIndexer" 147 | ) as MockIndexer: 148 | mock_indexer = MockIndexer.return_value 149 | mock_indexer.index_vector = AsyncMock() 150 | 151 | # Process everything 152 | page_ids = await process_crawl_with_hierarchy( 153 | url="https://example.com", 154 | job_id="test-job", 155 | tags=["test"], 156 | max_pages=10, 157 | ) 158 | 159 | # Verify results 160 | assert len(page_ids) == 4 161 | assert page_ids == ["page-0", "page-1", "page-2", "page-3"] 162 | 163 | # Verify pages were stored in order (parent before child) 164 | assert mock_db.store_page.call_count == 4 165 | 166 | # Check that parent IDs were properly resolved 167 | # First page (root) has no parent 168 | first_call = mock_db.store_page.call_args_list[0] 169 | assert first_call.kwargs["parent_page_id"] is None 170 | 171 | # Second page's parent should be the root's ID 172 | second_call = mock_db.store_page.call_args_list[1] 173 | assert second_call.kwargs["parent_page_id"] == "page-0" 174 | 175 | # Fourth page's parent should be the third page's ID 176 | fourth_call = mock_db.store_page.call_args_list[3] 177 | assert fourth_call.kwargs["parent_page_id"] == "page-2" 178 | 179 | async def test_process_crawl_with_hierarchy_error_handling(self) -> None: 180 | """Test error handling in hierarchy processing. 181 | 182 | Args: 183 | None. 184 | 185 | Returns: 186 | None. 187 | """ 188 | # Create a result that will fail 189 | base_result = Mock() 190 | base_result.url = "https://example.com/bad" 191 | 192 | enhanced_result = CrawlResultWithHierarchy(base_result) 193 | enhanced_result.url = "https://example.com/bad" 194 | enhanced_result.parent_url = None 195 | enhanced_result.depth = 0 196 | enhanced_result.title = "Bad Page" 197 | 198 | with patch("src.common.processor_enhanced.crawl_url_with_hierarchy") as mock_crawl: 199 | mock_crawl.return_value = [enhanced_result] 200 | 201 | with patch("src.common.processor_enhanced.extract_page_text") as mock_extract: 202 | # Make extraction fail 203 | mock_extract.side_effect = Exception("Extraction failed") 204 | 205 | with patch("src.common.processor_enhanced.DatabaseOperations") as MockDB: 206 | mock_db = MockDB.return_value 207 | mock_db.update_job_status = AsyncMock() 208 | 209 | # Process should continue despite error 210 | page_ids = await process_crawl_with_hierarchy( 211 | url="https://example.com", job_id="test-job" 212 | ) 213 | 214 | # Should return empty list due to error 215 | assert page_ids == [] 216 | 217 | # Job status won't be updated if all pages fail 218 | mock_db.update_job_status.assert_not_called() 219 | -------------------------------------------------------------------------------- /src/web_service/api/diagnostics.py: -------------------------------------------------------------------------------- 1 | """Diagnostic API routes for debugging purposes. 2 | Important: Should be disabled in production. 3 | """ 4 | 5 | from fastapi import APIRouter, HTTPException, Query 6 | 7 | from src.common.logger import get_logger 8 | from src.lib.database import DatabaseOperations 9 | from src.web_service.services.debug_bm25 import debug_bm25_search 10 | 11 | # Get logger for this module 12 | logger = get_logger(__name__) 13 | 14 | # Create router 15 | router = APIRouter(tags=["diagnostics"]) 16 | 17 | 18 | @router.get("/bm25_diagnostics", operation_id="bm25_diagnostics") 19 | async def bm25_diagnostics_endpoint( 20 | query: str = Query("test", description="The search query to test with"), 21 | initialize: bool = Query( 22 | True, description="If true, attempt to initialize FTS if in write mode." 23 | ), 24 | ): 25 | """ 26 | Combined endpoint for DuckDB FTS initialization and diagnostics. 27 | - Checks for the existence of the 'pages' table and 'raw_text' column before attempting FTS index creation. 28 | - Always includes the full error message from FTS index creation in the response. 29 | - Adds verbose logging and schema checks to the diagnostics output. 30 | """ 31 | logger.info( 32 | "API: Running combined BM25 diagnostics and initialization endpoint (DuckDB-specific)" 33 | ) 34 | 35 | db = DatabaseOperations(read_only=not initialize) 36 | conn = db.db.ensure_connection() 37 | read_only = db.db.read_only 38 | initialization_summary = {} 39 | schema_info = {} 40 | 41 | # Check for 'pages' table and 'raw_text' column existence 42 | try: 43 | tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() 44 | table_names = [row[0] for row in tables] 45 | schema_info["tables"] = table_names 46 | pages_table_exists = "pages" in table_names 47 | initialization_summary["pages_table_exists"] = pages_table_exists 48 | if pages_table_exists: 49 | columns = conn.execute("PRAGMA table_info('pages')").fetchall() 50 | col_names = [row[1] for row in columns] 51 | schema_info["pages_columns"] = col_names 52 | initialization_summary["pages_columns"] = col_names 53 | raw_text_exists = "raw_text" in col_names 54 | initialization_summary["raw_text_column_exists"] = raw_text_exists 55 | else: 56 | initialization_summary["pages_columns"] = [] 57 | initialization_summary["raw_text_column_exists"] = False 58 | except Exception as e_schema: 59 | initialization_summary["schema_check_error"] = str(e_schema) 60 | schema_info["schema_check_error"] = str(e_schema) 61 | 62 | try: 63 | if initialize and not read_only: 64 | logger.info("Attempting FTS initialization in write mode...") 65 | # 1. Load FTS extension 66 | try: 67 | conn.execute("INSTALL fts;") 68 | conn.execute("LOAD fts;") 69 | initialization_summary["fts_extension_loaded"] = True 70 | except Exception as e_fts: 71 | logger.warning(f"FTS extension loading during initialization: {e_fts}") 72 | initialization_summary["fts_extension_loaded"] = False 73 | initialization_summary["fts_extension_error"] = str(e_fts) 74 | # 2. Create FTS index (only if table and column exist) 75 | if initialization_summary.get("pages_table_exists") and initialization_summary.get( 76 | "raw_text_column_exists" 77 | ): 78 | try: 79 | conn.execute("PRAGMA create_fts_index('pages', 'id', 'raw_text');") 80 | initialization_summary["fts_index_created"] = True 81 | initialization_summary["fts_index_creation_error"] = None 82 | except Exception as e_index: 83 | logger.warning(f"FTS index creation error: {e_index}") 84 | initialization_summary["fts_index_created"] = False 85 | initialization_summary["fts_index_creation_error"] = str(e_index) 86 | if "already exists" in str(e_index): 87 | initialization_summary["fts_index_exists"] = True 88 | else: 89 | initialization_summary["fts_index_created"] = False 90 | initialization_summary["fts_index_creation_error"] = ( 91 | "'pages' table or 'raw_text' column does not exist." 92 | ) 93 | # 2b. Check for FTS index existence using DuckDB's sqlite_master and information_schema.tables 94 | try: 95 | fts_indexes = conn.execute( 96 | "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'fts_idx_%'" 97 | ).fetchall() 98 | initialization_summary["fts_indexes_sqlite_master"] = ( 99 | [row[0] for row in fts_indexes] if fts_indexes else [] 100 | ) 101 | initialization_summary["fts_index_exists_sqlite_master"] = bool(fts_indexes) 102 | fts_idx_tables = conn.execute( 103 | "SELECT table_name FROM information_schema.tables WHERE table_name LIKE 'fts_idx_%'" 104 | ).fetchall() 105 | initialization_summary["fts_indexes_information_schema"] = ( 106 | [row[0] for row in fts_idx_tables] if fts_idx_tables else [] 107 | ) 108 | initialization_summary["fts_index_exists_information_schema"] = bool(fts_idx_tables) 109 | except Exception as e_fts_idx: 110 | initialization_summary["fts_index_check_error"] = str(e_fts_idx) 111 | # 3. Drop legacy table 112 | try: 113 | table_exists_result = conn.execute( 114 | "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'fts_main_pages'" 115 | ).fetchone() 116 | table_exists = table_exists_result[0] > 0 if table_exists_result else False 117 | if table_exists: 118 | conn.execute("DROP TABLE IF EXISTS fts_main_pages;") 119 | initialization_summary["legacy_fts_main_pages_dropped"] = True 120 | else: 121 | initialization_summary["legacy_fts_main_pages_dropped"] = False 122 | except Exception as e_drop: 123 | initialization_summary["legacy_fts_main_pages_drop_error"] = str(e_drop) 124 | else: 125 | initialization_summary["initialization_skipped"] = True 126 | initialization_summary["read_only_mode"] = read_only 127 | except Exception as e: 128 | logger.error(f"Error during initialization: {e}") 129 | initialization_summary["initialization_error"] = str(e) 130 | 131 | # Always run diagnostics 132 | try: 133 | diagnostics = await debug_bm25_search(conn, query) 134 | # Add FTS index check and schema info to diagnostics as well 135 | try: 136 | fts_indexes_diag = conn.execute( 137 | "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'fts_idx_%'" 138 | ).fetchall() 139 | diagnostics["fts_indexes_sqlite_master"] = ( 140 | [row[0] for row in fts_indexes_diag] if fts_indexes_diag else [] 141 | ) 142 | diagnostics["fts_index_exists_sqlite_master"] = bool(fts_indexes_diag) 143 | fts_idx_tables_diag = conn.execute( 144 | "SELECT table_name FROM information_schema.tables WHERE table_name LIKE 'fts_idx_%'" 145 | ).fetchall() 146 | diagnostics["fts_indexes_information_schema"] = ( 147 | [row[0] for row in fts_idx_tables_diag] if fts_idx_tables_diag else [] 148 | ) 149 | diagnostics["fts_index_exists_information_schema"] = bool(fts_idx_tables_diag) 150 | # Add schema info 151 | diagnostics["schema_info"] = schema_info 152 | except Exception as e_fts_idx_diag: 153 | diagnostics["fts_index_check_error"] = str(e_fts_idx_diag) 154 | # Remove BM25 function checks and errors from diagnostics if present 155 | for k in list(diagnostics.keys()): 156 | if "bm25" in k or "BM25" in k: 157 | diagnostics.pop(k) 158 | # Remove recommendations about BM25 function 159 | recs = diagnostics.get("recommendations", []) 160 | diagnostics["recommendations"] = [rec for rec in recs if "match_bm25 function" not in rec] 161 | except Exception as e: 162 | logger.error(f"Error during BM25 diagnostics: {e!s}") 163 | raise HTTPException(status_code=500, detail=f"Diagnostic error: {e!s}") 164 | finally: 165 | db.db.close() 166 | 167 | # Compose response 168 | response = { 169 | "initialization": initialization_summary, 170 | "diagnostics": diagnostics, 171 | } 172 | # Recommendations: only mention /bm25_diagnostics?initialize=true for FTS index creation if no FTS index is found 173 | fts_index_exists = response["diagnostics"].get("fts_index_exists_sqlite_master") or response[ 174 | "diagnostics" 175 | ].get("fts_index_exists_information_schema") 176 | recs = diagnostics.get("recommendations", []) 177 | new_recs = [ 178 | rec for rec in recs if "/initialize_bm25" not in rec and "match_bm25 function" not in rec 179 | ] 180 | if not fts_index_exists: 181 | new_recs.append( 182 | "No FTS indexes found. Run /bm25_diagnostics?initialize=true with a write connection." 183 | ) 184 | response["recommendations"] = new_recs 185 | if "status" in diagnostics: 186 | response["status"] = diagnostics["status"] 187 | return response 188 | -------------------------------------------------------------------------------- /tests/integration/services/test_document_service.py: -------------------------------------------------------------------------------- 1 | """Integration tests for the document service with real DuckDB.""" 2 | 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from src.common.config import VECTOR_SIZE 8 | from src.common.indexer import VectorIndexer 9 | from src.web_service.services.document_service import ( 10 | get_doc_page, 11 | list_doc_pages, 12 | list_tags, 13 | search_docs, 14 | ) 15 | 16 | 17 | @pytest.mark.integration 18 | @pytest.mark.async_test 19 | async def test_search_docs_with_duckdb(in_memory_duckdb_connection): 20 | """Test searching documents with DuckDB backend using hybrid search (vector and BM25).""" 21 | # Create test data in the in-memory database 22 | indexer = VectorIndexer(connection=in_memory_duckdb_connection) 23 | 24 | # Insert test pages with raw text for BM25 search 25 | in_memory_duckdb_connection.execute(""" 26 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text, job_id) 27 | VALUES 28 | ('page1', 'https://example.com/ai', 'example.com', '2023-01-01', 29 | '["ai", "tech"]', 'This is a document about artificial intelligence and machine learning.', 'job1'), 30 | ('page2', 'https://example.com/ml', 'example.com', '2023-01-02', 31 | '["ml", "tech"]', 'This is a document about machine learning algorithms and neural networks.', 'job1'), 32 | ('page3', 'https://example.com/nlp', 'example.com', '2023-01-03', 33 | '["nlp", "tech"]', 'Natural language processing is a field of artificial intelligence.', 'job1') 34 | """) 35 | 36 | # Try creating FTS index for BM25 search 37 | try: 38 | in_memory_duckdb_connection.execute("INSTALL fts; LOAD fts;") 39 | in_memory_duckdb_connection.execute("PRAGMA create_fts_index('pages', 'id', 'raw_text');") 40 | except Exception as e: 41 | print(f"Warning: Could not create FTS index: {e}. BM25 search may not work in tests.") 42 | 43 | # Create test vectors and payloads for vector search 44 | test_vector1 = [0.1] * VECTOR_SIZE 45 | test_payload1 = { 46 | "text": "This is a document about artificial intelligence", 47 | "page_id": "page1", 48 | "url": "https://example.com/ai", 49 | "domain": "example.com", 50 | "tags": ["ai", "tech"], 51 | "job_id": "job1", 52 | } 53 | 54 | test_vector2 = [0.2] * VECTOR_SIZE 55 | test_payload2 = { 56 | "text": "This is a document about machine learning", 57 | "page_id": "page2", 58 | "url": "https://example.com/ml", 59 | "domain": "example.com", 60 | "tags": ["ml", "tech"], 61 | "job_id": "job1", 62 | } 63 | 64 | test_vector3 = [0.3] * VECTOR_SIZE 65 | test_payload3 = { 66 | "text": "Natural language processing is a field of artificial intelligence", 67 | "page_id": "page3", 68 | "url": "https://example.com/nlp", 69 | "domain": "example.com", 70 | "tags": ["nlp", "tech"], 71 | "job_id": "job1", 72 | } 73 | 74 | # Index test data for vector search 75 | await indexer.index_vector(test_vector1, test_payload1) 76 | await indexer.index_vector(test_vector2, test_payload2) 77 | await indexer.index_vector(test_vector3, test_payload3) 78 | 79 | # Mock the embedding generation to return a vector similar to test_vector1 80 | with patch("src.lib.embedder.generate_embedding", return_value=[0.11] * VECTOR_SIZE): 81 | # Test 1: Standard hybrid search with a query that should match via both vector and BM25 82 | result = await search_docs( 83 | conn=in_memory_duckdb_connection, 84 | query="artificial intelligence", 85 | tags=None, 86 | max_results=5, 87 | ) 88 | 89 | # Verify results - should find page1 and page3 as they mention 'artificial intelligence' 90 | assert len(result.results) >= 2 # Should return at least 2 results 91 | # The first result should be about artificial intelligence 92 | assert any("artificial intelligence" in r.chunk_text for r in result.results) 93 | # Check page_ids - both page1 and page3 should be in results 94 | result_page_ids = [r.page_id for r in result.results] 95 | assert "page1" in result_page_ids 96 | assert "page3" in result_page_ids 97 | 98 | # Test 2: Using tag filter 99 | result = await search_docs( 100 | conn=in_memory_duckdb_connection, 101 | query="artificial intelligence", 102 | tags=["ai"], 103 | max_results=5, 104 | ) 105 | 106 | # Verify filtered results 107 | assert any(r.page_id == "page1" for r in result.results) 108 | assert all("ai" in r.tags for r in result.results) 109 | 110 | # Test 3: Test with BM25-specific term not in vectors but in raw text 111 | result = await search_docs( 112 | conn=in_memory_duckdb_connection, 113 | query="neural networks", 114 | tags=None, 115 | max_results=5, 116 | ) 117 | 118 | # Verify BM25 results - should find page2 which mentions 'neural networks' 119 | assert any(r.page_id == "page2" for r in result.results) 120 | assert any("neural networks" in r.chunk_text for r in result.results) 121 | 122 | # Test 4: Adjusting hybrid weights 123 | result = await search_docs( 124 | conn=in_memory_duckdb_connection, 125 | query="artificial intelligence", 126 | tags=None, 127 | max_results=5, 128 | hybrid_weight=0.3, # More weight on BM25 (0.7) than vector (0.3) 129 | ) 130 | 131 | # With more weight on BM25, we expect different result ordering 132 | # Check that the results are not empty 133 | assert len(result.results) > 0 134 | 135 | 136 | @pytest.mark.integration 137 | @pytest.mark.async_test 138 | async def test_list_doc_pages_with_duckdb(in_memory_duckdb_connection): 139 | """Test listing document pages with DuckDB backend.""" 140 | # Insert test data directly into the pages table 141 | in_memory_duckdb_connection.execute(""" 142 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text) 143 | VALUES 144 | ('page1', 'https://example.com/page1', 'example.com', '2023-01-01', 145 | '["doc","example"]', 'This is page 1'), 146 | ('page2', 'https://example.com/page2', 'example.com', '2023-01-02', 147 | '["doc","test"]', 'This is page 2'), 148 | ('page3', 'https://example.org/page3', 'example.org', '2023-01-03', 149 | '["other","example"]', 'This is page 3') 150 | """) 151 | 152 | # Test with no filters 153 | result = await list_doc_pages(conn=in_memory_duckdb_connection, page=1, tags=None) 154 | 155 | # Verify results 156 | assert result.total_pages >= 1 # There is at least 1 page of results 157 | assert result.current_page == 1 158 | assert len(result.doc_pages) == 3 159 | 160 | # Test with tag filter 161 | result = await list_doc_pages(conn=in_memory_duckdb_connection, page=1, tags=["doc"]) 162 | 163 | # Verify filtered results 164 | assert len(result.doc_pages) == 2 165 | assert all("doc" in page.tags for page in result.doc_pages) 166 | 167 | 168 | @pytest.mark.integration 169 | @pytest.mark.async_test 170 | async def test_get_doc_page_with_duckdb(in_memory_duckdb_connection): 171 | """Test retrieving a document page with DuckDB backend.""" 172 | # Insert test data directly into the pages table 173 | in_memory_duckdb_connection.execute(""" 174 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text) 175 | VALUES ('test-page', 'https://example.com/test', 'example.com', '2023-01-01', 176 | '["test"]', 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5') 177 | """) 178 | 179 | # Test retrieving the entire page 180 | result = await get_doc_page( 181 | conn=in_memory_duckdb_connection, 182 | page_id="test-page", 183 | starting_line=1, 184 | ending_line=-1, 185 | ) 186 | 187 | # Verify results 188 | assert result.text == "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" 189 | assert result.total_lines == 5 190 | 191 | # Test retrieving specific lines 192 | result = await get_doc_page( 193 | conn=in_memory_duckdb_connection, 194 | page_id="test-page", 195 | starting_line=2, 196 | ending_line=4, 197 | ) 198 | 199 | # Verify partial results 200 | assert result.text == "Line 2\nLine 3\nLine 4" 201 | assert result.total_lines == 5 202 | 203 | 204 | @pytest.mark.integration 205 | @pytest.mark.async_test 206 | async def test_list_tags_with_duckdb(in_memory_duckdb_connection): 207 | """Test listing unique tags with DuckDB backend.""" 208 | # Insert test data directly into the pages table with different tags 209 | in_memory_duckdb_connection.execute(""" 210 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text) 211 | VALUES 212 | ('page1', 'https://example.com/page1', 'example.com', '2023-01-01', 213 | '["tag1","common"]', 'Page 1'), 214 | ('page2', 'https://example.com/page2', 'example.com', '2023-01-02', 215 | '["tag2","common"]', 'Page 2'), 216 | ('page3', 'https://example.com/page3', 'example.com', '2023-01-03', 217 | '["tag3","special"]', 'Page 3') 218 | """) 219 | 220 | # Test listing all tags 221 | result = await list_tags(conn=in_memory_duckdb_connection, search_substring=None) 222 | 223 | # Verify results contain all unique tags 224 | assert len(result.tags) == 5 225 | assert set(result.tags) == {"tag1", "tag2", "tag3", "common", "special"} 226 | 227 | # Test with search substring 228 | result = await list_tags(conn=in_memory_duckdb_connection, search_substring="tag") 229 | 230 | # Verify filtered results 231 | assert len(result.tags) == 3 232 | assert all("tag" in tag for tag in result.tags) 233 | -------------------------------------------------------------------------------- /tests/services/test_map_service.py: -------------------------------------------------------------------------------- 1 | """Tests for the map service.""" 2 | 3 | import pytest 4 | from unittest.mock import AsyncMock, patch 5 | import datetime 6 | 7 | from src.web_service.services.map_service import MapService 8 | 9 | 10 | @pytest.mark.asyncio 11 | class TestMapService: 12 | """Test the MapService class.""" 13 | 14 | async def test_get_all_sites(self) -> None: 15 | """Test getting all sites. 16 | 17 | Args: 18 | None. 19 | 20 | Returns: 21 | None. 22 | """ 23 | service = MapService() 24 | 25 | # Mock the database response - sites from different domains 26 | mock_sites = [ 27 | { 28 | "id": "site1", 29 | "url": "https://example1.com", 30 | "title": "Site 1", 31 | "domain": "example1.com", 32 | }, 33 | { 34 | "id": "site2", 35 | "url": "https://example2.com", 36 | "title": "Site 2", 37 | "domain": "example2.com", 38 | }, 39 | ] 40 | 41 | with patch.object(service.db_ops, "get_root_pages", new_callable=AsyncMock) as mock_get: 42 | mock_get.return_value = mock_sites 43 | 44 | sites = await service.get_all_sites() 45 | 46 | # Since they're from different domains, they should not be grouped 47 | assert len(sites) == 2 48 | assert sites[0]["id"] == "site1" 49 | assert sites[1]["id"] == "site2" 50 | mock_get.assert_called_once() 51 | 52 | async def test_build_page_tree(self) -> None: 53 | """Test building a page tree structure. 54 | 55 | Args: 56 | None. 57 | 58 | Returns: 59 | None. 60 | """ 61 | service = MapService() 62 | root_id = "root-123" 63 | 64 | # Mock hierarchy data 65 | mock_pages = [ 66 | { 67 | "id": "root-123", 68 | "parent_page_id": None, 69 | "title": "Home", 70 | "url": "https://example.com", 71 | }, 72 | { 73 | "id": "page-1", 74 | "parent_page_id": "root-123", 75 | "title": "About", 76 | "url": "https://example.com/about", 77 | }, 78 | { 79 | "id": "page-2", 80 | "parent_page_id": "root-123", 81 | "title": "Docs", 82 | "url": "https://example.com/docs", 83 | }, 84 | { 85 | "id": "page-3", 86 | "parent_page_id": "page-2", 87 | "title": "API", 88 | "url": "https://example.com/docs/api", 89 | }, 90 | ] 91 | 92 | with patch.object(service.db_ops, "get_page_hierarchy", new_callable=AsyncMock) as mock_get: 93 | mock_get.return_value = mock_pages 94 | 95 | tree = await service.build_page_tree(root_id) 96 | 97 | # Verify tree structure 98 | assert tree["id"] == "root-123" 99 | assert tree["title"] == "Home" 100 | assert len(tree["children"]) == 2 101 | 102 | # Check first level children 103 | about_page = next(c for c in tree["children"] if c["id"] == "page-1") 104 | docs_page = next(c for c in tree["children"] if c["id"] == "page-2") 105 | 106 | assert about_page["title"] == "About" 107 | assert len(about_page["children"]) == 0 108 | 109 | assert docs_page["title"] == "Docs" 110 | assert len(docs_page["children"]) == 1 111 | assert docs_page["children"][0]["title"] == "API" 112 | 113 | async def test_get_navigation_context(self) -> None: 114 | """Test getting navigation context for a page. 115 | 116 | Args: 117 | None. 118 | 119 | Returns: 120 | None. 121 | """ 122 | service = MapService() 123 | page_id = "page-123" 124 | 125 | # Mock current page 126 | mock_page = { 127 | "id": page_id, 128 | "parent_page_id": "parent-123", 129 | "root_page_id": "root-123", 130 | "title": "Current Page", 131 | } 132 | 133 | # Mock related pages 134 | mock_parent = {"id": "parent-123", "title": "Parent Page"} 135 | mock_siblings = [ 136 | {"id": "sibling-1", "title": "Sibling 1"}, 137 | {"id": "sibling-2", "title": "Sibling 2"}, 138 | ] 139 | mock_children = [ 140 | {"id": "child-1", "title": "Child 1"}, 141 | ] 142 | mock_root = {"id": "root-123", "title": "Home"} 143 | 144 | with patch.object( 145 | service.db_ops, "get_page_by_id", new_callable=AsyncMock 146 | ) as mock_get_page: 147 | mock_get_page.side_effect = [mock_page, mock_parent, mock_root] 148 | 149 | with patch.object( 150 | service.db_ops, "get_sibling_pages", new_callable=AsyncMock 151 | ) as mock_siblings_fn: 152 | mock_siblings_fn.return_value = mock_siblings 153 | 154 | with patch.object( 155 | service.db_ops, "get_child_pages", new_callable=AsyncMock 156 | ) as mock_children_fn: 157 | mock_children_fn.return_value = mock_children 158 | 159 | context = await service.get_navigation_context(page_id) 160 | 161 | assert context["current_page"] == mock_page 162 | assert context["parent"] == mock_parent 163 | assert context["siblings"] == mock_siblings 164 | assert context["children"] == mock_children 165 | assert context["root"] == mock_root 166 | 167 | def test_render_page_html(self) -> None: 168 | """Test rendering a page as HTML. 169 | 170 | Args: 171 | None. 172 | 173 | Returns: 174 | None. 175 | """ 176 | service = MapService() 177 | 178 | page = { 179 | "id": "page-123", 180 | "title": "Test Page", 181 | "raw_text": "# Test Page\n\nThis is **markdown** content.", 182 | } 183 | 184 | navigation = { 185 | "current_page": page, 186 | "parent": {"id": "parent-123", "title": "Parent"}, 187 | "siblings": [ 188 | {"id": "sib-1", "title": "Sibling 1"}, 189 | ], 190 | "children": [ 191 | {"id": "child-1", "title": "Child 1"}, 192 | ], 193 | "root": {"id": "root-123", "title": "Home"}, 194 | } 195 | 196 | html = service.render_page_html(page, navigation) 197 | 198 | # Check that HTML contains expected elements 199 | assert "Test Page" in html 200 | assert "

Test Page

" in html 201 | assert "markdown" in html # Markdown was rendered 202 | assert 'href="/map/page/parent-123"' in html # Parent link 203 | assert 'href="/map/page/sib-1"' in html # Sibling link 204 | assert 'href="/map/page/child-1"' in html # Child link 205 | assert 'href="/map/site/root-123"' in html # Site map link 206 | 207 | def test_format_site_list(self) -> None: 208 | """Test formatting a list of sites. 209 | 210 | Args: 211 | None. 212 | 213 | Returns: 214 | None. 215 | """ 216 | service = MapService() 217 | 218 | sites = [ 219 | { 220 | "id": "site1", 221 | "url": "https://example1.com", 222 | "title": "Example Site 1", 223 | "crawl_date": datetime.datetime(2024, 1, 1, 12, 0, 0), 224 | }, 225 | { 226 | "id": "site2", 227 | "url": "https://example2.com", 228 | "title": "Example Site 2", 229 | "crawl_date": datetime.datetime(2024, 1, 2, 12, 0, 0), 230 | }, 231 | ] 232 | 233 | html = service.format_site_list(sites) 234 | 235 | # Check HTML content 236 | assert "Site Map - All Sites" in html 237 | assert "Example Site 1" in html 238 | assert "Example Site 2" in html 239 | assert 'href="/map/site/site1"' in html 240 | assert 'href="/map/site/site2"' in html 241 | assert "https://example1.com" in html 242 | assert "https://example2.com" in html 243 | 244 | def test_format_site_list_empty(self) -> None: 245 | """Test formatting empty site list. 246 | 247 | Args: 248 | None. 249 | 250 | Returns: 251 | None. 252 | """ 253 | service = MapService() 254 | 255 | html = service.format_site_list([]) 256 | 257 | assert "No Sites Found" in html 258 | assert "No crawled sites are available" in html 259 | 260 | def test_format_site_tree(self) -> None: 261 | """Test formatting a site tree. 262 | 263 | Args: 264 | None. 265 | 266 | Returns: 267 | None. 268 | """ 269 | service = MapService() 270 | 271 | tree = { 272 | "id": "root-123", 273 | "title": "My Site", 274 | "children": [ 275 | {"id": "page-1", "title": "About", "children": []}, 276 | { 277 | "id": "page-2", 278 | "title": "Docs", 279 | "children": [{"id": "page-3", "title": "API", "children": []}], 280 | }, 281 | ], 282 | } 283 | 284 | html = service.format_site_tree(tree) 285 | 286 | # Check HTML structure 287 | assert "My Site - Site Map" in html 288 | assert 'href="/map/page/root-123"' in html 289 | assert 'href="/map/page/page-1"' in html 290 | assert 'href="/map/page/page-2"' in html 291 | assert 'href="/map/page/page-3"' in html 292 | assert "
" in html # Collapsible sections 293 | assert "Back to all sites" in html 294 | 295 | def test_build_tree_html_leaf_node(self) -> None: 296 | """Test building HTML for a leaf node. 297 | 298 | Args: 299 | None. 300 | 301 | Returns: 302 | None. 303 | """ 304 | service = MapService() 305 | 306 | node = {"id": "leaf-123", "title": "Leaf Page", "children": []} 307 | 308 | html = service._build_tree_html(node, is_root=False) 309 | 310 | assert 'Leaf Page' in html 311 | assert '' in html 312 | assert '' in html 313 | assert "
" not in html # No collapsible section for leaf 314 | 315 | def test_build_tree_html_with_children(self) -> None: 316 | """Test building HTML for a node with children. 317 | 318 | Args: 319 | None. 320 | 321 | Returns: 322 | None. 323 | """ 324 | service = MapService() 325 | 326 | node = { 327 | "id": "parent-123", 328 | "title": "Parent Page", 329 | "children": [ 330 | {"id": "child-1", "title": "Child 1", "children": []}, 331 | {"id": "child-2", "title": "Child 2", "children": []}, 332 | ], 333 | } 334 | 335 | html = service._build_tree_html(node, is_root=False) 336 | 337 | assert "
" in html 338 | assert "" in html 339 | assert "Parent Page" in html 340 | assert "Child 1" in html 341 | assert "Child 2" in html 342 | --------------------------------------------------------------------------------