├── .python-version
├── tests
├── __init__.py
├── lib
│ ├── __init__.py
│ ├── test_embedder.py
│ ├── test_crawler.py
│ ├── test_chunker.py
│ └── test_crawler_enhanced.py
├── integration
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ └── test_vector_indexer.py
│ ├── services
│ │ ├── __init__.py
│ │ └── test_document_service.py
│ ├── conftest.py
│ └── test_processor_enhanced.py
├── services
│ ├── __init__.py
│ ├── test_admin_service.py
│ ├── test_map_service_legacy.py
│ ├── test_job_service.py
│ └── test_map_service.py
├── README.md
├── conftest.py
├── api
│ └── test_map_api.py
└── common
│ └── test_processor.py
├── doctor.png
├── src
├── common
│ ├── __init__.py
│ ├── logger.py
│ ├── config.py
│ ├── processor.py
│ ├── models.py
│ └── processor_enhanced.py
├── lib
│ ├── __init__.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── migrations
│ │ │ └── 001_add_hierarchy.sql
│ │ ├── migrations.py
│ │ └── schema.py
│ ├── embedder.py
│ ├── chunker.py
│ ├── crawler.py
│ └── crawler_enhanced.py
├── web_service
│ ├── __init__.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── admin_service.py
│ │ ├── debug_bm25.py
│ │ └── job_service.py
│ ├── Dockerfile
│ ├── api
│ │ ├── __init__.py
│ │ ├── admin.py
│ │ ├── map.py
│ │ ├── jobs.py
│ │ ├── documents.py
│ │ └── diagnostics.py
│ └── main.py
└── crawl_worker
│ ├── __init__.py
│ ├── Dockerfile
│ └── main.py
├── .cursor
└── rules
│ ├── code-style.mdc
│ ├── testing.mdc
│ └── refactors.mdc
├── pytest.ini
├── Dockerfile.base
├── .pre-commit-config.yaml
├── .gitignore
├── LICENSE.md
├── .github
└── workflows
│ └── pytest.yml
├── docker-compose.yml
├── pyproject.toml
├── llms.txt
├── docs
└── maps_api_examples.md
└── README.md
/.python-version:
--------------------------------------------------------------------------------
1 | 3.12
2 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the Doctor application."""
2 |
--------------------------------------------------------------------------------
/tests/lib/__init__.py:
--------------------------------------------------------------------------------
1 | """Unit tests for the lib module."""
2 |
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 | """Integration tests package."""
2 |
--------------------------------------------------------------------------------
/doctor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sisig-ai/doctor/HEAD/doctor.png
--------------------------------------------------------------------------------
/src/common/__init__.py:
--------------------------------------------------------------------------------
1 | """Common utilities for the Doctor project."""
2 |
--------------------------------------------------------------------------------
/tests/services/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the web service's services."""
2 |
--------------------------------------------------------------------------------
/src/lib/__init__.py:
--------------------------------------------------------------------------------
1 | """Shared library components for the Doctor application."""
2 |
--------------------------------------------------------------------------------
/src/web_service/__init__.py:
--------------------------------------------------------------------------------
1 | """Web Service module for the Doctor project."""
2 |
--------------------------------------------------------------------------------
/src/crawl_worker/__init__.py:
--------------------------------------------------------------------------------
1 | """Crawl worker package for the Doctor application."""
2 |
--------------------------------------------------------------------------------
/src/web_service/services/__init__.py:
--------------------------------------------------------------------------------
1 | """Services package for the web service."""
2 |
--------------------------------------------------------------------------------
/tests/integration/common/__init__.py:
--------------------------------------------------------------------------------
1 | """Integration tests for common modules."""
2 |
--------------------------------------------------------------------------------
/tests/integration/services/__init__.py:
--------------------------------------------------------------------------------
1 | """Integration tests for service modules."""
2 |
--------------------------------------------------------------------------------
/.cursor/rules/code-style.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description:
3 | globs:
4 | alwaysApply: true
5 | ---
6 | - Use Google-style docstrings
7 | - Always type-annotate variables and methods
8 | - After making all changes, run `ruff format` and `ruff check --fix`
9 |
--------------------------------------------------------------------------------
/src/crawl_worker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM doctor-base:latest AS crawl_worker
2 |
3 | # The WORKDIR, ENV, and code are inherited from the base image.
4 | # If specific overrides or additional ENV vars are needed for crawl_worker, set them here.
5 |
6 | RUN crawl4ai-setup
7 |
8 | # Run the worker
9 | CMD ["python", "-m", "src.crawl_worker.main"]
10 |
--------------------------------------------------------------------------------
/src/web_service/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM doctor-base:latest AS web_service
2 |
3 | # The WORKDIR, ENV, and code are inherited from the base image.
4 | # If specific overrides or additional ENV vars are needed for web_service, set them here.
5 |
6 | # Expose the port
7 | EXPOSE 9111
8 |
9 | # Run the web service
10 | CMD ["python", "src/web_service/main.py"]
11 |
--------------------------------------------------------------------------------
/.cursor/rules/testing.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: Writing tests or implementing new logic
3 | globs:
4 | alwaysApply: false
5 | ---
6 | - Always write tests using functional pytest
7 | - Tests should always maintain or increase code coverage
8 | - Use parametrized testing whenever possible, instead of multiple test functions
9 | - Use snapshot tests to check for known/constant values
10 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | testpaths = tests
3 | python_files = test_*.py
4 | python_classes = Test*
5 | python_functions = test_*
6 | addopts = -v --cov=src --cov-report=term-missing
7 | asyncio_mode = auto
8 | asyncio_default_fixture_loop_scope = function
9 | markers =
10 | unit: marks tests as unit tests
11 | integration: marks tests as integration tests
12 | async_test: marks tests as asynchronous
13 |
--------------------------------------------------------------------------------
/.cursor/rules/refactors.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: Refactoring and implementing new logic
3 | globs:
4 | alwaysApply: false
5 | ---
6 | - Run tests after refactor/implementation is completed, ensuring they pass, and add news tests as necessary
7 | - To run the tests, do `pytest` (making sure that you are using `.venv/bin/python`)
8 | - When tests fail, adjust them according to the changes in code OR fix the code if it's a logic issue
9 |
--------------------------------------------------------------------------------
/Dockerfile.base:
--------------------------------------------------------------------------------
1 | FROM python:3.12-slim
2 |
3 | WORKDIR /app
4 |
5 | # Install uv
6 | RUN pip install --no-cache-dir uv
7 |
8 | # Copy project configuration and readme
9 | COPY pyproject.toml README.md /app/
10 |
11 | # Install dependencies
12 | RUN uv pip install --system -e .
13 |
14 | # Copy source code
15 | COPY src /app/src
16 |
17 | # Create data directory
18 | RUN mkdir -p /app/data
19 |
20 | # Set environment variables
21 | ENV PYTHONPATH=/app
22 | ENV PYTHONUNBUFFERED=1
23 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | rev: v0.11.8
4 | hooks:
5 | - id: ruff
6 | name: ruff check
7 | args: [--fix]
8 | - id: ruff-format
9 | name: ruff format
10 | - repo: https://github.com/pre-commit/pre-commit-hooks
11 | rev: v4.5.0
12 | hooks:
13 | - id: trailing-whitespace
14 | - id: end-of-file-fixer
15 | - id: check-yaml
16 | - repo: local
17 | hooks:
18 | - id: mypy-warn-only
19 | name: mypy (warning only)
20 | entry: bash -c "mypy --ignore-missing-imports || true"
21 | language: system
22 | types: [python]
23 | require_serial: true
24 | verbose: true
25 |
--------------------------------------------------------------------------------
/src/web_service/api/__init__.py:
--------------------------------------------------------------------------------
1 | """API routes package for the web service."""
2 |
3 | from fastapi import APIRouter
4 |
5 | from src.web_service.api.admin import router as admin_router
6 | from src.web_service.api.diagnostics import router as diagnostics_router
7 | from src.web_service.api.documents import router as documents_router
8 | from src.web_service.api.jobs import router as jobs_router
9 | from src.web_service.api.map import router as map_router
10 |
11 | # Create a main router that includes all the other routers
12 | api_router = APIRouter()
13 |
14 | # Include the routers
15 | api_router.include_router(documents_router)
16 | api_router.include_router(jobs_router)
17 | api_router.include_router(admin_router)
18 | api_router.include_router(diagnostics_router)
19 | api_router.include_router(map_router)
20 |
--------------------------------------------------------------------------------
/src/lib/database/__init__.py:
--------------------------------------------------------------------------------
1 | """Initializes the database package, exposing key components for database interaction.
2 |
3 | This package provides a modular approach to database management for the Doctor project,
4 | separating connection handling, high-level operations, schema definitions, and utilities.
5 |
6 | The main interface for database operations is the `DatabaseOperations` class.
7 | """
8 |
9 | from .connection import DuckDBConnectionManager
10 | from .operations import DatabaseOperations
11 | from .schema import (
12 | CREATE_DOCUMENT_EMBEDDINGS_TABLE_SQL,
13 | CREATE_JOBS_TABLE_SQL,
14 | CREATE_PAGES_TABLE_SQL,
15 | )
16 | from .utils import deserialize_tags, serialize_tags
17 |
18 | __all__ = [
19 | "CREATE_DOCUMENT_EMBEDDINGS_TABLE_SQL",
20 | "CREATE_JOBS_TABLE_SQL",
21 | "CREATE_PAGES_TABLE_SQL",
22 | "DatabaseOperations",
23 | "DuckDBConnectionManager",
24 | "deserialize_tags",
25 | "serialize_tags",
26 | ]
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | parts/
14 | sdist/
15 | var/
16 | wheels/
17 | *.egg-info/
18 | .installed.cfg
19 | *.egg
20 | .ropeproject/
21 | .ruff_cache/
22 | .pytest_cache/
23 | data/
24 |
25 | # Virtual Environment
26 | .env
27 | .venv
28 | env/
29 | venv/
30 | ENV/
31 | env.bak/
32 | venv.bak/
33 |
34 | # DuckDB
35 | *.duckdb
36 | *.duckdb.wal
37 |
38 | # IDE
39 | .idea/
40 | .vscode/
41 | *.swp
42 | *.swo
43 | .zed
44 |
45 | # Logs
46 | *.log
47 | logs/
48 |
49 | # Local configuration
50 | .env.local
51 | .env.development.local
52 | .env.test.local
53 | .env.production.local
54 |
55 | # Docker
56 | .docker/
57 | docker-compose.override.yml
58 |
59 | # OS specific
60 | .DS_Store
61 | Thumbs.db
62 |
63 | .roo
64 | .coverage
65 | coverage.xml
66 | .aider*
67 |
68 | **/.claude/settings.local.json
69 | CLAUDE.md
70 | .claude/
71 |
--------------------------------------------------------------------------------
/tests/integration/conftest.py:
--------------------------------------------------------------------------------
1 | """Fixtures for integration tests."""
2 |
3 | import duckdb
4 | import pytest
5 |
6 |
7 | @pytest.fixture
8 | def in_memory_duckdb_connection():
9 | """Create an in-memory DuckDB connection for integration testing.
10 |
11 | This connection has the proper setup for vector search using the same
12 | setup logic as the main application:
13 | - VSS extension loaded
14 | - document_embeddings table created with the proper schema
15 | - HNSW index created
16 | - pages table created (for document service tests)
17 | """
18 | from src.lib.database import DatabaseOperations
19 |
20 | conn = duckdb.connect(":memory:")
21 |
22 | # Use the same setup functions as the main application
23 | try:
24 | # Create a Database instance and initialize with our in-memory connection
25 | db = DatabaseOperations()
26 | db.db.conn = conn
27 |
28 | # Initialize all tables and extensions
29 | db.db.initialize()
30 |
31 | yield conn
32 | finally:
33 | conn.close()
34 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 sisig-ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # Tests for Doctor Application
2 |
3 | This directory contains tests for the Doctor application.
4 |
5 | ## Structure
6 |
7 | - `conftest.py`: Common fixtures for all tests
8 | - `lib/`: Tests for the library components
9 | - `test_crawler.py`: Tests for the crawler module
10 | - `test_chunker.py`: Tests for the chunker module
11 | - `test_embedder.py`: Tests for the embedder module
12 | - `test_indexer.py`: Tests for the indexer module
13 | - `test_database.py`: Tests for the database module
14 | - `test_processor.py`: Tests for the processor module
15 |
16 | ## Running Tests
17 |
18 | To run all tests:
19 |
20 | ```bash
21 | pytest
22 | ```
23 |
24 | To run tests with coverage:
25 |
26 | ```bash
27 | pytest --cov=src
28 | ```
29 |
30 | To run specific test categories:
31 |
32 | ```bash
33 | # Run all unit tests
34 | pytest -m unit
35 |
36 | # Run all async tests
37 | pytest -m async_test
38 |
39 | # Run tests for a specific module
40 | pytest tests/lib/test_crawler.py
41 | ```
42 |
43 | ## Test Markers
44 |
45 | - `unit`: Unit tests
46 | - `integration`: Integration tests (not implemented yet)
47 | - `async_test`: Tests that use asyncio
48 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: Python Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v3
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: '3.12'
20 |
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install uv
25 | uv venv .venv
26 | uv sync
27 |
28 | - name: Debug environment
29 | run: |
30 | echo "Current directory: $(pwd)"
31 | echo "Python path: $(which python)"
32 | echo "Python version: $(python --version)"
33 | echo "Pytest version: $(.venv/bin/pytest --version)"
34 | echo "Listing test directory:"
35 | ls -la tests/
36 | ls -la tests/lib/
37 |
38 | - name: Test with pytest
39 | run: |
40 | .venv/bin/pytest tests/ -v --cov=src --cov-report=xml
41 |
42 | - name: Upload coverage to GitHub
43 | env:
44 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
45 | uses: codecov/codecov-action@v3
46 | with:
47 | file: ./coverage.xml
48 | fail_ci_if_error: false
49 |
--------------------------------------------------------------------------------
/src/web_service/services/admin_service.py:
--------------------------------------------------------------------------------
1 | """Admin service for the web service."""
2 |
3 | import uuid
4 |
5 | from rq import Queue
6 |
7 | from src.common.logger import get_logger
8 |
9 | # Get logger for this module
10 | logger = get_logger(__name__)
11 |
12 |
13 | async def delete_docs(
14 | queue: Queue,
15 | tags: list[str] | None = None,
16 | domain: str | None = None,
17 | page_ids: list[str] | None = None,
18 | ) -> str:
19 | """Delete documents from the database based on filters.
20 |
21 | Args:
22 | queue: Redis queue for job processing
23 | tags: Optional list of tags to filter by
24 | domain: Optional domain substring to filter by
25 | page_ids: Optional list of specific page IDs to delete
26 |
27 | Returns:
28 | str: The task ID for tracking
29 |
30 | """
31 | logger.info(
32 | f"Enqueueing delete task with filters: tags={tags}, domain={domain}, page_ids={page_ids}",
33 | )
34 |
35 | # Generate a task ID for tracking logs
36 | task_id = str(uuid.uuid4())
37 |
38 | # Enqueue the delete task
39 | queue.enqueue(
40 | "src.crawl_worker.tasks.delete_docs",
41 | task_id,
42 | tags,
43 | domain,
44 | page_ids,
45 | )
46 |
47 | logger.info(f"Enqueued delete task with ID: {task_id}")
48 |
49 | return task_id
50 |
--------------------------------------------------------------------------------
/src/lib/database/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for the database module.
2 |
3 | This module contains helper functions for common data transformations or
4 | operations used within the database package, such as serializing and
5 | deserializing tags.
6 | """
7 |
8 | import json
9 |
10 | from src.common.logger import get_logger
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | def serialize_tags(tags: list[str] | None) -> str:
16 | """Serialize a list of tags to a JSON string.
17 |
18 | Args:
19 | tags: A list of tag strings, or None.
20 | Returns:
21 | str: A JSON string representation of the tags list. Returns an empty list '[]' if tags is None.
22 | """
23 | if tags is None:
24 | return json.dumps([])
25 | return json.dumps(tags)
26 |
27 |
28 | def deserialize_tags(tags_json: str | None) -> list[str]:
29 | """Deserialize a tags JSON string to a list of strings.
30 |
31 | Args:
32 | tags_json: The JSON string containing the tags, or None.
33 | Returns:
34 | list[str]: A list of tag strings. Returns an empty list if input is None, empty, or invalid JSON.
35 | """
36 | if not tags_json:
37 | return []
38 | try:
39 | return json.loads(tags_json)
40 | except json.JSONDecodeError:
41 | logger.warning(f"Could not decode tags JSON: {tags_json!r}") # Added !r for better logging
42 | return []
43 |
--------------------------------------------------------------------------------
/src/lib/database/migrations/001_add_hierarchy.sql:
--------------------------------------------------------------------------------
1 | -- Migration: Add hierarchy columns to pages table
2 | -- This migration adds columns to track page relationships for the maps feature
3 |
4 | -- Add hierarchy columns to pages table
5 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS parent_page_id VARCHAR;
6 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS root_page_id VARCHAR;
7 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS depth INTEGER DEFAULT 0;
8 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS path TEXT;
9 | ALTER TABLE pages ADD COLUMN IF NOT EXISTS title TEXT;
10 |
11 | -- Add foreign key constraints (DuckDB doesn't enforce these, but good for documentation)
12 | -- ALTER TABLE pages ADD CONSTRAINT fk_parent_page FOREIGN KEY (parent_page_id) REFERENCES pages(id);
13 | -- ALTER TABLE pages ADD CONSTRAINT fk_root_page FOREIGN KEY (root_page_id) REFERENCES pages(id);
14 |
15 | -- Create indexes for better performance
16 | CREATE INDEX IF NOT EXISTS idx_pages_parent_page_id ON pages(parent_page_id);
17 | CREATE INDEX IF NOT EXISTS idx_pages_root_page_id ON pages(root_page_id);
18 | CREATE INDEX IF NOT EXISTS idx_pages_depth ON pages(depth);
19 |
20 | -- Update existing pages to have sensible defaults
21 | -- Set root_page_id to self for existing pages (they become roots)
22 | UPDATE pages
23 | SET root_page_id = id
24 | WHERE root_page_id IS NULL;
25 |
26 | -- Set depth to 0 for existing pages
27 | UPDATE pages
28 | SET depth = 0
29 | WHERE depth IS NULL;
30 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 |
3 | services:
4 | base:
5 | build:
6 | context: .
7 | dockerfile: Dockerfile.base
8 | image: doctor-base:latest
9 |
10 |
11 | redis:
12 | image: redis:alpine
13 | ports:
14 | - "6379:6379"
15 | networks:
16 | - doctor-net
17 | healthcheck:
18 | test: ["CMD", "redis-cli", "ping"]
19 | interval: 10s
20 | timeout: 5s
21 | retries: 5
22 |
23 | crawl_worker:
24 | build:
25 | context: .
26 | dockerfile: src/crawl_worker/Dockerfile
27 | volumes:
28 | - ./src:/app/src
29 | - ./data:/app/data
30 | environment:
31 | - REDIS_URI=redis://redis:6379
32 | - OPENAI_API_KEY=${OPENAI_API_KEY}
33 | - PYTHONPATH=/app
34 | depends_on:
35 | redis:
36 | condition: service_healthy
37 | base:
38 | condition: service_completed_successfully
39 |
40 | networks:
41 | - doctor-net
42 |
43 | web_service:
44 | build:
45 | context: .
46 | dockerfile: src/web_service/Dockerfile
47 | ports:
48 | - "9111:9111"
49 | volumes:
50 | - ./src:/app/src
51 | - ./data:/app/data
52 | environment:
53 | - REDIS_URI=redis://redis:6379
54 | - PYTHONPATH=/app
55 | - OPENAI_API_KEY=${OPENAI_API_KEY}
56 |
57 | depends_on:
58 | redis:
59 | condition: service_healthy
60 | base:
61 | condition: service_completed_successfully
62 |
63 | networks:
64 | - doctor-net
65 |
66 | networks:
67 | doctor-net:
68 | driver: bridge
69 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "doctor"
3 | version = "0.2.0"
4 | description = "LLM agent system for discovering, crawling, and indexing web sites"
5 | readme = "README.md"
6 | requires-python = ">=3.12"
7 | dependencies = [
8 | "fastapi",
9 | "uvicorn[standard]",
10 | "redis",
11 | "rq",
12 | "crawl4ai==0.6.0",
13 | "langchain-text-splitters",
14 | "litellm",
15 | "openai",
16 | "duckdb",
17 | "pydantic>=2.0.0",
18 | "httpx",
19 | "python-multipart",
20 | "uuid",
21 | "fastapi-mcp==0.3.2",
22 | "mcp==1.7.1",
23 | "nest-asyncio>=1.6.0",
24 | "markdown>=3.8",
25 | ]
26 |
27 | [build-system]
28 | requires = ["hatchling"]
29 | build-backend = "hatchling.build"
30 |
31 | [tool.ruff]
32 | line-length = 100
33 | target-version = "py312"
34 |
35 | [tool.black]
36 | line-length = 100
37 | target-version = ["py312"]
38 |
39 | [tool.hatch.build.targets.wheel]
40 | # Tell hatch where to find the packages for the 'doctor' project
41 | packages = ["src/common", "src/crawl_worker", "src/web_service", "src/lib"]
42 |
43 | [dependency-groups]
44 | dev = [
45 | "ruff>=0.11.8",
46 | "pytest-asyncio>=0.21.0",
47 | "pytest-cov>=4.1.0",
48 | "pytest>=8.3.5",
49 | "pre-commit>=4.2.0",
50 | "coverage>=7.8.0",
51 | ]
52 |
53 | [tool.pytest.ini_options]
54 | testpaths = ["tests"]
55 | python_files = "test_*.py"
56 | python_classes = "Test*"
57 | python_functions = "test_*"
58 | markers = [
59 | "unit: marks tests as unit tests",
60 | "integration: marks tests as integration tests",
61 | "async_test: marks tests as asynchronous tests",
62 | ]
63 |
--------------------------------------------------------------------------------
/src/lib/embedder.py:
--------------------------------------------------------------------------------
1 | """Text embedding functionality using LiteLLM."""
2 |
3 | from typing import Literal
4 |
5 | import litellm
6 |
7 | from src.common.config import DOC_EMBEDDING_MODEL, QUERY_EMBEDDING_MODEL
8 | from src.common.logger import get_logger
9 |
10 | # Configure logging
11 | logger = get_logger(__name__)
12 |
13 |
14 | async def generate_embedding(
15 | text: str,
16 | model: str = None,
17 | timeout: int = 30,
18 | text_type: Literal["doc", "query"] = "doc",
19 | ) -> list[float]:
20 | """Generate an embedding for a text chunk.
21 |
22 | Args:
23 | text: The text to embed
24 | model: The embedding model to use (defaults to config value)
25 | timeout: Timeout in seconds for the embedding API call
26 | text_type: The type of text to embed (defaults to "doc")
27 |
28 | Returns:
29 | The generated embedding as a list of floats
30 |
31 | """
32 | if not text.strip():
33 | logger.warning("Received empty text for embedding, cannot proceed")
34 | raise ValueError("Cannot generate embedding for empty text")
35 |
36 | model_name = model or (DOC_EMBEDDING_MODEL if text_type == "doc" else QUERY_EMBEDDING_MODEL)
37 | logger.debug(f"Generating embedding for text of length {len(text)} using model {model_name}")
38 |
39 | try:
40 | embedding_response = await litellm.aembedding(
41 | model=model_name,
42 | input=[text],
43 | timeout=timeout,
44 | )
45 |
46 | # Extract the embedding vector from the response
47 | embedding = embedding_response["data"][0]["embedding"]
48 | logger.debug(f"Successfully generated embedding of dimension {len(embedding)}")
49 |
50 | return embedding
51 |
52 | except Exception as e:
53 | logger.error(f"Error generating embedding: {e!s}")
54 | raise
55 |
--------------------------------------------------------------------------------
/src/web_service/api/admin.py:
--------------------------------------------------------------------------------
1 | """Admin API routes for the web service."""
2 |
3 | import redis
4 | from fastapi import APIRouter, Depends, status
5 | from rq import Queue
6 |
7 | from src.common.config import REDIS_URI
8 | from src.common.logger import get_logger
9 | from src.common.models import (
10 | DeleteDocsRequest,
11 | )
12 | from src.web_service.services.admin_service import (
13 | delete_docs,
14 | )
15 |
16 | # Get logger for this module
17 | logger = get_logger(__name__)
18 |
19 | # Create router
20 | router = APIRouter(tags=["admin"])
21 |
22 |
23 | @router.post("/delete_docs", status_code=status.HTTP_204_NO_CONTENT, operation_id="delete_docs")
24 | async def delete_docs_endpoint(
25 | request: DeleteDocsRequest,
26 | queue: Queue = Depends(lambda: Queue("worker", connection=redis.from_url(REDIS_URI))),
27 | ) -> None:
28 | """Deletes documents from the database based on filters.
29 |
30 | Args:
31 | request: The delete request with optional filters.
32 | queue: The RQ queue for enqueueing the delete task.
33 |
34 | Returns:
35 | None: Returns a 204 No Content response upon successful enqueueing.
36 |
37 | """
38 | logger.info(
39 | f"API: Deleting docs with filters: tags={request.tags}, domain={request.domain}, page_ids={request.page_ids}",
40 | )
41 |
42 | try:
43 | # Call the service function
44 | await delete_docs(
45 | queue=queue,
46 | tags=request.tags,
47 | domain=request.domain,
48 | page_ids=request.page_ids,
49 | )
50 |
51 | # Return 204 No Content
52 | return
53 | except Exception as e:
54 | logger.error(f"Error deleting documents: {e!s}")
55 | # Since this is an asynchronous operation, we still return 204
56 | # The actual deletion happens in the background
57 | return
58 |
--------------------------------------------------------------------------------
/src/lib/chunker.py:
--------------------------------------------------------------------------------
1 | """Text chunking functionality using LangChain."""
2 |
3 | from langchain_text_splitters import RecursiveCharacterTextSplitter
4 |
5 | from src.common.config import CHUNK_OVERLAP, CHUNK_SIZE
6 | from src.common.logger import get_logger
7 |
8 | # Configure logging
9 | logger = get_logger(__name__)
10 |
11 |
12 | class TextChunker:
13 | """Class for splitting text into semantic chunks."""
14 |
15 | def __init__(self, chunk_size: int = None, chunk_overlap: int = None):
16 | """Initialize the text chunker.
17 |
18 | Args:
19 | chunk_size: Size of each text chunk (defaults to config value)
20 | chunk_overlap: Overlap between chunks (defaults to config value)
21 |
22 | """
23 | self.chunk_size = chunk_size or CHUNK_SIZE
24 | self.chunk_overlap = chunk_overlap or CHUNK_OVERLAP
25 |
26 | self.text_splitter = RecursiveCharacterTextSplitter(
27 | chunk_size=self.chunk_size,
28 | chunk_overlap=self.chunk_overlap,
29 | length_function=len,
30 | )
31 |
32 | logger.debug(
33 | f"Initialized TextChunker with chunk_size={self.chunk_size}, chunk_overlap={self.chunk_overlap}",
34 | )
35 |
36 | def split_text(self, text: str) -> list[str]:
37 | """Split the text into chunks.
38 |
39 | Args:
40 | text: The text to split
41 |
42 | Returns:
43 | List of text chunks
44 |
45 | """
46 | if not text or not text.strip():
47 | logger.warning("Received empty text for chunking, returning empty list")
48 | return []
49 |
50 | chunks = self.text_splitter.split_text(text)
51 | logger.debug(f"Split text of length {len(text)} into {len(chunks)} chunks")
52 |
53 | # Filter out empty chunks
54 | non_empty_chunks = [chunk for chunk in chunks if chunk.strip()]
55 | if len(non_empty_chunks) < len(chunks):
56 | logger.debug(f"Filtered out {len(chunks) - len(non_empty_chunks)} empty chunks")
57 |
58 | return non_empty_chunks
59 |
--------------------------------------------------------------------------------
/src/common/logger.py:
--------------------------------------------------------------------------------
1 | """Centralized logging configuration for the Doctor project."""
2 |
3 | import logging
4 | import os
5 |
6 |
7 | def configure_logging(
8 | name: str = None,
9 | level: str = None,
10 | log_file: str | None = None,
11 | ) -> logging.Logger:
12 | """Configure and return a logger with consistent formatting.
13 |
14 | Args:
15 | name: Logger name (typically __name__ from the calling module)
16 | level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
17 | Falls back to DOCTOR_LOG_LEVEL env var, defaults to INFO
18 | log_file: Optional path to log file
19 |
20 | Returns:
21 | Configured logger instance
22 |
23 | """
24 | # Get log level from params, env var, or default to INFO
25 | if level is None:
26 | level = os.getenv("DOCTOR_LOG_LEVEL", "INFO")
27 |
28 | level = getattr(logging, level.upper())
29 |
30 | # Configure root logger if not already configured
31 | if not logging.getLogger().handlers:
32 | # Basic configuration with consistent format
33 | logging.basicConfig(
34 | level=level,
35 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
36 | datefmt="%Y-%m-%d %H:%M:%S",
37 | )
38 |
39 | # Add file handler if specified
40 | if log_file:
41 | file_handler = logging.FileHandler(log_file)
42 | file_handler.setFormatter(
43 | logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"),
44 | )
45 | logging.getLogger().addHandler(file_handler)
46 |
47 | # Get logger for the specific module
48 | logger = logging.getLogger(name)
49 | logger.setLevel(level)
50 |
51 | return logger
52 |
53 |
54 | def get_logger(name: str = None) -> logging.Logger:
55 | """Get a logger with the specified name.
56 |
57 | Args:
58 | name: Logger name (typically __name__ from the calling module)
59 |
60 | Returns:
61 | Configured logger instance
62 |
63 | """
64 | return configure_logging(name)
65 |
--------------------------------------------------------------------------------
/llms.txt:
--------------------------------------------------------------------------------
1 | # Doctor
2 |
3 | > Doctor is a system that lets LLM agents discover, crawl, and index web sites for better and more up-to-date reasoning and code generation.
4 |
5 | Doctor provides a complete stack for crawling, indexing, and searching web content to enhance LLM capabilities. It handles web crawling, text chunking, embedding generation, and semantic search, making this functionality available to large language models through a Model Context Protocol (MCP) server.
6 |
7 | ## Core Components
8 |
9 | - [Crawl Worker](/src/crawl_worker) - Processes crawl jobs, chunks text, and creates embeddings
10 | - [Web Service](/src/web_service) - FastAPI service exposing endpoints for fetching, searching, and viewing data, and exposing MCP server
11 | - [Common](/src/common) - Shared code, models, and database utilities
12 |
13 | ## Infrastructure
14 |
15 | - DuckDB - Database for storing document data and embeddings with vector search capabilities
16 | - Redis - Message broker for asynchronous task processing
17 | - Docker - Container orchestration for deploying the complete stack
18 |
19 | ## Models
20 |
21 | - OpenAI Text Embeddings - Used for generating vector embeddings of text chunks
22 | - Implementation: text-embedding-ada-002 (1536 dimensions)
23 | - Integration: Accessed via litellm library
24 |
25 | ## Libraries
26 |
27 | - crawl4ai - For web page crawling
28 | - langchain_text_splitters - For chunking text content
29 | - litellm - Wrapper for accessing embedding models
30 | - fastapi - Web service framework
31 | - fastapi-mcp - MCP server implementation
32 | - duckdb-vss - Vector similarity search extension for DuckDB
33 |
34 | ## Technical Requirements
35 |
36 | - Docker and Docker Compose - For running the complete stack
37 | - Python 3.10+ - Primary programming language
38 | - OpenAI API key - Required for embedding generation
39 |
40 | ## API Configuration
41 |
42 | - OpenAI API key must be provided via environment variable: OPENAI_API_KEY
43 | - Additional environment variables:
44 | - REDIS_URI - URI for Redis connection
45 | - DATA_DIR - Directory for storing DuckDB data files (default: "data")
46 |
--------------------------------------------------------------------------------
/src/common/config.py:
--------------------------------------------------------------------------------
1 | """Configuration settings for the Doctor project."""
2 |
3 | import os
4 |
5 | from src.common.logger import get_logger
6 |
7 | # Get logger for this module
8 | logger = get_logger(__name__)
9 |
10 | # Vector settings
11 | VECTOR_SIZE = 3072 # OpenAI text-embedding-3-large embedding size
12 |
13 | # Redis settings
14 | REDIS_URI = os.getenv("REDIS_URI", "redis://localhost:6379")
15 |
16 | # DuckDB settings
17 | DATA_DIR = os.getenv("DATA_DIR", "data")
18 | DUCKDB_PATH = os.path.join(DATA_DIR, "doctor.duckdb")
19 | DUCKDB_EMBEDDINGS_TABLE = "document_embeddings"
20 | DB_RETRY_ATTEMPTS = int(os.getenv("DB_RETRY_ATTEMPTS", "5"))
21 | DB_RETRY_DELAY_SEC = float(os.getenv("DB_RETRY_DELAY_SEC", "0.5"))
22 |
23 | # OpenAI settings
24 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
25 | DOC_EMBEDDING_MODEL = "openai/text-embedding-3-large"
26 | QUERY_EMBEDDING_MODEL = "openai/text-embedding-3-large"
27 |
28 | # Web service settings
29 | WEB_SERVICE_HOST = os.getenv("WEB_SERVICE_HOST", "0.0.0.0")
30 | WEB_SERVICE_PORT = int(os.getenv("WEB_SERVICE_PORT", "9111"))
31 |
32 | # Crawl settings
33 | DEFAULT_MAX_PAGES = 100
34 | CHUNK_SIZE = 1000
35 | CHUNK_OVERLAP = 200
36 |
37 | # Search settings
38 | RETURN_FULL_DOCUMENT_TEXT = True
39 |
40 | # MCP Server settings
41 | DOCTOR_BASE_URL = os.getenv("DOCTOR_BASE_URL", "http://localhost:9111")
42 |
43 |
44 | def validate_config() -> list[str]:
45 | """Validate the configuration and return a list of any issues."""
46 | issues = []
47 |
48 | # Only check for valid API key in production
49 | if not os.getenv("OPENAI_API_KEY"):
50 | issues.append("OPENAI_API_KEY environment variable is not set")
51 |
52 | # Ensure data directory exists
53 | os.makedirs(DATA_DIR, exist_ok=True)
54 |
55 | return issues
56 |
57 |
58 | def check_config() -> bool:
59 | """Check if the configuration is valid and log any issues."""
60 | issues = validate_config()
61 |
62 | if issues:
63 | for issue in issues:
64 | logger.error(f"Configuration error: {issue}")
65 | return False
66 |
67 | return True
68 |
69 |
70 | if __name__ == "__main__":
71 | # When run directly, validate the configuration
72 | check_config()
73 |
--------------------------------------------------------------------------------
/tests/lib/test_embedder.py:
--------------------------------------------------------------------------------
1 | """Tests for the embedder module."""
2 |
3 | from unittest.mock import AsyncMock, patch
4 |
5 | import pytest
6 |
7 | from src.lib.embedder import generate_embedding
8 |
9 |
10 | @pytest.fixture
11 | def mock_embedding_response():
12 | """Mock response from litellm.aembedding."""
13 | return {"data": [{"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]}]}
14 |
15 |
16 | @pytest.mark.unit
17 | @pytest.mark.async_test
18 | async def test_generate_embedding(sample_text, mock_embedding_response):
19 | """Test generating an embedding for a text chunk."""
20 | with patch("src.lib.embedder.litellm.aembedding", new_callable=AsyncMock) as mock_aembedding:
21 | mock_aembedding.return_value = mock_embedding_response
22 |
23 | # Test with default model
24 | with patch("src.lib.embedder.DOC_EMBEDDING_MODEL", "text-embedding-3-small"):
25 | embedding = await generate_embedding(sample_text)
26 | assert embedding == mock_embedding_response["data"][0]["embedding"]
27 | mock_aembedding.assert_called_once_with(
28 | model="text-embedding-3-small",
29 | input=[sample_text],
30 | timeout=30,
31 | )
32 |
33 | # Test with custom model and timeout
34 | mock_aembedding.reset_mock()
35 | embedding = await generate_embedding(sample_text, model="custom-model", timeout=60)
36 |
37 | # Check that aembedding was called with the correct arguments
38 | mock_aembedding.assert_called_once_with(
39 | model="custom-model",
40 | input=[sample_text],
41 | timeout=60,
42 | )
43 |
44 |
45 | @pytest.mark.unit
46 | @pytest.mark.async_test
47 | async def test_generate_embedding_with_empty_text():
48 | """Test generating an embedding with empty text."""
49 | with pytest.raises(ValueError, match="Cannot generate embedding for empty text"):
50 | await generate_embedding("")
51 |
52 | with pytest.raises(ValueError, match="Cannot generate embedding for empty text"):
53 | await generate_embedding(" ")
54 |
55 |
56 | @pytest.mark.unit
57 | @pytest.mark.async_test
58 | async def test_generate_embedding_error_handling():
59 | """Test error handling when generating an embedding."""
60 | with patch("src.lib.embedder.litellm.aembedding", new_callable=AsyncMock) as mock_aembedding:
61 | # Simulate an API error
62 | mock_aembedding.side_effect = Exception("API error")
63 |
64 | with pytest.raises(Exception, match="API error"):
65 | await generate_embedding("Some text")
66 |
--------------------------------------------------------------------------------
/tests/integration/common/test_vector_indexer.py:
--------------------------------------------------------------------------------
1 | """Integration tests for the VectorIndexer class with real DuckDB."""
2 |
3 | import pytest
4 |
5 | from src.common.config import VECTOR_SIZE
6 | from src.common.indexer import VectorIndexer
7 |
8 |
9 | @pytest.mark.integration
10 | @pytest.mark.async_test
11 | async def test_vectorindexer_with_real_duckdb(in_memory_duckdb_connection):
12 | """Test the DuckDB implementation of VectorIndexer with a real in-memory database."""
13 | # Create VectorIndexer with the in-memory test connection
14 | indexer = VectorIndexer(connection=in_memory_duckdb_connection)
15 |
16 | # Test vector with random values (correct dimension)
17 | test_vector = [0.1] * VECTOR_SIZE
18 | test_payload = {
19 | "text": "Test chunk",
20 | "page_id": "test-page-id",
21 | "url": "https://example.com",
22 | "domain": "example.com",
23 | "tags": ["test", "example"],
24 | "job_id": "test-job",
25 | }
26 |
27 | # Test index_vector
28 | point_id = await indexer.index_vector(test_vector, test_payload)
29 | assert point_id is not None
30 |
31 | # Test search - should find the vector we just indexed
32 | results = await indexer.search(test_vector, limit=1)
33 | assert len(results) == 1
34 | assert results[0]["id"] == point_id
35 | assert results[0]["payload"]["text"] == "Test chunk"
36 | assert results[0]["payload"]["tags"] == ["test", "example"]
37 |
38 | # Test tag filtering
39 | filter_payload = {"must": [{"key": "tags", "match": {"any": ["test"]}}]}
40 |
41 | results = await indexer.search(test_vector, limit=10, filter_payload=filter_payload)
42 | assert len(results) == 1
43 |
44 | # Test with non-matching tag filter
45 | filter_payload = {"must": [{"key": "tags", "match": {"any": ["nonexistent"]}}]}
46 |
47 | results = await indexer.search(test_vector, limit=10, filter_payload=filter_payload)
48 | assert len(results) == 0
49 |
50 | # Add another vector with different tags
51 | test_vector2 = [0.2] * VECTOR_SIZE
52 | test_payload2 = {
53 | "text": "Another test chunk",
54 | "page_id": "test-page-id-2",
55 | "url": "https://example.org",
56 | "domain": "example.org",
57 | "tags": ["different", "example"],
58 | "job_id": "test-job",
59 | }
60 |
61 | await indexer.index_vector(test_vector2, test_payload2)
62 |
63 | # Test filtering by the "example" tag which both vectors have
64 | filter_payload = {"must": [{"key": "tags", "match": {"any": ["example"]}}]}
65 |
66 | results = await indexer.search(test_vector, limit=10, filter_payload=filter_payload)
67 | assert len(results) == 2
68 |
--------------------------------------------------------------------------------
/src/web_service/main.py:
--------------------------------------------------------------------------------
1 | """Main module for the Doctor Web Service."""
2 |
3 | from collections.abc import AsyncIterator
4 | from contextlib import asynccontextmanager
5 |
6 | import redis
7 | from fastapi import FastAPI
8 | from fastapi_mcp import FastApiMCP
9 |
10 | from src.common.config import (
11 | REDIS_URI,
12 | WEB_SERVICE_HOST,
13 | WEB_SERVICE_PORT,
14 | check_config,
15 | )
16 | from src.common.logger import get_logger
17 | from src.lib.database import DatabaseOperations
18 | from src.web_service.api import api_router
19 |
20 | # Get logger for this module
21 | logger = get_logger(__name__)
22 |
23 |
24 | @asynccontextmanager
25 | async def lifespan(app: FastAPI) -> AsyncIterator[None]:
26 | """Lifespan context manager for the FastAPI application.
27 |
28 | Handles startup and shutdown events.
29 |
30 | Args:
31 | app: The FastAPI application instance.
32 |
33 | Yields:
34 | None: Indicates the application is ready.
35 |
36 | """
37 | # Initialize databases for the web service
38 | DatabaseOperations()
39 | # The db.db.initialize() is called within DatabaseOperations constructor using a context manager.
40 | # The db.db.close() is also handled by the context manager in initialize.
41 | logger.info("Database initialization complete")
42 | if not check_config():
43 | logger.error("Invalid configuration. Exiting.")
44 | exit(1)
45 | yield
46 | logger.info("Shutting down application")
47 |
48 |
49 | def create_application() -> FastAPI:
50 | """Creates and configures the FastAPI application.
51 |
52 | Returns:
53 | FastAPI: Configured FastAPI application
54 |
55 | """
56 | app = FastAPI(
57 | title="Doctor API",
58 | description="API for the Doctor web crawling and indexing system",
59 | version="0.2.0",
60 | lifespan=lifespan,
61 | )
62 |
63 | # Include the API router
64 | app.include_router(api_router)
65 |
66 | # Set up MCP
67 | mcp_server_description = """
68 | Search for documents using semantic search.
69 | 1. Use the `list_tags` endpoint to get a list of all available tags.
70 | 2. Use the `search_docs` endpoint to search for documents using semantic search, optionally filtered by tag.
71 | 3. Use the `get_doc_page` endpoint to get the full text of a document page.
72 | 4. You can also use the `list_doc_pages` endpoint to get a list of all available document pages.
73 | """
74 | mcp = FastApiMCP(
75 | app,
76 | name="Doctor",
77 | description=mcp_server_description,
78 | exclude_operations=["fetch_url", "job_progress", "delete_docs"],
79 | )
80 |
81 | mcp.mount()
82 |
83 | return app
84 |
85 |
86 | # Create the application
87 | app = create_application()
88 |
89 | # Create Redis connection
90 | redis_conn = redis.from_url(REDIS_URI)
91 |
92 |
93 | if __name__ == "__main__":
94 | import uvicorn
95 |
96 | uvicorn.run(app, host=WEB_SERVICE_HOST, port=WEB_SERVICE_PORT)
97 |
--------------------------------------------------------------------------------
/src/crawl_worker/main.py:
--------------------------------------------------------------------------------
1 | """Main module for the Doctor Crawl Worker."""
2 |
3 | import redis
4 | from rq import Worker
5 |
6 | from src.common.config import REDIS_URI, check_config
7 | from src.common.logger import get_logger
8 | from src.lib.database import DatabaseOperations
9 |
10 | # Get logger for this module
11 | logger = get_logger(__name__)
12 |
13 |
14 | def main() -> int:
15 | """Main entry point for the Crawl Worker.
16 |
17 | Returns:
18 | int: The exit code (0 for success, 1 for failure).
19 |
20 | """
21 | # Validate configuration
22 | if not check_config():
23 | logger.error("Invalid configuration. Exiting.")
24 | return 1
25 |
26 | # Initialize databases with write access
27 | try:
28 | logger.info("Initializing databases for the crawl worker...")
29 | # DatabaseOperations() constructor now calls initialize() internally using a context manager.
30 | db_ops = DatabaseOperations()
31 | logger.info(
32 | "Database initialization (via DatabaseOperations constructor) completed successfully."
33 | )
34 |
35 | # Double-check that the document_embeddings table exists
36 | # Use a new context-managed connection for this check
37 | with db_ops.db as conn_manager:
38 | actual_conn = conn_manager.conn
39 | if not actual_conn:
40 | logger.error("Failed to get DuckDB connection for table verification.")
41 | return 1
42 |
43 | result = actual_conn.execute(
44 | "SELECT count(*) FROM information_schema.tables WHERE table_name = 'document_embeddings'",
45 | ).fetchone()
46 |
47 | if result is None:
48 | logger.error("Failed to execute query to check for document_embeddings table")
49 | return 1 # conn_manager will close connection
50 |
51 | table_count = result[0]
52 |
53 | if table_count == 0:
54 | logger.error("document_embeddings table is still missing after initialization!")
55 | return 1 # conn_manager will close connection
56 | logger.info("Verified document_embeddings table exists")
57 | # Connection is closed by conn_manager context exit
58 |
59 | except Exception as e:
60 | logger.error(f"Database initialization or verification failed: {e!s}")
61 | return 1
62 |
63 | # Connect to Redis
64 | try:
65 | logger.info(f"Connecting to Redis at {REDIS_URI}")
66 | redis_conn = redis.from_url(REDIS_URI)
67 |
68 | # Start worker
69 | logger.info("Starting worker, listening on queue: worker")
70 | worker = Worker(["worker"], connection=redis_conn)
71 | worker.work(with_scheduler=True)
72 | return 0 # Return success if worker completes normally
73 | except Exception as redis_error:
74 | logger.error(f"Redis worker error: {redis_error!s}")
75 | return 1
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/tests/lib/test_crawler.py:
--------------------------------------------------------------------------------
1 | """Tests for the crawler module."""
2 |
3 | from unittest.mock import AsyncMock, MagicMock, patch
4 |
5 | import pytest
6 |
7 | from src.lib.crawler import crawl_url, extract_page_text
8 |
9 |
10 | @pytest.mark.unit
11 | @pytest.mark.async_test
12 | async def test_crawl_url(sample_url):
13 | """Test crawling a URL."""
14 | # Create mock results
15 | mock_results = [
16 | MagicMock(url=sample_url),
17 | MagicMock(url=f"{sample_url}/page1"),
18 | MagicMock(url=f"{sample_url}/page2"),
19 | ]
20 |
21 | # Create a mock AsyncWebCrawler class
22 | mock_crawler_instance = AsyncMock()
23 | mock_crawler_instance.arun = AsyncMock(return_value=mock_results)
24 |
25 | # Create a mock context manager
26 | mock_crawler_class = MagicMock()
27 | mock_crawler_class.return_value.__aenter__.return_value = mock_crawler_instance
28 |
29 | # Patch the AsyncWebCrawler
30 | with patch("src.lib.crawler.AsyncWebCrawler", mock_crawler_class):
31 | results = await crawl_url(sample_url, max_pages=3, max_depth=2)
32 |
33 | # Check that arun was called with the correct arguments
34 | mock_crawler_instance.arun.assert_called_once()
35 | args, kwargs = mock_crawler_instance.arun.call_args
36 | assert kwargs["url"] == sample_url
37 | assert "config" in kwargs
38 |
39 | # Check that we got the expected results
40 | assert results == mock_results
41 | assert len(results) == 3
42 |
43 |
44 | @pytest.mark.unit
45 | def test_extract_page_text_with_markdown(sample_crawl_result):
46 | """Test extracting text from a crawl result with markdown."""
47 | text = extract_page_text(sample_crawl_result)
48 | assert text == "# Example Page\n\nThis is some example content."
49 |
50 |
51 | @pytest.mark.unit
52 | def test_extract_page_text_with_extracted_content():
53 | """Test extracting text from a crawl result with extracted content."""
54 | # Create a mock result with only extracted content
55 | mock_markdown = MagicMock()
56 | mock_markdown.fit_markdown = "Example Page. This is some example content."
57 |
58 | mock_result = MagicMock(
59 | url="https://example.com",
60 | markdown=mock_markdown,
61 | _markdown=None,
62 | extracted_content="Fallback content that should not be used",
63 | html="...",
64 | )
65 |
66 | text = extract_page_text(mock_result)
67 | assert text == "Example Page. This is some example content."
68 |
69 |
70 | @pytest.mark.unit
71 | def test_extract_page_text_with_html_only():
72 | """Test extracting text from a crawl result with only HTML."""
73 | # Create a mock result with only HTML
74 | html_content = "
Example
"
75 |
76 | # Create mock with no markdown or extracted content, only HTML
77 | mock_result = MagicMock(
78 | url="https://example.com",
79 | markdown=None,
80 | _markdown=None,
81 | extracted_content=None,
82 | html=html_content,
83 | )
84 |
85 | text = extract_page_text(mock_result)
86 | assert text == html_content
87 |
--------------------------------------------------------------------------------
/tests/lib/test_chunker.py:
--------------------------------------------------------------------------------
1 | """Tests for the chunker module."""
2 |
3 | from unittest.mock import MagicMock, patch
4 |
5 | import pytest
6 |
7 | from src.lib.chunker import TextChunker
8 |
9 |
10 | @pytest.fixture
11 | def mock_text_splitter():
12 | """Mock for the RecursiveCharacterTextSplitter."""
13 | mock = MagicMock()
14 | mock.split_text.return_value = [
15 | "This is chunk 1.",
16 | "This is chunk 2.",
17 | "This is chunk 3.",
18 | " ", # Empty chunk (with whitespace) that should be filtered
19 | ]
20 | return mock
21 |
22 |
23 | @pytest.mark.unit
24 | def test_text_chunker_initialization():
25 | """Test TextChunker initialization with default values."""
26 | with (
27 | patch("src.lib.chunker.CHUNK_SIZE", 1000),
28 | patch("src.lib.chunker.CHUNK_OVERLAP", 100),
29 | patch("src.lib.chunker.RecursiveCharacterTextSplitter") as mock_splitter_class,
30 | ):
31 | chunker = TextChunker()
32 |
33 | # Check that the chunker was initialized with the correct values
34 | assert chunker.chunk_size == 1000
35 | assert chunker.chunk_overlap == 100
36 |
37 | # Check that the text splitter was initialized with the correct values
38 | mock_splitter_class.assert_called_once_with(
39 | chunk_size=1000,
40 | chunk_overlap=100,
41 | length_function=len,
42 | )
43 |
44 |
45 | @pytest.mark.unit
46 | def test_text_chunker_initialization_with_custom_values():
47 | """Test TextChunker initialization with custom values."""
48 | with patch("src.lib.chunker.RecursiveCharacterTextSplitter") as mock_splitter_class:
49 | chunker = TextChunker(chunk_size=500, chunk_overlap=50)
50 |
51 | # Check that the chunker was initialized with the correct values
52 | assert chunker.chunk_size == 500
53 | assert chunker.chunk_overlap == 50
54 |
55 | # Check that the text splitter was initialized with the correct values
56 | mock_splitter_class.assert_called_once_with(
57 | chunk_size=500,
58 | chunk_overlap=50,
59 | length_function=len,
60 | )
61 |
62 |
63 | @pytest.mark.unit
64 | def test_split_text(sample_text, mock_text_splitter):
65 | """Test splitting text into chunks."""
66 | with patch("src.lib.chunker.RecursiveCharacterTextSplitter", return_value=mock_text_splitter):
67 | chunker = TextChunker()
68 | chunks = chunker.split_text(sample_text)
69 |
70 | # Check that the text splitter was called with the correct input
71 | mock_text_splitter.split_text.assert_called_once_with(sample_text)
72 |
73 | # Check that we got the expected chunks (empty chunks filtered out)
74 | assert chunks == ["This is chunk 1.", "This is chunk 2.", "This is chunk 3."]
75 | assert len(chunks) == 3
76 |
77 |
78 | @pytest.mark.unit
79 | def test_split_text_empty_input():
80 | """Test splitting empty text."""
81 | chunker = TextChunker()
82 |
83 | # Test with empty string
84 | chunks = chunker.split_text("")
85 | assert chunks == []
86 |
87 | # Test with whitespace only
88 | chunks = chunker.split_text(" \n ")
89 | assert chunks == []
90 |
91 |
92 | @pytest.mark.unit
93 | def test_split_text_all_empty_chunks():
94 | """Test when all chunks are empty after filtering."""
95 | with patch("src.lib.chunker.RecursiveCharacterTextSplitter") as mock_splitter_class:
96 | # Mock splitter to return only empty chunks
97 | mock_splitter = MagicMock()
98 | mock_splitter.split_text.return_value = [" ", "\n", " \t "]
99 | mock_splitter_class.return_value = mock_splitter
100 |
101 | chunker = TextChunker()
102 | chunks = chunker.split_text("Some text")
103 |
104 | # Check that we got an empty list
105 | assert chunks == []
106 |
--------------------------------------------------------------------------------
/src/web_service/services/debug_bm25.py:
--------------------------------------------------------------------------------
1 | """BM25 diagnostic function."""
2 |
3 | import duckdb
4 |
5 | from src.common.logger import get_logger
6 |
7 | # Get logger for this module
8 | logger = get_logger(__name__)
9 |
10 |
11 | async def debug_bm25_search(conn: duckdb.DuckDBPyConnection, query: str) -> dict:
12 | """Diagnose BM25/FTS search issues by always attempting to query the FTS index using DuckDB's FTS extension.
13 | Reports the result of the FTS query attempt, any errors, and explains FTS index visibility.
14 | """
15 | logger.info(f"Running BM25/FTS search diagnostics for query: '{query}' (DuckDB-specific)")
16 | results = {}
17 |
18 | try:
19 | # Check if FTS extension is loaded
20 | fts_loaded = conn.execute("""
21 | SELECT * FROM duckdb_extensions()
22 | WHERE extension_name = 'fts' AND loaded = true
23 | """).fetchall()
24 | results["fts_extension_loaded"] = bool(fts_loaded)
25 | logger.info(f"FTS extension loaded: {results['fts_extension_loaded']}")
26 |
27 | # Count records in pages table
28 | try:
29 | pages_count = conn.execute("SELECT COUNT(*) FROM pages").fetchone()[0]
30 | results["pages_count"] = pages_count
31 | logger.info(f"Pages table record count: {pages_count}")
32 | except Exception as e:
33 | results["pages_count"] = None
34 | results["pages_count_error"] = str(e)
35 | logger.warning(f"Could not count pages: {e}")
36 |
37 | # Always attempt FTS/BM25 search using DuckDB FTS extension
38 | try:
39 | escaped_query = query.replace("'", "''")
40 | bm25_sql = f"""
41 | SELECT p.id, p.url, fts_main_pages.match_bm25(p.id, '{escaped_query}') AS score
42 | FROM pages p
43 | WHERE score IS NOT NULL
44 | ORDER BY score DESC
45 | LIMIT 5
46 | """
47 | logger.info(f"Attempting FTS/BM25 search: {bm25_sql}")
48 | bm25_results = conn.execute(bm25_sql).fetchall()
49 | results["fts_bm25_search"] = {
50 | "success": True,
51 | "count": len(bm25_results),
52 | "samples": [
53 | {"id": row[0], "url": row[1], "score": row[2]} for row in bm25_results[:3]
54 | ],
55 | "error": None,
56 | }
57 | logger.info(f"FTS/BM25 search returned {len(bm25_results)} results")
58 | except Exception as e_bm25:
59 | results["fts_bm25_search"] = {
60 | "success": False,
61 | "count": 0,
62 | "samples": [],
63 | "error": str(e_bm25),
64 | }
65 | logger.warning(f"FTS/BM25 search failed: {e_bm25}")
66 |
67 | # Add a note about FTS index visibility in DuckDB
68 | results["fts_index_visibility_note"] = (
69 | "DuckDB FTS indexes are not visible in sqlite_master or information_schema.tables. "
70 | "If FTS index creation returned an 'already exists' error, the index is present and can be queried "
71 | "using fts_main_.match_bm25 or similar functions, even if not listed in system tables."
72 | )
73 |
74 | # Add table/column info for clarity
75 | try:
76 | tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
77 | table_names = [row[0] for row in tables]
78 | results["tables"] = table_names
79 | if "pages" in table_names:
80 | columns = conn.execute("PRAGMA table_info('pages')").fetchall()
81 | col_names = [row[1] for row in columns]
82 | results["pages_columns"] = col_names
83 | except Exception as e_schema:
84 | results["schema_check_error"] = str(e_schema)
85 |
86 | except Exception as e:
87 | logger.error(f"BM25/FTS diagnostics failed: {e}")
88 | results["error"] = str(e)
89 |
90 | return results
91 |
--------------------------------------------------------------------------------
/src/web_service/api/map.py:
--------------------------------------------------------------------------------
1 | """API endpoints for the site map feature."""
2 |
3 | from fastapi import APIRouter, HTTPException, Response
4 | from fastapi.responses import HTMLResponse
5 |
6 | from src.common.logger import get_logger
7 | from src.web_service.services.map_service import MapService
8 |
9 | logger = get_logger(__name__)
10 |
11 | router = APIRouter(tags=["map"])
12 |
13 |
14 | @router.get("/map", response_class=HTMLResponse)
15 | async def get_site_index() -> str:
16 | """Get an index of all crawled sites.
17 |
18 | Returns:
19 | HTML page listing all root pages/sites.
20 | """
21 | try:
22 | service = MapService()
23 | sites = await service.get_all_sites()
24 | # Log the sites data for debugging
25 | logger.debug(f"Retrieved {len(sites)} sites from database")
26 | if sites:
27 | logger.debug(f"First site data: {sites[0]}")
28 | return service.format_site_list(sites)
29 | except Exception as e:
30 | logger.error(f"Error getting site index: {e}", exc_info=True)
31 | raise HTTPException(status_code=500, detail=str(e))
32 |
33 |
34 | @router.get("/map/site/{root_page_id}", response_class=HTMLResponse)
35 | async def get_site_tree(root_page_id: str) -> str:
36 | """Get the hierarchical tree view for a specific site.
37 |
38 | Args:
39 | root_page_id: The ID of the root page.
40 |
41 | Returns:
42 | HTML page showing the site's page hierarchy.
43 | """
44 | try:
45 | service = MapService()
46 | tree = await service.build_page_tree(root_page_id)
47 | logger.debug(f"Built tree for {root_page_id}: {tree}")
48 | return service.format_site_tree(tree)
49 | except Exception as e:
50 | logger.error(f"Error getting site tree for {root_page_id}: {e}", exc_info=True)
51 | raise HTTPException(status_code=500, detail=str(e))
52 |
53 |
54 | @router.get("/map/page/{page_id}", response_class=HTMLResponse)
55 | async def view_page(page_id: str) -> str:
56 | """View a specific page with navigation.
57 |
58 | Args:
59 | page_id: The ID of the page to view.
60 |
61 | Returns:
62 | HTML page with the page content and navigation.
63 | """
64 | try:
65 | service = MapService()
66 |
67 | # Get the page
68 | page = await service.db_ops.get_page_by_id(page_id)
69 | if not page:
70 | raise HTTPException(status_code=404, detail="Page not found")
71 |
72 | # Get navigation context
73 | navigation = await service.get_navigation_context(page_id)
74 |
75 | # Render the page
76 | return service.render_page_html(page, navigation)
77 | except HTTPException:
78 | raise
79 | except Exception as e:
80 | logger.error(f"Error viewing page {page_id}: {e}")
81 | raise HTTPException(status_code=500, detail=str(e))
82 |
83 |
84 | @router.get("/map/page/{page_id}/raw", response_class=Response)
85 | async def get_page_raw(page_id: str) -> Response:
86 | """Get the raw markdown content of a page.
87 |
88 | Args:
89 | page_id: The ID of the page.
90 |
91 | Returns:
92 | Raw markdown content.
93 | """
94 | try:
95 | service = MapService()
96 |
97 | # Get the page
98 | page = await service.db_ops.get_page_by_id(page_id)
99 | if not page:
100 | raise HTTPException(status_code=404, detail="Page not found")
101 |
102 | # Return raw text as markdown
103 | return Response(
104 | content=page.get("raw_text", ""),
105 | media_type="text/markdown",
106 | headers={"Content-Disposition": f'inline; filename="{page.get("title", "page")}.md"'},
107 | )
108 | except HTTPException:
109 | raise
110 | except Exception as e:
111 | logger.error(f"Error getting raw content for page {page_id}: {e}")
112 | raise HTTPException(status_code=500, detail=str(e))
113 |
--------------------------------------------------------------------------------
/src/lib/database/migrations.py:
--------------------------------------------------------------------------------
1 | """Database migration management for the Doctor project.
2 |
3 | This module handles applying database migrations to ensure the schema is up-to-date.
4 | """
5 |
6 | import pathlib
7 |
8 | import duckdb
9 |
10 | from src.common.logger import get_logger
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | class MigrationRunner:
16 | """Manages and executes database migrations."""
17 |
18 | def __init__(self, conn: duckdb.DuckDBPyConnection) -> None:
19 | """Initialize the migration runner.
20 |
21 | Args:
22 | conn: Active DuckDB connection.
23 |
24 | Returns:
25 | None.
26 | """
27 | self.conn = conn
28 | self.migrations_dir = pathlib.Path(__file__).parent / "migrations"
29 |
30 | def _ensure_migration_table(self) -> None:
31 | """Create the migrations tracking table if it doesn't exist.
32 |
33 | Args:
34 | None.
35 |
36 | Returns:
37 | None.
38 | """
39 | self.conn.execute("""
40 | CREATE TABLE IF NOT EXISTS _migrations (
41 | migration_name VARCHAR PRIMARY KEY,
42 | applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
43 | )
44 | """)
45 |
46 | def _get_applied_migrations(self) -> set[str]:
47 | """Get the set of already applied migrations.
48 |
49 | Args:
50 | None.
51 |
52 | Returns:
53 | set[str]: Set of migration names that have been applied.
54 | """
55 | result = self.conn.execute("SELECT migration_name FROM _migrations").fetchall()
56 | return {row[0] for row in result}
57 |
58 | def _apply_migration(self, migration_path: pathlib.Path) -> None:
59 | """Apply a single migration file.
60 |
61 | Args:
62 | migration_path: Path to the migration SQL file.
63 |
64 | Returns:
65 | None.
66 |
67 | Raises:
68 | Exception: If migration fails.
69 | """
70 | migration_name = migration_path.name
71 | logger.info(f"Applying migration: {migration_name}")
72 |
73 | try:
74 | # Read and execute the migration
75 | sql_content = migration_path.read_text()
76 |
77 | # Split by semicolons and execute each statement
78 | # Filter out empty statements
79 | statements = [s.strip() for s in sql_content.split(";") if s.strip()]
80 |
81 | for statement in statements:
82 | self.conn.execute(statement)
83 |
84 | # Record the migration as applied
85 | self.conn.execute(
86 | "INSERT INTO _migrations (migration_name) VALUES (?)", [migration_name]
87 | )
88 |
89 | logger.info(f"Successfully applied migration: {migration_name}")
90 |
91 | except Exception as e:
92 | logger.error(f"Failed to apply migration {migration_name}: {e}")
93 | raise
94 |
95 | def run_migrations(self) -> None:
96 | """Run all pending migrations in order.
97 |
98 | Args:
99 | None.
100 |
101 | Returns:
102 | None.
103 | """
104 | # Ensure migration tracking table exists
105 | self._ensure_migration_table()
106 |
107 | # Get already applied migrations
108 | applied = self._get_applied_migrations()
109 |
110 | # Find all migration files
111 | if not self.migrations_dir.exists():
112 | logger.debug("No migrations directory found")
113 | return
114 |
115 | migration_files = sorted(self.migrations_dir.glob("*.sql"))
116 |
117 | # Apply pending migrations
118 | pending_count = 0
119 | for migration_file in migration_files:
120 | if migration_file.name not in applied:
121 | self._apply_migration(migration_file)
122 | pending_count += 1
123 |
124 | if pending_count == 0:
125 | logger.debug("No pending migrations to apply")
126 | else:
127 | logger.info(f"Applied {pending_count} migration(s)")
128 |
--------------------------------------------------------------------------------
/src/lib/database/schema.py:
--------------------------------------------------------------------------------
1 | """Schema definitions for Doctor project database tables."""
2 | # Table creation SQLs and schema helpers for Doctor DB
3 |
4 | from src.common.config import VECTOR_SIZE
5 |
6 | # Table creation SQL statements
7 | # These constants define the SQL for creating the main tables in the Doctor database.
8 | CREATE_JOBS_TABLE_SQL = """
9 | CREATE TABLE IF NOT EXISTS jobs (
10 | job_id VARCHAR PRIMARY KEY,
11 | start_url VARCHAR,
12 | status VARCHAR,
13 | pages_discovered INTEGER DEFAULT 0,
14 | pages_crawled INTEGER DEFAULT 0,
15 | max_pages INTEGER,
16 | tags VARCHAR, -- JSON string array
17 | created_at TIMESTAMP,
18 | updated_at TIMESTAMP,
19 | error_message VARCHAR
20 | )
21 | """
22 |
23 | CREATE_PAGES_TABLE_SQL = """
24 | CREATE TABLE IF NOT EXISTS pages (
25 | id VARCHAR PRIMARY KEY,
26 | url VARCHAR,
27 | domain VARCHAR,
28 | raw_text TEXT,
29 | crawl_date TIMESTAMP,
30 | tags VARCHAR, -- JSON string array
31 | job_id VARCHAR, -- Reference to the job that crawled this page
32 | parent_page_id VARCHAR, -- Reference to parent page for hierarchy
33 | root_page_id VARCHAR, -- Reference to root page of the site
34 | depth INTEGER DEFAULT 0, -- Distance from root page
35 | path TEXT, -- Relative path from root page
36 | title TEXT -- Extracted page title
37 | )
38 | """
39 |
40 | CREATE_DOCUMENT_EMBEDDINGS_TABLE_SQL = f"""
41 | CREATE TABLE IF NOT EXISTS document_embeddings (
42 | id VARCHAR PRIMARY KEY,
43 | embedding FLOAT[{VECTOR_SIZE}] NOT NULL,
44 | text_chunk VARCHAR,
45 | page_id VARCHAR,
46 | url VARCHAR,
47 | domain VARCHAR,
48 | tags VARCHAR[],
49 | job_id VARCHAR
50 | );
51 | """
52 |
53 | # Extension management SQL
54 | # These constants are used to manage DuckDB extensions.
55 | CHECK_EXTENSION_LOADED_SQL = "SELECT * FROM duckdb_loaded_extensions() WHERE name = '{0}';"
56 | INSTALL_EXTENSION_SQL = "INSTALL {0};"
57 | LOAD_EXTENSION_SQL = "LOAD {0};"
58 |
59 | # FTS (Full-Text Search) related SQL
60 | # These constants are used for managing FTS indexes and tables.
61 | CREATE_FTS_INDEX_SQL = "PRAGMA create_fts_index('pages', 'id', 'raw_text', overwrite=1);"
62 | CHECK_FTS_INDEXES_SQL = (
63 | "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'fts_idx_%'"
64 | )
65 | CHECK_FTS_MAIN_PAGES_TABLE_SQL = (
66 | "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'fts_main_pages'"
67 | )
68 | DROP_FTS_MAIN_PAGES_TABLE_SQL = "DROP TABLE IF EXISTS fts_main_pages;"
69 |
70 | # HNSW (Vector Search) related SQL
71 | # These constants are used for managing HNSW vector search indexes.
72 | SET_HNSW_PERSISTENCE_SQL = "SET hnsw_enable_experimental_persistence = true;"
73 | CHECK_HNSW_INDEX_SQL = (
74 | "SELECT count(*) FROM duckdb_indexes() WHERE index_name = 'hnsw_index_on_embeddings'"
75 | )
76 | CREATE_HNSW_INDEX_SQL = """
77 | CREATE INDEX hnsw_index_on_embeddings
78 | ON document_embeddings
79 | USING HNSW (embedding)
80 | WITH (metric = 'cosine');
81 | """
82 |
83 | # VSS (Vector Similarity Search) verification SQL
84 | # These constants are used to verify VSS extension functionality.
85 | VSS_ARRAY_TO_STRING_TEST_SQL = "SELECT array_to_string([0.1, 0.2]::FLOAT[], ', ');"
86 | VSS_COSINE_SIMILARITY_TEST_SQL = "SELECT list_cosine_similarity([0.1,0.2],[0.2,0.3]);"
87 |
88 | # Table existence check SQL
89 | # Used to check if a table exists in the database.
90 | CHECK_TABLE_EXISTS_SQL = "SELECT count(*) FROM information_schema.tables WHERE table_name = '{0}'"
91 |
92 | # Transaction management SQL
93 | # Used for managing transactions and checkpoints.
94 | BEGIN_TRANSACTION_SQL = "BEGIN TRANSACTION"
95 | CHECKPOINT_SQL = "CHECKPOINT"
96 | TEST_CONNECTION_SQL = "SELECT 1"
97 |
98 | # DML (Data Manipulation Language) SQL
99 | # Used for inserting and updating data in tables.
100 | INSERT_PAGE_SQL = """
101 | INSERT INTO pages (id, url, domain, raw_text, crawl_date, tags, job_id,
102 | parent_page_id, root_page_id, depth, path, title)
103 | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
104 | """
105 |
106 | # Base for dynamic UPDATE job query
107 | UPDATE_JOB_STATUS_BASE_SQL = "UPDATE jobs SET status = ?, updated_at = ?"
108 |
--------------------------------------------------------------------------------
/src/lib/crawler.py:
--------------------------------------------------------------------------------
1 | """Web crawling functionality using crawl4ai."""
2 |
3 | from typing import Any
4 |
5 | from crawl4ai import AsyncWebCrawler, CrawlerRunConfig
6 | from crawl4ai.content_filter_strategy import PruningContentFilter
7 | from crawl4ai.deep_crawling import BFSDeepCrawlStrategy
8 | from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
9 |
10 | from src.common.logger import get_logger
11 |
12 | # Configure logging
13 | logger = get_logger(__name__)
14 |
15 |
16 | async def crawl_url(
17 | url: str,
18 | max_pages: int = 100,
19 | max_depth: int = 2,
20 | strip_urls: bool = True,
21 | ) -> list[Any]:
22 | """Crawl a URL and return the results.
23 |
24 | Args:
25 | url: The URL to start crawling from
26 | max_pages: Maximum number of pages to crawl
27 | max_depth: Maximum depth for the BFS crawl
28 | strip_urls: Whether to strip URLs from the returned markdown
29 |
30 | Returns:
31 | List of crawled page results
32 |
33 | """
34 | logger.info(f"Starting crawl for URL: {url} with max_pages={max_pages}")
35 |
36 | # Create content filter to remove navigation elements and other non-essential content
37 | content_filter = PruningContentFilter(threshold=0.6, threshold_type="fixed")
38 |
39 | # Configure markdown generator to ignore links and navigation elements
40 | markdown_generator = DefaultMarkdownGenerator(
41 | content_filter=content_filter,
42 | options={
43 | "ignore_links": strip_urls,
44 | "body_width": 0,
45 | "ignore_images": True,
46 | "single_line_break": True,
47 | },
48 | )
49 |
50 | config = CrawlerRunConfig(
51 | deep_crawl_strategy=BFSDeepCrawlStrategy(
52 | max_depth=max_depth,
53 | max_pages=max_pages,
54 | logger=get_logger("crawl4ai"),
55 | include_external=False,
56 | ),
57 | markdown_generator=markdown_generator,
58 | excluded_tags=["nav", "footer", "aside", "header"],
59 | remove_overlay_elements=True,
60 | verbose=True,
61 | )
62 |
63 | # Initialize the crawler
64 | async with AsyncWebCrawler() as crawler:
65 | crawl_results = await crawler.arun(url=url, config=config)
66 | logger.info(f"Deep crawl discovered {len(crawl_results)} pages")
67 | return crawl_results
68 |
69 |
70 | def extract_page_text(page_result: Any) -> str:
71 | """Extract the text content from a crawl4ai page result.
72 |
73 | Args:
74 | page_result: The crawl result for the page
75 |
76 | Returns:
77 | The extracted text content
78 |
79 | """
80 | # Use filtered markdown if available, otherwise use raw markdown,
81 | # extracted content, or HTML as fallbacks
82 | if hasattr(page_result, "markdown") and page_result.markdown:
83 | if hasattr(page_result.markdown, "fit_markdown") and page_result.markdown.fit_markdown:
84 | page_text = page_result.markdown.fit_markdown
85 | logger.debug(f"Using fit markdown text of length {len(page_text)}")
86 | elif hasattr(page_result.markdown, "raw_markdown"):
87 | page_text = page_result.markdown.raw_markdown
88 | logger.debug(f"Using raw markdown text of length {len(page_text)}")
89 | else:
90 | # Handle string-like markdown (backward compatibility)
91 | page_text = str(page_result.markdown)
92 | logger.debug(f"Using string markdown text of length {len(page_text)}")
93 | elif hasattr(page_result, "_markdown") and page_result._markdown:
94 | if hasattr(page_result._markdown, "fit_markdown") and page_result._markdown.fit_markdown:
95 | page_text = page_result._markdown.fit_markdown
96 | logger.debug(f"Using fit markdown text from _markdown of length {len(page_text)}")
97 | else:
98 | page_text = page_result._markdown.raw_markdown
99 | logger.debug(f"Using raw markdown text from _markdown of length {len(page_text)}")
100 | elif page_result.extracted_content:
101 | page_text = page_result.extracted_content
102 | logger.debug(f"Using extracted content of length {len(page_text)}")
103 | else:
104 | page_text = page_result.html
105 | logger.debug(f"Using HTML content of length {len(page_text)}")
106 |
107 | return page_text
108 |
--------------------------------------------------------------------------------
/docs/maps_api_examples.md:
--------------------------------------------------------------------------------
1 | # Maps Feature API Examples
2 |
3 | This document provides examples of using the Maps feature API endpoints.
4 |
5 | ## Overview
6 |
7 | The Maps feature tracks page hierarchies during crawling, allowing you to navigate crawled sites as they were originally structured. This makes it easy for LLMs to understand site organization and find related content.
8 |
9 | ## API Endpoints
10 |
11 | ### 1. Get All Sites - `/map`
12 |
13 | Lists all crawled root pages (sites).
14 |
15 | **Request:**
16 | ```http
17 | GET /map
18 | ```
19 |
20 | **Response:**
21 | HTML page listing all sites with links to their tree views.
22 |
23 | **Example Usage:**
24 | ```bash
25 | curl http://localhost:9111/map
26 | ```
27 |
28 | ### 2. View Site Tree - `/map/site/{root_page_id}`
29 |
30 | Shows the hierarchical structure of a specific site.
31 |
32 | **Request:**
33 | ```http
34 | GET /map/site/550e8400-e29b-41d4-a716-446655440000
35 | ```
36 |
37 | **Response:**
38 | HTML page with collapsible tree view of all pages in the site.
39 |
40 | **Example Usage:**
41 | ```bash
42 | curl http://localhost:9111/map/site/550e8400-e29b-41d4-a716-446655440000
43 | ```
44 |
45 | ### 3. View Page - `/map/page/{page_id}`
46 |
47 | Displays a specific page with navigation links.
48 |
49 | **Request:**
50 | ```http
51 | GET /map/page/660e8400-e29b-41d4-a716-446655440001
52 | ```
53 |
54 | **Response:**
55 | HTML page with:
56 | - Breadcrumb navigation
57 | - Rendered markdown content
58 | - Links to parent, siblings, and child pages
59 | - Link back to site map
60 |
61 | **Example Usage:**
62 | ```bash
63 | curl http://localhost:9111/map/page/660e8400-e29b-41d4-a716-446655440001
64 | ```
65 |
66 | ### 4. Get Raw Content - `/map/page/{page_id}/raw`
67 |
68 | Returns the raw markdown content of a page.
69 |
70 | **Request:**
71 | ```http
72 | GET /map/page/660e8400-e29b-41d4-a716-446655440001/raw
73 | ```
74 |
75 | **Response:**
76 | ```markdown
77 | # Page Title
78 |
79 | Raw markdown content of the page...
80 | ```
81 |
82 | **Example Usage:**
83 | ```bash
84 | curl http://localhost:9111/map/page/660e8400-e29b-41d4-a716-446655440001/raw > page.md
85 | ```
86 |
87 | ## Workflow Example
88 |
89 | ### 1. Crawl a Website
90 |
91 | First, crawl a website to populate the hierarchy:
92 |
93 | ```bash
94 | curl -X POST http://localhost:9111/fetch_url \
95 | -H "Content-Type: application/json" \
96 | -d '{
97 | "url": "https://docs.example.com",
98 | "max_pages": 100,
99 | "tags": ["documentation"]
100 | }'
101 | ```
102 |
103 | ### 2. View All Sites
104 |
105 | Visit the map index to see all crawled sites:
106 |
107 | ```bash
108 | curl http://localhost:9111/map
109 | ```
110 |
111 | ### 3. Explore Site Structure
112 |
113 | Click on a site or use its ID to view the tree structure:
114 |
115 | ```bash
116 | curl http://localhost:9111/map/site/{root_page_id}
117 | ```
118 |
119 | ### 4. Navigate Pages
120 |
121 | Click through pages to explore content with full navigation context.
122 |
123 | ## Database Schema
124 |
125 | The hierarchy is tracked using these fields in the `pages` table:
126 |
127 | - `parent_page_id`: ID of the parent page (NULL for root pages)
128 | - `root_page_id`: ID of the root page for this site
129 | - `depth`: Distance from the root (0 for root pages)
130 | - `path`: Relative path from the root
131 | - `title`: Extracted page title
132 |
133 | ## Integration with LLMs
134 |
135 | The Maps feature is particularly useful for LLMs because:
136 |
137 | 1. **Contextual Understanding**: LLMs can understand how pages relate to each other
138 | 2. **Efficient Navigation**: Direct links to related content without searching
139 | 3. **Site Structure**: Understanding the organization helps with better responses
140 | 4. **No JavaScript**: Clean HTML that's easy for LLMs to parse
141 |
142 | ## Python Client Example
143 |
144 | ```python
145 | import httpx
146 | import asyncio
147 |
148 | async def explore_site_map():
149 | async with httpx.AsyncClient() as client:
150 | # Get all sites
151 | sites_response = await client.get("http://localhost:9111/map")
152 |
153 | # Parse the HTML to extract site IDs (simplified example)
154 | # In practice, use an HTML parser like BeautifulSoup
155 |
156 | # Get a specific site's tree
157 | tree_response = await client.get(
158 | "http://localhost:9111/map/site/550e8400-e29b-41d4-a716-446655440000"
159 | )
160 |
161 | # Get raw content of a page
162 | raw_response = await client.get(
163 | "http://localhost:9111/map/page/660e8400-e29b-41d4-a716-446655440001/raw"
164 | )
165 |
166 | return raw_response.text
167 |
168 | # Run the example
169 | content = asyncio.run(explore_site_map())
170 | print(content)
171 | ```
172 |
173 | ## Notes
174 |
175 | - The Maps feature automatically tracks hierarchy during crawling
176 | - No additional configuration is needed - it works with existing crawl jobs
177 | - Pages maintain their relationships even after crawling is complete
178 | - The UI requires no JavaScript, making it accessible and fast
179 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Common fixtures for tests."""
2 |
3 | import asyncio
4 | import random
5 |
6 | import duckdb
7 | import pytest
8 |
9 | from src.common.config import VECTOR_SIZE
10 |
11 |
12 | @pytest.fixture
13 | def event_loop():
14 | """Create an instance of the default event loop for each test."""
15 | loop = asyncio.get_event_loop_policy().new_event_loop()
16 | yield loop
17 | loop.close()
18 |
19 |
20 | @pytest.fixture
21 | def sample_url():
22 | """Sample URL for testing."""
23 | return "https://example.com"
24 |
25 |
26 | @pytest.fixture
27 | def sample_text():
28 | """Sample text content for testing."""
29 | return """
30 | Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam auctor,
31 | nisl nec ultricies lacinia, nisl nisl aliquet nisl, nec ultricies nisl
32 | nisl nec ultricies lacinia, nisl nisl aliquet nisl, nec ultricies nisl
33 | nisl nec ultricies lacinia, nisl nisl aliquet nisl, nec ultricies nisl.
34 |
35 | Pellentesque habitant morbi tristique senectus et netus et malesuada
36 | fames ac turpis egestas. Sed euismod, nisl nec ultricies lacinia, nisl
37 | nisl aliquet nisl, nec ultricies nisl nisl nec ultricies lacinia.
38 | """
39 |
40 |
41 | @pytest.fixture
42 | def sample_embedding():
43 | """Sample embedding vector for testing (legacy version: 384 dim)."""
44 | random.seed(42) # For reproducibility
45 | return [random.random() for _ in range(384)]
46 |
47 |
48 | @pytest.fixture
49 | def sample_embedding_full_size() -> list[float]:
50 | """Sample embedding vector with the full VECTOR_SIZE dimension for DuckDB tests."""
51 | random.seed(42) # For reproducibility
52 | return [random.random() for _ in range(VECTOR_SIZE)]
53 |
54 |
55 | @pytest.fixture
56 | def sample_crawl_result():
57 | """Sample crawl result for testing."""
58 |
59 | class MockCrawlResult:
60 | def __init__(self, url, markdown=None, extracted_content=None, html=None):
61 | self.url = url
62 | self._markdown = markdown
63 | self.extracted_content = extracted_content
64 | self.html = html
65 |
66 | # Create a mock _markdown attribute if provided
67 | if markdown:
68 |
69 | class MockMarkdown:
70 | def __init__(self, raw_markdown):
71 | self.raw_markdown = raw_markdown
72 |
73 | self._markdown = MockMarkdown(markdown)
74 |
75 | return MockCrawlResult(
76 | url="https://example.com",
77 | markdown="# Example Page\n\nThis is some example content.",
78 | extracted_content="Example Page. This is some example content.",
79 | html="ExampleExample Page
This is some example content.
",
80 | )
81 |
82 |
83 | @pytest.fixture
84 | def job_id():
85 | """Sample job ID for testing."""
86 | return "test-job-123"
87 |
88 |
89 | @pytest.fixture
90 | def page_id():
91 | """Sample page ID for testing."""
92 | return "test-page-456"
93 |
94 |
95 | @pytest.fixture
96 | def sample_tags():
97 | """Sample tags for testing."""
98 | return ["test", "example", "documentation"]
99 |
100 |
101 | @pytest.fixture
102 | def in_memory_duckdb_connection():
103 | """Create an in-memory DuckDB connection for testing.
104 |
105 | This connection has the proper setup for vector search using the same
106 | setup logic as the main application:
107 | - VSS extension loaded
108 | - document_embeddings table created with the proper schema
109 | - HNSW index created
110 | - pages table created (for document service tests)
111 |
112 | Usage:
113 | def test_something(in_memory_duckdb_connection):
114 | # Use the connection for testing
115 | ...
116 | """
117 | from src.lib.database import DatabaseOperations
118 |
119 | # Create in-memory connection
120 | conn = duckdb.connect(":memory:")
121 |
122 | try:
123 | # Use the Database class to set up the connection
124 | db = DatabaseOperations()
125 | db.conn = conn
126 |
127 | # Create tables and set up VSS extension
128 | db.db.ensure_tables()
129 | db.db.ensure_vss_extension()
130 |
131 | yield conn
132 | finally:
133 | conn.close()
134 |
135 |
136 | @pytest.fixture(scope="session", autouse=True)
137 | def ensure_duckdb_database():
138 | """Ensure the DuckDB database file exists before running tests.
139 |
140 | This fixture runs once per test session and initializes the database
141 | if it doesn't exist yet, which is especially important for CI environments.
142 | """
143 | import os
144 |
145 | from src.common.config import DUCKDB_PATH
146 | from src.lib.database import DatabaseOperations
147 |
148 | # Only initialize if the file doesn't exist
149 | if not os.path.exists(DUCKDB_PATH):
150 | print(f"Database file {DUCKDB_PATH} does not exist. Creating it for tests...")
151 | try:
152 | db = DatabaseOperations(read_only=False)
153 | db.db.initialize()
154 | db.db.close()
155 | print(f"Successfully created database at {DUCKDB_PATH}")
156 | except Exception as e:
157 | print(f"Warning: Failed to create database: {e}")
158 | # Don't fail the tests if we can't create the DB - individual tests
159 | # that need it can handle the situation appropriately
160 |
--------------------------------------------------------------------------------
/src/lib/crawler_enhanced.py:
--------------------------------------------------------------------------------
1 | """Enhanced web crawler with hierarchy tracking for the maps feature."""
2 |
3 | from typing import Any, Optional
4 | from urllib.parse import urlparse
5 | import re
6 |
7 | from src.common.logger import get_logger
8 | from src.lib.crawler import crawl_url as base_crawl_url, extract_page_text
9 |
10 | logger = get_logger(__name__)
11 |
12 |
13 | class CrawlResultWithHierarchy:
14 | """Enhanced crawl result that includes hierarchy information."""
15 |
16 | def __init__(self, base_result: Any) -> None:
17 | """Initialize with a base crawl4ai result.
18 |
19 | Args:
20 | base_result: The original crawl4ai result object.
21 |
22 | Returns:
23 | None.
24 | """
25 | self.base_result = base_result
26 | self.url = base_result.url
27 | self.parent_url: Optional[str] = None
28 | self.root_url: Optional[str] = None
29 | self.depth: int = 0
30 | self.relative_path: str = ""
31 | self.title: Optional[str] = None
32 |
33 | # Extract hierarchy info from metadata if available
34 | if hasattr(base_result, "metadata"):
35 | self.parent_url = base_result.metadata.get("parent_url")
36 | self.depth = base_result.metadata.get("depth", 0)
37 |
38 | # Extract title from the page
39 | self._extract_title()
40 |
41 | def _extract_title(self) -> None:
42 | """Extract the page title from the HTML or markdown content.
43 |
44 | Args:
45 | None.
46 |
47 | Returns:
48 | None.
49 | """
50 | try:
51 | # Try to extract from HTML title tag first
52 | if hasattr(self.base_result, "html") and self.base_result.html:
53 | title_match = re.search(
54 | r"]*>(.*?)", self.base_result.html, re.IGNORECASE | re.DOTALL
55 | )
56 | if title_match:
57 | self.title = title_match.group(1).strip()
58 | # Clean up common HTML entities
59 | self.title = (
60 | self.title.replace("&", "&").replace("<", "<").replace(">", ">")
61 | )
62 | return
63 |
64 | # Fallback to first H1 in markdown
65 | text = extract_page_text(self.base_result)
66 | if text:
67 | # Look for first markdown heading
68 | h1_match = re.search(r"^#\s+(.+)$", text, re.MULTILINE)
69 | if h1_match:
70 | self.title = h1_match.group(1).strip()
71 | return
72 |
73 | # Look for first line of non-empty text as last resort
74 | lines = text.strip().split("\n")
75 | for line in lines[:5]: # Check first 5 lines
76 | if line.strip() and len(line.strip()) > 10:
77 | self.title = line.strip()[:100] # Limit to 100 chars
78 | break
79 |
80 | except Exception as e:
81 | logger.warning(f"Error extracting title from {self.url}: {e}")
82 |
83 | # Default to URL path if no title found
84 | if not self.title:
85 | path = urlparse(self.url).path
86 | self.title = path.strip("/").split("/")[-1] or "Home"
87 |
88 |
89 | async def crawl_url_with_hierarchy(
90 | url: str,
91 | max_pages: int = 100,
92 | max_depth: int = 2,
93 | strip_urls: bool = True,
94 | ) -> list[CrawlResultWithHierarchy]:
95 | """Crawl a URL and return results with hierarchy information.
96 |
97 | Args:
98 | url: The URL to start crawling from.
99 | max_pages: Maximum number of pages to crawl.
100 | max_depth: Maximum depth for the BFS crawl.
101 | strip_urls: Whether to strip URLs from the returned markdown.
102 |
103 | Returns:
104 | List of crawl results with hierarchy information.
105 | """
106 | # Get base crawl results
107 | base_results = await base_crawl_url(url, max_pages, max_depth, strip_urls)
108 |
109 | # Enhance results with hierarchy information
110 | enhanced_results = []
111 | root_url = url # The starting URL is the root
112 |
113 | # Build a map of URLs to results for quick lookup
114 | url_to_result = {}
115 | for result in base_results:
116 | enhanced = CrawlResultWithHierarchy(result)
117 | enhanced.root_url = root_url
118 | url_to_result[enhanced.url] = enhanced
119 | enhanced_results.append(enhanced)
120 |
121 | # Calculate relative paths
122 | for enhanced in enhanced_results:
123 | if enhanced.parent_url and enhanced.parent_url in url_to_result:
124 | parent = url_to_result[enhanced.parent_url]
125 | # Calculate relative path based on parent's path
126 | if parent.relative_path == "/":
127 | enhanced.relative_path = enhanced.title
128 | elif parent.relative_path:
129 | enhanced.relative_path = f"{parent.relative_path}/{enhanced.title}"
130 | else:
131 | enhanced.relative_path = enhanced.title
132 | elif enhanced.url == root_url:
133 | enhanced.relative_path = "/"
134 | else:
135 | # Fallback to URL-based path
136 | enhanced.relative_path = urlparse(enhanced.url).path
137 |
138 | logger.info(f"Enhanced {len(enhanced_results)} pages with hierarchy information")
139 | return enhanced_results
140 |
--------------------------------------------------------------------------------
/tests/services/test_admin_service.py:
--------------------------------------------------------------------------------
1 | """Tests for the admin service."""
2 |
3 | from unittest.mock import MagicMock, patch
4 |
5 | import pytest
6 | from rq import Queue
7 |
8 | from src.web_service.services.admin_service import delete_docs
9 |
10 |
11 | @pytest.fixture
12 | def mock_queue():
13 | """Mock Redis queue."""
14 | mock = MagicMock(spec=Queue)
15 | mock.enqueue = MagicMock()
16 | return mock
17 |
18 |
19 | @pytest.mark.unit
20 | @pytest.mark.async_test
21 | async def test_delete_docs_basic(mock_queue):
22 | """Test basic document deletion without filters."""
23 | # Set a fixed UUID for testing
24 | with patch("uuid.uuid4", return_value="delete-task-123"):
25 | # Call the function without filters
26 | task_id = await delete_docs(queue=mock_queue)
27 |
28 | # Verify queue.enqueue was called with the right parameters
29 | mock_queue.enqueue.assert_called_once()
30 | call_args = mock_queue.enqueue.call_args[0]
31 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs"
32 | assert call_args[1] == "delete-task-123" # task_id
33 | assert call_args[2] is None # tags
34 | assert call_args[3] is None # domain
35 | assert call_args[4] is None # page_ids
36 |
37 | # Check that the function returned the expected task ID
38 | assert task_id == "delete-task-123"
39 |
40 |
41 | @pytest.mark.unit
42 | @pytest.mark.async_test
43 | async def test_delete_docs_with_tags(mock_queue):
44 | """Test document deletion with tag filter."""
45 | # Set a fixed UUID for testing
46 | with patch("uuid.uuid4", return_value="delete-task-456"):
47 | # Call the function with tags filter
48 | task_id = await delete_docs(
49 | queue=mock_queue,
50 | tags=["test", "example"],
51 | )
52 |
53 | # Verify queue.enqueue was called with the right parameters
54 | mock_queue.enqueue.assert_called_once()
55 | call_args = mock_queue.enqueue.call_args[0]
56 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs"
57 | assert call_args[1] == "delete-task-456" # task_id
58 | assert call_args[2] == ["test", "example"] # tags
59 | assert call_args[3] is None # domain
60 | assert call_args[4] is None # page_ids
61 |
62 | # Check that the function returned the expected task ID
63 | assert task_id == "delete-task-456"
64 |
65 |
66 | @pytest.mark.unit
67 | @pytest.mark.async_test
68 | async def test_delete_docs_with_domain(mock_queue):
69 | """Test document deletion with domain filter."""
70 | # Set a fixed UUID for testing
71 | with patch("uuid.uuid4", return_value="delete-task-789"):
72 | # Call the function with domain filter
73 | task_id = await delete_docs(
74 | queue=mock_queue,
75 | domain="example.com",
76 | )
77 |
78 | # Verify queue.enqueue was called with the right parameters
79 | mock_queue.enqueue.assert_called_once()
80 | call_args = mock_queue.enqueue.call_args[0]
81 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs"
82 | assert call_args[1] == "delete-task-789" # task_id
83 | assert call_args[2] is None # tags
84 | assert call_args[3] == "example.com" # domain
85 | assert call_args[4] is None # page_ids
86 |
87 | # Check that the function returned the expected task ID
88 | assert task_id == "delete-task-789"
89 |
90 |
91 | @pytest.mark.unit
92 | @pytest.mark.async_test
93 | async def test_delete_docs_with_page_ids(mock_queue):
94 | """Test document deletion with specific page IDs."""
95 | # Set a fixed UUID for testing
96 | with patch("uuid.uuid4", return_value="delete-task-abc"):
97 | # Call the function with page_ids filter
98 | page_ids = ["page-1", "page-2", "page-3"]
99 | task_id = await delete_docs(
100 | queue=mock_queue,
101 | page_ids=page_ids,
102 | )
103 |
104 | # Verify queue.enqueue was called with the right parameters
105 | mock_queue.enqueue.assert_called_once()
106 | call_args = mock_queue.enqueue.call_args[0]
107 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs"
108 | assert call_args[1] == "delete-task-abc" # task_id
109 | assert call_args[2] is None # tags
110 | assert call_args[3] is None # domain
111 | assert call_args[4] == page_ids # page_ids
112 |
113 | # Check that the function returned the expected task ID
114 | assert task_id == "delete-task-abc"
115 |
116 |
117 | @pytest.mark.unit
118 | @pytest.mark.async_test
119 | async def test_delete_docs_with_all_filters(mock_queue):
120 | """Test document deletion with all filters."""
121 | # Set a fixed UUID for testing
122 | with patch("uuid.uuid4", return_value="delete-task-all"):
123 | # Call the function with all filters
124 | page_ids = ["page-1", "page-2"]
125 | task_id = await delete_docs(
126 | queue=mock_queue,
127 | tags=["test"],
128 | domain="example.com",
129 | page_ids=page_ids,
130 | )
131 |
132 | # Verify queue.enqueue was called with the right parameters
133 | mock_queue.enqueue.assert_called_once()
134 | call_args = mock_queue.enqueue.call_args[0]
135 | assert call_args[0] == "src.crawl_worker.tasks.delete_docs"
136 | assert call_args[1] == "delete-task-all" # task_id
137 | assert call_args[2] == ["test"] # tags
138 | assert call_args[3] == "example.com" # domain
139 | assert call_args[4] == page_ids # page_ids
140 |
141 | # Check that the function returned the expected task ID
142 | assert task_id == "delete-task-all"
143 |
--------------------------------------------------------------------------------
/src/web_service/services/job_service.py:
--------------------------------------------------------------------------------
1 | """Job service for the web service."""
2 |
3 | import uuid
4 |
5 | import duckdb
6 | from rq import Queue
7 |
8 | from src.common.logger import get_logger
9 | from src.common.models import (
10 | FetchUrlResponse,
11 | JobProgressResponse,
12 | )
13 |
14 | # Get logger for this module
15 | logger = get_logger(__name__)
16 |
17 |
18 | async def fetch_url(
19 | queue: Queue,
20 | url: str,
21 | tags: list[str] | None = None,
22 | max_pages: int = 100,
23 | ) -> FetchUrlResponse:
24 | """Initiate a fetch job to crawl a website.
25 |
26 | Args:
27 | queue: Redis queue for job processing
28 | url: The URL to start indexing from
29 | tags: Optional tags to assign this website
30 | max_pages: How many pages to index
31 |
32 | Returns:
33 | FetchUrlResponse: The job ID
34 |
35 | """
36 | # Generate a temporary job ID
37 | job_id = str(uuid.uuid4())
38 |
39 | # Enqueue the job creation task
40 | # The crawl worker will handle both creating the job record and enqueueing the crawl task
41 | queue.enqueue(
42 | "src.crawl_worker.tasks.create_job",
43 | url,
44 | job_id,
45 | tags=tags,
46 | max_pages=max_pages,
47 | )
48 |
49 | logger.info(f"Enqueued job creation for URL: {url}, job_id: {job_id}")
50 |
51 | return FetchUrlResponse(job_id=job_id)
52 |
53 |
54 | async def get_job_progress(
55 | conn: duckdb.DuckDBPyConnection,
56 | job_id: str,
57 | ) -> JobProgressResponse | None:
58 | """Check the progress of a job.
59 |
60 | Args:
61 | conn: Connected DuckDB connection
62 | job_id: The job ID to check progress for
63 |
64 | Returns:
65 | JobProgressResponse: Job progress information or None if job not found
66 |
67 | """
68 | logger.info(f"Checking progress for job {job_id}")
69 |
70 | # First try exact job_id match
71 | logger.info(f"Looking for exact match for job ID: {job_id}")
72 | result = conn.execute(
73 | """
74 | SELECT job_id, start_url, status, pages_discovered, pages_crawled,
75 | max_pages, tags, created_at, updated_at, error_message
76 | FROM jobs
77 | WHERE job_id = ?
78 | """,
79 | (job_id,),
80 | ).fetchone()
81 |
82 | # If not found, try partial match (useful if client received truncated UUID)
83 | if not result and len(job_id) >= 8:
84 | logger.info(f"Exact match not found, trying partial match for job ID: {job_id}")
85 | result = conn.execute(
86 | """
87 | SELECT job_id, start_url, status, pages_discovered, pages_crawled,
88 | max_pages, tags, created_at, updated_at, error_message
89 | FROM jobs
90 | WHERE job_id LIKE ?
91 | ORDER BY created_at DESC
92 | LIMIT 1
93 | """,
94 | (f"{job_id}%",),
95 | ).fetchone()
96 |
97 | if not result:
98 | logger.warning(f"Job {job_id} not found")
99 | return None
100 |
101 | # --- Job found in database ---
102 | (
103 | job_id,
104 | url,
105 | status,
106 | pages_discovered,
107 | pages_crawled,
108 | max_pages,
109 | tags_json,
110 | created_at,
111 | updated_at,
112 | error_message,
113 | ) = result
114 |
115 | logger.info(
116 | f"Found job with ID: {job_id}, status: {status}, discovered: {pages_discovered}, crawled: {pages_crawled}, updated: {updated_at}",
117 | )
118 |
119 | # Determine if job is completed
120 | completed = status in ["completed", "failed"]
121 |
122 | # Calculate progress percentage
123 | progress_percent = 0
124 | if max_pages > 0 and pages_discovered > 0:
125 | progress_percent = min(100, int((pages_crawled / min(pages_discovered, max_pages)) * 100))
126 |
127 | logger.info(
128 | f"Job {job_id} progress: {pages_crawled}/{pages_discovered} pages, {progress_percent}% complete, status: {status}",
129 | )
130 |
131 | return JobProgressResponse(
132 | pages_crawled=pages_crawled,
133 | pages_total=pages_discovered,
134 | completed=completed,
135 | status=status,
136 | error_message=error_message,
137 | progress_percent=progress_percent,
138 | url=url,
139 | max_pages=max_pages,
140 | created_at=created_at,
141 | updated_at=updated_at,
142 | )
143 |
144 |
145 | async def get_job_count() -> int:
146 | """Get the total number of jobs in the database.
147 | Used for debugging purposes.
148 |
149 | Returns:
150 | int: Total number of jobs
151 |
152 | """
153 | from src.lib.database import DatabaseOperations
154 |
155 | db_ops = DatabaseOperations()
156 | try:
157 | # Use DuckDBConnectionManager as a context manager
158 | with db_ops.db as conn_manager:
159 | actual_conn = conn_manager.conn
160 | if not actual_conn:
161 | logger.error("Failed to obtain database connection for get_job_count.")
162 | return -1 # Or raise an error
163 | job_count_result = actual_conn.execute("SELECT COUNT(*) FROM jobs").fetchone()
164 | if job_count_result:
165 | job_count = job_count_result[0]
166 | logger.info(f"Database contains {job_count} total jobs.")
167 | return job_count
168 | logger.warning("Could not retrieve job count from database.")
169 | return -1
170 | except Exception as count_error:
171 | logger.warning(f"Failed to count jobs in database: {count_error!s}")
172 | return -1
173 | # No finally block needed to close connection, context manager handles it.
174 |
--------------------------------------------------------------------------------
/src/common/processor.py:
--------------------------------------------------------------------------------
1 | """Page processing pipeline combining crawling, chunking, embedding, and indexing."""
2 |
3 | import asyncio
4 | from itertools import islice
5 | from typing import Any
6 | from urllib.parse import urlparse
7 |
8 | from src.common.indexer import VectorIndexer
9 | from src.common.logger import get_logger
10 | from src.lib.chunker import TextChunker
11 | from src.lib.crawler import extract_page_text
12 | from src.lib.database import DatabaseOperations
13 | from src.lib.embedder import generate_embedding
14 |
15 | # Configure logging
16 | logger = get_logger(__name__)
17 |
18 |
19 | async def process_crawl_result(
20 | page_result: Any,
21 | job_id: str,
22 | tags: list[str] | None = None,
23 | max_concurrent_embeddings: int = 5,
24 | ) -> str:
25 | """Process a single crawled page result through the entire pipeline.
26 |
27 | Args:
28 | page_result: The crawl result for the page
29 | job_id: The ID of the crawl job
30 | tags: Optional tags to associate with the page
31 | max_concurrent_embeddings: Maximum number of concurrent embedding generations
32 |
33 | Returns:
34 | The ID of the processed page
35 |
36 | """
37 | if tags is None:
38 | tags = [] # Initialize as empty list instead of None
39 |
40 | try:
41 | logger.info(f"Processing page: {page_result.url}")
42 |
43 | # Extract text from the crawl result
44 | page_text = extract_page_text(page_result)
45 |
46 | # Store the page in the database
47 | db_ops = DatabaseOperations()
48 | page_id = await db_ops.store_page( # store_page now handles its own connection
49 | url=page_result.url,
50 | text=page_text,
51 | job_id=job_id,
52 | tags=tags,
53 | )
54 | # No explicit close needed as store_page uses context manager internally
55 |
56 | # Initialize components
57 | chunker = TextChunker()
58 | indexer = VectorIndexer() # VectorIndexer will now create and manage its own connection.
59 |
60 | # Split text into chunks
61 | chunks = chunker.split_text(page_text)
62 | logger.info(f"Split page into {len(chunks)} chunks")
63 |
64 | # Extract domain from URL
65 | domain = urlparse(page_result.url).netloc
66 |
67 | # Process chunks in parallel batches
68 | successful_chunks = 0
69 |
70 | async def process_chunk(chunk_text):
71 | try:
72 | # Generate embedding
73 | embedding = await generate_embedding(chunk_text)
74 |
75 | # Prepare payload
76 | payload = {
77 | "text": chunk_text,
78 | "page_id": page_id,
79 | "url": page_result.url,
80 | "domain": domain,
81 | "tags": tags,
82 | "job_id": job_id,
83 | }
84 |
85 | # Index the vector
86 | await indexer.index_vector(embedding, payload)
87 | return True
88 | except Exception as chunk_error:
89 | logger.error(f"Error processing chunk: {chunk_error!s}")
90 | return False
91 |
92 | # Process chunks in batches with limited concurrency
93 | i = 0
94 | while i < len(chunks):
95 | # Take up to max_concurrent_embeddings chunks
96 | batch_chunks = list(islice(chunks, i, i + max_concurrent_embeddings))
97 | i += max_concurrent_embeddings
98 |
99 | # Process this batch in parallel
100 | results = await asyncio.gather(
101 | *[process_chunk(chunk) for chunk in batch_chunks],
102 | return_exceptions=False,
103 | )
104 |
105 | successful_chunks += sum(1 for result in results if result)
106 |
107 | logger.info(
108 | f"Successfully indexed {successful_chunks}/{len(chunks)} chunks for page {page_id}",
109 | )
110 |
111 | return page_id
112 |
113 | except Exception as e:
114 | logger.error(f"Error processing page {page_result.url}: {e!s}")
115 | raise
116 | finally:
117 | # VectorIndexer now manages its own connection if instantiated without one.
118 | # Its __del__ method will attempt to close its connection if self._own_connection is True.
119 | # No explicit action needed here for the indexer's connection.
120 | pass
121 |
122 |
123 | async def process_page_batch(
124 | page_results: list[Any],
125 | job_id: str,
126 | tags: list[str] | None = None,
127 | batch_size: int = 10,
128 | ) -> list[str]:
129 | """Process a batch of crawled pages.
130 |
131 | Args:
132 | page_results: List of crawl results
133 | job_id: The ID of the crawl job
134 | tags: Optional tags to associate with the pages
135 | batch_size: Size of batches for processing
136 |
137 | Returns:
138 | List of processed page IDs
139 |
140 | """
141 | if tags is None:
142 | tags = []
143 |
144 | processed_page_ids = []
145 |
146 | # Process pages in smaller batches to avoid memory issues
147 | for i in range(0, len(page_results), batch_size):
148 | batch = page_results[i : i + batch_size]
149 | logger.info(f"Processing batch of {len(batch)} pages (batch {i // batch_size + 1})")
150 |
151 | for page_result in batch:
152 | try:
153 | page_id = await process_crawl_result(page_result, job_id, tags)
154 | processed_page_ids.append(page_id)
155 |
156 | # Update job progress
157 | db_ops_status = DatabaseOperations()
158 | await (
159 | db_ops_status.update_job_status( # update_job_status handles its own connection
160 | job_id=job_id,
161 | status="running",
162 | pages_discovered=len(page_results),
163 | pages_crawled=len(processed_page_ids),
164 | )
165 | )
166 |
167 | except Exception as page_error:
168 | logger.error(f"Error in batch processing for {page_result.url}: {page_error!s}")
169 | continue
170 |
171 | return processed_page_ids
172 |
--------------------------------------------------------------------------------
/src/web_service/api/jobs.py:
--------------------------------------------------------------------------------
1 | """Job API routes for the web service."""
2 |
3 | import asyncio
4 |
5 | import redis
6 | from fastapi import APIRouter, Depends, HTTPException, Query
7 | from rq import Queue
8 |
9 | from src.common.config import REDIS_URI
10 | from src.common.logger import get_logger
11 | from src.common.models import (
12 | FetchUrlRequest,
13 | FetchUrlResponse,
14 | JobProgressResponse,
15 | )
16 | from src.lib.database import DatabaseOperations
17 | from src.web_service.services.job_service import (
18 | fetch_url,
19 | get_job_count,
20 | get_job_progress,
21 | )
22 |
23 | # Get logger for this module
24 | logger = get_logger(__name__)
25 |
26 | # Create router
27 | router = APIRouter(tags=["jobs"])
28 |
29 |
30 | @router.post("/fetch_url", response_model=FetchUrlResponse, operation_id="fetch_url")
31 | async def fetch_url_endpoint(
32 | request: FetchUrlRequest,
33 | queue: Queue = Depends(lambda: Queue("worker", connection=redis.from_url(REDIS_URI))),
34 | ):
35 | """Initiate a fetch job to crawl a website.
36 |
37 | Args:
38 | request: The fetch URL request
39 |
40 | Returns:
41 | The job ID
42 |
43 | """
44 | logger.info(f"API: Initiating fetch for URL: {request.url}")
45 |
46 | try:
47 | # Call the service function
48 | response = await fetch_url(
49 | queue=queue,
50 | url=request.url,
51 | tags=request.tags,
52 | max_pages=request.max_pages,
53 | )
54 | return response
55 | except Exception as e:
56 | logger.error(f"Error initiating fetch: {e!s}")
57 | raise HTTPException(status_code=500, detail=f"Error initiating fetch: {e!s}")
58 |
59 |
60 | @router.get("/job_progress", response_model=JobProgressResponse, operation_id="job_progress")
61 | async def job_progress_endpoint(
62 | job_id: str = Query(..., description="The job ID to check progress for"),
63 | ):
64 | """Check the progress of a job.
65 |
66 | Args:
67 | job_id: The job ID to check progress for
68 |
69 | Returns:
70 | Job progress information
71 |
72 | """
73 | logger.info(f"API: BEGIN job_progress for job {job_id}")
74 |
75 | # Create a fresh connection to get the latest data
76 | # This ensures we always get the most recent job state
77 | attempts = 0
78 | max_attempts = 3
79 | retry_delay = 0.1 # seconds
80 |
81 | logger.info(f"Starting check loop for job {job_id} (max_attempts={max_attempts})")
82 | while attempts < max_attempts:
83 | attempts += 1
84 | db_ops = DatabaseOperations() # Instantiate inside the loop
85 |
86 | try:
87 | with db_ops.db as conn_manager: # Use DuckDBConnectionManager as context manager
88 | actual_conn = conn_manager.conn
89 | if not actual_conn:
90 | # This case should ideally be handled by DuckDBConnectionManager.connect() raising an error
91 | logger.error(
92 | f"Attempt {attempts}: Failed to obtain DB connection from manager for job {job_id}."
93 | )
94 | if attempts >= max_attempts:
95 | raise HTTPException(
96 | status_code=500, detail="Database connection error after retries."
97 | )
98 | await asyncio.sleep(retry_delay)
99 | retry_delay *= 2
100 | continue # To next attempt
101 |
102 | logger.info(
103 | f"Established fresh connection to database (attempt {attempts}) for job {job_id}"
104 | )
105 |
106 | # Call the service function
107 | result = await get_job_progress(actual_conn, job_id)
108 |
109 | if not result:
110 | # Job not found on this attempt
111 | logger.warning(
112 | f"Job {job_id} not found (attempt {attempts}/{max_attempts}). Retrying in {retry_delay}s...",
113 | )
114 | # Connection is closed by conn_manager context exit
115 | if attempts >= max_attempts:
116 | logger.warning(f"Job {job_id} not found after {max_attempts} attempts.")
117 | break # Exit the while loop to raise 404
118 | await asyncio.sleep(retry_delay)
119 | retry_delay *= 2
120 | continue # Go to next iteration of the while loop
121 |
122 | # Job found, return the result
123 | return result
124 |
125 | except HTTPException:
126 | # Re-raise HTTP exceptions as-is
127 | logger.warning(
128 | f"HTTPException occurred during attempt {attempts} for job {job_id}, re-raising.",
129 | )
130 | raise
131 | except Exception as e: # Includes duckdb.Error if connect fails in conn_manager
132 | logger.error(
133 | f"Error checking job progress (attempt {attempts}) for job {job_id}: {e!s}",
134 | )
135 | if attempts < max_attempts:
136 | logger.info(f"Non-fatal error on attempt {attempts}, will retry after delay.")
137 | # Connection (if opened by conn_manager) is closed on context exit
138 | await asyncio.sleep(retry_delay)
139 | retry_delay *= 2
140 | continue
141 | logger.error(f"Error on final attempt ({attempts}) for job {job_id}. Raising 500.")
142 | raise HTTPException(
143 | status_code=500,
144 | detail=f"Database error after retries: {e!s}",
145 | )
146 | # The 'finally' block for closing 'conn' is removed as conn_manager handles it.
147 |
148 | # If the loop finished without finding the job (i.e., break was hit after max attempts)
149 | # raise the 404 error.
150 | # Check job count for debugging before raising 404
151 | job_count = await get_job_count()
152 | logger.info(f"Database contains {job_count} total jobs during final 404 check.")
153 |
154 | logger.warning(f"Raising 404 for job {job_id} after all retries.") # Add log before raising
155 | raise HTTPException(status_code=404, detail=f"Job {job_id} not found after retries")
156 |
--------------------------------------------------------------------------------
/src/common/models.py:
--------------------------------------------------------------------------------
1 | """Pydantic models for the Doctor project."""
2 |
3 | from datetime import datetime
4 |
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class FetchUrlRequest(BaseModel):
9 | """Request model for the /fetch_url endpoint."""
10 |
11 | url: str = Field(..., description="The URL to start indexing from")
12 | tags: list[str] | None = Field(default=None, description="Tags to assign this website")
13 | max_pages: int = Field(default=100, description="How many pages to index", ge=1, le=1000)
14 |
15 |
16 | class FetchUrlResponse(BaseModel):
17 | """Response model for the /fetch_url endpoint."""
18 |
19 | job_id: str = Field(..., description="The job ID of the index job")
20 |
21 |
22 | class SearchDocsRequest(BaseModel):
23 | """Request model for the /search_docs endpoint."""
24 |
25 | query: str = Field(..., description="The search string to query the database with")
26 | tags: list[str] | None = Field(default=None, description="Tags to limit the search with")
27 | max_results: int = Field(
28 | default=10,
29 | description="Maximum number of results to return",
30 | ge=1,
31 | le=100,
32 | )
33 |
34 |
35 | class SearchResult(BaseModel):
36 | """A single search result."""
37 |
38 | chunk_text: str = Field(..., description="The text of the chunk")
39 | page_id: str = Field(..., description="Reference to the original page")
40 | tags: list[str] = Field(default_factory=list, description="Tags associated with the chunk")
41 | score: float = Field(..., description="Similarity score")
42 | url: str = Field(..., description="Original URL of the page")
43 |
44 |
45 | class SearchDocsResponse(BaseModel):
46 | """Response model for the /search_docs endpoint."""
47 |
48 | results: list[SearchResult] = Field(default_factory=list, description="Search results")
49 |
50 |
51 | class JobProgressRequest(BaseModel):
52 | """Request model for the /job_progress endpoint."""
53 |
54 | job_id: str = Field(..., description="The job ID to check progress for")
55 |
56 |
57 | class JobProgressResponse(BaseModel):
58 | """Response model for the /job_progress endpoint."""
59 |
60 | pages_crawled: int = Field(..., description="Number of pages crawled so far")
61 | pages_total: int = Field(..., description="Total number of pages discovered")
62 | completed: bool = Field(..., description="Whether the job is completed")
63 | status: str = Field(..., description="Current job status")
64 | error_message: str | None = Field(default=None, description="Error message if job failed")
65 | progress_percent: int | None = Field(
66 | default=None,
67 | description="Percentage of crawl completed",
68 | )
69 | url: str | None = Field(default=None, description="URL being crawled")
70 | max_pages: int | None = Field(default=None, description="Maximum pages to crawl")
71 | created_at: datetime | None = Field(default=None, description="When the job was created")
72 | updated_at: datetime | None = Field(
73 | default=None,
74 | description="When the job was last updated",
75 | )
76 |
77 |
78 | class ListDocPagesRequest(BaseModel):
79 | """Request model for the /list_doc_pages endpoint."""
80 |
81 | page: int = Field(default=1, description="Page number", ge=1)
82 | tags: list[str] | None = Field(default=None, description="Tags to filter by")
83 |
84 |
85 | class DocPageSummary(BaseModel):
86 | """Summary information about a document page."""
87 |
88 | page_id: str = Field(..., description="Unique page ID")
89 | domain: str = Field(..., description="Domain of the page")
90 | tags: list[str] = Field(default_factory=list, description="Tags associated with the page")
91 | crawl_date: datetime = Field(..., description="When the page was crawled")
92 | url: str = Field(..., description="URL of the page")
93 |
94 |
95 | class ListDocPagesResponse(BaseModel):
96 | """Response model for the /list_doc_pages endpoint."""
97 |
98 | doc_pages: list[DocPageSummary] = Field(
99 | default_factory=list,
100 | description="List of document pages",
101 | )
102 | total_pages: int = Field(..., description="Total number of pages matching the query")
103 | current_page: int = Field(..., description="Current page number")
104 | pages_per_page: int = Field(default=100, description="Number of items per page")
105 |
106 |
107 | class GetDocPageRequest(BaseModel):
108 | """Request model for the /get_doc_page endpoint."""
109 |
110 | page_id: str = Field(..., description="The page ID to retrieve")
111 | starting_line: int = Field(default=1, description="Line to view from", ge=1)
112 | ending_line: int = Field(default=100, description="Line to view up to", ge=1)
113 |
114 |
115 | class GetDocPageResponse(BaseModel):
116 | """Response model for the /get_doc_page endpoint."""
117 |
118 | text: str = Field(..., description="The document page text")
119 | total_lines: int = Field(..., description="Total number of lines in the document")
120 |
121 |
122 | class Job(BaseModel):
123 | """Internal model for a crawl job."""
124 |
125 | job_id: str
126 | start_url: str
127 | status: str = "pending" # pending, running, completed, failed
128 | pages_discovered: int = 0
129 | pages_crawled: int = 0
130 | max_pages: int
131 | tags: list[str] = Field(default_factory=list)
132 | created_at: datetime = Field(default_factory=datetime.now)
133 | updated_at: datetime = Field(default_factory=datetime.now)
134 | error_message: str | None = None
135 |
136 |
137 | class Page(BaseModel):
138 | """Internal model for a crawled page."""
139 |
140 | id: str
141 | url: str
142 | domain: str
143 | raw_text: str
144 | crawl_date: datetime = Field(default_factory=datetime.now)
145 | tags: list[str] = Field(default_factory=list)
146 |
147 |
148 | class Chunk(BaseModel):
149 | """Internal model for a text chunk with its embedding."""
150 |
151 | id: str
152 | text: str
153 | page_id: str
154 | domain: str
155 | tags: list[str] = Field(default_factory=list)
156 | embedding: list[float]
157 |
158 |
159 | class DeleteDocsRequest(BaseModel):
160 | """Request model for the /delete_docs endpoint."""
161 |
162 | tags: list[str] | None = Field(default=None, description="Tags to filter by")
163 | domain: str | None = Field(default=None, description="Domain substring to filter by")
164 | page_ids: list[str] | None = Field(default=None, description="Specific page IDs to delete")
165 |
166 |
167 | class ListTagsResponse(BaseModel):
168 | """Response model for the /list_tags endpoint."""
169 |
170 | tags: list[str] = Field(
171 | default_factory=list,
172 | description="List of all unique tags in the database",
173 | )
174 |
--------------------------------------------------------------------------------
/tests/api/test_map_api.py:
--------------------------------------------------------------------------------
1 | """Tests for the map API endpoints."""
2 |
3 | import pytest
4 | from fastapi.testclient import TestClient
5 | from unittest.mock import AsyncMock, patch
6 |
7 | from src.web_service.main import app
8 |
9 |
10 | @pytest.fixture
11 | def client():
12 | """Create a test client for the FastAPI app.
13 |
14 | Args:
15 | None.
16 |
17 | Returns:
18 | TestClient: Test client instance.
19 | """
20 | return TestClient(app)
21 |
22 |
23 | class TestMapAPI:
24 | """Test the map API endpoints."""
25 |
26 | def test_get_site_index(self, client: TestClient) -> None:
27 | """Test the /map endpoint for site index.
28 |
29 | Args:
30 | client: The test client.
31 |
32 | Returns:
33 | None.
34 | """
35 | # Mock the map service
36 | with patch("src.web_service.api.map.MapService") as MockService:
37 | mock_service = MockService.return_value
38 | mock_service.get_all_sites = AsyncMock(
39 | return_value=[
40 | {"id": "site1", "title": "Site 1", "url": "https://example1.com"},
41 | {"id": "site2", "title": "Site 2", "url": "https://example2.com"},
42 | ]
43 | )
44 | mock_service.format_site_list.return_value = "Site List"
45 |
46 | response = client.get("/map")
47 |
48 | assert response.status_code == 200
49 | assert response.headers["content-type"] == "text/html; charset=utf-8"
50 | assert response.text == "Site List"
51 |
52 | def test_get_site_index_error(self, client: TestClient) -> None:
53 | """Test /map endpoint error handling.
54 |
55 | Args:
56 | client: The test client.
57 |
58 | Returns:
59 | None.
60 | """
61 | with patch("src.web_service.api.map.MapService") as MockService:
62 | mock_service = MockService.return_value
63 | mock_service.get_all_sites = AsyncMock(side_effect=Exception("Database error"))
64 |
65 | response = client.get("/map")
66 |
67 | assert response.status_code == 500
68 | assert "Database error" in response.json()["detail"]
69 |
70 | def test_get_site_tree(self, client: TestClient) -> None:
71 | """Test the /map/site/{root_page_id} endpoint.
72 |
73 | Args:
74 | client: The test client.
75 |
76 | Returns:
77 | None.
78 | """
79 | root_id = "root-123"
80 |
81 | with patch("src.web_service.api.map.MapService") as MockService:
82 | mock_service = MockService.return_value
83 | mock_service.build_page_tree = AsyncMock(
84 | return_value={"id": root_id, "title": "Test Site", "children": []}
85 | )
86 | mock_service.format_site_tree.return_value = "Site Tree"
87 |
88 | response = client.get(f"/map/site/{root_id}")
89 |
90 | assert response.status_code == 200
91 | assert response.headers["content-type"] == "text/html; charset=utf-8"
92 | assert response.text == "Site Tree"
93 |
94 | def test_view_page(self, client: TestClient) -> None:
95 | """Test the /map/page/{page_id} endpoint.
96 |
97 | Args:
98 | client: The test client.
99 |
100 | Returns:
101 | None.
102 | """
103 | page_id = "page-123"
104 |
105 | with patch("src.web_service.api.map.MapService") as MockService:
106 | mock_service = MockService.return_value
107 | mock_service.db_ops.get_page_by_id = AsyncMock(
108 | return_value={
109 | "id": page_id,
110 | "title": "Test Page",
111 | "raw_text": "# Test Content",
112 | }
113 | )
114 | mock_service.get_navigation_context = AsyncMock(
115 | return_value={
116 | "current_page": {"id": page_id, "title": "Test Page"},
117 | "parent": None,
118 | "siblings": [],
119 | "children": [],
120 | "root": None,
121 | }
122 | )
123 | mock_service.render_page_html.return_value = "Rendered Page"
124 |
125 | response = client.get(f"/map/page/{page_id}")
126 |
127 | assert response.status_code == 200
128 | assert response.headers["content-type"] == "text/html; charset=utf-8"
129 | assert response.text == "Rendered Page"
130 |
131 | def test_view_page_not_found(self, client: TestClient) -> None:
132 | """Test viewing a non-existent page.
133 |
134 | Args:
135 | client: The test client.
136 |
137 | Returns:
138 | None.
139 | """
140 | page_id = "nonexistent"
141 |
142 | with patch("src.web_service.api.map.MapService") as MockService:
143 | mock_service = MockService.return_value
144 | mock_service.db_ops.get_page_by_id = AsyncMock(return_value=None)
145 |
146 | response = client.get(f"/map/page/{page_id}")
147 |
148 | assert response.status_code == 404
149 | assert response.json()["detail"] == "Page not found"
150 |
151 | def test_get_page_raw(self, client: TestClient) -> None:
152 | """Test the /map/page/{page_id}/raw endpoint.
153 |
154 | Args:
155 | client: The test client.
156 |
157 | Returns:
158 | None.
159 | """
160 | page_id = "page-123"
161 | raw_content = "# Test Page\n\nThis is markdown content."
162 |
163 | with patch("src.web_service.api.map.MapService") as MockService:
164 | mock_service = MockService.return_value
165 | mock_service.db_ops.get_page_by_id = AsyncMock(
166 | return_value={
167 | "id": page_id,
168 | "title": "Test Page",
169 | "raw_text": raw_content,
170 | }
171 | )
172 |
173 | response = client.get(f"/map/page/{page_id}/raw")
174 |
175 | assert response.status_code == 200
176 | assert response.headers["content-type"] == "text/markdown; charset=utf-8"
177 | assert response.text == raw_content
178 | assert 'filename="Test Page.md"' in response.headers["content-disposition"]
179 |
180 | def test_get_page_raw_not_found(self, client: TestClient) -> None:
181 | """Test getting raw content for non-existent page.
182 |
183 | Args:
184 | client: The test client.
185 |
186 | Returns:
187 | None.
188 | """
189 | page_id = "nonexistent"
190 |
191 | with patch("src.web_service.api.map.MapService") as MockService:
192 | mock_service = MockService.return_value
193 | mock_service.db_ops.get_page_by_id = AsyncMock(return_value=None)
194 |
195 | response = client.get(f"/map/page/{page_id}/raw")
196 |
197 | assert response.status_code == 404
198 | assert response.json()["detail"] == "Page not found"
199 |
--------------------------------------------------------------------------------
/tests/services/test_map_service_legacy.py:
--------------------------------------------------------------------------------
1 | """Tests for the map service legacy page handling."""
2 |
3 | import pytest
4 | from unittest.mock import AsyncMock, patch
5 | import datetime
6 |
7 | from src.web_service.services.map_service import MapService
8 |
9 |
10 | @pytest.mark.asyncio
11 | class TestMapServiceLegacy:
12 | """Test the MapService legacy page handling."""
13 |
14 | async def test_get_all_sites_with_legacy(self) -> None:
15 | """Test getting all sites including legacy domain groups.
16 |
17 | Args:
18 | None.
19 |
20 | Returns:
21 | None.
22 | """
23 | service = MapService()
24 |
25 | # Mock root pages (pages with hierarchy)
26 | mock_root_pages = [
27 | {
28 | "id": "site1",
29 | "url": "https://docs.example.com",
30 | "title": "Documentation Site",
31 | "domain": "docs.example.com",
32 | "crawl_date": datetime.datetime(2024, 1, 2),
33 | }
34 | ]
35 |
36 | # Mock legacy pages
37 | mock_legacy_pages = [
38 | {
39 | "id": "page1",
40 | "url": "https://blog.example.com/post1",
41 | "domain": "blog.example.com",
42 | "title": "Blog Post 1",
43 | "crawl_date": datetime.datetime(2024, 1, 1),
44 | "root_page_id": None,
45 | "parent_page_id": None,
46 | },
47 | {
48 | "id": "page2",
49 | "url": "https://blog.example.com/post2",
50 | "domain": "blog.example.com",
51 | "title": "Blog Post 2",
52 | "crawl_date": datetime.datetime(2024, 1, 3),
53 | "root_page_id": None,
54 | "parent_page_id": None,
55 | },
56 | {
57 | "id": "page3",
58 | "url": "https://shop.example.com/product1",
59 | "domain": "shop.example.com",
60 | "title": "Product 1",
61 | "crawl_date": datetime.datetime(2024, 1, 1),
62 | "root_page_id": None,
63 | "parent_page_id": None,
64 | },
65 | ]
66 |
67 | with patch.object(
68 | service.db_ops, "get_root_pages", new_callable=AsyncMock
69 | ) as mock_get_root:
70 | mock_get_root.return_value = mock_root_pages
71 |
72 | with patch.object(
73 | service.db_ops, "get_legacy_pages", new_callable=AsyncMock
74 | ) as mock_get_legacy:
75 | mock_get_legacy.return_value = mock_legacy_pages
76 |
77 | sites = await service.get_all_sites()
78 |
79 | # Should have 1 root page + 2 domain groups
80 | assert len(sites) == 3
81 |
82 | # Check that we have the regular site
83 | regular_sites = [s for s in sites if not s.get("is_synthetic")]
84 | assert len(regular_sites) == 1
85 | assert regular_sites[0]["title"] == "Documentation Site"
86 |
87 | # Check that we have domain groups
88 | domain_groups = [s for s in sites if s.get("is_synthetic")]
89 | assert len(domain_groups) == 2
90 |
91 | # Check blog domain group
92 | blog_group = next(s for s in domain_groups if "blog.example.com" in s["title"])
93 | assert blog_group["id"] == "legacy-domain-blog.example.com"
94 | assert blog_group["page_count"] == 2
95 | assert blog_group["is_synthetic"] is True
96 |
97 | # Check shop domain group
98 | shop_group = next(s for s in domain_groups if "shop.example.com" in s["title"])
99 | assert shop_group["id"] == "legacy-domain-shop.example.com"
100 | assert shop_group["page_count"] == 1
101 |
102 | async def test_build_domain_tree(self) -> None:
103 | """Test building a tree for a domain group.
104 |
105 | Args:
106 | None.
107 |
108 | Returns:
109 | None.
110 | """
111 | service = MapService()
112 | domain = "blog.example.com"
113 |
114 | # Mock pages for the domain
115 | mock_domain_pages = [
116 | {
117 | "id": "page1",
118 | "url": "https://blog.example.com/post1",
119 | "title": "Post 1",
120 | "root_page_id": None,
121 | },
122 | {
123 | "id": "page2",
124 | "url": "https://blog.example.com/post2",
125 | "title": "Post 2",
126 | "root_page_id": None,
127 | },
128 | {
129 | "id": "page3",
130 | "url": "https://blog.example.com/post3",
131 | "title": "Post 3",
132 | "root_page_id": "page3", # This is a proper root, not legacy
133 | "parent_page_id": None,
134 | "depth": 0,
135 | },
136 | ]
137 |
138 | with patch.object(
139 | service.db_ops, "get_pages_by_domain", new_callable=AsyncMock
140 | ) as mock_get:
141 | mock_get.return_value = mock_domain_pages
142 |
143 | tree = await service.build_page_tree(f"legacy-domain-{domain}")
144 |
145 | # Check the synthetic root
146 | assert tree["id"] == f"legacy-domain-{domain}"
147 | assert tree["is_synthetic"] is True
148 | assert tree["title"] == f"{domain} (3 Pages)"
149 |
150 | # Should only include legacy pages (not page3)
151 | assert len(tree["children"]) == 2
152 | assert all(child["id"] in ["page1", "page2"] for child in tree["children"])
153 |
154 | # Children should be sorted by URL
155 | assert tree["children"][0]["url"] < tree["children"][1]["url"]
156 |
157 | def test_format_site_list_with_legacy(self) -> None:
158 | """Test formatting site list with legacy domain groups.
159 |
160 | Args:
161 | None.
162 |
163 | Returns:
164 | None.
165 | """
166 | service = MapService()
167 |
168 | sites = [
169 | {
170 | "id": "site1",
171 | "url": "https://docs.example.com",
172 | "title": "Documentation",
173 | "crawl_date": datetime.datetime(2024, 1, 1),
174 | "is_synthetic": False,
175 | },
176 | {
177 | "id": "legacy-domain-blog.example.com",
178 | "url": "https://blog.example.com",
179 | "title": "blog.example.com (Legacy Pages)",
180 | "crawl_date": datetime.datetime(2024, 1, 1),
181 | "is_synthetic": True,
182 | "page_count": 5,
183 | },
184 | ]
185 |
186 | html = service.format_site_list(sites)
187 |
188 | # Check regular site formatting
189 | assert "Documentation" in html
190 | assert "Crawled: 2024-01-01" in html
191 |
192 | # Check legacy domain group formatting
193 | assert "blog.example.com (Legacy Pages)" in html
194 | assert "Domain Group • 5 pages" in html
195 | assert "First crawled: 2024-01-01" in html
196 |
--------------------------------------------------------------------------------
/src/common/processor_enhanced.py:
--------------------------------------------------------------------------------
1 | """Enhanced page processing pipeline with hierarchy tracking."""
2 |
3 | import asyncio
4 | from itertools import islice
5 | from urllib.parse import urlparse
6 |
7 | from src.common.indexer import VectorIndexer
8 | from src.common.logger import get_logger
9 | from src.lib.chunker import TextChunker
10 | from src.lib.crawler import extract_page_text
11 | from src.lib.crawler_enhanced import CrawlResultWithHierarchy, crawl_url_with_hierarchy
12 | from src.lib.database import DatabaseOperations
13 | from src.lib.embedder import generate_embedding
14 |
15 | # Configure logging
16 | logger = get_logger(__name__)
17 |
18 |
19 | async def process_crawl_result_with_hierarchy(
20 | page_result: CrawlResultWithHierarchy,
21 | job_id: str,
22 | tags: list[str] | None = None,
23 | max_concurrent_embeddings: int = 5,
24 | url_to_page_id: dict[str, str] | None = None,
25 | ) -> str:
26 | """Process a single crawled page result with hierarchy information.
27 |
28 | Args:
29 | page_result: The enhanced crawl result with hierarchy info.
30 | job_id: The ID of the crawl job.
31 | tags: Optional tags to associate with the page.
32 | max_concurrent_embeddings: Maximum number of concurrent embedding generations.
33 | url_to_page_id: Optional mapping of URLs to page IDs for parent lookup.
34 |
35 | Returns:
36 | The ID of the processed page.
37 | """
38 | if tags is None:
39 | tags = []
40 | if url_to_page_id is None:
41 | url_to_page_id = {}
42 |
43 | try:
44 | logger.info(f"Processing page with hierarchy: {page_result.url}")
45 |
46 | # Extract text from the crawl result
47 | page_text = extract_page_text(page_result.base_result)
48 |
49 | # Look up parent page ID if we have a parent URL
50 | parent_page_id = None
51 | if page_result.parent_url and page_result.parent_url in url_to_page_id:
52 | parent_page_id = url_to_page_id[page_result.parent_url]
53 |
54 | # Look up root page ID
55 | root_page_id = None
56 | if page_result.root_url and page_result.root_url in url_to_page_id:
57 | root_page_id = url_to_page_id[page_result.root_url]
58 |
59 | # Store the page in the database with hierarchy info
60 | db_ops = DatabaseOperations()
61 | page_id = await db_ops.store_page(
62 | url=page_result.url,
63 | text=page_text,
64 | job_id=job_id,
65 | tags=tags,
66 | parent_page_id=parent_page_id,
67 | root_page_id=root_page_id,
68 | depth=page_result.depth,
69 | path=page_result.relative_path,
70 | title=page_result.title,
71 | )
72 |
73 | # Store the mapping for future lookups
74 | url_to_page_id[page_result.url] = page_id
75 |
76 | # Initialize components
77 | chunker = TextChunker()
78 | indexer = VectorIndexer()
79 |
80 | # Split text into chunks
81 | chunks = chunker.split_text(page_text)
82 | logger.info(f"Split page into {len(chunks)} chunks")
83 |
84 | # Extract domain from URL
85 | domain = urlparse(page_result.url).netloc
86 |
87 | # Process chunks in parallel batches
88 | successful_chunks = 0
89 |
90 | async def process_chunk(chunk_text):
91 | try:
92 | # Generate embedding
93 | embedding = await generate_embedding(chunk_text)
94 |
95 | # Prepare payload
96 | payload = {
97 | "text": chunk_text,
98 | "page_id": page_id,
99 | "url": page_result.url,
100 | "domain": domain,
101 | "tags": tags,
102 | "job_id": job_id,
103 | }
104 |
105 | # Index the vector
106 | await indexer.index_vector(embedding, payload)
107 | return True
108 | except Exception as chunk_error:
109 | logger.error(f"Error processing chunk: {chunk_error!s}")
110 | return False
111 |
112 | # Process chunks in batches with limited concurrency
113 | i = 0
114 | while i < len(chunks):
115 | # Take up to max_concurrent_embeddings chunks
116 | batch_chunks = list(islice(chunks, i, i + max_concurrent_embeddings))
117 | i += max_concurrent_embeddings
118 |
119 | # Process this batch in parallel
120 | results = await asyncio.gather(
121 | *[process_chunk(chunk) for chunk in batch_chunks],
122 | return_exceptions=False,
123 | )
124 |
125 | successful_chunks += sum(1 for result in results if result)
126 |
127 | logger.info(
128 | f"Successfully indexed {successful_chunks}/{len(chunks)} chunks for page {page_id}",
129 | )
130 |
131 | return page_id
132 |
133 | except Exception as e:
134 | logger.error(f"Error processing page {page_result.url}: {e!s}")
135 | raise
136 |
137 |
138 | async def process_crawl_with_hierarchy(
139 | url: str,
140 | job_id: str,
141 | tags: list[str] | None = None,
142 | max_pages: int = 100,
143 | max_depth: int = 2,
144 | strip_urls: bool = True,
145 | ) -> list[str]:
146 | """Crawl a URL and process all pages with hierarchy tracking.
147 |
148 | Args:
149 | url: The URL to start crawling from.
150 | job_id: The ID of the crawl job.
151 | tags: Optional tags to associate with the pages.
152 | max_pages: Maximum number of pages to crawl.
153 | max_depth: Maximum depth for the BFS crawl.
154 | strip_urls: Whether to strip URLs from the returned markdown.
155 |
156 | Returns:
157 | List of processed page IDs.
158 | """
159 | if tags is None:
160 | tags = []
161 |
162 | # Crawl with hierarchy tracking
163 | crawl_results = await crawl_url_with_hierarchy(url, max_pages, max_depth, strip_urls)
164 |
165 | # Sort results by depth to ensure parents are processed before children
166 | crawl_results.sort(key=lambda r: r.depth)
167 |
168 | # Map to track URL to page ID relationships
169 | url_to_page_id = {}
170 | processed_page_ids = []
171 |
172 | # Process pages in order
173 | for i, page_result in enumerate(crawl_results):
174 | try:
175 | page_id = await process_crawl_result_with_hierarchy(
176 | page_result, job_id, tags, url_to_page_id=url_to_page_id
177 | )
178 | processed_page_ids.append(page_id)
179 |
180 | # Update job progress
181 | db_ops_status = DatabaseOperations()
182 | await db_ops_status.update_job_status(
183 | job_id=job_id,
184 | status="running",
185 | pages_discovered=len(crawl_results),
186 | pages_crawled=len(processed_page_ids),
187 | )
188 |
189 | except Exception as page_error:
190 | logger.error(f"Error processing page {page_result.url}: {page_error!s}")
191 | continue
192 |
193 | return processed_page_ids
194 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
🩺 Doctor
8 |
9 | [](https://github.com/sisig-ai/doctor)
10 | [](LICENSE.md)
11 | [](https://github.com/sisig-ai/doctor/actions/workflows/pytest.yml)
12 | [](https://codecov.io/gh/sisig-ai/doctor)
13 |
14 | A tool for discovering, crawl, and indexing web sites to be exposed as an MCP server for LLM agents for better and more up-to-date reasoning and code generation.
15 |
16 |
17 |
18 | ---
19 |
20 | ### 🔍 Overview
21 |
22 | Doctor provides a complete stack for:
23 | - Crawling web pages using crawl4ai with hierarchy tracking
24 | - Chunking text with LangChain
25 | - Creating embeddings with OpenAI via litellm
26 | - Storing data in DuckDB with vector search support
27 | - Exposing search functionality via a FastAPI web service
28 | - Making these capabilities available to LLMs through an MCP server
29 | - Navigating crawled sites with hierarchical site maps
30 |
31 | ---
32 |
33 | ### 🏗️ Core Infrastructure
34 |
35 | #### 🗄️ DuckDB
36 | - Database for storing document data and embeddings with vector search capabilities
37 | - Managed by unified Database class
38 |
39 | #### 📨 Redis
40 | - Message broker for asynchronous task processing
41 |
42 | #### 🕸️ Crawl Worker
43 | - Processes crawl jobs
44 | - Chunks text
45 | - Creates embeddings
46 |
47 | #### 🌐 Web Server
48 | - FastAPI service exposing endpoints
49 | - Fetching, searching, and viewing data
50 | - Exposing the MCP server
51 |
52 | ---
53 |
54 | ### 💻 Setup
55 |
56 | #### ⚙️ Prerequisites
57 | - Docker and Docker Compose
58 | - Python 3.10+
59 | - uv (Python package manager)
60 | - OpenAI API key
61 |
62 | #### 📦 Installation
63 | 1. Clone this repository
64 | 2. Set up environment variables:
65 | ```
66 | export OPENAI_API_KEY=your-openai-key
67 | ```
68 | 3. Run the stack:
69 | ```
70 | docker compose up
71 | ```
72 |
73 | ---
74 |
75 | ### 👁 Usage
76 | 1. Go to http://localhost:9111/docs to see the OpenAPI docs
77 | 2. Look for the `/fetch_url` endpoint and start a crawl job by providing a URL
78 | 3. Use `/job_progress` to see the current job status
79 | 4. Configure your editor to use `http://localhost:9111/mcp` as an MCP server
80 |
81 | ---
82 |
83 | ### ☁️ Web API
84 |
85 | #### Core Endpoints
86 | - `POST /fetch_url`: Start crawling a URL
87 | - `GET /search_docs`: Search indexed documents
88 | - `GET /job_progress`: Check crawl job progress
89 | - `GET /list_doc_pages`: List indexed pages
90 | - `GET /get_doc_page`: Get full text of a page
91 |
92 | #### Site Map Feature
93 | The Maps feature provides a hierarchical view of crawled websites, making it easy to navigate and explore the structure of indexed sites.
94 |
95 | **Endpoints:**
96 | - `GET /map`: View an index of all crawled sites
97 | - `GET /map/site/{root_page_id}`: View the hierarchical tree structure of a specific site
98 | - `GET /map/page/{page_id}`: View a specific page with navigation (parent, siblings, children)
99 | - `GET /map/page/{page_id}/raw`: Get the raw markdown content of a page
100 |
101 | **Features:**
102 | - **Hierarchical Navigation**: Pages maintain parent-child relationships, allowing you to navigate through the site structure
103 | - **Domain Grouping**: Pages from the same domain crawled individually are automatically grouped together
104 | - **Automatic Title Extraction**: Page titles are extracted from HTML or markdown content
105 | - **Breadcrumb Navigation**: Easy navigation with breadcrumbs showing the path from root to current page
106 | - **Sibling Navigation**: Quick access to pages at the same level in the hierarchy
107 | - **Legacy Page Support**: Pages crawled before hierarchy tracking are grouped by domain for easy access
108 | - **No JavaScript Required**: All navigation works with pure HTML and CSS for maximum compatibility
109 |
110 | **Usage Example:**
111 | 1. Crawl a website using the `/fetch_url` endpoint
112 | 2. Visit `/map` to see all crawled sites
113 | 3. Click on a site to view its hierarchical structure
114 | 4. Navigate through pages using the provided links
115 |
116 | ---
117 |
118 | ### 🔧 MCP Integration
119 | Ensure that your Docker Compose stack is up, and then add to your Cursor or VSCode MCP Servers configuration:
120 |
121 | ```json
122 | "doctor": {
123 | "type": "sse",
124 | "url": "http://localhost:9111/mcp"
125 | }
126 | ```
127 |
128 | ---
129 |
130 | ### 🧪 Testing
131 |
132 | #### Running Tests
133 | To run all tests:
134 | ```bash
135 | # Run all tests with coverage report
136 | pytest
137 | ```
138 |
139 | To run specific test categories:
140 | ```bash
141 | # Run only unit tests
142 | pytest -m unit
143 |
144 | # Run only async tests
145 | pytest -m async_test
146 |
147 | # Run tests for a specific component
148 | pytest tests/lib/test_crawler.py
149 | ```
150 |
151 | #### Test Coverage
152 | The project is configured to generate coverage reports automatically:
153 | ```bash
154 | # Run tests with detailed coverage report
155 | pytest --cov=src --cov-report=term-missing
156 | ```
157 |
158 | #### Test Structure
159 | - `tests/conftest.py`: Common fixtures for all tests
160 | - `tests/lib/`: Tests for library components
161 | - `test_crawler.py`: Tests for the crawler module
162 | - `test_crawler_enhanced.py`: Tests for enhanced crawler with hierarchy tracking
163 | - `test_chunker.py`: Tests for the chunker module
164 | - `test_embedder.py`: Tests for the embedder module
165 | - `test_database.py`: Tests for the unified Database class
166 | - `test_database_hierarchy.py`: Tests for database hierarchy operations
167 | - `tests/common/`: Tests for common modules
168 | - `tests/services/`: Tests for service layer
169 | - `test_map_service.py`: Tests for the map service
170 | - `tests/api/`: Tests for API endpoints
171 | - `test_map_api.py`: Tests for map API endpoints
172 | - `tests/integration/`: Integration tests
173 | - `test_processor_enhanced.py`: Tests for enhanced processor with hierarchy
174 |
175 | ---
176 |
177 | ### 🐞 Code Quality
178 |
179 | #### Pre-commit Hooks
180 | The project is configured with pre-commit hooks that run automatically before each commit:
181 | - `ruff check --fix`: Lints code and automatically fixes issues
182 | - `ruff format`: Formats code according to project style
183 | - Trailing whitespace removal
184 | - End-of-file fixing
185 | - YAML validation
186 | - Large file checks
187 |
188 | #### Setup Pre-commit
189 | To set up pre-commit hooks:
190 | ```bash
191 | # Install pre-commit
192 | uv pip install pre-commit
193 |
194 | # Install the git hooks
195 | pre-commit install
196 | ```
197 |
198 | #### Running Pre-commit Manually
199 | You can run the pre-commit hooks manually on all files:
200 | ```bash
201 | # Run all pre-commit hooks
202 | pre-commit run --all-files
203 | ```
204 |
205 | Or on staged files only:
206 | ```bash
207 | # Run on staged files
208 | pre-commit run
209 | ```
210 |
211 | ---
212 |
213 | ### ⚖️ License
214 | This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details.
215 |
--------------------------------------------------------------------------------
/src/web_service/api/documents.py:
--------------------------------------------------------------------------------
1 | """Document API routes for the web service."""
2 |
3 | from fastapi import APIRouter, HTTPException, Query
4 |
5 | from src.common.config import RETURN_FULL_DOCUMENT_TEXT
6 | from src.common.logger import get_logger
7 | from src.common.models import (
8 | GetDocPageResponse,
9 | ListDocPagesResponse,
10 | ListTagsResponse,
11 | SearchDocsResponse,
12 | )
13 | from src.lib.database import DatabaseOperations
14 | from src.web_service.services.document_service import (
15 | get_doc_page,
16 | list_doc_pages,
17 | list_tags,
18 | search_docs,
19 | )
20 |
21 | # Get logger for this module
22 | logger = get_logger(__name__)
23 |
24 | # Create router
25 | router = APIRouter(tags=["documents"])
26 |
27 |
28 | @router.get("/search_docs", response_model=SearchDocsResponse, operation_id="search_docs")
29 | async def search_docs_endpoint(
30 | query: str = Query(..., description="The search string to query the database with"),
31 | tags: list[str] | None = Query(None, description="Tags to limit the search with"),
32 | max_results: int = Query(10, description="Maximum number of results to return", ge=1, le=100),
33 | return_full_document_text: bool = Query(
34 | RETURN_FULL_DOCUMENT_TEXT,
35 | description="Whether to return the full document text instead of the matching chunks only",
36 | ),
37 | ):
38 | """Search for documents using semantic search. Use `get_doc_page` to get the full text of a document page.
39 |
40 | Args:
41 | query: The search query
42 | tags: Optional tags to filter by
43 | max_results: Maximum number of results to return
44 |
45 | Returns:
46 | Search results
47 |
48 | """
49 | logger.info(
50 | f"API: Searching docs with query: '{query}', tags: {tags}, max_results: {max_results}, return_full_document_text: {return_full_document_text}",
51 | )
52 |
53 | db_ops = DatabaseOperations()
54 | try:
55 | # Use DuckDBConnectionManager as a context manager
56 | with db_ops.db as conn_manager:
57 | actual_conn = conn_manager.conn
58 | if not actual_conn:
59 | logger.error("Failed to obtain database connection for search_docs.")
60 | raise HTTPException(status_code=500, detail="Database connection error.")
61 | # Call the service function
62 | response = await search_docs(
63 | actual_conn, query, tags, max_results, return_full_document_text
64 | )
65 | return response
66 | except HTTPException: # Re-raise HTTP exceptions directly
67 | raise
68 | except (
69 | Exception
70 | ) as e: # Catch other exceptions (like DB errors from service or connection issues)
71 | logger.error(f"Error searching documents: {e!s}")
72 | # Log the specific error, but return a generic server error to the client
73 | raise HTTPException(status_code=500, detail=f"Search error: {e!s}")
74 | # No finally block to close connection, context manager handles it.
75 |
76 |
77 | @router.get("/list_doc_pages", response_model=ListDocPagesResponse, operation_id="list_doc_pages")
78 | async def list_doc_pages_endpoint(
79 | page: int = Query(1, description="Page number", ge=1),
80 | tags: list[str] | None = Query(None, description="Tags to filter by"),
81 | ):
82 | """List all available indexed pages.
83 |
84 | Args:
85 | page: Page number (1-based)
86 | tags: Optional tags to filter by
87 |
88 | Returns:
89 | List of document pages
90 |
91 | """
92 | logger.info(f"API: Listing document pages (page={page}, tags={tags})")
93 |
94 | db_ops = DatabaseOperations()
95 | try:
96 | with db_ops.db as conn_manager:
97 | actual_conn = conn_manager.conn
98 | if not actual_conn:
99 | logger.error("Failed to obtain database connection for list_doc_pages.")
100 | raise HTTPException(status_code=500, detail="Database connection error.")
101 | # Call the service function
102 | response = await list_doc_pages(actual_conn, page, tags)
103 | return response
104 | except HTTPException:
105 | raise
106 | except Exception as e:
107 | logger.error(f"Error listing document pages: {e!s}")
108 | raise HTTPException(status_code=500, detail=f"Database error: {e!s}")
109 | # No finally block to close connection
110 |
111 |
112 | @router.get("/get_doc_page", response_model=GetDocPageResponse, operation_id="get_doc_page")
113 | async def get_doc_page_endpoint(
114 | page_id: str = Query(..., description="The page ID to retrieve"),
115 | starting_line: int = Query(1, description="Line to view from", ge=1),
116 | ending_line: int = Query(
117 | -1,
118 | description="Line to view up to. Set to -1 to view the entire page.",
119 | ),
120 | ):
121 | """Get the full text of a document page. Use `search_docs` or `list_doc_pages` to get the page IDs.
122 |
123 | Args:
124 | page_id: The page ID
125 | starting_line: Line to view from (1-based)
126 | ending_line: Line to view up to (1-based). Set to -1 to view the entire page.
127 |
128 | Returns:
129 | Document page text
130 |
131 | """
132 | logger.info(f"API: Retrieving document page {page_id} (lines {starting_line}-{ending_line})")
133 |
134 | db_ops = DatabaseOperations()
135 | try:
136 | with db_ops.db as conn_manager:
137 | actual_conn = conn_manager.conn
138 | if not actual_conn:
139 | logger.error("Failed to obtain database connection for get_doc_page.")
140 | raise HTTPException(status_code=500, detail="Database connection error.")
141 | # Call the service function
142 | response = await get_doc_page(actual_conn, page_id, starting_line, ending_line)
143 | if response is None:
144 | raise HTTPException(status_code=404, detail=f"Page {page_id} not found")
145 | return response
146 | except HTTPException:
147 | # Re-raise HTTP exceptions as-is
148 | raise
149 | except Exception as e:
150 | logger.error(f"Error retrieving document page {page_id}: {e!s}")
151 | raise HTTPException(status_code=500, detail=f"Database error: {e!s}")
152 | # No finally block to close connection
153 |
154 |
155 | @router.get("/list_tags", response_model=ListTagsResponse, operation_id="list_tags")
156 | async def list_tags_endpoint(
157 | search_substring: str | None = Query(
158 | None,
159 | description="Optional substring to filter tags (case-insensitive fuzzy matching)",
160 | ),
161 | ):
162 | """List all unique tags available in the document database. Use `search_docs` or `list_doc_pages` to get the page IDs using the tags.
163 |
164 | Args:
165 | search_substring: Optional substring to filter tags using case-insensitive fuzzy matching
166 |
167 | Returns:
168 | List of unique tags
169 |
170 | """
171 | logger.info(
172 | f"API: Listing all unique document tags{' with filter: ' + search_substring if search_substring else ''}",
173 | )
174 |
175 | db_ops = DatabaseOperations()
176 | try:
177 | with db_ops.db as conn_manager:
178 | actual_conn = conn_manager.conn
179 | if not actual_conn:
180 | logger.error("Failed to obtain database connection for list_tags.")
181 | raise HTTPException(status_code=500, detail="Database connection error.")
182 | # Call the service function
183 | response = await list_tags(actual_conn, search_substring)
184 | return response
185 | except HTTPException:
186 | raise
187 | except Exception as e:
188 | logger.error(f"Error listing document tags: {e!s}")
189 | raise HTTPException(status_code=500, detail=f"Database error: {e!s}")
190 | # No finally block to close connection
191 |
--------------------------------------------------------------------------------
/tests/common/test_processor.py:
--------------------------------------------------------------------------------
1 | """Tests for the processor module."""
2 |
3 | from unittest.mock import AsyncMock, MagicMock, patch
4 |
5 | import pytest
6 |
7 | from src.common.processor import process_crawl_result, process_page_batch
8 |
9 |
10 | @pytest.fixture
11 | def mock_processor_dependencies():
12 | """Setup mock dependencies for processor tests."""
13 | extract_page_text_mock = MagicMock()
14 | extract_page_text_mock.return_value = "Extracted text content"
15 |
16 | # Mock for DatabaseOperations class store_page method
17 | store_page_mock = AsyncMock()
18 | store_page_mock.return_value = "test-page-123"
19 |
20 | # Mock for DatabaseOperations class update_job_status method
21 | update_job_status_mock = MagicMock()
22 |
23 | # Create a DatabaseOperations class mock that can be both used directly
24 | # and as a context manager
25 | database_mock = MagicMock(
26 | __enter__=lambda self: self,
27 | __exit__=lambda *args: None,
28 | store_page=store_page_mock,
29 | update_job_status=update_job_status_mock,
30 | )
31 |
32 | generate_embedding_mock = AsyncMock()
33 | generate_embedding_mock.return_value = [0.1, 0.2, 0.3, 0.4, 0.5]
34 |
35 | mocks = {
36 | "extract_page_text": extract_page_text_mock,
37 | "database": database_mock,
38 | "store_page": store_page_mock,
39 | "TextChunker": MagicMock(),
40 | "generate_embedding": generate_embedding_mock,
41 | "VectorIndexer": MagicMock(),
42 | "update_job_status": update_job_status_mock,
43 | }
44 |
45 | chunker_instance = MagicMock()
46 | chunker_instance.split_text.return_value = ["Chunk 1", "Chunk 2", "Chunk 3"]
47 | mocks["TextChunker"].return_value = chunker_instance
48 |
49 | indexer_instance = MagicMock()
50 | indexer_instance.index_vector = AsyncMock(return_value="vector-id-123")
51 | mocks["VectorIndexer"].return_value = indexer_instance
52 |
53 | return mocks
54 |
55 |
56 | @pytest.mark.unit
57 | @pytest.mark.async_test
58 | async def test_process_crawl_result(
59 | sample_crawl_result,
60 | job_id,
61 | sample_tags,
62 | mock_processor_dependencies,
63 | ):
64 | """Test processing a single crawl result."""
65 | with (
66 | patch(
67 | "src.common.processor.extract_page_text",
68 | mock_processor_dependencies["extract_page_text"],
69 | ),
70 | patch(
71 | "src.common.processor.DatabaseOperations",
72 | return_value=mock_processor_dependencies["database"],
73 | ),
74 | patch("src.common.processor.TextChunker", mock_processor_dependencies["TextChunker"]),
75 | patch(
76 | "src.common.processor.generate_embedding",
77 | mock_processor_dependencies["generate_embedding"],
78 | ),
79 | patch("src.common.processor.VectorIndexer", mock_processor_dependencies["VectorIndexer"]),
80 | ):
81 | page_id = await process_crawl_result(
82 | page_result=sample_crawl_result,
83 | job_id=job_id,
84 | tags=sample_tags,
85 | )
86 |
87 | mock_processor_dependencies["extract_page_text"].assert_called_once_with(
88 | sample_crawl_result,
89 | )
90 | mock_processor_dependencies["store_page"].assert_called_once_with(
91 | url=sample_crawl_result.url,
92 | text="Extracted text content",
93 | job_id=job_id,
94 | tags=sample_tags,
95 | )
96 | chunker_instance = mock_processor_dependencies["TextChunker"].return_value
97 | chunker_instance.split_text.assert_called_once_with("Extracted text content")
98 | assert mock_processor_dependencies["generate_embedding"].call_count == 3
99 | indexer_instance = mock_processor_dependencies["VectorIndexer"].return_value
100 | assert indexer_instance.index_vector.call_count == 3
101 | assert page_id == "test-page-123"
102 |
103 |
104 | @pytest.mark.unit
105 | @pytest.mark.async_test
106 | async def test_process_crawl_result_with_errors(
107 | sample_crawl_result,
108 | job_id,
109 | sample_tags,
110 | mock_processor_dependencies,
111 | ):
112 | """Test processing a crawl result with errors during chunking/embedding."""
113 | with (
114 | patch(
115 | "src.common.processor.extract_page_text",
116 | mock_processor_dependencies["extract_page_text"],
117 | ),
118 | patch(
119 | "src.common.processor.DatabaseOperations",
120 | return_value=mock_processor_dependencies["database"],
121 | ),
122 | patch("src.common.processor.TextChunker", mock_processor_dependencies["TextChunker"]),
123 | patch("src.common.processor.generate_embedding", side_effect=Exception("Embedding error")),
124 | patch("src.common.processor.VectorIndexer", mock_processor_dependencies["VectorIndexer"]),
125 | ):
126 | page_id = await process_crawl_result(
127 | page_result=sample_crawl_result,
128 | job_id=job_id,
129 | tags=sample_tags,
130 | )
131 | assert page_id == "test-page-123"
132 | mock_processor_dependencies["store_page"].assert_called_once()
133 | indexer_instance = mock_processor_dependencies["VectorIndexer"].return_value
134 | assert indexer_instance.index_vector.call_count == 0
135 |
136 |
137 | @pytest.mark.unit
138 | @pytest.mark.async_test
139 | async def test_process_page_batch(mock_processor_dependencies):
140 | """Test processing a batch of crawl results."""
141 | mock_results = [
142 | MagicMock(url="https://example.com/page1"),
143 | MagicMock(url="https://example.com/page2"),
144 | MagicMock(url="https://example.com/page3"),
145 | ]
146 | mock_process_result = AsyncMock(side_effect=["page-1", "page-2", "page-3"])
147 |
148 | with (
149 | patch("src.common.processor.process_crawl_result", mock_process_result),
150 | patch(
151 | "src.common.processor.DatabaseOperations",
152 | return_value=mock_processor_dependencies["database"],
153 | ),
154 | ):
155 | job_id = "test-job"
156 | tags = ["test", "batch"]
157 | page_ids = await process_page_batch(
158 | page_results=mock_results,
159 | job_id=job_id,
160 | tags=tags,
161 | batch_size=2,
162 | )
163 | assert mock_process_result.call_count == 3
164 | mock_process_result.assert_any_call(mock_results[0], job_id, tags)
165 | mock_process_result.assert_any_call(mock_results[1], job_id, tags)
166 | mock_process_result.assert_any_call(mock_results[2], job_id, tags)
167 | assert mock_processor_dependencies["update_job_status"].call_count == 3
168 | assert page_ids == ["page-1", "page-2", "page-3"]
169 |
170 |
171 | @pytest.mark.unit
172 | @pytest.mark.async_test
173 | async def test_process_page_batch_with_errors(mock_processor_dependencies):
174 | """Test processing a batch with errors on some pages."""
175 | mock_results = [
176 | MagicMock(url="https://example.com/page1"),
177 | MagicMock(url="https://example.com/page2"),
178 | MagicMock(url="https://example.com/page3"),
179 | ]
180 | mock_process_result = AsyncMock(
181 | side_effect=["page-1", Exception("Error processing page 2"), "page-3"],
182 | )
183 |
184 | with (
185 | patch("src.common.processor.process_crawl_result", mock_process_result),
186 | patch(
187 | "src.common.processor.DatabaseOperations",
188 | return_value=mock_processor_dependencies["database"],
189 | ),
190 | ):
191 | job_id = "test-job"
192 | tags = ["test", "batch"]
193 | page_ids = await process_page_batch(page_results=mock_results, job_id=job_id, tags=tags)
194 | assert mock_process_result.call_count == 3
195 | assert mock_processor_dependencies["update_job_status"].call_count == 2
196 | assert page_ids == ["page-1", "page-3"]
197 |
198 |
199 | @pytest.mark.unit
200 | @pytest.mark.async_test
201 | async def test_process_page_batch_empty():
202 | """Test processing an empty batch of pages."""
203 | page_ids = await process_page_batch(page_results=[], job_id="test-job", tags=["test"])
204 | assert page_ids == []
205 |
--------------------------------------------------------------------------------
/tests/lib/test_crawler_enhanced.py:
--------------------------------------------------------------------------------
1 | """Tests for the enhanced crawler with hierarchy tracking."""
2 |
3 | import pytest
4 | from unittest.mock import Mock, AsyncMock, patch
5 | from src.lib.crawler_enhanced import CrawlResultWithHierarchy, crawl_url_with_hierarchy
6 |
7 |
8 | class TestCrawlResultWithHierarchy:
9 | """Test the CrawlResultWithHierarchy class."""
10 |
11 | def test_init_with_parent_url(self) -> None:
12 | """Test initialization with parent URL in metadata.
13 |
14 | Args:
15 | None.
16 |
17 | Returns:
18 | None.
19 | """
20 | # Create mock base result
21 | base_result = Mock()
22 | base_result.url = "https://example.com/page1"
23 | base_result.html = "Test Page"
24 | base_result.metadata = {"parent_url": "https://example.com", "depth": 1}
25 |
26 | # Create enhanced result
27 | enhanced = CrawlResultWithHierarchy(base_result)
28 |
29 | assert enhanced.url == "https://example.com/page1"
30 | assert enhanced.parent_url == "https://example.com"
31 | assert enhanced.depth == 1
32 | assert enhanced.title == "Test Page"
33 |
34 | def test_init_without_metadata(self) -> None:
35 | """Test initialization without metadata.
36 |
37 | Args:
38 | None.
39 |
40 | Returns:
41 | None.
42 | """
43 | # Create mock base result without metadata
44 | base_result = Mock()
45 | base_result.url = "https://example.com"
46 | base_result.html = "Content"
47 | base_result.metadata = None
48 |
49 | # Mock hasattr to return False for metadata
50 | with patch("builtins.hasattr", side_effect=lambda obj, attr: attr != "metadata"):
51 | enhanced = CrawlResultWithHierarchy(base_result)
52 |
53 | assert enhanced.url == "https://example.com"
54 | assert enhanced.parent_url is None
55 | assert enhanced.depth == 0
56 |
57 | def test_extract_title_from_html(self) -> None:
58 | """Test extracting title from HTML.
59 |
60 | Args:
61 | None.
62 |
63 | Returns:
64 | None.
65 | """
66 | base_result = Mock()
67 | base_result.url = "https://example.com"
68 | base_result.html = "My Test Page"
69 | base_result.metadata = {}
70 |
71 | enhanced = CrawlResultWithHierarchy(base_result)
72 | assert enhanced.title == "My Test Page"
73 |
74 | def test_extract_title_from_markdown(self) -> None:
75 | """Test extracting title from markdown when no HTML title.
76 |
77 | Args:
78 | None.
79 |
80 | Returns:
81 | None.
82 | """
83 | base_result = Mock()
84 | base_result.url = "https://example.com"
85 | base_result.html = "Content"
86 | base_result.metadata = {}
87 |
88 | # Mock extract_page_text to return markdown
89 | with patch(
90 | "src.lib.crawler_enhanced.extract_page_text",
91 | return_value="# Main Heading\n\nSome content",
92 | ):
93 | enhanced = CrawlResultWithHierarchy(base_result)
94 |
95 | assert enhanced.title == "Main Heading"
96 |
97 | def test_extract_title_fallback_to_url(self) -> None:
98 | """Test title fallback to URL path when no title found.
99 |
100 | Args:
101 | None.
102 |
103 | Returns:
104 | None.
105 | """
106 | base_result = Mock()
107 | base_result.url = "https://example.com/docs/api"
108 | base_result.html = ""
109 | base_result.metadata = {}
110 |
111 | with patch("src.lib.crawler_enhanced.extract_page_text", return_value=""):
112 | enhanced = CrawlResultWithHierarchy(base_result)
113 |
114 | assert enhanced.title == "api"
115 |
116 | def test_extract_title_home_for_root(self) -> None:
117 | """Test title defaults to 'Home' for root URLs.
118 |
119 | Args:
120 | None.
121 |
122 | Returns:
123 | None.
124 | """
125 | base_result = Mock()
126 | base_result.url = "https://example.com/"
127 | base_result.html = ""
128 | base_result.metadata = {}
129 |
130 | with patch("src.lib.crawler_enhanced.extract_page_text", return_value=""):
131 | enhanced = CrawlResultWithHierarchy(base_result)
132 |
133 | assert enhanced.title == "Home"
134 |
135 |
136 | @pytest.mark.asyncio
137 | class TestCrawlUrlWithHierarchy:
138 | """Test the crawl_url_with_hierarchy function."""
139 |
140 | async def test_crawl_with_hierarchy(self) -> None:
141 | """Test crawling with hierarchy enhancement.
142 |
143 | Args:
144 | None.
145 |
146 | Returns:
147 | None.
148 | """
149 | # Mock base crawl results
150 | mock_results = [
151 | Mock(
152 | url="https://example.com",
153 | metadata={"parent_url": None, "depth": 0},
154 | html="Home",
155 | ),
156 | Mock(
157 | url="https://example.com/about",
158 | metadata={"parent_url": "https://example.com", "depth": 1},
159 | html="About",
160 | ),
161 | Mock(
162 | url="https://example.com/contact",
163 | metadata={"parent_url": "https://example.com", "depth": 1},
164 | html="Contact",
165 | ),
166 | ]
167 |
168 | with patch("src.lib.crawler_enhanced.base_crawl_url", new_callable=AsyncMock) as mock_crawl:
169 | mock_crawl.return_value = mock_results
170 |
171 | results = await crawl_url_with_hierarchy("https://example.com", max_pages=10)
172 |
173 | assert len(results) == 3
174 | assert all(isinstance(r, CrawlResultWithHierarchy) for r in results)
175 |
176 | # Check hierarchy information
177 | assert results[0].parent_url is None
178 | assert results[0].root_url == "https://example.com"
179 | assert results[0].depth == 0
180 |
181 | assert results[1].parent_url == "https://example.com"
182 | assert results[1].root_url == "https://example.com"
183 | assert results[1].depth == 1
184 |
185 | async def test_relative_path_calculation(self) -> None:
186 | """Test calculation of relative paths.
187 |
188 | Args:
189 | None.
190 |
191 | Returns:
192 | None.
193 | """
194 | # Mock results with parent-child relationship
195 | mock_results = [
196 | Mock(
197 | url="https://example.com",
198 | metadata={"parent_url": None, "depth": 0},
199 | html="Home",
200 | ),
201 | Mock(
202 | url="https://example.com/docs",
203 | metadata={"parent_url": "https://example.com", "depth": 1},
204 | html="Documentation",
205 | ),
206 | Mock(
207 | url="https://example.com/docs/api",
208 | metadata={"parent_url": "https://example.com/docs", "depth": 2},
209 | html="API Reference",
210 | ),
211 | ]
212 |
213 | with patch("src.lib.crawler_enhanced.base_crawl_url", new_callable=AsyncMock) as mock_crawl:
214 | mock_crawl.return_value = mock_results
215 |
216 | results = await crawl_url_with_hierarchy("https://example.com")
217 |
218 | # Check relative paths
219 | assert results[0].relative_path == "/"
220 | assert results[1].relative_path == "Documentation"
221 | assert results[2].relative_path == "Documentation/API Reference"
222 |
223 | async def test_empty_crawl_results(self) -> None:
224 | """Test handling of empty crawl results.
225 |
226 | Args:
227 | None.
228 |
229 | Returns:
230 | None.
231 | """
232 | with patch("src.lib.crawler_enhanced.base_crawl_url", new_callable=AsyncMock) as mock_crawl:
233 | mock_crawl.return_value = []
234 |
235 | results = await crawl_url_with_hierarchy("https://example.com")
236 |
237 | assert results == []
238 |
--------------------------------------------------------------------------------
/tests/services/test_job_service.py:
--------------------------------------------------------------------------------
1 | """Tests for the job service."""
2 |
3 | import datetime
4 | from unittest.mock import MagicMock, patch
5 |
6 | import duckdb
7 | import pytest
8 | from rq import Queue
9 |
10 | from src.common.models import (
11 | FetchUrlResponse,
12 | JobProgressResponse,
13 | )
14 | from src.web_service.services.job_service import (
15 | fetch_url,
16 | get_job_count,
17 | get_job_progress,
18 | )
19 |
20 |
21 | @pytest.fixture
22 | def mock_duckdb_connection():
23 | """Mock DuckDB connection."""
24 | mock = MagicMock(spec=duckdb.DuckDBPyConnection)
25 | return mock
26 |
27 |
28 | @pytest.fixture
29 | def mock_queue():
30 | """Mock Redis queue."""
31 | mock = MagicMock(spec=Queue)
32 | mock.enqueue = MagicMock()
33 | return mock
34 |
35 |
36 | @pytest.mark.unit
37 | @pytest.mark.async_test
38 | async def test_fetch_url_basic(mock_queue):
39 | """Test fetching a URL with basic parameters."""
40 | # Set a fixed UUID for testing
41 | with patch("uuid.uuid4", return_value="test-job-123"):
42 | # Call the function with just URL
43 | result = await fetch_url(
44 | queue=mock_queue,
45 | url="https://example.com",
46 | )
47 |
48 | # Verify queue.enqueue was called with the right parameters
49 | mock_queue.enqueue.assert_called_once()
50 | call_args = mock_queue.enqueue.call_args[0]
51 | assert call_args[0] == "src.crawl_worker.tasks.create_job"
52 | assert call_args[1] == "https://example.com"
53 | assert call_args[2] == "test-job-123"
54 |
55 | # Should receive keyword arguments for tags and max_pages
56 | kwargs = mock_queue.enqueue.call_args[1]
57 | assert kwargs["tags"] is None
58 | assert kwargs["max_pages"] == 100
59 |
60 | # Check that the function returned the expected response
61 | assert isinstance(result, FetchUrlResponse)
62 | assert result.job_id == "test-job-123"
63 |
64 |
65 | @pytest.mark.unit
66 | @pytest.mark.async_test
67 | async def test_fetch_url_with_options(mock_queue):
68 | """Test fetching a URL with all parameters."""
69 | # Set a fixed UUID for testing
70 | with patch("uuid.uuid4", return_value="test-job-456"):
71 | # Call the function with all parameters
72 | result = await fetch_url(
73 | queue=mock_queue,
74 | url="https://example.com",
75 | tags=["test", "example"],
76 | max_pages=50,
77 | )
78 |
79 | # Verify queue.enqueue was called with the right parameters
80 | mock_queue.enqueue.assert_called_once()
81 | call_args = mock_queue.enqueue.call_args[0]
82 | assert call_args[0] == "src.crawl_worker.tasks.create_job"
83 | assert call_args[1] == "https://example.com"
84 | assert call_args[2] == "test-job-456"
85 |
86 | # Should receive keyword arguments with the provided values
87 | kwargs = mock_queue.enqueue.call_args[1]
88 | assert kwargs["tags"] == ["test", "example"]
89 | assert kwargs["max_pages"] == 50
90 |
91 | # Check that the function returned the expected response
92 | assert isinstance(result, FetchUrlResponse)
93 | assert result.job_id == "test-job-456"
94 |
95 |
96 | @pytest.mark.unit
97 | @pytest.mark.async_test
98 | async def test_get_job_progress_exact_match(mock_duckdb_connection):
99 | """Test getting job progress with exact job ID match."""
100 | # Mock database results for successful job
101 | created_at = datetime.datetime(2023, 1, 1, 10, 0, 0)
102 | updated_at = datetime.datetime(2023, 1, 1, 10, 5, 0)
103 |
104 | mock_result = (
105 | "test-job-123", # job_id
106 | "https://example.com", # start_url
107 | "running", # status
108 | 200, # pages_discovered
109 | 50, # pages_crawled
110 | 100, # max_pages
111 | "test,example", # tags_json
112 | created_at, # created_at
113 | updated_at, # updated_at
114 | None, # error_message
115 | )
116 |
117 | # Set up mock execution and result
118 | mock_cursor = MagicMock()
119 | mock_cursor.fetchone.return_value = mock_result
120 | mock_duckdb_connection.execute.return_value = mock_cursor
121 |
122 | # Call the function
123 | result = await get_job_progress(
124 | conn=mock_duckdb_connection,
125 | job_id="test-job-123",
126 | )
127 |
128 | # Verify database was queried with exact job ID
129 | mock_duckdb_connection.execute.assert_called_once()
130 | query = mock_duckdb_connection.execute.call_args[0][0]
131 | assert "SELECT job_id, start_url, status" in query
132 | assert "WHERE job_id = ?" in query
133 |
134 | # Check that results are returned correctly
135 | assert isinstance(result, JobProgressResponse)
136 | assert result.status == "running"
137 | assert result.pages_crawled == 50
138 | assert result.pages_total == 200
139 | assert result.completed is False
140 | assert result.progress_percent == 50 # 50/100 = 50%
141 | assert result.url == "https://example.com"
142 | assert result.max_pages == 100
143 | assert result.created_at == created_at
144 | assert result.updated_at == updated_at
145 | assert result.error_message is None
146 |
147 |
148 | @pytest.mark.unit
149 | @pytest.mark.async_test
150 | async def test_get_job_progress_partial_match(mock_duckdb_connection):
151 | """Test getting job progress with partial job ID match."""
152 | # First attempt should return None to simulate no exact match
153 | mock_cursor1 = MagicMock()
154 | mock_cursor1.fetchone.return_value = None
155 |
156 | # Second attempt (partial match) should succeed
157 | created_at = datetime.datetime(2023, 1, 1, 10, 0, 0)
158 | updated_at = datetime.datetime(2023, 1, 1, 10, 5, 0)
159 |
160 | mock_result = (
161 | "test-job-123-full", # job_id
162 | "https://example.com", # start_url
163 | "completed", # status
164 | 100, # pages_discovered
165 | 100, # pages_crawled
166 | 100, # max_pages
167 | "test,example", # tags_json
168 | created_at, # created_at
169 | updated_at, # updated_at
170 | None, # error_message
171 | )
172 |
173 | mock_cursor2 = MagicMock()
174 | mock_cursor2.fetchone.return_value = mock_result
175 |
176 | # Set up side effect for consecutive calls
177 | mock_duckdb_connection.execute.side_effect = [mock_cursor1, mock_cursor2]
178 |
179 | # Call the function with partial job ID
180 | result = await get_job_progress(
181 | conn=mock_duckdb_connection,
182 | job_id="test-job",
183 | )
184 |
185 | # Should make two database calls
186 | assert mock_duckdb_connection.execute.call_count == 2
187 |
188 | # Second call should use LIKE with wildcard
189 | query2 = mock_duckdb_connection.execute.call_args_list[1][0][0]
190 | assert "WHERE job_id LIKE ?" in query2
191 | params2 = mock_duckdb_connection.execute.call_args_list[1][0][1]
192 | assert params2[0] == "test-job%"
193 |
194 | # Check that results are returned correctly
195 | assert isinstance(result, JobProgressResponse)
196 | assert result.status == "completed"
197 | assert result.pages_crawled == 100
198 | assert result.pages_total == 100
199 | assert result.completed is True
200 | assert result.progress_percent == 100 # 100/100 = 100%
201 |
202 |
203 | @pytest.mark.unit
204 | @pytest.mark.async_test
205 | async def test_get_job_progress_not_found(mock_duckdb_connection):
206 | """Test getting job progress for a non-existent job."""
207 | # Both exact and partial matches return None
208 | mock_cursor1 = MagicMock()
209 | mock_cursor1.fetchone.return_value = None
210 |
211 | mock_cursor2 = MagicMock()
212 | mock_cursor2.fetchone.return_value = None
213 |
214 | # Set up side effect for consecutive calls
215 | mock_duckdb_connection.execute.side_effect = [mock_cursor1, mock_cursor2]
216 |
217 | # Call the function
218 | result = await get_job_progress(
219 | conn=mock_duckdb_connection,
220 | job_id="nonexistent",
221 | )
222 |
223 | # Should return None for non-existent job
224 | assert result is None
225 |
226 |
227 | @pytest.mark.unit
228 | @pytest.mark.async_test
229 | async def test_get_job_count():
230 | """Test getting job count."""
231 | # Mock database connection
232 | mock_conn = MagicMock()
233 | mock_conn.execute.return_value.fetchone.return_value = (42,)
234 |
235 | with patch("src.web_service.services.job_service.duckdb.connect", return_value=mock_conn):
236 | # Call the function
237 | result = await get_job_count()
238 |
239 | # Verify connection was closed
240 | # In new implementation, we close the database, not the connection directly
241 |
242 | # Check result
243 | assert result == 42
244 |
245 |
246 | @pytest.mark.unit
247 | @pytest.mark.async_test
248 | async def test_get_job_count_error():
249 | """Test getting job count when an error occurs."""
250 | # Mock database connection with an error
251 | mock_conn = MagicMock()
252 | mock_conn.execute.side_effect = duckdb.Error("Database error")
253 |
254 | with patch("src.web_service.services.job_service.duckdb.connect", return_value=mock_conn):
255 | # Call the function and assert duckdb.Error is raised
256 | with pytest.raises(duckdb.Error, match="Database error"):
257 | await get_job_count()
258 |
259 | # Should still close the database even with an error
260 | # In new implementation, we close the database, not the connection directly
261 |
--------------------------------------------------------------------------------
/tests/integration/test_processor_enhanced.py:
--------------------------------------------------------------------------------
1 | """Integration tests for the enhanced processor with hierarchy tracking."""
2 |
3 | import pytest
4 | from unittest.mock import Mock, AsyncMock, patch
5 |
6 | from src.common.processor_enhanced import (
7 | process_crawl_result_with_hierarchy,
8 | process_crawl_with_hierarchy,
9 | )
10 | from src.lib.crawler_enhanced import CrawlResultWithHierarchy
11 |
12 |
13 | @pytest.mark.asyncio
14 | @pytest.mark.integration
15 | class TestProcessorEnhancedIntegration:
16 | """Integration tests for enhanced processor."""
17 |
18 | async def test_process_crawl_result_with_hierarchy(self) -> None:
19 | """Test processing a single crawl result with hierarchy.
20 |
21 | Args:
22 | None.
23 |
24 | Returns:
25 | None.
26 | """
27 | # Create a mock enhanced crawl result
28 | base_result = Mock()
29 | base_result.url = "https://example.com/docs"
30 | base_result.html = "Documentation"
31 |
32 | enhanced_result = CrawlResultWithHierarchy(base_result)
33 | enhanced_result.parent_url = "https://example.com"
34 | enhanced_result.root_url = "https://example.com"
35 | enhanced_result.depth = 1
36 | enhanced_result.relative_path = "/docs"
37 | enhanced_result.title = "Documentation"
38 |
39 | # Mock dependencies
40 | with patch("src.common.processor_enhanced.extract_page_text") as mock_extract:
41 | mock_extract.return_value = "# Documentation\n\nThis is the documentation."
42 |
43 | with patch("src.common.processor_enhanced.DatabaseOperations") as MockDB:
44 | mock_db = MockDB.return_value
45 | mock_db.store_page = AsyncMock(return_value="page-123")
46 |
47 | with patch("src.common.processor_enhanced.TextChunker") as MockChunker:
48 | mock_chunker = MockChunker.return_value
49 | mock_chunker.split_text.return_value = ["Chunk 1 content", "Chunk 2 content"]
50 |
51 | with patch("src.common.processor_enhanced.generate_embedding") as mock_embed:
52 | mock_embed.return_value = [0.1, 0.2, 0.3]
53 |
54 | with patch("src.common.processor_enhanced.VectorIndexer") as MockIndexer:
55 | mock_indexer = MockIndexer.return_value
56 | mock_indexer.index_vector = AsyncMock()
57 |
58 | # Process with hierarchy tracking
59 | url_to_page_id = {"https://example.com": "root-123"}
60 | page_id = await process_crawl_result_with_hierarchy(
61 | enhanced_result,
62 | job_id="test-job",
63 | tags=["test"],
64 | url_to_page_id=url_to_page_id,
65 | )
66 |
67 | # Verify results
68 | assert page_id == "page-123"
69 | assert url_to_page_id[enhanced_result.url] == "page-123"
70 |
71 | # Verify store_page was called with hierarchy info
72 | mock_db.store_page.assert_called_once_with(
73 | url="https://example.com/docs",
74 | text="# Documentation\n\nThis is the documentation.",
75 | job_id="test-job",
76 | tags=["test"],
77 | parent_page_id="root-123", # Parent ID was looked up
78 | root_page_id="root-123",
79 | depth=1,
80 | path="/docs",
81 | title="Documentation",
82 | )
83 |
84 | # Verify chunks were indexed
85 | assert mock_indexer.index_vector.call_count == 2
86 |
87 | async def test_process_crawl_with_hierarchy_full_flow(self) -> None:
88 | """Test the full crawl and process flow with hierarchy.
89 |
90 | Args:
91 | None.
92 |
93 | Returns:
94 | None.
95 | """
96 | # Mock crawl results with hierarchy
97 | mock_crawl_results = []
98 | for i, (url, parent_url, depth) in enumerate(
99 | [
100 | ("https://example.com", None, 0),
101 | ("https://example.com/about", "https://example.com", 1),
102 | ("https://example.com/docs", "https://example.com", 1),
103 | ("https://example.com/docs/api", "https://example.com/docs", 2),
104 | ]
105 | ):
106 | base = Mock()
107 | base.url = url
108 | base.html = f"Page {i}"
109 |
110 | enhanced = CrawlResultWithHierarchy(base)
111 | enhanced.url = url
112 | enhanced.parent_url = parent_url
113 | enhanced.root_url = "https://example.com"
114 | enhanced.depth = depth
115 | enhanced.title = f"Page {i}"
116 | enhanced.relative_path = url.replace("https://example.com", "") or "/"
117 |
118 | mock_crawl_results.append(enhanced)
119 |
120 | # Mock the enhanced crawler
121 | with patch("src.common.processor_enhanced.crawl_url_with_hierarchy") as mock_crawl:
122 | mock_crawl.return_value = mock_crawl_results
123 |
124 | # Mock other dependencies
125 | with patch("src.common.processor_enhanced.extract_page_text") as mock_extract:
126 | mock_extract.return_value = "Page content"
127 |
128 | with patch("src.common.processor_enhanced.DatabaseOperations") as MockDB:
129 | mock_db = MockDB.return_value
130 | # Return unique IDs for each page
131 | mock_db.store_page = AsyncMock(
132 | side_effect=["page-0", "page-1", "page-2", "page-3"]
133 | )
134 | mock_db.update_job_status = AsyncMock()
135 |
136 | with patch("src.common.processor_enhanced.TextChunker") as MockChunker:
137 | mock_chunker = MockChunker.return_value
138 | mock_chunker.split_text.return_value = ["chunk"]
139 |
140 | with patch(
141 | "src.common.processor_enhanced.generate_embedding"
142 | ) as mock_embed:
143 | mock_embed.return_value = [0.1, 0.2]
144 |
145 | with patch(
146 | "src.common.processor_enhanced.VectorIndexer"
147 | ) as MockIndexer:
148 | mock_indexer = MockIndexer.return_value
149 | mock_indexer.index_vector = AsyncMock()
150 |
151 | # Process everything
152 | page_ids = await process_crawl_with_hierarchy(
153 | url="https://example.com",
154 | job_id="test-job",
155 | tags=["test"],
156 | max_pages=10,
157 | )
158 |
159 | # Verify results
160 | assert len(page_ids) == 4
161 | assert page_ids == ["page-0", "page-1", "page-2", "page-3"]
162 |
163 | # Verify pages were stored in order (parent before child)
164 | assert mock_db.store_page.call_count == 4
165 |
166 | # Check that parent IDs were properly resolved
167 | # First page (root) has no parent
168 | first_call = mock_db.store_page.call_args_list[0]
169 | assert first_call.kwargs["parent_page_id"] is None
170 |
171 | # Second page's parent should be the root's ID
172 | second_call = mock_db.store_page.call_args_list[1]
173 | assert second_call.kwargs["parent_page_id"] == "page-0"
174 |
175 | # Fourth page's parent should be the third page's ID
176 | fourth_call = mock_db.store_page.call_args_list[3]
177 | assert fourth_call.kwargs["parent_page_id"] == "page-2"
178 |
179 | async def test_process_crawl_with_hierarchy_error_handling(self) -> None:
180 | """Test error handling in hierarchy processing.
181 |
182 | Args:
183 | None.
184 |
185 | Returns:
186 | None.
187 | """
188 | # Create a result that will fail
189 | base_result = Mock()
190 | base_result.url = "https://example.com/bad"
191 |
192 | enhanced_result = CrawlResultWithHierarchy(base_result)
193 | enhanced_result.url = "https://example.com/bad"
194 | enhanced_result.parent_url = None
195 | enhanced_result.depth = 0
196 | enhanced_result.title = "Bad Page"
197 |
198 | with patch("src.common.processor_enhanced.crawl_url_with_hierarchy") as mock_crawl:
199 | mock_crawl.return_value = [enhanced_result]
200 |
201 | with patch("src.common.processor_enhanced.extract_page_text") as mock_extract:
202 | # Make extraction fail
203 | mock_extract.side_effect = Exception("Extraction failed")
204 |
205 | with patch("src.common.processor_enhanced.DatabaseOperations") as MockDB:
206 | mock_db = MockDB.return_value
207 | mock_db.update_job_status = AsyncMock()
208 |
209 | # Process should continue despite error
210 | page_ids = await process_crawl_with_hierarchy(
211 | url="https://example.com", job_id="test-job"
212 | )
213 |
214 | # Should return empty list due to error
215 | assert page_ids == []
216 |
217 | # Job status won't be updated if all pages fail
218 | mock_db.update_job_status.assert_not_called()
219 |
--------------------------------------------------------------------------------
/src/web_service/api/diagnostics.py:
--------------------------------------------------------------------------------
1 | """Diagnostic API routes for debugging purposes.
2 | Important: Should be disabled in production.
3 | """
4 |
5 | from fastapi import APIRouter, HTTPException, Query
6 |
7 | from src.common.logger import get_logger
8 | from src.lib.database import DatabaseOperations
9 | from src.web_service.services.debug_bm25 import debug_bm25_search
10 |
11 | # Get logger for this module
12 | logger = get_logger(__name__)
13 |
14 | # Create router
15 | router = APIRouter(tags=["diagnostics"])
16 |
17 |
18 | @router.get("/bm25_diagnostics", operation_id="bm25_diagnostics")
19 | async def bm25_diagnostics_endpoint(
20 | query: str = Query("test", description="The search query to test with"),
21 | initialize: bool = Query(
22 | True, description="If true, attempt to initialize FTS if in write mode."
23 | ),
24 | ):
25 | """
26 | Combined endpoint for DuckDB FTS initialization and diagnostics.
27 | - Checks for the existence of the 'pages' table and 'raw_text' column before attempting FTS index creation.
28 | - Always includes the full error message from FTS index creation in the response.
29 | - Adds verbose logging and schema checks to the diagnostics output.
30 | """
31 | logger.info(
32 | "API: Running combined BM25 diagnostics and initialization endpoint (DuckDB-specific)"
33 | )
34 |
35 | db = DatabaseOperations(read_only=not initialize)
36 | conn = db.db.ensure_connection()
37 | read_only = db.db.read_only
38 | initialization_summary = {}
39 | schema_info = {}
40 |
41 | # Check for 'pages' table and 'raw_text' column existence
42 | try:
43 | tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
44 | table_names = [row[0] for row in tables]
45 | schema_info["tables"] = table_names
46 | pages_table_exists = "pages" in table_names
47 | initialization_summary["pages_table_exists"] = pages_table_exists
48 | if pages_table_exists:
49 | columns = conn.execute("PRAGMA table_info('pages')").fetchall()
50 | col_names = [row[1] for row in columns]
51 | schema_info["pages_columns"] = col_names
52 | initialization_summary["pages_columns"] = col_names
53 | raw_text_exists = "raw_text" in col_names
54 | initialization_summary["raw_text_column_exists"] = raw_text_exists
55 | else:
56 | initialization_summary["pages_columns"] = []
57 | initialization_summary["raw_text_column_exists"] = False
58 | except Exception as e_schema:
59 | initialization_summary["schema_check_error"] = str(e_schema)
60 | schema_info["schema_check_error"] = str(e_schema)
61 |
62 | try:
63 | if initialize and not read_only:
64 | logger.info("Attempting FTS initialization in write mode...")
65 | # 1. Load FTS extension
66 | try:
67 | conn.execute("INSTALL fts;")
68 | conn.execute("LOAD fts;")
69 | initialization_summary["fts_extension_loaded"] = True
70 | except Exception as e_fts:
71 | logger.warning(f"FTS extension loading during initialization: {e_fts}")
72 | initialization_summary["fts_extension_loaded"] = False
73 | initialization_summary["fts_extension_error"] = str(e_fts)
74 | # 2. Create FTS index (only if table and column exist)
75 | if initialization_summary.get("pages_table_exists") and initialization_summary.get(
76 | "raw_text_column_exists"
77 | ):
78 | try:
79 | conn.execute("PRAGMA create_fts_index('pages', 'id', 'raw_text');")
80 | initialization_summary["fts_index_created"] = True
81 | initialization_summary["fts_index_creation_error"] = None
82 | except Exception as e_index:
83 | logger.warning(f"FTS index creation error: {e_index}")
84 | initialization_summary["fts_index_created"] = False
85 | initialization_summary["fts_index_creation_error"] = str(e_index)
86 | if "already exists" in str(e_index):
87 | initialization_summary["fts_index_exists"] = True
88 | else:
89 | initialization_summary["fts_index_created"] = False
90 | initialization_summary["fts_index_creation_error"] = (
91 | "'pages' table or 'raw_text' column does not exist."
92 | )
93 | # 2b. Check for FTS index existence using DuckDB's sqlite_master and information_schema.tables
94 | try:
95 | fts_indexes = conn.execute(
96 | "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'fts_idx_%'"
97 | ).fetchall()
98 | initialization_summary["fts_indexes_sqlite_master"] = (
99 | [row[0] for row in fts_indexes] if fts_indexes else []
100 | )
101 | initialization_summary["fts_index_exists_sqlite_master"] = bool(fts_indexes)
102 | fts_idx_tables = conn.execute(
103 | "SELECT table_name FROM information_schema.tables WHERE table_name LIKE 'fts_idx_%'"
104 | ).fetchall()
105 | initialization_summary["fts_indexes_information_schema"] = (
106 | [row[0] for row in fts_idx_tables] if fts_idx_tables else []
107 | )
108 | initialization_summary["fts_index_exists_information_schema"] = bool(fts_idx_tables)
109 | except Exception as e_fts_idx:
110 | initialization_summary["fts_index_check_error"] = str(e_fts_idx)
111 | # 3. Drop legacy table
112 | try:
113 | table_exists_result = conn.execute(
114 | "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'fts_main_pages'"
115 | ).fetchone()
116 | table_exists = table_exists_result[0] > 0 if table_exists_result else False
117 | if table_exists:
118 | conn.execute("DROP TABLE IF EXISTS fts_main_pages;")
119 | initialization_summary["legacy_fts_main_pages_dropped"] = True
120 | else:
121 | initialization_summary["legacy_fts_main_pages_dropped"] = False
122 | except Exception as e_drop:
123 | initialization_summary["legacy_fts_main_pages_drop_error"] = str(e_drop)
124 | else:
125 | initialization_summary["initialization_skipped"] = True
126 | initialization_summary["read_only_mode"] = read_only
127 | except Exception as e:
128 | logger.error(f"Error during initialization: {e}")
129 | initialization_summary["initialization_error"] = str(e)
130 |
131 | # Always run diagnostics
132 | try:
133 | diagnostics = await debug_bm25_search(conn, query)
134 | # Add FTS index check and schema info to diagnostics as well
135 | try:
136 | fts_indexes_diag = conn.execute(
137 | "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'fts_idx_%'"
138 | ).fetchall()
139 | diagnostics["fts_indexes_sqlite_master"] = (
140 | [row[0] for row in fts_indexes_diag] if fts_indexes_diag else []
141 | )
142 | diagnostics["fts_index_exists_sqlite_master"] = bool(fts_indexes_diag)
143 | fts_idx_tables_diag = conn.execute(
144 | "SELECT table_name FROM information_schema.tables WHERE table_name LIKE 'fts_idx_%'"
145 | ).fetchall()
146 | diagnostics["fts_indexes_information_schema"] = (
147 | [row[0] for row in fts_idx_tables_diag] if fts_idx_tables_diag else []
148 | )
149 | diagnostics["fts_index_exists_information_schema"] = bool(fts_idx_tables_diag)
150 | # Add schema info
151 | diagnostics["schema_info"] = schema_info
152 | except Exception as e_fts_idx_diag:
153 | diagnostics["fts_index_check_error"] = str(e_fts_idx_diag)
154 | # Remove BM25 function checks and errors from diagnostics if present
155 | for k in list(diagnostics.keys()):
156 | if "bm25" in k or "BM25" in k:
157 | diagnostics.pop(k)
158 | # Remove recommendations about BM25 function
159 | recs = diagnostics.get("recommendations", [])
160 | diagnostics["recommendations"] = [rec for rec in recs if "match_bm25 function" not in rec]
161 | except Exception as e:
162 | logger.error(f"Error during BM25 diagnostics: {e!s}")
163 | raise HTTPException(status_code=500, detail=f"Diagnostic error: {e!s}")
164 | finally:
165 | db.db.close()
166 |
167 | # Compose response
168 | response = {
169 | "initialization": initialization_summary,
170 | "diagnostics": diagnostics,
171 | }
172 | # Recommendations: only mention /bm25_diagnostics?initialize=true for FTS index creation if no FTS index is found
173 | fts_index_exists = response["diagnostics"].get("fts_index_exists_sqlite_master") or response[
174 | "diagnostics"
175 | ].get("fts_index_exists_information_schema")
176 | recs = diagnostics.get("recommendations", [])
177 | new_recs = [
178 | rec for rec in recs if "/initialize_bm25" not in rec and "match_bm25 function" not in rec
179 | ]
180 | if not fts_index_exists:
181 | new_recs.append(
182 | "No FTS indexes found. Run /bm25_diagnostics?initialize=true with a write connection."
183 | )
184 | response["recommendations"] = new_recs
185 | if "status" in diagnostics:
186 | response["status"] = diagnostics["status"]
187 | return response
188 |
--------------------------------------------------------------------------------
/tests/integration/services/test_document_service.py:
--------------------------------------------------------------------------------
1 | """Integration tests for the document service with real DuckDB."""
2 |
3 | from unittest.mock import patch
4 |
5 | import pytest
6 |
7 | from src.common.config import VECTOR_SIZE
8 | from src.common.indexer import VectorIndexer
9 | from src.web_service.services.document_service import (
10 | get_doc_page,
11 | list_doc_pages,
12 | list_tags,
13 | search_docs,
14 | )
15 |
16 |
17 | @pytest.mark.integration
18 | @pytest.mark.async_test
19 | async def test_search_docs_with_duckdb(in_memory_duckdb_connection):
20 | """Test searching documents with DuckDB backend using hybrid search (vector and BM25)."""
21 | # Create test data in the in-memory database
22 | indexer = VectorIndexer(connection=in_memory_duckdb_connection)
23 |
24 | # Insert test pages with raw text for BM25 search
25 | in_memory_duckdb_connection.execute("""
26 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text, job_id)
27 | VALUES
28 | ('page1', 'https://example.com/ai', 'example.com', '2023-01-01',
29 | '["ai", "tech"]', 'This is a document about artificial intelligence and machine learning.', 'job1'),
30 | ('page2', 'https://example.com/ml', 'example.com', '2023-01-02',
31 | '["ml", "tech"]', 'This is a document about machine learning algorithms and neural networks.', 'job1'),
32 | ('page3', 'https://example.com/nlp', 'example.com', '2023-01-03',
33 | '["nlp", "tech"]', 'Natural language processing is a field of artificial intelligence.', 'job1')
34 | """)
35 |
36 | # Try creating FTS index for BM25 search
37 | try:
38 | in_memory_duckdb_connection.execute("INSTALL fts; LOAD fts;")
39 | in_memory_duckdb_connection.execute("PRAGMA create_fts_index('pages', 'id', 'raw_text');")
40 | except Exception as e:
41 | print(f"Warning: Could not create FTS index: {e}. BM25 search may not work in tests.")
42 |
43 | # Create test vectors and payloads for vector search
44 | test_vector1 = [0.1] * VECTOR_SIZE
45 | test_payload1 = {
46 | "text": "This is a document about artificial intelligence",
47 | "page_id": "page1",
48 | "url": "https://example.com/ai",
49 | "domain": "example.com",
50 | "tags": ["ai", "tech"],
51 | "job_id": "job1",
52 | }
53 |
54 | test_vector2 = [0.2] * VECTOR_SIZE
55 | test_payload2 = {
56 | "text": "This is a document about machine learning",
57 | "page_id": "page2",
58 | "url": "https://example.com/ml",
59 | "domain": "example.com",
60 | "tags": ["ml", "tech"],
61 | "job_id": "job1",
62 | }
63 |
64 | test_vector3 = [0.3] * VECTOR_SIZE
65 | test_payload3 = {
66 | "text": "Natural language processing is a field of artificial intelligence",
67 | "page_id": "page3",
68 | "url": "https://example.com/nlp",
69 | "domain": "example.com",
70 | "tags": ["nlp", "tech"],
71 | "job_id": "job1",
72 | }
73 |
74 | # Index test data for vector search
75 | await indexer.index_vector(test_vector1, test_payload1)
76 | await indexer.index_vector(test_vector2, test_payload2)
77 | await indexer.index_vector(test_vector3, test_payload3)
78 |
79 | # Mock the embedding generation to return a vector similar to test_vector1
80 | with patch("src.lib.embedder.generate_embedding", return_value=[0.11] * VECTOR_SIZE):
81 | # Test 1: Standard hybrid search with a query that should match via both vector and BM25
82 | result = await search_docs(
83 | conn=in_memory_duckdb_connection,
84 | query="artificial intelligence",
85 | tags=None,
86 | max_results=5,
87 | )
88 |
89 | # Verify results - should find page1 and page3 as they mention 'artificial intelligence'
90 | assert len(result.results) >= 2 # Should return at least 2 results
91 | # The first result should be about artificial intelligence
92 | assert any("artificial intelligence" in r.chunk_text for r in result.results)
93 | # Check page_ids - both page1 and page3 should be in results
94 | result_page_ids = [r.page_id for r in result.results]
95 | assert "page1" in result_page_ids
96 | assert "page3" in result_page_ids
97 |
98 | # Test 2: Using tag filter
99 | result = await search_docs(
100 | conn=in_memory_duckdb_connection,
101 | query="artificial intelligence",
102 | tags=["ai"],
103 | max_results=5,
104 | )
105 |
106 | # Verify filtered results
107 | assert any(r.page_id == "page1" for r in result.results)
108 | assert all("ai" in r.tags for r in result.results)
109 |
110 | # Test 3: Test with BM25-specific term not in vectors but in raw text
111 | result = await search_docs(
112 | conn=in_memory_duckdb_connection,
113 | query="neural networks",
114 | tags=None,
115 | max_results=5,
116 | )
117 |
118 | # Verify BM25 results - should find page2 which mentions 'neural networks'
119 | assert any(r.page_id == "page2" for r in result.results)
120 | assert any("neural networks" in r.chunk_text for r in result.results)
121 |
122 | # Test 4: Adjusting hybrid weights
123 | result = await search_docs(
124 | conn=in_memory_duckdb_connection,
125 | query="artificial intelligence",
126 | tags=None,
127 | max_results=5,
128 | hybrid_weight=0.3, # More weight on BM25 (0.7) than vector (0.3)
129 | )
130 |
131 | # With more weight on BM25, we expect different result ordering
132 | # Check that the results are not empty
133 | assert len(result.results) > 0
134 |
135 |
136 | @pytest.mark.integration
137 | @pytest.mark.async_test
138 | async def test_list_doc_pages_with_duckdb(in_memory_duckdb_connection):
139 | """Test listing document pages with DuckDB backend."""
140 | # Insert test data directly into the pages table
141 | in_memory_duckdb_connection.execute("""
142 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text)
143 | VALUES
144 | ('page1', 'https://example.com/page1', 'example.com', '2023-01-01',
145 | '["doc","example"]', 'This is page 1'),
146 | ('page2', 'https://example.com/page2', 'example.com', '2023-01-02',
147 | '["doc","test"]', 'This is page 2'),
148 | ('page3', 'https://example.org/page3', 'example.org', '2023-01-03',
149 | '["other","example"]', 'This is page 3')
150 | """)
151 |
152 | # Test with no filters
153 | result = await list_doc_pages(conn=in_memory_duckdb_connection, page=1, tags=None)
154 |
155 | # Verify results
156 | assert result.total_pages >= 1 # There is at least 1 page of results
157 | assert result.current_page == 1
158 | assert len(result.doc_pages) == 3
159 |
160 | # Test with tag filter
161 | result = await list_doc_pages(conn=in_memory_duckdb_connection, page=1, tags=["doc"])
162 |
163 | # Verify filtered results
164 | assert len(result.doc_pages) == 2
165 | assert all("doc" in page.tags for page in result.doc_pages)
166 |
167 |
168 | @pytest.mark.integration
169 | @pytest.mark.async_test
170 | async def test_get_doc_page_with_duckdb(in_memory_duckdb_connection):
171 | """Test retrieving a document page with DuckDB backend."""
172 | # Insert test data directly into the pages table
173 | in_memory_duckdb_connection.execute("""
174 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text)
175 | VALUES ('test-page', 'https://example.com/test', 'example.com', '2023-01-01',
176 | '["test"]', 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5')
177 | """)
178 |
179 | # Test retrieving the entire page
180 | result = await get_doc_page(
181 | conn=in_memory_duckdb_connection,
182 | page_id="test-page",
183 | starting_line=1,
184 | ending_line=-1,
185 | )
186 |
187 | # Verify results
188 | assert result.text == "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
189 | assert result.total_lines == 5
190 |
191 | # Test retrieving specific lines
192 | result = await get_doc_page(
193 | conn=in_memory_duckdb_connection,
194 | page_id="test-page",
195 | starting_line=2,
196 | ending_line=4,
197 | )
198 |
199 | # Verify partial results
200 | assert result.text == "Line 2\nLine 3\nLine 4"
201 | assert result.total_lines == 5
202 |
203 |
204 | @pytest.mark.integration
205 | @pytest.mark.async_test
206 | async def test_list_tags_with_duckdb(in_memory_duckdb_connection):
207 | """Test listing unique tags with DuckDB backend."""
208 | # Insert test data directly into the pages table with different tags
209 | in_memory_duckdb_connection.execute("""
210 | INSERT INTO pages (id, url, domain, crawl_date, tags, raw_text)
211 | VALUES
212 | ('page1', 'https://example.com/page1', 'example.com', '2023-01-01',
213 | '["tag1","common"]', 'Page 1'),
214 | ('page2', 'https://example.com/page2', 'example.com', '2023-01-02',
215 | '["tag2","common"]', 'Page 2'),
216 | ('page3', 'https://example.com/page3', 'example.com', '2023-01-03',
217 | '["tag3","special"]', 'Page 3')
218 | """)
219 |
220 | # Test listing all tags
221 | result = await list_tags(conn=in_memory_duckdb_connection, search_substring=None)
222 |
223 | # Verify results contain all unique tags
224 | assert len(result.tags) == 5
225 | assert set(result.tags) == {"tag1", "tag2", "tag3", "common", "special"}
226 |
227 | # Test with search substring
228 | result = await list_tags(conn=in_memory_duckdb_connection, search_substring="tag")
229 |
230 | # Verify filtered results
231 | assert len(result.tags) == 3
232 | assert all("tag" in tag for tag in result.tags)
233 |
--------------------------------------------------------------------------------
/tests/services/test_map_service.py:
--------------------------------------------------------------------------------
1 | """Tests for the map service."""
2 |
3 | import pytest
4 | from unittest.mock import AsyncMock, patch
5 | import datetime
6 |
7 | from src.web_service.services.map_service import MapService
8 |
9 |
10 | @pytest.mark.asyncio
11 | class TestMapService:
12 | """Test the MapService class."""
13 |
14 | async def test_get_all_sites(self) -> None:
15 | """Test getting all sites.
16 |
17 | Args:
18 | None.
19 |
20 | Returns:
21 | None.
22 | """
23 | service = MapService()
24 |
25 | # Mock the database response - sites from different domains
26 | mock_sites = [
27 | {
28 | "id": "site1",
29 | "url": "https://example1.com",
30 | "title": "Site 1",
31 | "domain": "example1.com",
32 | },
33 | {
34 | "id": "site2",
35 | "url": "https://example2.com",
36 | "title": "Site 2",
37 | "domain": "example2.com",
38 | },
39 | ]
40 |
41 | with patch.object(service.db_ops, "get_root_pages", new_callable=AsyncMock) as mock_get:
42 | mock_get.return_value = mock_sites
43 |
44 | sites = await service.get_all_sites()
45 |
46 | # Since they're from different domains, they should not be grouped
47 | assert len(sites) == 2
48 | assert sites[0]["id"] == "site1"
49 | assert sites[1]["id"] == "site2"
50 | mock_get.assert_called_once()
51 |
52 | async def test_build_page_tree(self) -> None:
53 | """Test building a page tree structure.
54 |
55 | Args:
56 | None.
57 |
58 | Returns:
59 | None.
60 | """
61 | service = MapService()
62 | root_id = "root-123"
63 |
64 | # Mock hierarchy data
65 | mock_pages = [
66 | {
67 | "id": "root-123",
68 | "parent_page_id": None,
69 | "title": "Home",
70 | "url": "https://example.com",
71 | },
72 | {
73 | "id": "page-1",
74 | "parent_page_id": "root-123",
75 | "title": "About",
76 | "url": "https://example.com/about",
77 | },
78 | {
79 | "id": "page-2",
80 | "parent_page_id": "root-123",
81 | "title": "Docs",
82 | "url": "https://example.com/docs",
83 | },
84 | {
85 | "id": "page-3",
86 | "parent_page_id": "page-2",
87 | "title": "API",
88 | "url": "https://example.com/docs/api",
89 | },
90 | ]
91 |
92 | with patch.object(service.db_ops, "get_page_hierarchy", new_callable=AsyncMock) as mock_get:
93 | mock_get.return_value = mock_pages
94 |
95 | tree = await service.build_page_tree(root_id)
96 |
97 | # Verify tree structure
98 | assert tree["id"] == "root-123"
99 | assert tree["title"] == "Home"
100 | assert len(tree["children"]) == 2
101 |
102 | # Check first level children
103 | about_page = next(c for c in tree["children"] if c["id"] == "page-1")
104 | docs_page = next(c for c in tree["children"] if c["id"] == "page-2")
105 |
106 | assert about_page["title"] == "About"
107 | assert len(about_page["children"]) == 0
108 |
109 | assert docs_page["title"] == "Docs"
110 | assert len(docs_page["children"]) == 1
111 | assert docs_page["children"][0]["title"] == "API"
112 |
113 | async def test_get_navigation_context(self) -> None:
114 | """Test getting navigation context for a page.
115 |
116 | Args:
117 | None.
118 |
119 | Returns:
120 | None.
121 | """
122 | service = MapService()
123 | page_id = "page-123"
124 |
125 | # Mock current page
126 | mock_page = {
127 | "id": page_id,
128 | "parent_page_id": "parent-123",
129 | "root_page_id": "root-123",
130 | "title": "Current Page",
131 | }
132 |
133 | # Mock related pages
134 | mock_parent = {"id": "parent-123", "title": "Parent Page"}
135 | mock_siblings = [
136 | {"id": "sibling-1", "title": "Sibling 1"},
137 | {"id": "sibling-2", "title": "Sibling 2"},
138 | ]
139 | mock_children = [
140 | {"id": "child-1", "title": "Child 1"},
141 | ]
142 | mock_root = {"id": "root-123", "title": "Home"}
143 |
144 | with patch.object(
145 | service.db_ops, "get_page_by_id", new_callable=AsyncMock
146 | ) as mock_get_page:
147 | mock_get_page.side_effect = [mock_page, mock_parent, mock_root]
148 |
149 | with patch.object(
150 | service.db_ops, "get_sibling_pages", new_callable=AsyncMock
151 | ) as mock_siblings_fn:
152 | mock_siblings_fn.return_value = mock_siblings
153 |
154 | with patch.object(
155 | service.db_ops, "get_child_pages", new_callable=AsyncMock
156 | ) as mock_children_fn:
157 | mock_children_fn.return_value = mock_children
158 |
159 | context = await service.get_navigation_context(page_id)
160 |
161 | assert context["current_page"] == mock_page
162 | assert context["parent"] == mock_parent
163 | assert context["siblings"] == mock_siblings
164 | assert context["children"] == mock_children
165 | assert context["root"] == mock_root
166 |
167 | def test_render_page_html(self) -> None:
168 | """Test rendering a page as HTML.
169 |
170 | Args:
171 | None.
172 |
173 | Returns:
174 | None.
175 | """
176 | service = MapService()
177 |
178 | page = {
179 | "id": "page-123",
180 | "title": "Test Page",
181 | "raw_text": "# Test Page\n\nThis is **markdown** content.",
182 | }
183 |
184 | navigation = {
185 | "current_page": page,
186 | "parent": {"id": "parent-123", "title": "Parent"},
187 | "siblings": [
188 | {"id": "sib-1", "title": "Sibling 1"},
189 | ],
190 | "children": [
191 | {"id": "child-1", "title": "Child 1"},
192 | ],
193 | "root": {"id": "root-123", "title": "Home"},
194 | }
195 |
196 | html = service.render_page_html(page, navigation)
197 |
198 | # Check that HTML contains expected elements
199 | assert "Test Page" in html
200 | assert "Test Page
" in html
201 | assert "markdown" in html # Markdown was rendered
202 | assert 'href="/map/page/parent-123"' in html # Parent link
203 | assert 'href="/map/page/sib-1"' in html # Sibling link
204 | assert 'href="/map/page/child-1"' in html # Child link
205 | assert 'href="/map/site/root-123"' in html # Site map link
206 |
207 | def test_format_site_list(self) -> None:
208 | """Test formatting a list of sites.
209 |
210 | Args:
211 | None.
212 |
213 | Returns:
214 | None.
215 | """
216 | service = MapService()
217 |
218 | sites = [
219 | {
220 | "id": "site1",
221 | "url": "https://example1.com",
222 | "title": "Example Site 1",
223 | "crawl_date": datetime.datetime(2024, 1, 1, 12, 0, 0),
224 | },
225 | {
226 | "id": "site2",
227 | "url": "https://example2.com",
228 | "title": "Example Site 2",
229 | "crawl_date": datetime.datetime(2024, 1, 2, 12, 0, 0),
230 | },
231 | ]
232 |
233 | html = service.format_site_list(sites)
234 |
235 | # Check HTML content
236 | assert "Site Map - All Sites" in html
237 | assert "Example Site 1" in html
238 | assert "Example Site 2" in html
239 | assert 'href="/map/site/site1"' in html
240 | assert 'href="/map/site/site2"' in html
241 | assert "https://example1.com" in html
242 | assert "https://example2.com" in html
243 |
244 | def test_format_site_list_empty(self) -> None:
245 | """Test formatting empty site list.
246 |
247 | Args:
248 | None.
249 |
250 | Returns:
251 | None.
252 | """
253 | service = MapService()
254 |
255 | html = service.format_site_list([])
256 |
257 | assert "No Sites Found" in html
258 | assert "No crawled sites are available" in html
259 |
260 | def test_format_site_tree(self) -> None:
261 | """Test formatting a site tree.
262 |
263 | Args:
264 | None.
265 |
266 | Returns:
267 | None.
268 | """
269 | service = MapService()
270 |
271 | tree = {
272 | "id": "root-123",
273 | "title": "My Site",
274 | "children": [
275 | {"id": "page-1", "title": "About", "children": []},
276 | {
277 | "id": "page-2",
278 | "title": "Docs",
279 | "children": [{"id": "page-3", "title": "API", "children": []}],
280 | },
281 | ],
282 | }
283 |
284 | html = service.format_site_tree(tree)
285 |
286 | # Check HTML structure
287 | assert "My Site - Site Map" in html
288 | assert 'href="/map/page/root-123"' in html
289 | assert 'href="/map/page/page-1"' in html
290 | assert 'href="/map/page/page-2"' in html
291 | assert 'href="/map/page/page-3"' in html
292 | assert "" in html # Collapsible sections
293 | assert "Back to all sites" in html
294 |
295 | def test_build_tree_html_leaf_node(self) -> None:
296 | """Test building HTML for a leaf node.
297 |
298 | Args:
299 | None.
300 |
301 | Returns:
302 | None.
303 | """
304 | service = MapService()
305 |
306 | node = {"id": "leaf-123", "title": "Leaf Page", "children": []}
307 |
308 | html = service._build_tree_html(node, is_root=False)
309 |
310 | assert 'Leaf Page' in html
311 | assert '' in html
312 | assert '' in html
313 | assert "" not in html # No collapsible section for leaf
314 |
315 | def test_build_tree_html_with_children(self) -> None:
316 | """Test building HTML for a node with children.
317 |
318 | Args:
319 | None.
320 |
321 | Returns:
322 | None.
323 | """
324 | service = MapService()
325 |
326 | node = {
327 | "id": "parent-123",
328 | "title": "Parent Page",
329 | "children": [
330 | {"id": "child-1", "title": "Child 1", "children": []},
331 | {"id": "child-2", "title": "Child 2", "children": []},
332 | ],
333 | }
334 |
335 | html = service._build_tree_html(node, is_root=False)
336 |
337 | assert "" in html
338 | assert "" in html
339 | assert "Parent Page" in html
340 | assert "Child 1" in html
341 | assert "Child 2" in html
342 |
--------------------------------------------------------------------------------