├── src ├── __init__.py ├── api │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── cors_config.py │ ├── handlers │ │ ├── __init__.py │ │ └── error_handlers.py │ ├── system │ │ ├── __init__.py │ │ ├── sentry_debug.py │ │ └── health.py │ ├── middleware │ │ ├── __init__.py │ │ ├── debug_middleware.py │ │ └── rate_limit.py │ ├── v0 │ │ ├── endpoints │ │ │ ├── __init__.py │ │ │ ├── webhooks.py │ │ │ ├── sources.py │ │ │ └── chat.py │ │ └── schemas │ │ │ ├── __init__.py │ │ │ ├── health_schemas.py │ │ │ ├── base_schemas.py │ │ │ └── webhook_schemas.py │ ├── routes.py │ └── dependencies.py ├── core │ ├── __init__.py │ ├── chat │ │ ├── __init__.py │ │ ├── prompt_manager.py │ │ ├── tool_manager.py │ │ ├── tools │ │ │ └── tools.yaml │ │ └── prompts │ │ │ └── prompts.yaml │ ├── content │ │ └── __init__.py │ └── search │ │ ├── __init__.py │ │ ├── embedding_manager.py │ │ ├── reranker.py │ │ └── retriever.py ├── infra │ ├── __init__.py │ ├── arq │ │ ├── __init__.py │ │ ├── worker.py │ │ ├── redis_pool.py │ │ ├── arq_settings.py │ │ └── worker_services.py │ ├── data │ │ └── __init__.py │ ├── events │ │ ├── __init__.py │ │ ├── channels.py │ │ ├── event_publisher.py │ │ └── event_consumer.py │ ├── misc │ │ ├── __init__.py │ │ └── ngrok_service.py │ ├── external │ │ ├── __init__.py │ │ ├── chroma_manager.py │ │ ├── supabase_manager.py │ │ └── redis_manager.py │ ├── logger.py │ └── service_container.py ├── models │ ├── __init__.py │ ├── task_models.py │ ├── pubsub_models.py │ ├── base_models.py │ ├── vector_models.py │ ├── llm_models.py │ └── job_models.py ├── services │ ├── __init__.py │ ├── webhook_handler.py │ └── job_manager.py └── app.py ├── tests ├── unit │ ├── infra │ │ ├── __init__.py │ │ ├── arq │ │ │ ├── __init__.py │ │ │ ├── test_worker.py │ │ │ ├── test_arq_settings.py │ │ │ ├── test_redis_pool.py │ │ │ └── test_task_definitions.py │ │ ├── external │ │ │ ├── __init__.py │ │ │ ├── test_supabase_manager.py │ │ │ ├── test_chroma_manager.py │ │ │ └── test_redis_manager.py │ │ └── storage │ │ │ └── __init__.py │ ├── __init__.py │ ├── conftest.py │ ├── test_app.py │ ├── test_supabase_client.py │ └── test_settings.py ├── integration │ ├── infra │ │ ├── __init__.py │ │ ├── storage │ │ │ └── __init__.py │ │ ├── external │ │ │ ├── __init__.py │ │ │ ├── test_chroma_manager_integration.py │ │ │ └── test_redis_manager_integration.py │ │ └── arq │ │ │ ├── __init__.py │ │ │ ├── test_worker_integration.py │ │ │ ├── arq_fixtures_integration.py │ │ │ ├── test_worker_settings_integration.py │ │ │ ├── test_redis_pool_integration.py │ │ │ └── test_arq_settings_integration.py │ ├── __init__.py │ ├── conftest.py │ ├── services │ │ ├── __init__.py │ │ └── test_chat_service.py │ └── test_app.py ├── __init__.py └── data │ └── __init__.py ├── scripts └── docker │ ├── entrypoint.sh │ ├── Dockerfile │ ├── .dockerignore │ └── compose.yaml ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── pull_request_template.md ├── docs ├── feature-spec │ ├── content-ingestion.md │ └── arq-migration.md ├── changelog │ └── CHANGELOG.md └── user-guide.md ├── config └── .env.example ├── Makefile ├── nixpacks.toml ├── .pre-commit-config.yaml └── pyproject.toml /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/infra/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/system/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/chat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/content/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/search/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/infra/arq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/infra/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/infra/events/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/infra/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/services/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/infra/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/middleware/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/v0/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/v0/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/infra/external/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/infra/arq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/infra/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/infra/external/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/infra/storage/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/infra/storage/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/infra/external/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/infra/external/test_supabase_manager.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the Kollektiv application.""" 2 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the Kollektiv application.""" 2 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the Kollektiv application.""" 2 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the Kollektiv application.""" 2 | -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytestmark = pytest.mark.unit 4 | -------------------------------------------------------------------------------- /tests/integration/infra/arq/__init__.py: -------------------------------------------------------------------------------- 1 | """Integration tests for ARQ components.""" 2 | -------------------------------------------------------------------------------- /tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytestmark = pytest.mark.integration 4 | -------------------------------------------------------------------------------- /tests/integration/services/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the Kollektiv application.""" 2 | -------------------------------------------------------------------------------- /tests/integration/services/test_chat_service.py: -------------------------------------------------------------------------------- 1 | """Integration tests for the Chat Service.""" 2 | 3 | # 1. Methods to test all three endpoints with proper mocking of conversation manager, llm assistant, and redis 4 | # 2. Ger response test 5 | # 3. Process stream test 6 | -------------------------------------------------------------------------------- /src/api/v0/schemas/health_schemas.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class HealthCheckResponse(BaseModel): 7 | """Response for the health check endpoint.""" 8 | 9 | status: Literal["operational", "degraded", "down"] = Field(..., description="The health status of the system") 10 | message: str = Field(..., description="A message describing the health status of the system") 11 | -------------------------------------------------------------------------------- /src/api/system/sentry_debug.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Response, status 2 | 3 | from src.api.routes import Routes 4 | from src.infra.logger import get_logger 5 | 6 | router = APIRouter() 7 | logger = get_logger() 8 | 9 | 10 | @router.get(Routes.System.SENTRY_DEBUG) 11 | async def trigger_error() -> Response: 12 | """Trigger a Sentry error.""" 13 | division_by_zero = 1 / 0 14 | return Response(status_code=status.HTTP_204_NO_CONTENT) 15 | -------------------------------------------------------------------------------- /scripts/docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | # Set SUPABASE_URL dynamically 5 | # Necessary to point to the correct supabase url locally 6 | # In CI, the supabase url is set in the .env file 7 | if [ "$ENVIRONMENT" != "local" ]; then 8 | export SUPABASE_URL="${SUPABASE_URL}" 9 | else 10 | export SUPABASE_URL="http://host.docker.internal:54321" 11 | fi 12 | 13 | # Debug output 14 | echo "Using SUPABASE_URL: $SUPABASE_URL" 15 | 16 | # Execute the main process 17 | exec "$@" -------------------------------------------------------------------------------- /src/infra/events/channels.py: -------------------------------------------------------------------------------- 1 | from uuid import UUID 2 | 3 | 4 | class Channels: 5 | """Channel definitions for pub/sub events""" 6 | 7 | # Base namespaces 8 | CONTENT_PROCESSING = "content_processing" 9 | CHAT = "chat" 10 | 11 | @staticmethod 12 | def content_processing_channel(source_id: UUID | str) -> str: 13 | """Creates a source-specific content processing channel.""" 14 | return f"{Channels.CONTENT_PROCESSING}/{str(source_id)}" 15 | 16 | class Config: 17 | """Configuration for channels""" 18 | 19 | SSE_TIMEOUT = 60 * 60 # 1 hour 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /src/models/task_models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class KollektivTaskStatus(str, Enum): 8 | """Individual status of a Kollektiv task.""" 9 | 10 | SUCCESS = "success" 11 | FAILED = "failed" 12 | # Can be expanded later to include more statuses 13 | 14 | 15 | class KollektivTaskResult(BaseModel): 16 | """Base model for all Kollektiv task results. Must be returned by all tasks.""" 17 | 18 | status: KollektivTaskStatus = Field(..., description="Status of the task") 19 | message: str = Field(..., description="Message of the task") 20 | data: dict[str, Any] | None = Field(None, description="Any additional data for the task") 21 | -------------------------------------------------------------------------------- /tests/integration/infra/arq/test_worker_integration.py: -------------------------------------------------------------------------------- 1 | """Integration tests for ARQ worker functionality.""" 2 | 3 | import sys 4 | from unittest.mock import patch 5 | 6 | import pytest 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_run_worker(): 11 | """Test that worker runs through CLI with correct settings.""" 12 | with patch.object(sys, "argv", ["arq", "src.infra.arq.worker.WorkerSettings"]), patch("arq.cli.cli") as mock_cli: 13 | from src.infra.arq.worker import run_worker 14 | 15 | run_worker() 16 | 17 | # Verify CLI was called 18 | mock_cli.assert_called_once() 19 | 20 | # Verify argv was set correctly 21 | assert sys.argv == ["arq", "src.infra.arq.worker.WorkerSettings"] 22 | -------------------------------------------------------------------------------- /docs/feature-spec/content-ingestion.md: -------------------------------------------------------------------------------- 1 | # Content Ingestion Feature 2 | 3 | ## SSE Workflow 4 | 5 | SourceEvents are emitted at the following stages: 6 | CREATED -> published after initial data source entry created 7 | CRAWLING_STARTED -> published after first crawl_started event 8 | PROCESSING_SCHEDULED -> published after processing task is scheduled 9 | SUMMARY_GENERATED -> publshed after summary is generated 10 | COMPLETED -> published after processing is complete 11 | FAILED -> published if any step of content ingestion process fails 12 | 13 | ### Mapping of ContentProcessingEvents to SourceEvents 14 | STARTED -> internal event, is not sent to FE 15 | CHUNKS_GENERATED -> internal event, is not sent to FE 16 | SUMMARY_GENERATED -> sent to FE 17 | COMPLETED -> sent to FE 18 | FAILED -> sent to FE 19 | 20 | 21 | ## API Endpoints 22 | 23 | ## Data Model 24 | -------------------------------------------------------------------------------- /config/.env.example: -------------------------------------------------------------------------------- 1 | # Environment 2 | ENVIRONMENT=local/staging/production 3 | 4 | # Server Configuration 5 | PORT=8080 6 | DEBUG=true 7 | SERVICE=api/worker 8 | 9 | # Debug mode 10 | DEBUG=true 11 | 12 | # API Keys 13 | FIRECRAWL_API_KEY=api-key 14 | ANTHROPIC_API_KEY=api-key 15 | OPENAI_API_KEY=api-key 16 | COHERE_API_KEY=api-key 17 | WEAVE_PROJECT_NAME='project-name' 18 | 19 | # Supabase 20 | SUPABASE_URL=api-url 21 | SUPABASE_SERVICE_KEY=supabase-service-key 22 | 23 | # Redis 24 | REDIS_URL=redis://localhost:6379 25 | 26 | # Monitoring 27 | LOGFIRE_TOKEN= 28 | 29 | # Railway tokens, used when deploying to Railway via CLI / GitHub Actions 30 | RAILWAY_TOKEN=token 31 | 32 | # ChromaDB 33 | CHROMA_PRIVATE_URL=http://localhost:8000 # a single unified variable 34 | 35 | # Ngrok 36 | NGROK_AUTH_TOKEN=your-ngrok-auth-token # only used for local dev 37 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: up down ps logs worker rebuild push-ghcr rebuild-and-push 2 | 3 | # Docker commands 4 | up: 5 | docker compose --env-file config/.env -f scripts/docker/compose.yaml up -d --remove-orphans 6 | 7 | down: 8 | docker compose --env-file config/.env -f scripts/docker/compose.yaml down 9 | 10 | ps: 11 | docker compose --env-file config/.env -f scripts/docker/compose.yaml ps 12 | 13 | logs: 14 | docker compose --env-file config/.env -f scripts/docker/compose.yaml logs -f 15 | 16 | rebuild: 17 | docker compose --env-file config/.env -f scripts/docker/compose.yaml build 18 | docker compose --env-file config/.env -f scripts/docker/compose.yaml up -d --remove-orphans 19 | 20 | push-ghcr: 21 | docker build -f scripts/docker/Dockerfile -t ghcr.io/alexander-zuev/kollektiv-rq:latest . 22 | docker push ghcr.io/alexander-zuev/kollektiv-rq:latest 23 | 24 | rebuild-and-push: rebuild push-ghcr 25 | -------------------------------------------------------------------------------- /scripts/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use a base image 2 | ############################## 3 | # Base stage 4 | ############################## 5 | FROM python:3.12-slim AS base 6 | 7 | ENV PYTHONDONTWRITEBYTECODE=1 \ 8 | PYTHONUNBUFFERED=1 \ 9 | PYTHONPATH=/app 10 | 11 | WORKDIR /app 12 | 13 | ############################## 14 | # Builder stage: install deps 15 | ############################## 16 | FROM base AS builder 17 | 18 | # Install system dependencies 19 | RUN apt-get update && apt-get install -y --no-install-recommends curl \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | # Install Poetry 23 | RUN pip install poetry 24 | 25 | # Copy only dependency files first to leverage caching 26 | COPY pyproject.toml poetry.lock ./ 27 | 28 | # Use a cache mount for Poetry’s cache 29 | RUN --mount=type=cache,target=/root/.cache/pypoetry \ 30 | poetry config virtualenvs.create false && \ 31 | poetry install --only main --no-interaction --no-ansi -------------------------------------------------------------------------------- /nixpacks.toml: -------------------------------------------------------------------------------- 1 | [variables] 2 | NIXPACKS_PYTHON_VERSION = "3.12.7" 3 | NIXPACKS_POETRY_VERSION = "1.8.5" 4 | 5 | [phases.setup] 6 | # Specifies the required Python version 7 | nixpkgs = ["python312", "gcc"] 8 | 9 | [phases.install] 10 | # Set up a virtual environment and install dependencies 11 | # Cache dependencies layer 12 | cache_directories = [ 13 | "/root/.cache/pip", # Pip downloads 14 | "/opt/venv", # Virtual env 15 | "~/.cache/poetry", # Poetry cache 16 | "/root/.cache/pip/wheels", # Pre-built wheels 17 | ] 18 | cmds = [ 19 | "echo '🚀 Starting Kollektiv installation...'", 20 | "python -m venv /opt/venv", 21 | ". /opt/venv/bin/activate", 22 | "pip install --no-cache-dir poetry", 23 | "poetry config virtualenvs.create false", 24 | "poetry install --only main --no-interaction --no-ansi --no-root", 25 | "echo \"🚀 Starting Kollektiv service: ${SERVICE}\"", 26 | ] 27 | 28 | [start] 29 | cmd = "poetry run $SERVICE" 30 | -------------------------------------------------------------------------------- /src/models/pubsub_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import UTC, datetime 4 | from enum import Enum 5 | from typing import Any 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | 10 | class EventType(str, Enum): 11 | """High-level categorization of events in the system.""" 12 | 13 | CONTENT_PROCESSING = "content_processing" 14 | # Later we might add: 15 | # AUTH = "auth" 16 | # BILLING = "billing" 17 | # etc. 18 | 19 | 20 | class KollektivEvent(BaseModel): 21 | """Base model for all events emitted by the Kollektiv tasks.""" 22 | 23 | event_type: EventType = Field(..., description="Type of the event") 24 | error: str | None = Field(default=None, description="Error message, null if no error") 25 | metadata: dict[str, Any] | None = Field(None, description="Optional metadata for the event") 26 | timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Timestamp of the event") 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # PR Title 2 | 3 | ## Key changes summary 4 | 5 | Provide a brief description of the changes and the problem they solve. Mention any related issues. 6 | - List the key changes made in this PR. 7 | - Highlight any new features, bug fixes, or improvements. 8 | 9 | ## Type of change 10 | 11 | Please delete options that are not relevant. 12 | - [ ] New feature (non-breaking change which adds functionality, including tests) 13 | - [ ] Bug fix (non-breaking change which fixes an issue) 14 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 15 | - [ ] Documentation update 16 | - [ ] Refactoring (no functional changes, no bug fixes) 17 | 18 | ## Testing 19 | 20 | - Describe how the changes were tested. 21 | - Mention any specific test cases or scenarios. 22 | 23 | ## MUST-HAVE 24 | 25 | - [X] Code follows project style guidelines 26 | - [X] Self-reviewed code 27 | - [X] Added/updated documentation 28 | - [X] Added/updated tests 29 | - [X] All tests pass -------------------------------------------------------------------------------- /tests/integration/infra/arq/arq_fixtures_integration.py: -------------------------------------------------------------------------------- 1 | """Common fixtures for ARQ integration tests.""" 2 | 3 | from unittest.mock import AsyncMock, Mock 4 | 5 | import pytest 6 | 7 | from src.infra.arq.worker import WorkerSettings 8 | from src.infra.arq.worker_services import WorkerServices 9 | 10 | 11 | @pytest.fixture 12 | async def mock_worker_services(): 13 | """Create a mock worker services instance.""" 14 | mock_services = AsyncMock(spec=WorkerServices) 15 | mock_services.arq_redis_pool = Mock() 16 | mock_services.shutdown_services = AsyncMock() 17 | return mock_services 18 | 19 | 20 | @pytest.fixture 21 | def worker_settings(): 22 | """Create a worker settings instance.""" 23 | return WorkerSettings() 24 | 25 | 26 | @pytest.fixture 27 | async def mock_worker_context(mock_worker_services): 28 | """Create a mock worker context.""" 29 | ctx = { 30 | "worker_services": mock_worker_services, 31 | "arq_redis": mock_worker_services.arq_redis_pool, 32 | } 33 | return ctx 34 | -------------------------------------------------------------------------------- /tests/integration/infra/external/test_chroma_manager_integration.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from src.infra.external.chroma_manager import ChromaManager 4 | 5 | 6 | @pytest.mark.integration 7 | class TestChromaManagerIntegration: 8 | """Integration tests for ChromaManager with real ChromaDB.""" 9 | 10 | @pytest.mark.asyncio 11 | async def test_connection_and_heartbeat(self): 12 | """Test that ChromaManager can connect and verify connection.""" 13 | manager = await ChromaManager.create_async() 14 | client = await manager.get_async_client() 15 | 16 | # Verify connection is alive 17 | await client.heartbeat() 18 | 19 | @pytest.mark.asyncio 20 | async def test_client_reuse(self): 21 | """Test that manager properly reuses the client.""" 22 | manager = ChromaManager() 23 | 24 | client1 = await manager.get_async_client() 25 | client2 = await manager.get_async_client() 26 | 27 | assert client1 is client2 28 | await client1.heartbeat() # Verify the connection still works 29 | -------------------------------------------------------------------------------- /src/api/config/cors_config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from src.models.base_models import Environment 4 | 5 | 6 | class AllowedOrigins(str, Enum): 7 | """Allowed origins for CORS.""" 8 | 9 | LOCAL = "*" # Allow all origins for local development 10 | STAGING = [ 11 | "https://staging.thekollektiv.ai", 12 | "https://*.railway.app", 13 | "https://*.up.railway.app", 14 | "https://*.railway.internal", 15 | ] 16 | PRODUCTION = [ 17 | "https://thekollektiv.ai", 18 | "https://*.railway.app", 19 | "https://*.up.railway.app", 20 | "https://*.railway.internal", 21 | ] 22 | 23 | 24 | def get_cors_config(environment: Environment) -> dict: 25 | """Get the CORS configuration based on the environment.""" 26 | return { 27 | "allow_origins": AllowedOrigins[environment.value.upper()].value, 28 | "allow_credentials": True, 29 | "allow_methods": ["GET", "POST", "OPTIONS", "DELETE", "PATCH", "PUT"], 30 | "allow_headers": [ 31 | "Authorization", 32 | "Content-Type", 33 | "X-Request-ID", 34 | "baggage", 35 | "sentry-trace", 36 | ], 37 | } 38 | -------------------------------------------------------------------------------- /tests/unit/infra/arq/test_worker.py: -------------------------------------------------------------------------------- 1 | from src.infra.arq.arq_settings import get_arq_settings 2 | from src.infra.arq.serializer import deserialize, serialize 3 | from src.infra.arq.task_definitions import task_list 4 | from src.infra.arq.worker import WorkerSettings 5 | 6 | arq_settings = get_arq_settings() 7 | 8 | 9 | def test_worker_settings_configuration(): 10 | """Test WorkerSettings class has correct configuration.""" 11 | settings = WorkerSettings() 12 | 13 | # Test task configuration 14 | assert settings.functions == task_list 15 | 16 | # Test startup/shutdown handlers 17 | assert callable(settings.on_startup) 18 | assert callable(settings.on_shutdown) 19 | 20 | # Test Redis configuration 21 | assert settings.redis_settings == arq_settings.redis_settings 22 | 23 | # Test job configuration 24 | assert settings.health_check_interval == arq_settings.health_check_interval 25 | assert settings.max_jobs == arq_settings.max_jobs 26 | assert settings.max_retries == arq_settings.job_retries 27 | 28 | # Test serialization configuration 29 | assert settings.job_serializer == serialize 30 | assert settings.job_deserializer == deserialize 31 | 32 | # Test result configuration 33 | assert settings.keep_result == 60 # 60 seconds 34 | -------------------------------------------------------------------------------- /scripts/docker/.dockerignore: -------------------------------------------------------------------------------- 1 | # Include any files or directories that you don't want to be copied to your 2 | # container here (e.g., local build artifacts, temporary files, etc.). 3 | # 4 | # For more help, visit the .dockerignore file reference guide at 5 | # https://docs.docker.com/go/build-context-dockerignore/ 6 | 7 | # Python 8 | __pycache__/ 9 | *.py[cod] 10 | .Python 11 | *.egg-info/ 12 | .installed.cfg 13 | .pytest_cache/ 14 | .coverage 15 | .tox/ 16 | .nox/ 17 | .mypy_cache/ 18 | .ruff_cache/ 19 | 20 | # Environment & Config 21 | .env* 22 | .venv 23 | venv/ 24 | ENV/ 25 | .secrets/ 26 | config/environments/* 27 | !config/environments/.env.example 28 | 29 | # Development & IDE 30 | .idea/ 31 | .vscode/ 32 | *.swp 33 | .DS_Store 34 | 35 | # Project Specific 36 | src/data/raw/* 37 | src/data/chunks/* 38 | src/vector_storage/chroma/* 39 | logs/* 40 | docs/ 41 | output_chunks/ 42 | .sentry/ 43 | .logfire/ 44 | 45 | # Testing 46 | tests/ 47 | *_test.py 48 | *_tests.py 49 | coverage.xml 50 | 51 | # Documentation 52 | *.md 53 | LICENSE* 54 | CHANGELOG* 55 | roadmap.md 56 | 57 | # Docker 58 | Dockerfile* 59 | docker-compose* 60 | .docker 61 | 62 | # Dependencies 63 | poetry.lock # Only if you want to use the one from build context 64 | 65 | # Version Control 66 | .git 67 | .gitignore 68 | .gitattributes 69 | .github 70 | -------------------------------------------------------------------------------- /src/api/handlers/error_handlers.py: -------------------------------------------------------------------------------- 1 | from fastapi import Request 2 | from fastapi.responses import JSONResponse 3 | 4 | from src.api.v0.schemas.base_schemas import ErrorCode, ErrorResponse 5 | from src.core._exceptions import NonRetryableError 6 | from src.infra.logger import get_logger 7 | 8 | logger = get_logger() 9 | 10 | 11 | async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse: 12 | """Global exception handler for any unhandled exceptions.""" 13 | logger.critical(f"Unhandled exception at {request.url.path}: {exc}", exc_info=True) 14 | return JSONResponse( 15 | status_code=500, 16 | content=ErrorResponse( 17 | code=ErrorCode.SERVER_ERROR, detail="An internal server error occurred, please contact support." 18 | ), 19 | ) 20 | 21 | 22 | async def non_retryable_exception_handler(request: Request, exc: NonRetryableError) -> JSONResponse: 23 | """Catch and log a non-retryable error.""" 24 | logger.error(f"Non-retryable error at {request.url.path}: {exc.error_message}") 25 | return JSONResponse( 26 | status_code=500, 27 | content=ErrorResponse( 28 | code=ErrorCode.SERVER_ERROR, 29 | detail=f"An internal error occurred while processing your request: {exc.error_message}.", 30 | ), 31 | ) 32 | -------------------------------------------------------------------------------- /src/api/routes.py: -------------------------------------------------------------------------------- 1 | """API route definitions.""" 2 | 3 | from typing import Final 4 | 5 | # API version prefix 6 | CURRENT_API_VERSION: Final = "/v0" 7 | 8 | 9 | class Routes: 10 | """API route definitions.""" 11 | 12 | class System: 13 | """System routes (non-versioned).""" 14 | 15 | HEALTH = "/health" 16 | SENTRY_DEBUG = "/sentry-debug" 17 | 18 | class Webhooks: 19 | """Webhook routes.""" 20 | 21 | BASE = "/webhooks" 22 | FIRECRAWL = f"{BASE}/firecrawl" 23 | 24 | class V0: # Use lowercase for consistency with prefix 25 | """API routes (versioned).""" 26 | 27 | SOURCES = "/sources" 28 | CHAT = "/chat" 29 | CONVERSATIONS = "/conversations" 30 | 31 | class Sources: 32 | """Content management routes.""" 33 | 34 | SOURCES = "/sources" 35 | SOURCE_EVENTS = "/sources/{source_id}/events" 36 | 37 | class Chat: 38 | """Chat routes.""" 39 | 40 | CHAT = "/chat" # for sending and receiving messages 41 | 42 | class Conversations: 43 | """Conversation routes.""" 44 | 45 | LIST = "/conversations" # for getting conversation list 46 | GET = "/conversations/{conversation_id}" # for getting a single conversation 47 | -------------------------------------------------------------------------------- /src/api/v0/schemas/base_schemas.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Generic, TypeVar 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | T = TypeVar("T") 7 | 8 | 9 | class BaseResponse(BaseModel, Generic[T]): 10 | """Base API response for all API endpoints in Kollektiv API.""" 11 | 12 | success: bool 13 | data: T | None = None 14 | message: str | None = None 15 | 16 | 17 | class ErrorCode(str, Enum): 18 | """Base api error code for all API endpoints in Kollektiv API.""" 19 | 20 | UNKNOWN_ERROR = "UNKNOWN_ERROR" 21 | SERVER_ERROR = "SERVER_ERROR" 22 | CLIENT_ERROR = "CLIENT_ERROR" 23 | 24 | 25 | class ErrorResponse(BaseModel): 26 | """Base api error response for all API endpoints in Kollektiv API.""" 27 | 28 | code: ErrorCode = Field(..., description="Custom error code classification shared by FE and BE.") 29 | detail: str = Field( 30 | default="An unknown error occurred.", 31 | description="Error message for the client. All unknown errors should be properly parsed and grouped periodically.", 32 | ) 33 | 34 | class Config: 35 | """Example configuration.""" 36 | 37 | json_schema_extra = { 38 | "example": { 39 | "detail": "An error occurred while processing your request", 40 | "code": "SERVER_ERROR", # Matches FE ERROR_CODES 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /tests/unit/test_app.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, patch 2 | 3 | import pytest 4 | from fastapi import FastAPI 5 | 6 | from src.app import lifespan 7 | from src.infra.service_container import ServiceContainer 8 | 9 | 10 | @pytest.mark.unit 11 | def test_app_basic_configuration(mock_app): 12 | """Test basic FastAPI app configuration.""" 13 | assert mock_app.title == "Kollektiv API" 14 | assert mock_app.description == "RAG-powered LLM chat application" 15 | # Test error handlers are registered 16 | assert mock_app.exception_handlers is not None 17 | assert Exception in mock_app.exception_handlers 18 | 19 | 20 | @pytest.mark.unit 21 | async def test_startup_error_handling(): 22 | """Test error handling during startup.""" 23 | app = FastAPI() 24 | with patch("src.infra.settings.settings.environment", "staging"): 25 | with patch("src.app.ServiceContainer.create", side_effect=Exception("Startup failed")): 26 | with pytest.raises(Exception, match="Startup failed"): 27 | async with lifespan(app): 28 | pass 29 | 30 | 31 | @pytest.mark.unit 32 | async def test_shutdown_services(): 33 | """Test services are properly shutdown.""" 34 | app = FastAPI() 35 | mock_container = AsyncMock(spec=ServiceContainer) 36 | 37 | with patch("src.app.ServiceContainer.create", return_value=mock_container): 38 | async with lifespan(app): 39 | pass 40 | 41 | mock_container.shutdown_services.assert_awaited_once() 42 | -------------------------------------------------------------------------------- /src/api/middleware/debug_middleware.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections.abc import Awaitable, Callable 3 | 4 | from fastapi import Request, Response 5 | from starlette.middleware.base import BaseHTTPMiddleware 6 | 7 | from src.infra.logger import get_logger 8 | from src.infra.settings import settings 9 | 10 | logger = get_logger() 11 | 12 | 13 | class RequestDebugMiddleware(BaseHTTPMiddleware): 14 | """Debug middleware to log request details.""" 15 | 16 | async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: 17 | """Dispatch the request.""" 18 | try: 19 | # Log detailed request info 20 | logger.debug( 21 | f"\n{'='*50}\n" 22 | f"REQUEST DETAILS:\n" 23 | f"Method: {request.method}\n" 24 | f"Path: {request.url.path}\n" 25 | f"Client: {request.client.host if request.client else 'Unknown'}\n" 26 | f"Headers: {json.dumps(dict(request.headers), indent=2)}\n" 27 | f"Environment: {settings.environment}" 28 | ) 29 | 30 | # Process request 31 | response = await call_next(request) 32 | 33 | # Log response info 34 | logger.debug(f"\nRESPONSE DETAILS:\n" f"Status: {response.status_code}\n" f"{'='*50}") 35 | 36 | return response 37 | except Exception as e: 38 | logger.error(f"Error in debug middleware: {str(e)}") 39 | return await call_next(request) 40 | -------------------------------------------------------------------------------- /src/infra/arq/worker.py: -------------------------------------------------------------------------------- 1 | from concurrent import futures 2 | from typing import Any 3 | 4 | from src.infra.arq.arq_settings import get_arq_settings 5 | from src.infra.arq.serializer import deserialize, serialize 6 | from src.infra.arq.task_definitions import task_list 7 | from src.infra.arq.worker_services import WorkerServices 8 | from src.infra.logger import configure_logging, get_logger 9 | from src.infra.settings import get_settings 10 | 11 | settings = get_settings() 12 | arq_settings = get_arq_settings() 13 | 14 | configure_logging() 15 | logger = get_logger() 16 | 17 | 18 | async def on_startup(ctx: dict[str, Any]) -> None: 19 | """Runs on startup.""" 20 | ctx["worker_services"] = await WorkerServices.create() 21 | ctx["arq_redis"] = ctx["worker_services"].arq_redis_pool 22 | ctx["pool"] = futures.ProcessPoolExecutor() 23 | 24 | 25 | async def on_shutdown(ctx: dict[str, Any]) -> None: 26 | """Runs on shutdown.""" 27 | await ctx["worker_services"].shutdown_services() 28 | 29 | 30 | class WorkerSettings: 31 | """Settings for the Arq worker.""" 32 | 33 | functions = task_list 34 | on_startup = on_startup 35 | on_shutdown = on_shutdown 36 | redis_settings = arq_settings.redis_settings 37 | health_check_interval = arq_settings.health_check_interval 38 | max_jobs = arq_settings.max_jobs 39 | max_retries = arq_settings.job_retries 40 | job_serializer = serialize 41 | job_deserializer = deserialize 42 | keep_result = 60 # Keep results for 60 seconds after completion 43 | 44 | 45 | def run_worker() -> None: 46 | """Run Arq worker.""" 47 | import sys 48 | 49 | from arq.cli import cli 50 | 51 | sys.argv = ["arq", "src.infra.arq.worker.WorkerSettings"] 52 | cli() 53 | -------------------------------------------------------------------------------- /src/models/base_models.py: -------------------------------------------------------------------------------- 1 | from datetime import UTC, datetime 2 | from enum import Enum 3 | from typing import Any, ClassVar, Self 4 | 5 | from pydantic import BaseModel, Field, PrivateAttr 6 | from pydantic.alias_generators import to_camel 7 | 8 | 9 | # TODO: this doesn't belong here - it's not a model 10 | class Environment(str, Enum): 11 | """Supported application environments.""" 12 | 13 | LOCAL = "local" 14 | STAGING = "staging" 15 | PRODUCTION = "production" 16 | 17 | 18 | class SupabaseModel(BaseModel): 19 | """Base class for all models stored in the Supabase database.""" 20 | 21 | _db_config: ClassVar[dict] = PrivateAttr( 22 | default={ 23 | "schema": "", # Database schema name 24 | "table": "", # Table name 25 | "primary_key": "", # Primary key field 26 | } 27 | ) 28 | 29 | created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Creation timestamp") 30 | updated_at: datetime | None = Field(default=None, description="Last updated timestamp") 31 | 32 | def update(self, **kwargs: Any) -> Self: 33 | """Generic update method preserving model constraints.""" 34 | protected = getattr(self, "_protected_fields", set()) 35 | allowed_updates = {k: v for k, v in kwargs.items() if k not in protected} 36 | return self.model_copy(update=allowed_updates) 37 | 38 | 39 | class APIModel(BaseModel): 40 | """Base class for all API data models. 41 | 42 | Enables: 43 | - Incoming JSON: to be converted from camelCase to snake_case 44 | - Outgoing JSON: to be converted from snake_case to camelCase 45 | """ 46 | 47 | class Config: 48 | """Pydantic model configuration.""" 49 | 50 | alias_generator = to_camel 51 | populate_by_name = True 52 | -------------------------------------------------------------------------------- /src/infra/misc/ngrok_service.py: -------------------------------------------------------------------------------- 1 | import ngrok 2 | 3 | from src.infra.logger import get_logger 4 | from src.infra.settings import settings 5 | from src.models.base_models import Environment 6 | 7 | logger = get_logger() 8 | 9 | 10 | class NgrokService: 11 | """Ngrok service used to create a tunnel in local environment.""" 12 | 13 | def __init__(self, ngrok_authtoken: str | None = settings.ngrok_authtoken): 14 | self.ngrok_authtoken = ngrok_authtoken 15 | 16 | @classmethod 17 | async def create(cls, ngrok_authtoken: str | None = settings.ngrok_authtoken) -> "NgrokService": 18 | """Creates an instance of NgrokService.""" 19 | if settings.environment == Environment.LOCAL: 20 | instance = cls(ngrok_authtoken=ngrok_authtoken) 21 | await instance.start_tunnel() 22 | return instance 23 | logger.info("✓ Skipping ngrok tunnel initialization for non-local environment") 24 | return None 25 | 26 | async def start_tunnel(self) -> str: 27 | """Start the ngrok tunnel and return the listener.""" 28 | try: 29 | listener = await ngrok.forward(addr=f"localhost:{settings.api_port}", authtoken=self.ngrok_authtoken) 30 | settings.ngrok_url = listener.url() 31 | logger.info(f"✓ Initialized ngrok tunnel successfully at: {settings.ngrok_url}") 32 | return settings.ngrok_url 33 | except Exception as e: 34 | logger.error(f"Failed to start ngrok tunnel: {e}") 35 | raise e 36 | 37 | async def stop_tunnel(self) -> None: 38 | """Stop the ngrok tunnel.""" 39 | try: 40 | ngrok.disconnect() 41 | logger.info("Disconnecting ngrok tunnel") 42 | except Exception as e: 43 | logger.error(f"Failed to disconnect ngrok tunnel: {e}") 44 | raise e 45 | -------------------------------------------------------------------------------- /docs/changelog/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Kollektiv Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.2.0] - 2024-12-29 11 | 12 | Complete and full rebuild of the project. 13 | 14 | ### Added 15 | 16 | - Content Service: Added a new service to handle content ingestion and processing end to end. 17 | - Chat functionality: Exposed a chat API endpoint to chat with the content. 18 | - Frontend integration: Integrated with a with React frontend application. 19 | 20 | ### Changed 21 | 22 | - Updated documentations to reflect new project structure 23 | - Updated to Python 3.12.6 -> 3.12.7 24 | - Restructured project layout for better domain separation 25 | 26 | ### Fixed 27 | 28 | - Fixed errors in streaming & non-streaming responses 29 | 30 | ### Security 31 | 32 | - Added a new security policy 33 | - Ratelimiting of API endpoints 34 | 35 | ## [0.1.6] - 2024-10-19 36 | 37 | ### Added 38 | 39 | - Web UI: You can now chat with synced content via web interface. Built using Chainlit. 40 | 41 | ### Under development 42 | 43 | - Basic evaluation suite is setup using Weave. 44 | 45 | ## [0.1.5] - 2024-10-19 46 | 47 | ### Changed 48 | 49 | - Kollektiv is born - the project was renamed in order to exclude confusion with regards to Anthropic's Claude family of models. 50 | 51 | ## [0.1.4] - 2024-09-28 52 | 53 | ### Added 54 | 55 | - Added Anthropic API exception handling 56 | 57 | ### Changed 58 | 59 | - Updated pre-processing of chunker to remove images due to lack of multi-modal embeddings support 60 | 61 | ### Removed 62 | 63 | - Removed redundant QueryGenerator class 64 | 65 | ### Fixed 66 | 67 | - Fixed errors in streaming & non-streaming responses 68 | -------------------------------------------------------------------------------- /src/services/webhook_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from src.api.v0.schemas.webhook_schemas import ( 4 | FireCrawlWebhookEvent, 5 | FireCrawlWebhookResponse, 6 | WebhookProvider, 7 | WebhookResponse, 8 | ) 9 | 10 | 11 | class FireCrawlWebhookHandler: 12 | """Handles FireCrawl webhook processing logic.""" 13 | 14 | @staticmethod 15 | def _parse_firecrawl_payload(data: dict[str, Any]) -> FireCrawlWebhookResponse: 16 | """Parse FireCrawl webhook payload into a structured response object. 17 | 18 | Args: 19 | data: Raw webhook payload from FireCrawl 20 | 21 | Raises: 22 | ValueError: If required fields are missing 23 | """ 24 | try: 25 | return FireCrawlWebhookResponse( 26 | success=data["success"], 27 | event_type=data["type"], 28 | firecrawl_id=data["id"], 29 | data=data.get("data", []), 30 | error=data.get("error"), 31 | ) 32 | except KeyError as e: 33 | raise ValueError(f"Missing required field: {e.args[0]}") from e 34 | 35 | @staticmethod 36 | def _create_webhook_event( 37 | event_data: FireCrawlWebhookResponse, raw_payload: dict[str, Any] 38 | ) -> FireCrawlWebhookEvent: 39 | """Create internal webhook event from parsed payload.""" 40 | return FireCrawlWebhookEvent(provider=WebhookProvider.FIRECRAWL, raw_payload=raw_payload, data=event_data) 41 | 42 | @staticmethod 43 | def _create_webhook_response(event: FireCrawlWebhookEvent) -> WebhookResponse: 44 | """Create API response for the webhook sender.""" 45 | return WebhookResponse( 46 | event_id=event.event_id, 47 | message=f"Processed {event.data.event_type} event for job {event.data.firecrawl_id}", 48 | provider=WebhookProvider.FIRECRAWL, 49 | ) 50 | -------------------------------------------------------------------------------- /tests/unit/test_supabase_client.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, patch 2 | 3 | import pytest 4 | 5 | from src.infra.external.supabase_manager import SupabaseManager 6 | from src.infra.settings import settings 7 | 8 | 9 | @pytest.fixture 10 | def mock_create_async_client(): 11 | """Mock the Supabase create_async_client function.""" 12 | with patch("src.infra.external.supabase_manager.create_async_client") as mock: 13 | mock.return_value = AsyncMock() # Return an async mock client 14 | yield mock 15 | 16 | 17 | class TestSupabaseClient: 18 | """Test suite for SupabaseClient.""" 19 | 20 | async def test_successful_initialization(self, mock_create_async_client): 21 | """Test that initialization works properly.""" 22 | # Arrange & Act 23 | manager = await SupabaseManager.create_async() 24 | 25 | # Assert 26 | mock_create_async_client.assert_called_once_with( 27 | supabase_url=settings.supabase_url, 28 | supabase_key=settings.supabase_service_role_key, 29 | ) 30 | assert manager._client is mock_create_async_client.return_value 31 | 32 | async def test_connection_failure(self, mock_create_async_client): 33 | """Test that initialization failure is handled properly.""" 34 | # Arrange 35 | mock_create_async_client.side_effect = Exception("Connection failed") 36 | 37 | # Act & Assert 38 | with pytest.raises(Exception) as exc_info: 39 | await SupabaseManager.create_async() 40 | assert str(exc_info.value) == "Connection failed" 41 | 42 | async def test_get_client(self, mock_create_async_client): 43 | """Test that get_client connects if not already connected.""" 44 | # Arrange 45 | manager = await SupabaseManager.create_async() 46 | 47 | # Act 48 | client = await manager.get_async_client() 49 | 50 | # Assert 51 | mock_create_async_client.assert_called_once() 52 | assert client is mock_create_async_client.return_value 53 | -------------------------------------------------------------------------------- /tests/integration/infra/arq/test_worker_settings_integration.py: -------------------------------------------------------------------------------- 1 | """Integration tests for ARQ worker settings.""" 2 | 3 | import pytest 4 | from arq.connections import RedisSettings 5 | 6 | from src.infra.arq.arq_settings import get_arq_settings 7 | from src.infra.arq.serializer import deserialize, serialize 8 | from src.infra.arq.task_definitions import task_list 9 | from src.infra.arq.worker import WorkerSettings 10 | 11 | 12 | @pytest.fixture 13 | def worker_settings(): 14 | """Create a fresh worker settings instance.""" 15 | return WorkerSettings() 16 | 17 | 18 | def test_worker_task_configuration(worker_settings): 19 | """Test that worker has all required tasks configured.""" 20 | assert worker_settings.functions == task_list 21 | 22 | # Verify that all tasks are callable 23 | for task in worker_settings.functions: 24 | assert callable(task) 25 | 26 | 27 | def test_worker_redis_configuration(worker_settings): 28 | """Test worker Redis configuration integration.""" 29 | arq_settings = get_arq_settings() 30 | 31 | assert isinstance(worker_settings.redis_settings, RedisSettings) 32 | assert worker_settings.redis_settings == arq_settings.redis_settings 33 | 34 | 35 | def test_worker_job_configuration(worker_settings): 36 | """Test worker job processing configuration.""" 37 | arq_settings = get_arq_settings() 38 | 39 | assert worker_settings.health_check_interval == arq_settings.health_check_interval 40 | assert worker_settings.max_jobs == arq_settings.max_jobs 41 | assert worker_settings.max_retries == arq_settings.job_retries 42 | 43 | 44 | def test_worker_serialization_configuration(worker_settings): 45 | """Test worker serialization configuration.""" 46 | assert worker_settings.job_serializer == serialize 47 | assert worker_settings.job_deserializer == deserialize 48 | 49 | 50 | def test_worker_lifecycle_handlers(worker_settings): 51 | """Test that worker has lifecycle handlers configured.""" 52 | assert callable(worker_settings.on_startup) 53 | assert callable(worker_settings.on_shutdown) 54 | -------------------------------------------------------------------------------- /src/core/search/embedding_manager.py: -------------------------------------------------------------------------------- 1 | import chromadb.utils.embedding_functions as embedding_functions 2 | 3 | from src.infra.settings import settings 4 | from src.models.vector_models import ( 5 | CohereEmbeddingModelName, 6 | EmbeddingProvider, 7 | ) 8 | 9 | 10 | class EmbeddingManager: 11 | """Manages the embedding functions for the vector database.""" 12 | 13 | def __init__( 14 | self, 15 | provider: EmbeddingProvider | None = EmbeddingProvider.COHERE, 16 | model: CohereEmbeddingModelName | None = CohereEmbeddingModelName.SMALL_MULTI, 17 | ): 18 | self.provider = provider 19 | self.model = model 20 | self.embedding_function = self._get_embedding_function() 21 | 22 | def _get_embedding_function(self) -> embedding_functions.EmbeddingFunction: 23 | """Get the embedding function based on the provider.""" 24 | if self.provider == EmbeddingProvider.COHERE: 25 | return self._get_cohere_embedding_function() 26 | elif self.provider == EmbeddingProvider.OPENAI: 27 | return self._get_openai_embedding_function() 28 | else: 29 | raise ValueError(f"Unsupported embedding provider: {self.provider}") 30 | 31 | def _get_cohere_embedding_function(self) -> embedding_functions.EmbeddingFunction: 32 | """Get the Cohere embedding function.""" 33 | self.embedding_function = embedding_functions.CohereEmbeddingFunction( 34 | model_name=self.model, 35 | api_key=settings.cohere_api_key, 36 | ) 37 | return self.embedding_function 38 | 39 | def _get_openai_embedding_function(self) -> embedding_functions.EmbeddingFunction: 40 | """Get the OpenAI embedding function.""" 41 | self.embedding_function = embedding_functions.OpenAIEmbeddingFunction( 42 | api_key=settings.openai_api_key, model_name=self.model 43 | ) 44 | return self.embedding_function 45 | 46 | def get_embedding_function(self) -> embedding_functions.EmbeddingFunction: 47 | """Get the embedding function.""" 48 | return self.embedding_function 49 | -------------------------------------------------------------------------------- /src/core/chat/prompt_manager.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | import yaml 5 | 6 | from src.infra.settings import settings 7 | from src.models.llm_models import PromptType, SystemPrompt 8 | 9 | 10 | class PromptManager: 11 | """Manages loading and formatting of system prompts.""" 12 | 13 | def __init__(self, prompt_dir: Path = settings.prompt_dir, prompt_file: str = settings.prompts_file): 14 | self.prompt_path = prompt_dir / prompt_file 15 | self._load_prompts() 16 | 17 | def _load_prompts(self) -> None: 18 | """Load prompts from YAML file.""" 19 | with open(self.prompt_path) as f: 20 | self.prompts = yaml.safe_load(f) 21 | 22 | # TODO: Refactor to be prompt-agnostic 23 | def get_system_prompt(self, **kwargs: Any) -> SystemPrompt: 24 | """Get system prompt model with provided kwargs.""" 25 | text = self.prompts[PromptType.LLM_ASSISTANT_PROMPT].format(**kwargs) 26 | return SystemPrompt(text=text) 27 | 28 | def get_multi_query_prompt(self, **kwargs: Any) -> str: 29 | """Get multi-query prompt.""" 30 | try: 31 | text = self.prompts[PromptType.MULTI_QUERY_PROMPT].format(**kwargs) 32 | except KeyError: 33 | raise ValueError("Multi-query prompt not found") 34 | if isinstance(text, str): 35 | return text 36 | raise ValueError("Multi-query prompt is not a string") 37 | 38 | def get_summary_prompt(self, **kwargs: Any) -> str: 39 | """Get summary prompt.""" 40 | try: 41 | text = self.prompts[PromptType.SUMMARY_PROMPT].format(**kwargs) 42 | except KeyError: 43 | raise ValueError("Summary prompt not found") 44 | return text 45 | 46 | def return_system_prompt(self, prompt_type: PromptType, **kwargs: Any) -> SystemPrompt: 47 | """Return a prompt-agnostic system prompt.""" 48 | try: 49 | text = self.prompts[prompt_type].format(**kwargs) 50 | except KeyError as e: 51 | raise ValueError(f"System prompt not found for {prompt_type}") from e 52 | return SystemPrompt(text=text) 53 | -------------------------------------------------------------------------------- /tests/integration/infra/arq/test_redis_pool_integration.py: -------------------------------------------------------------------------------- 1 | """Integration tests for Redis pool.""" 2 | 3 | import pytest 4 | from arq import ArqRedis 5 | 6 | from src.infra.arq.arq_settings import get_arq_settings 7 | from src.infra.arq.redis_pool import RedisPool 8 | from src.infra.settings import get_settings 9 | 10 | settings = get_settings() 11 | arq_settings = get_arq_settings() 12 | 13 | 14 | @pytest.fixture(scope="function") 15 | async def redis_pool(redis_integration_manager): 16 | """Create a RedisPool instance for testing.""" 17 | pool = RedisPool() 18 | yield pool 19 | # Cleanup 20 | if pool.is_connected and pool._pool: 21 | await pool._pool.close() 22 | 23 | 24 | @pytest.mark.asyncio 25 | async def test_redis_pool_settings_integration(redis_pool): 26 | """Test RedisPool integration with settings.""" 27 | # Verify settings integration 28 | assert redis_pool.arq_settings.redis_host == settings.redis_host 29 | assert redis_pool.arq_settings.redis_port == settings.redis_port 30 | assert redis_pool.arq_settings.redis_user == settings.redis_user 31 | assert redis_pool.arq_settings.redis_password == settings.redis_password 32 | assert redis_pool.arq_settings.connection_retries == 5 # Default from arq_settings 33 | 34 | 35 | @pytest.mark.asyncio 36 | async def test_redis_pool_get_pool(redis_pool): 37 | """Test get_pool method - the main way other services get Redis access.""" 38 | # Act 39 | redis = await redis_pool.get_pool() 40 | 41 | # Assert 42 | assert isinstance(redis, ArqRedis) 43 | assert redis_pool.is_connected 44 | assert await redis.ping() is True # Verify actual Redis connection 45 | 46 | 47 | @pytest.mark.asyncio 48 | async def test_redis_pool_factory_method(redis_integration_manager): 49 | """Test the static factory method - used by worker services.""" 50 | # Act 51 | redis = await RedisPool.create_redis_pool() 52 | 53 | try: 54 | # Assert 55 | assert isinstance(redis, ArqRedis) 56 | assert await redis.ping() is True # Verify actual Redis connection 57 | finally: 58 | # Cleanup 59 | await redis.close() 60 | -------------------------------------------------------------------------------- /src/infra/external/chroma_manager.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlparse 2 | 3 | import chromadb 4 | from chromadb.api import AsyncClientAPI 5 | 6 | from src.infra.decorators import tenacity_retry_wrapper 7 | from src.infra.logger import get_logger 8 | from src.infra.settings import settings 9 | 10 | logger = get_logger() 11 | 12 | 13 | class ChromaManager: 14 | """Chroma client manager.""" 15 | 16 | def __init__(self) -> None: 17 | """Initialize ChromaManager with the necessary dependencies.""" 18 | self._client: AsyncClientAPI | None = None 19 | 20 | @staticmethod 21 | def _parse_url(url: str) -> tuple[str, int]: 22 | """Parse the URL to get the host and port.""" 23 | try: 24 | url_parts = urlparse(url) 25 | host = url_parts.hostname 26 | port = url_parts.port 27 | return host, port 28 | except Exception as e: 29 | logger.error(f"Failed to parse URL: {str(e)}", exc_info=True) 30 | raise 31 | 32 | @tenacity_retry_wrapper() 33 | async def _connect_async(self) -> None: 34 | """Connect to ChromaDB.""" 35 | host, port = self._parse_url(settings.chroma_private_url) 36 | if self._client is None: 37 | try: 38 | self._client = await chromadb.AsyncHttpClient( 39 | host=host, 40 | port=port, 41 | ) 42 | await self._client.heartbeat() 43 | logger.info(f"✓ Initialized Chroma client successfully on {settings.chroma_private_url}") 44 | except Exception as e: 45 | logger.exception(f"Failed to initialize Chroma client: {str(e)}") 46 | raise 47 | 48 | @classmethod 49 | async def create_async(cls) -> "ChromaManager": 50 | """Create a new ChromaManager instance and connect to ChromaDB.""" 51 | instance = cls() 52 | if instance._client is None: 53 | await instance._connect_async() 54 | return instance 55 | 56 | async def get_async_client(self) -> AsyncClientAPI: 57 | """Get the async client, reconnect if not connected.""" 58 | await self._connect_async() 59 | return self._client 60 | -------------------------------------------------------------------------------- /tests/integration/test_app.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi.testclient import TestClient 3 | 4 | from src.infra.service_container import ServiceContainer 5 | 6 | 7 | @pytest.mark.integration 8 | class TestAppInitialization: 9 | """Test suite for application initialization and configuration.""" 10 | 11 | def test_app_creates_successfully(self, integration_client: TestClient): 12 | """Test that the FastAPI application initializes correctly.""" 13 | assert integration_client.app is not None 14 | assert hasattr(integration_client.app.state, "container") 15 | assert integration_client.app.title == "Kollektiv API" 16 | 17 | def test_service_container_initialization(self, integration_client: TestClient): 18 | """Test that the service container initializes with all required services.""" 19 | container = integration_client.app.state.container 20 | assert isinstance(container, ServiceContainer) 21 | 22 | # Core services that must be present 23 | assert container.job_manager is not None 24 | assert container.content_service is not None 25 | assert container.data_service is not None 26 | 27 | # Redis services 28 | assert container.async_redis_manager is not None 29 | assert container.redis_repository is not None 30 | assert container.event_publisher is not None 31 | 32 | # Vector operations 33 | assert container.chroma_manager is not None 34 | assert container.embedding_manager is not None 35 | assert container.vector_db is not None 36 | assert container.retriever is not None 37 | assert container.reranker is not None 38 | 39 | # Chat services 40 | assert container.claude_assistant is not None 41 | assert container.conversation_manager is not None 42 | assert container.chat_service is not None 43 | 44 | def test_middleware_setup(self, integration_client: TestClient): 45 | """Test middleware configuration.""" 46 | app = integration_client.app 47 | middleware_classes = [m.cls.__name__ for m in app.user_middleware] 48 | 49 | assert "CORSMiddleware" in middleware_classes 50 | assert "HealthCheckRateLimit" in middleware_classes 51 | -------------------------------------------------------------------------------- /scripts/docker/compose.yaml: -------------------------------------------------------------------------------- 1 | name: kollektiv 2 | 3 | x-versions: &versions 4 | REDIS_IMAGE: redis 5 | REDIS_VERSION: 7.4.1-alpine 6 | CHROMA_IMAGE: chromadb/chroma 7 | CHROMA_VERSION: 0.5.23 8 | WORKER_IMAGE: kollektiv-worker 9 | WORKER_VERSION: latest 10 | 11 | services: 12 | redis: 13 | image: ${REDIS_IMAGE:-redis}:${REDIS_VERSION:-7.4.1-alpine} 14 | ports: 15 | - "6379:6379" 16 | container_name: redis 17 | healthcheck: 18 | test: [ "CMD", "redis-cli", "ping" ] 19 | interval: 30s 20 | timeout: 5s 21 | retries: 3 22 | start_period: 5s 23 | restart: unless-stopped 24 | chromadb: 25 | image: ${CHROMA_IMAGE:-chromadb/chroma}:${CHROMA_VERSION:-0.5.23} 26 | container_name: chromadb 27 | ports: 28 | - "8000:8000" 29 | env_file: 30 | - ../../config/.env 31 | healthcheck: 32 | test: [ "CMD", "curl", "-f", "${CHROMA_PRIVATE_URL}/api/v1/heartbeat" ] 33 | interval: 5s 34 | timeout: 5s 35 | retries: 5 36 | start_period: 5s 37 | restart: unless-stopped 38 | arq_worker: 39 | container_name: arq_worker 40 | image: ${WORKER_IMAGE:-kollektiv-worker}:${WORKER_VERSION:-latest} 41 | build: 42 | context: ../../ 43 | dockerfile: scripts/docker/Dockerfile 44 | entrypoint: ["/app/scripts/docker/entrypoint.sh"] 45 | command: ["watchfiles", "poetry run worker", "/app/src"] 46 | env_file: 47 | - ../../config/.env 48 | volumes: 49 | - ../../:/app # this allows locally to not relaunch as code changes 50 | environment: 51 | - REDIS_URL=redis://redis:6379 # need to connect to the redis service in the container, not localhost 52 | - CHROMA_PRIVATE_URL=http://chromadb:8000 # need to connect to the chromadb service in the container, not localhost 53 | - SERVICE=worker 54 | depends_on: 55 | redis: 56 | condition: service_healthy 57 | chromadb: 58 | condition: service_healthy 59 | healthcheck: 60 | test: [ "CMD", "poetry", "run", "arq", "--check", "src.infra.arq.worker:WorkerSettings"] 61 | interval: 30s 62 | timeout: 10s 63 | retries: 3 64 | start_period: 5s 65 | restart: unless-stopped 66 | extra_hosts: 67 | - "host.docker.internal:host-gateway" 68 | -------------------------------------------------------------------------------- /src/infra/external/supabase_manager.py: -------------------------------------------------------------------------------- 1 | from supabase import AsyncClient, AuthRetryableError, NotConnectedError, create_async_client 2 | 3 | from src.infra.decorators import tenacity_retry_wrapper 4 | from src.infra.logger import get_logger 5 | from src.infra.settings import settings 6 | 7 | logger = get_logger() 8 | 9 | 10 | class SupabaseManager: 11 | """Efficient Supabase client manager with immediate connection and retrying logic.""" 12 | 13 | def __init__( 14 | self, supabase_url: str = settings.supabase_url, service_role_key: str = settings.supabase_service_role_key 15 | ) -> None: 16 | """Initialize Supabase client and connect immediately.""" 17 | self.supabase_url = supabase_url 18 | self.service_role_key = service_role_key 19 | self._client: AsyncClient | None = None # instance of Supabase async client 20 | 21 | @tenacity_retry_wrapper(exceptions=(AuthRetryableError, NotConnectedError)) 22 | async def _connect_async(self) -> None: 23 | """Connect to Supabase, retry if connection fails.""" 24 | if self._client is None: 25 | try: 26 | logger.debug( 27 | f"Attempting to connect to Supabase at: {self.supabase_url} with key partially masked as:" 28 | f"{self.service_role_key[:5]}..." 29 | ) 30 | self._client = await create_async_client( 31 | supabase_url=self.supabase_url, 32 | supabase_key=self.service_role_key, 33 | ) 34 | logger.info("✓ Initialized Supabase client successfully") 35 | except (AuthRetryableError, NotConnectedError) as e: 36 | logger.exception(f"Failed to initialize Supabase client: {e}") 37 | raise 38 | 39 | @classmethod 40 | async def create_async(cls) -> "SupabaseManager": 41 | """Factory method to elegantly create a Supabase client instance and connect immediately.""" 42 | instance = cls() 43 | await instance._connect_async() 44 | return instance 45 | 46 | async def get_async_client(self) -> AsyncClient: 47 | """Wrapper method to get the connected client instance or reconnect if necessary.""" 48 | await self._connect_async() 49 | return self._client 50 | -------------------------------------------------------------------------------- /src/infra/arq/redis_pool.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from arq import ArqRedis, create_pool 4 | 5 | from src.infra.arq.arq_settings import ArqSettings, get_arq_settings 6 | from src.infra.logger import get_logger 7 | from src.infra.settings import get_settings 8 | 9 | settings = get_settings() 10 | arq_settings = get_arq_settings() 11 | logger = get_logger() 12 | 13 | 14 | class RedisPool: 15 | """Manages ARQ Redis connection pool with consistent serialization.""" 16 | 17 | def __init__(self, arq_settings: ArqSettings = arq_settings) -> None: 18 | self.arq_settings = arq_settings 19 | self._pool: ArqRedis | None = None 20 | 21 | @property 22 | def is_connected(self) -> bool: 23 | """Check if the Redis connection pool is connected.""" 24 | return self._pool is not None 25 | 26 | async def initialize_pool(self) -> None: 27 | """Initialize the Redis connection pool, if not already initialized.""" 28 | if self.is_connected: 29 | return 30 | try: 31 | logger.debug("Initializing Redis connection pool...") 32 | self._pool = await create_pool( 33 | settings_=self.arq_settings.redis_settings, 34 | job_serializer=self.arq_settings.job_serializer, 35 | job_deserializer=self.arq_settings.job_deserializer, 36 | retry=self.arq_settings.connection_retries, 37 | ) 38 | logger.info("✓ Redis connection pool initialized successfully") 39 | except Exception as e: 40 | logger.exception(f"Failed to initialize Redis connection pool: {e}", exc_info=True) 41 | raise 42 | 43 | @classmethod 44 | async def create_redis_pool(cls) -> ArqRedis: 45 | """Create a RedisPool instance, initialize the pool and return it.""" 46 | instance = cls() 47 | 48 | # Initialize .pool if not set 49 | await instance.initialize_pool() 50 | if not instance._pool: 51 | raise RuntimeError("Redis connection pool not initialized") 52 | return instance._pool 53 | 54 | async def get_pool(self) -> ArqRedis: 55 | """Get the connected pool instance or reconnect if necessary.""" 56 | await self.initialize_pool() 57 | if not self._pool: 58 | raise RuntimeError("Redis connection pool not initialized") 59 | return self._pool 60 | -------------------------------------------------------------------------------- /src/models/vector_models.py: -------------------------------------------------------------------------------- 1 | """Holds all vector related models.""" 2 | 3 | from enum import Enum 4 | from typing import Any, Literal 5 | from uuid import UUID 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | from src.infra.logger import get_logger 10 | from src.models.base_models import SupabaseModel 11 | 12 | logger = get_logger() 13 | 14 | 15 | # Embedding models 16 | class EmbeddingProvider(str, Enum): 17 | """Enum for the embedding providers.""" 18 | 19 | OPENAI = "openai" 20 | COHERE = "cohere" 21 | 22 | 23 | class CohereEmbeddingModelName(str, Enum): 24 | """Enum for the Cohere embedding models.""" 25 | 26 | BASE_ENG = "embed-english-v3.0" 27 | BASE_MULTI = "embed-multilingual-v3.0" 28 | SMALL_ENG = "embed-english-v3.0-small" 29 | SMALL_MULTI = "embed-multilingual-v3.0-small" 30 | 31 | 32 | class OpenAIEmbeddingModelName(str, Enum): 33 | """Enum for the OpenAI embedding models.""" 34 | 35 | BASE = "text-embedding-3-large" 36 | SMALL = "text-embedding-3-small" 37 | 38 | 39 | # Vector search & retrieval models 40 | class UserQuery(BaseModel): 41 | """User query model.""" 42 | 43 | user_id: UUID = Field(..., description="User ID") 44 | query: str | list[str] = Field(..., description="User query, which can be a single string or a list of strings") 45 | 46 | 47 | class VectorSearchParams(BaseModel): 48 | """Default vector search params model.""" 49 | 50 | n_results: int = Field(..., description="Number of results to retrieve") 51 | where: dict[str, Any] = Field(..., description="Filtering conditions for the search") 52 | where_document: dict[str, Any] = Field(..., description="Filtering conditions for the documents") 53 | include: list[Literal["documents", "metadatas", "distances"]] = Field( 54 | ..., description="Fields to include in the search results" 55 | ) 56 | 57 | 58 | class VectorCollection(SupabaseModel): 59 | """Represents a vector collection consisting of vectors and metadata.""" 60 | 61 | user_id: UUID = Field(..., description="User ID to which the collection belongs to") 62 | documents_cnt: int = Field(default=0, description="Number of documents in the collection") 63 | deleted: bool = Field(default=False, description="Whether the collection is deleted") 64 | 65 | @property 66 | def name(self) -> str: 67 | """Generate collection name based on user_id.""" 68 | return str(self.user_id) 69 | -------------------------------------------------------------------------------- /tests/unit/infra/external/test_chroma_manager.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from src.infra.external.chroma_manager import ChromaManager 6 | 7 | 8 | @pytest.mark.unit 9 | class TestChromaManager: 10 | """Unit tests for ChromaManager.""" 11 | 12 | def test_initialization(self): 13 | """Test that ChromaManager initializes with None client.""" 14 | manager = ChromaManager() 15 | assert manager._client is None 16 | 17 | def test_parse_url(self): 18 | """Test URL parsing.""" 19 | test_url = "http://localhost:8000" 20 | host, port = ChromaManager._parse_url(test_url) 21 | assert host == "localhost" 22 | assert port == 8000 23 | 24 | @pytest.mark.asyncio 25 | async def test_connect_async(self, mock_chroma_client): 26 | """Test async connection.""" 27 | manager = ChromaManager() 28 | 29 | with patch("chromadb.AsyncHttpClient", return_value=mock_chroma_client): 30 | await manager._connect_async() 31 | 32 | assert manager._client is mock_chroma_client 33 | mock_chroma_client.heartbeat.assert_awaited_once() 34 | 35 | @pytest.mark.asyncio 36 | async def test_create_async(self, mock_chroma_client): 37 | """Test create_async class method.""" 38 | with patch("chromadb.AsyncHttpClient", return_value=mock_chroma_client): 39 | manager = await ChromaManager.create_async() 40 | 41 | assert isinstance(manager, ChromaManager) 42 | assert manager._client is mock_chroma_client 43 | mock_chroma_client.heartbeat.assert_awaited_once() 44 | 45 | @pytest.mark.asyncio 46 | async def test_get_async_client(self, mock_chroma_client): 47 | """Test get_async_client method.""" 48 | with patch("chromadb.AsyncHttpClient", return_value=mock_chroma_client): 49 | manager = ChromaManager() 50 | client = await manager.get_async_client() 51 | 52 | assert client is mock_chroma_client 53 | mock_chroma_client.heartbeat.assert_awaited_once() 54 | 55 | @pytest.mark.asyncio 56 | async def test_get_async_client_reuse(self, mock_chroma_client): 57 | """Test that get_async_client reuses existing client.""" 58 | with patch("chromadb.AsyncHttpClient", return_value=mock_chroma_client): 59 | manager = ChromaManager() 60 | 61 | client1 = await manager.get_async_client() 62 | client2 = await manager.get_async_client() 63 | 64 | assert client1 is client2 65 | assert mock_chroma_client.heartbeat.await_count == 1 # Only called once 66 | -------------------------------------------------------------------------------- /src/infra/arq/arq_settings.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from functools import lru_cache 3 | from typing import Any 4 | 5 | from arq.connections import RedisSettings 6 | from pydantic import Field 7 | from pydantic_settings import BaseSettings 8 | 9 | from src.infra.arq.serializer import deserialize, serialize 10 | from src.infra.logger import get_logger 11 | from src.infra.settings import get_settings 12 | 13 | settings = get_settings() 14 | logger = get_logger() 15 | 16 | 17 | class ArqSettings(BaseSettings): 18 | """Centralized settings for the ARQ worker and connection pool.""" 19 | 20 | # Cache for properties 21 | _redis_settings: RedisSettings | None = None 22 | _job_serializer: Callable[[Any], bytes] | None = None 23 | _job_deserializer: Callable[[bytes], Any] | None = None 24 | 25 | # Redis pool settings 26 | redis_host: str = settings.redis_host 27 | redis_port: int = settings.redis_port 28 | redis_user: str | None = settings.redis_user 29 | redis_password: str | None = settings.redis_password 30 | connection_retries: int = Field(5, description="Number of connection retries to redis connection pool") 31 | 32 | # Worker settings 33 | job_retries: int = Field(3, description="Number of default job retries, decreased from 5 to 3") 34 | health_check_interval: int = Field(60, description="Health check interval") 35 | max_jobs: int = Field(1000, description="Maximum number of jobs in the queue") 36 | 37 | @property 38 | def redis_settings(self) -> RedisSettings: 39 | """Get the Redis settings.""" 40 | if self._redis_settings is None: 41 | self._redis_settings = RedisSettings.from_dsn(settings.redis_url) 42 | return self._redis_settings 43 | 44 | @property 45 | def job_serializer(self) -> Callable[[Any], bytes]: 46 | """Get the serializer for the ARQ worker and redis pool.""" 47 | if self._job_serializer is None: 48 | self._job_serializer = serialize 49 | return self._job_serializer 50 | 51 | @property 52 | def job_deserializer(self) -> Callable[[bytes], Any]: 53 | """Get the deserializer for the ARQ worker and redis pool.""" 54 | if self._job_deserializer is None: 55 | self._job_deserializer = deserialize 56 | return self._job_deserializer 57 | 58 | 59 | def initialize_arq_settings() -> ArqSettings: 60 | """Initialize the ARQ settings.""" 61 | return ArqSettings() 62 | 63 | 64 | arq_settings = initialize_arq_settings() 65 | 66 | 67 | @lru_cache 68 | def get_arq_settings() -> ArqSettings: 69 | """Get the cached ARQ settings.""" 70 | return arq_settings 71 | -------------------------------------------------------------------------------- /tests/integration/infra/arq/test_arq_settings_integration.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from arq.connections import RedisSettings 3 | 4 | from src.infra.arq.arq_settings import ArqSettings, get_arq_settings 5 | from src.infra.arq.serializer import deserialize, serialize 6 | from src.infra.settings import get_settings 7 | 8 | settings = get_settings() 9 | 10 | 11 | @pytest.fixture 12 | def arq_settings(): 13 | """Get a fresh instance of ARQ settings.""" 14 | return ArqSettings() 15 | 16 | 17 | def test_redis_settings_integration(arq_settings): 18 | """Test Redis settings integration with actual configuration.""" 19 | redis_settings = arq_settings.redis_settings 20 | 21 | # Test that we get a proper RedisSettings instance 22 | assert isinstance(redis_settings, RedisSettings) 23 | 24 | # Test that settings are created from DSN 25 | expected_settings = RedisSettings.from_dsn(settings.redis_url) 26 | assert redis_settings.host == expected_settings.host 27 | assert redis_settings.port == expected_settings.port 28 | assert redis_settings.username == expected_settings.username 29 | assert redis_settings.password == expected_settings.password 30 | 31 | # Test caching behavior 32 | cached_settings = arq_settings.redis_settings 33 | assert cached_settings is redis_settings # Should return the same instance 34 | 35 | 36 | def test_redis_settings_from_env(arq_settings): 37 | """Test that Redis settings match environment configuration.""" 38 | redis_settings = arq_settings.redis_settings 39 | 40 | # Should match our environment settings 41 | assert redis_settings.host == settings.redis_host 42 | assert redis_settings.port == settings.redis_port 43 | assert redis_settings.username == settings.redis_user 44 | assert redis_settings.password == settings.redis_password 45 | 46 | 47 | def test_serializer_integration(arq_settings): 48 | """Test serializer/deserializer integration with actual implementation.""" 49 | assert arq_settings.job_serializer == serialize 50 | assert arq_settings.job_deserializer == deserialize 51 | 52 | # Test that they're actually callable 53 | assert callable(arq_settings.job_serializer) 54 | assert callable(arq_settings.job_deserializer) 55 | 56 | 57 | def test_settings_singleton(): 58 | """Test that get_arq_settings returns the same instance.""" 59 | settings1 = get_arq_settings() 60 | settings2 = get_arq_settings() 61 | 62 | assert settings1 is settings2 63 | assert isinstance(settings1, ArqSettings) 64 | 65 | # Test that the settings are properly configured 66 | assert settings1.job_retries == 3 67 | assert settings1.health_check_interval == 60 68 | assert settings1.max_jobs == 1000 69 | assert settings1.connection_retries == 5 70 | -------------------------------------------------------------------------------- /src/api/middleware/rate_limit.py: -------------------------------------------------------------------------------- 1 | # TO DO: transition to a common supported rate limit library 2 | import time 3 | from collections import defaultdict 4 | 5 | from fastapi import Request, Response 6 | from starlette.middleware.base import BaseHTTPMiddleware 7 | 8 | from src.api.routes import Routes 9 | from src.infra.logger import get_logger 10 | 11 | logger = get_logger() 12 | 13 | 14 | class HealthCheckRateLimit(BaseHTTPMiddleware): 15 | """Rate limit for the health endpoint""" 16 | 17 | def __init__(self, app, requests_per_minute: int = 60, cleanup_interval: int = 300): 18 | super().__init__(app) 19 | self.requests_per_minute = requests_per_minute 20 | self.requests = defaultdict(list) 21 | self.last_cleanup = time.time() 22 | self.cleanup_interval = cleanup_interval 23 | 24 | async def dispatch(self, request: Request, call_next): 25 | """Rate limit for the health endpoint""" 26 | if request.url.path == f"{Routes.System.HEALTH}": 27 | now = time.time() 28 | client_ip = request.client.host 29 | 30 | if now - self.last_cleanup > self.cleanup_interval: 31 | self._cleanup_old_data(now) 32 | self.last_cleanup = now 33 | 34 | self.requests[client_ip] = [req_time for req_time in self.requests[client_ip] if req_time > now - 60] 35 | 36 | if len(self.requests[client_ip]) >= self.requests_per_minute: 37 | logger.warning(f"Rate limit exceeded for IP {client_ip} on health endpoint") 38 | return Response( 39 | "Rate limit exceeded", 40 | status_code=429, 41 | headers={ 42 | "Retry-After": "60", 43 | "X-RateLimit-Limit": str(self.requests_per_minute), 44 | "X-RateLimit-Remaining": "0", 45 | "X-RateLimit-Reset": str(int(now + 60)), 46 | }, 47 | ) 48 | 49 | self.requests[client_ip].append(now) 50 | 51 | response = await call_next(request) 52 | remaining = self.requests_per_minute - len(self.requests[client_ip]) 53 | response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute) 54 | response.headers["X-RateLimit-Remaining"] = str(remaining) 55 | response.headers["X-RateLimit-Reset"] = str(int(now + 60)) 56 | return response 57 | 58 | return await call_next(request) 59 | 60 | def _cleanup_old_data(self, now: float): 61 | """Remove IPs that haven't made requests in the last minute""" 62 | for ip in list(self.requests.keys()): 63 | if not any(t > now - 60 for t in self.requests[ip]): 64 | del self.requests[ip] 65 | -------------------------------------------------------------------------------- /tests/integration/infra/external/test_redis_manager_integration.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from src.infra.external.redis_manager import RedisManager 4 | 5 | 6 | @pytest.mark.integration 7 | class TestRedisManagerIntegration: 8 | """Integration tests for RedisManager with real Redis.""" 9 | 10 | @pytest.mark.asyncio 11 | async def test_connection_and_operations(self, redis_integration_manager): 12 | """Test that RedisManager can connect and perform operations.""" 13 | client = await redis_integration_manager.get_async_client() 14 | 15 | # Verify basic operations work 16 | await client.set("test_key", "test_value") 17 | result = await client.get("test_key") 18 | assert result == "test_value" 19 | 20 | def test_sync_client_initialization(self): 21 | """Test sync client creation and connection with real Redis.""" 22 | manager = RedisManager() 23 | client = manager.get_sync_client() 24 | 25 | # Verify client works 26 | client.set("test_key", "test_value") 27 | result = client.get("test_key") 28 | assert result == "test_value" 29 | 30 | @pytest.mark.asyncio 31 | async def test_client_reuse(self): 32 | """Test that both sync and async clients are reused.""" 33 | manager = RedisManager() 34 | 35 | # Test sync client reuse 36 | sync1 = manager.get_sync_client() 37 | sync2 = manager.get_sync_client() 38 | assert sync1 is sync2 39 | 40 | # Test async client reuse 41 | async1 = await manager.get_async_client() 42 | async2 = await manager.get_async_client() 43 | assert async1 is async2 44 | 45 | @pytest.mark.asyncio 46 | async def test_decode_responses_setting(self, redis_integration_manager): 47 | """Test decode_responses setting is respected.""" 48 | # First clean up any existing data 49 | client = await redis_integration_manager.get_async_client() 50 | await client.flushall() 51 | 52 | # Test with decode_responses=True (default) 53 | await client.set("text_test", "value") 54 | result = await client.get("text_test") 55 | assert isinstance(result, str) 56 | assert result == "value" 57 | 58 | # Create a new manager with decode_responses=False 59 | binary_manager = RedisManager(decode_responses=False) 60 | try: 61 | binary_client = await binary_manager.get_async_client() 62 | await binary_client.set("binary_test", "value") 63 | result = await binary_client.get("binary_test") 64 | assert isinstance(result, bytes) 65 | assert result.decode() == "value" 66 | finally: 67 | if binary_manager._async_client: 68 | await binary_manager._async_client.close() 69 | binary_manager._async_client = None 70 | -------------------------------------------------------------------------------- /src/infra/events/event_publisher.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from uuid import UUID 3 | 4 | from pydantic import BaseModel 5 | from redis.exceptions import ConnectionError, TimeoutError 6 | 7 | from src.infra.decorators import tenacity_retry_wrapper 8 | from src.infra.external.redis_manager import RedisManager 9 | from src.infra.logger import get_logger 10 | from src.models.content_models import ContentProcessingEvent, SourceStage 11 | from src.models.pubsub_models import EventType, KollektivEvent 12 | 13 | logger = get_logger() 14 | 15 | 16 | class EventPublisher: 17 | """Responsible for publishing events to the event bus.""" 18 | 19 | def __init__(self, redis_manager: RedisManager) -> None: 20 | self.redis_manager = redis_manager 21 | 22 | async def publish(self, channel: str, message: str) -> None: 23 | """Simple wrapper around publish_event.""" 24 | # get the client 25 | client = await self.redis_manager.get_async_client() 26 | # publish the message 27 | await client.publish(channel, message) 28 | logger.info(f"Event published to {channel}: {message}") 29 | 30 | @classmethod 31 | async def create_async(cls, redis_manager: RedisManager) -> "EventPublisher": 32 | """Creates an instance of EventPublisher.""" 33 | instance = cls(redis_manager) 34 | return instance 35 | 36 | @classmethod 37 | def create_event( 38 | cls, 39 | stage: SourceStage, 40 | source_id: UUID, 41 | error: str | None = None, 42 | metadata: dict[str, Any] | None = None, 43 | ) -> KollektivEvent: 44 | """Creates a type of KollektivEvent. 45 | 46 | Args: 47 | source_id: ID of the source being processed 48 | stage: Enum of Sources tages 49 | error: Optional error message if something went wrong 50 | metadata: Optional metadata about the event 51 | """ 52 | return ContentProcessingEvent( 53 | source_id=source_id, 54 | event_type=EventType.CONTENT_PROCESSING, # This is fixed for content processing events 55 | stage=stage, # The stage parameter maps to what was previously event_type 56 | error=error, 57 | metadata=metadata, 58 | ) 59 | 60 | @tenacity_retry_wrapper(exceptions=(ConnectionError, TimeoutError)) 61 | async def publish_event( 62 | self, 63 | channel: str, 64 | message: BaseModel, 65 | ) -> None: 66 | """ 67 | Publish a message to the event bus. 68 | 69 | Args: 70 | channel: The channel to publish to 71 | message: The message to publish (will be JSON serialized) 72 | """ 73 | try: 74 | await self.publish(channel=channel, message=message.model_dump_json()) 75 | except (ConnectionError, TimeoutError) as e: 76 | logger.exception(f"Failed to publish event to {channel}: {e}") 77 | raise 78 | -------------------------------------------------------------------------------- /src/api/v0/endpoints/webhooks.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Request, status 2 | 3 | from src.api.dependencies import ContentServiceDep 4 | from src.api.routes import Routes 5 | from src.api.v0.schemas.webhook_schemas import WebhookResponse 6 | from src.infra.logger import get_logger 7 | from src.services.webhook_handler import FireCrawlWebhookHandler 8 | 9 | logger = get_logger() 10 | router = APIRouter() 11 | 12 | 13 | @router.post( 14 | path=Routes.System.Webhooks.FIRECRAWL, 15 | response_model=None, 16 | status_code=status.HTTP_200_OK, 17 | description="Handle FireCrawl webhook callbacks", 18 | ) 19 | async def handle_firecrawl_webhook( 20 | request: Request, content_service: ContentServiceDep 21 | ) -> WebhookResponse | HTTPException: 22 | """Handle FireCrawl webhook callbacks. 23 | 24 | Args: 25 | request: The incoming webhook request 26 | content_service: Injected content service dependency 27 | 28 | Returns: 29 | WebhookResponse: Response indicating successful processing 30 | 31 | Raises: 32 | HTTPException: 400 if webhook payload is invalid 33 | HTTPException: 500 if processing fails for other reasons 34 | """ 35 | logger.debug(f"Receiving webhook at: {request.url}") 36 | try: 37 | handler = FireCrawlWebhookHandler() 38 | 39 | # Get raw payload 40 | raw_payload = await request.json() 41 | 42 | try: 43 | # Parse the webhook payload 44 | parsed_payload = handler._parse_firecrawl_payload(data=raw_payload) 45 | logger.debug(f"Parsed event type: {parsed_payload.event_type}") 46 | except ValueError as ve: 47 | logger.error(f"Invalid webhook payload: {str(ve)}") 48 | raise HTTPException( 49 | status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid webhook payload: {str(ve)}" 50 | ) from ve 51 | 52 | # Create internal event object 53 | webhook_event = handler._create_webhook_event(event_data=parsed_payload, raw_payload=raw_payload) 54 | 55 | # Process the event 56 | try: 57 | await content_service.handle_webhook_event(event=webhook_event) 58 | except Exception as e: 59 | logger.error(f"Error processing webhook event: {str(e)}") 60 | raise HTTPException( 61 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error processing webhook: {str(e)}" 62 | ) from e 63 | 64 | # Return response to webhook sender 65 | return handler._create_webhook_response(event=webhook_event) 66 | 67 | except HTTPException: 68 | # Re-raise HTTP exceptions 69 | raise 70 | except Exception as e: 71 | # Catch any other unexpected errors 72 | logger.error(f"Unexpected error in webhook handler: {str(e)}") 73 | raise HTTPException( 74 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error: {str(e)}" 75 | ) from e 76 | -------------------------------------------------------------------------------- /src/core/chat/tool_manager.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from anthropic.types import ToolChoiceToolParam, ToolParam 5 | from pydantic import ValidationError 6 | 7 | from src.core._exceptions import NonRetryableLLMError 8 | from src.infra.logger import get_logger 9 | from src.infra.settings import settings 10 | from src.models.llm_models import Tool, ToolInputSchema, ToolName 11 | 12 | logger = get_logger() 13 | 14 | 15 | class ToolManager: 16 | """Manage LLM tools following Anthropic's API spec.""" 17 | 18 | def __init__(self, tools_dir: Path = settings.tools_dir, tools_file: str = settings.tools_file) -> None: 19 | self.tools_path = tools_dir / tools_file 20 | self.tools: dict[ToolName, Tool] = {} 21 | self._load_tools() 22 | 23 | def _load_tools(self) -> None: 24 | """Load and validate tools from YAML file.""" 25 | try: 26 | with open(self.tools_path) as f: 27 | raw_tools = yaml.safe_load(f) 28 | 29 | for name, tool_data in raw_tools.items(): 30 | # Ensure input_schema is properly structured 31 | if "input_schema" in tool_data: 32 | tool_data["input_schema"] = ToolInputSchema(**tool_data["input_schema"]) 33 | 34 | # Convert string name to ToolName enum 35 | tool_name = ToolName(name) # This validates the name matches an enum value 36 | self.tools[tool_name] = Tool(**tool_data) 37 | 38 | except (yaml.YAMLError, ValidationError) as e: 39 | logger.error(f"Failed to load tools: {e}", exc_info=True) 40 | raise 41 | 42 | def get_all_tools(self) -> list[Tool]: 43 | """Get all tools with caching enabled.""" 44 | return list(self.tools.values()) 45 | 46 | def get_tool(self, name: ToolName) -> ToolParam: 47 | """Get a specific tool by name.""" 48 | try: 49 | tool = self.tools.get(name) 50 | if not tool: 51 | raise KeyError(f"Tool {name} not found") 52 | return ToolParam(name=tool.name, input_schema=tool.input_schema, description=tool.description) 53 | except KeyError as e: 54 | logger.error(f"Tool {name} not found") 55 | raise NonRetryableLLMError(original_error=e, message=f"Tool {name} not found") from e 56 | 57 | def force_tool_choice(self, name: ToolName) -> ToolChoiceToolParam: 58 | """Forces Claude to always use the tool.""" 59 | return ToolChoiceToolParam(type="tool", name=name.value) 60 | 61 | 62 | if __name__ == "__main__": 63 | # Initialize manager 64 | manager = ToolManager() 65 | 66 | # Test tool retrieval and conversion 67 | rag_tool = manager.get_tool(ToolName.RAG_SEARCH) 68 | # print(f"RAG tool: {rag_tool}") 69 | 70 | my_tool = Tool.from_tool_param(rag_tool) 71 | print(f"My tool: {my_tool}") 72 | 73 | # Test tool choice 74 | tool_choice = manager.force_tool_choice(ToolName.RAG_SEARCH) 75 | print(f"Tool choice: {tool_choice}") 76 | -------------------------------------------------------------------------------- /src/models/llm_models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Literal 3 | 4 | from anthropic.types.text_block_param import TextBlockParam 5 | from anthropic.types.tool_param import ToolParam 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class CacheControl(BaseModel): 10 | """Cache control for Anthropic API.""" 11 | 12 | type: Literal["ephemeral"] = "ephemeral" 13 | 14 | 15 | class ToolInputSchema(BaseModel): 16 | """Base model for tool input schema validation.""" 17 | 18 | type: Literal["object"] = "object" 19 | properties: dict[str, Any] 20 | required: list[str] | None = None 21 | 22 | 23 | class ToolName(str, Enum): 24 | """Tool names.""" 25 | 26 | RAG_SEARCH = "rag_search" 27 | MULTI_QUERY = "multi_query_tool" 28 | SUMMARY = "summary_tool" 29 | 30 | 31 | class Tool(BaseModel): 32 | """Tool definition for LLM following Anthropic's API spec.""" 33 | 34 | name: ToolName = Field(..., description="Tool name. Must match regex ^[a-zA-Z0-9_-]{1,64}$") 35 | description: str = Field(..., description="Detailed description of what the tool does and when to use it") 36 | input_schema: ToolInputSchema = Field(..., description="JSON Schema defining expected parameters") 37 | cache_control: CacheControl | None = None 38 | 39 | @classmethod 40 | def from_tool_param(cls, tool_param: ToolParam) -> "Tool": 41 | """Convert Anthropic's ToolParam to our Tool model.""" 42 | return cls( 43 | name=ToolName(tool_param["name"]), 44 | description=tool_param["description"], 45 | input_schema=tool_param["input_schema"], 46 | ) 47 | 48 | def with_cache(self) -> dict[str, Any]: 49 | """Return tool definition with caching enabled.""" 50 | data = self.model_dump() 51 | data["cache_control"] = CacheControl().model_dump() 52 | return data 53 | 54 | def without_cache(self) -> dict[str, Any]: 55 | """Return tool definition without caching.""" 56 | data = self.model_dump() 57 | data.pop("cache_control", None) 58 | return data 59 | 60 | 61 | class PromptType(str, Enum): 62 | """Enum for prompt types for PromptManager.""" 63 | 64 | LLM_ASSISTANT_PROMPT = "llm_assistant_prompt" # Used for the LLM assistant 65 | MULTI_QUERY_PROMPT = "multi_query_prompt" # Used for the multi-query prompt 66 | SUMMARY_PROMPT = "summary_prompt" # Used for the summary prompt 67 | 68 | 69 | class SystemPrompt(BaseModel): 70 | """System prompt model for Anthropic LLMs.""" 71 | 72 | type: Literal["text"] = "text" 73 | text: str 74 | cache_control: CacheControl | None = None 75 | 76 | def with_cache(self) -> TextBlockParam: 77 | """Return prompt with caching enabled.""" 78 | data = self.model_dump() 79 | data["cache_control"] = CacheControl().model_dump() 80 | return data 81 | 82 | def without_cache(self) -> dict[str, Any]: 83 | """Return prompt without caching.""" 84 | data = self.model_dump() 85 | data.pop("cache_control", None) 86 | return data 87 | -------------------------------------------------------------------------------- /src/core/search/reranker.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import cohere 4 | from cohere.v2.types import V2RerankResponse 5 | 6 | from src.infra.logger import get_logger 7 | from src.infra.settings import settings 8 | 9 | logger = get_logger() 10 | 11 | 12 | class Reranker: 13 | """ 14 | Initializes and manages a Cohere Client for document re-ranking. 15 | 16 | Args: 17 | cohere_api_key (str): API key for the Cohere service. 18 | model_name (str): Name of the model to use for re-ranking. Defaults to "rerank-english-v3.0". 19 | """ 20 | 21 | def __init__(self, cohere_api_key: str = settings.cohere_api_key, model_name: str = "rerank-english-v3.0"): 22 | self.cohere_api_key = cohere_api_key 23 | self.model_name = model_name 24 | self.client = None 25 | 26 | self._init() 27 | 28 | def _init(self) -> None: 29 | try: 30 | self.client = cohere.ClientV2(api_key=self.cohere_api_key) 31 | logger.info("✓ Initialized Cohere client successfully") 32 | except Exception as e: 33 | logger.error(f"Failed to initialize Cohere client: {e}", exc_info=True) 34 | raise 35 | 36 | def extract_documents_list(self, unique_documents: dict[str, Any]) -> list[str]: 37 | """ 38 | Extract the 'text' field from each unique document. 39 | 40 | Args: 41 | unique_documents (dict[str, Any]): A dictionary where each value is a document represented as a dictionary 42 | with a 'text' field. 43 | 44 | Returns: 45 | list[str]: A list containing the 'text' field from each document. 46 | 47 | """ 48 | # extract the 'text' field from each unique document 49 | document_texts = [chunk["text"] for chunk in unique_documents.values()] 50 | return document_texts 51 | 52 | def rerank(self, query: str, documents: dict[str, Any], return_documents: bool = True) -> V2RerankResponse: 53 | """ 54 | Rerank a list of documents based on their relevance to a given query. 55 | 56 | Args: 57 | query (str): The search query to rank the documents against. 58 | documents (dict[str, Any]): A dictionary containing documents to be reranked. 59 | return_documents (bool): A flag indicating whether to return the full documents. Defaults to True. 60 | 61 | Returns: 62 | RerankResponse: The reranked list of documents and their relevance scores. 63 | 64 | Raises: 65 | SomeSpecificException: If an error occurs during the reranking process. 66 | 67 | """ 68 | # extract list of documents 69 | document_texts = self.extract_documents_list(documents) 70 | 71 | if document_texts: 72 | # get indexed results 73 | response = self.client.rerank( 74 | model=self.model_name, query=query, documents=document_texts, return_documents=return_documents 75 | ) 76 | else: 77 | raise ValueError("No documents to rerank") 78 | 79 | logger.debug(f"Received {len(response.results)} documents from Cohere.") 80 | return response 81 | -------------------------------------------------------------------------------- /src/services/job_manager.py: -------------------------------------------------------------------------------- 1 | # job_manager.py 2 | from typing import Any 3 | from uuid import UUID 4 | 5 | from src.core._exceptions import JobNotFoundError 6 | from src.infra.logger import get_logger 7 | from src.models.job_models import ( 8 | CrawlJobDetails, 9 | Job, 10 | JobType, 11 | ProcessingJobDetails, 12 | ) 13 | from src.services.data_service import DataService 14 | 15 | logger = get_logger() 16 | 17 | 18 | class JobManager: 19 | """Manages job lifecycle and operations.""" 20 | 21 | def __init__(self, data_service: DataService) -> None: 22 | self.data_service = data_service 23 | 24 | async def create_job(self, job_type: JobType, details: CrawlJobDetails | ProcessingJobDetails) -> Job: 25 | """ 26 | Create and persist a new job. 27 | 28 | Args: 29 | job_type: Type of job to create 30 | details: Details to create 31 | 32 | Returns: 33 | Job: The created job instance 34 | """ 35 | job = Job(job_type=job_type, details=details) 36 | await self.data_service.save_job(job) 37 | return job 38 | 39 | async def update_job(self, job_id: UUID, updates: dict[str, Any]) -> Job: 40 | """ 41 | Update a job with new data. 42 | 43 | Args: 44 | job_id: UUID of the job to update 45 | updates: Dictionary of updates to apply 46 | 47 | Returns: 48 | Job: Updated job instance 49 | 50 | Raises: 51 | JobNotFoundError: If job with given ID doesn't exist 52 | """ 53 | # Get current job state 54 | job = await self.data_service.get_job(job_id) 55 | if not job: 56 | raise JobNotFoundError(f"Job {job_id} not found") 57 | 58 | # Apply updates 59 | updated_job = job.update(**updates) 60 | 61 | # Persist and return 62 | return await self.data_service.save_job(updated_job) 63 | 64 | async def get_by_firecrawl_id(self, firecrawl_id: str) -> Job: 65 | """ 66 | Retrieve a job by its FireCrawl ID. 67 | 68 | Args: 69 | firecrawl_id: FireCrawl identifier 70 | 71 | Returns: 72 | Job: The requested job instance 73 | 74 | Raises: 75 | JobNotFoundError: If job with given FireCrawl ID doesn't exist 76 | """ 77 | job = await self.data_service.get_by_firecrawl_id(firecrawl_id) 78 | if not job: 79 | raise JobNotFoundError(f"Job with FireCrawl ID {firecrawl_id} not found") 80 | return job 81 | 82 | async def mark_job_completed(self, job_id: UUID, result_id: UUID | None = None) -> Job: 83 | """Mark a job as completed.""" 84 | job = await self.data_service.get_job(job_id) 85 | if not job: 86 | raise JobNotFoundError(f"Job {job_id} not found") 87 | job.complete() 88 | return await self.data_service.save_job(job) 89 | 90 | async def mark_job_failed(self, job_id: UUID, error: str) -> Job: 91 | """Mark a job as failed with error information.""" 92 | # job = await self.get_job(job_id) 93 | job = await self.data_service.get_job(job_id) 94 | if not job: 95 | raise JobNotFoundError(f"Job {job_id} not found") 96 | job.fail(error) 97 | return await self.data_service.save_job(job) 98 | -------------------------------------------------------------------------------- /src/infra/events/event_consumer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | 4 | from redis.asyncio.client import PubSub 5 | from redis.exceptions import ConnectionError, TimeoutError 6 | 7 | from src.infra.decorators import tenacity_retry_wrapper 8 | from src.infra.events.channels import Channels 9 | from src.infra.external.redis_manager import RedisManager 10 | from src.infra.logger import get_logger 11 | from src.infra.settings import settings 12 | from src.models.content_models import ContentProcessingEvent 13 | from src.services.content_service import ContentService 14 | 15 | logger = get_logger() 16 | 17 | 18 | class EventConsumer: 19 | """Consumes events from the event bus and dispatches them to the appropriate services.""" 20 | 21 | def __init__(self, redis_manager: RedisManager, content_service: ContentService) -> None: 22 | self.redis_manager = redis_manager 23 | self.content_service = content_service 24 | self.pubsub: PubSub | None = None 25 | 26 | @classmethod 27 | async def create_async(cls, redis_manager: RedisManager, content_service: ContentService) -> "EventConsumer": 28 | """Creates an instance of EventConsumer, connects to pub/sub, then starts listening for events.""" 29 | instance = cls(redis_manager=redis_manager, content_service=content_service) 30 | redis_client = await instance.redis_manager.get_async_client() 31 | instance.pubsub = redis_client.pubsub() 32 | await instance.subscribe_on_startup() 33 | return instance 34 | 35 | async def subscribe_on_startup(self) -> None: 36 | """Subscribe to the processing channel on startup.""" 37 | # Subscribe to all content processing events 38 | await self.pubsub.subscribe(f"{Channels.CONTENT_PROCESSING}/*") # Global subscriber pattern 39 | logger.info("✓ Event consumer subscribed successfully to content processing events") 40 | 41 | @tenacity_retry_wrapper(exceptions=(ConnectionError, TimeoutError)) 42 | async def start(self) -> None: 43 | """Start listening for events from the event bus.""" 44 | try: 45 | asyncio.create_task(self.listen_for_events()) 46 | 47 | except (ConnectionError, TimeoutError) as e: 48 | logger.exception(f"Failed to subscribe to channel {settings.process_documents_channel}: {e}") 49 | raise 50 | 51 | async def listen_for_events(self) -> None: 52 | """Listen for events from the event bus.""" 53 | try: 54 | async for message in self.pubsub.listen(): 55 | logger.debug(f"Received message: {message}") 56 | if message["type"] == "message": 57 | await self.handle_event(message["data"]) 58 | except (ConnectionError, TimeoutError) as e: 59 | logger.exception(f"Failed to listen for events: {e}") 60 | raise 61 | 62 | async def stop(self) -> None: 63 | """Stop listening for events from the event bus.""" 64 | if self.pubsub: 65 | await self.pubsub.unsubscribe() 66 | await self.pubsub.aclose() 67 | logger.info("✓ Event consumer stopped successfully") 68 | 69 | async def handle_event(self, message_data: bytes) -> None: 70 | """Handle an event from the event bus.""" 71 | try: 72 | message = ContentProcessingEvent(**json.loads(message_data)) 73 | 74 | logger.debug("Sending message to content service") 75 | await self.content_service.handle_pubsub_event(message) 76 | except Exception as e: 77 | logger.exception(f"Failed to handle event: {e}") 78 | raise 79 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # === Syntax & Basic Checks === 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v5.0.0 5 | hooks: 6 | - id: check-ast 7 | name: Validate Python syntax 8 | - id: check-yaml 9 | name: Validate YAML files 10 | - id: check-json 11 | name: Validate JSON files 12 | - id: check-toml 13 | name: Validate TOML files 14 | - id: mixed-line-ending 15 | name: Normalize line endings 16 | args: ['--fix=lf'] 17 | - id: trailing-whitespace 18 | name: Remove trailing whitespace 19 | - id: end-of-file-fixer 20 | name: Ensure file ends with newline 21 | 22 | # === Security === 23 | - repo: https://github.com/pre-commit/pre-commit-hooks 24 | rev: v5.0.0 25 | hooks: 26 | - id: detect-private-key 27 | name: Check for private keys 28 | stages: [pre-commit, pre-push, manual] 29 | - id: check-merge-conflict 30 | name: Check for merge conflicts 31 | stages: [pre-commit, manual] 32 | - id: debug-statements 33 | name: Check for debugger imports 34 | stages: [pre-commit, manual] 35 | 36 | # === Type Checking === 37 | 38 | - repo: https://github.com/pre-commit/mirrors-mypy 39 | rev: "v1.13.0" 40 | hooks: 41 | - id: mypy 42 | name: Run mypy type checker 43 | args: [ 44 | "--config-file=pyproject.toml", 45 | "--show-error-codes", 46 | "--pretty", 47 | ] 48 | additional_dependencies: [ 49 | "types-requests", 50 | "types-aiofiles", 51 | "types-pytz", 52 | "pydantic", 53 | "chainlit", 54 | "anthropic", 55 | "fastapi", 56 | "httpx", 57 | "tiktoken", 58 | "weave", 59 | "chromadb", 60 | "cohere", 61 | "langchain" 62 | ] 63 | entry: bash -c 'mypy "$@" || true' -- 64 | 65 | # === Code Quality & Style === 66 | - repo: https://github.com/astral-sh/ruff-pre-commit 67 | rev: v0.7.2 68 | hooks: 69 | - id: ruff 70 | name: Run Ruff linter 71 | args: [ 72 | --fix, 73 | --exit-zero, 74 | --quiet, 75 | ] 76 | types_or: [python, pyi, jupyter] 77 | files: ^(src|tests)/ 78 | exclude: ^src/experimental/ 79 | verbose: false 80 | - id: ruff-format 81 | name: Run Ruff formatter 82 | types_or: [python, pyi, jupyter] 83 | 84 | 85 | # === Testing & Documentation === 86 | - repo: local 87 | hooks: 88 | - id: pytest-unit 89 | name: Unit tests 90 | entry: pytest 91 | language: system 92 | pass_filenames: false 93 | always_run: true 94 | stages: [pre-commit, manual] 95 | args: [ 96 | "-v", # Verbose output 97 | "--no-header", 98 | "-ra", # Show extra test summary info 99 | "--color=yes", # Colorized output 100 | "tests/unit/", 101 | ] 102 | verbose: true 103 | 104 | - id: pytest-integration 105 | name: Integration tests 106 | entry: pytest 107 | language: system 108 | pass_filenames: false 109 | always_run: false 110 | stages: [pre-commit, manual] 111 | args: [ 112 | "-v", # Verbose output 113 | "--no-header", 114 | "-ra", # Show extra test summary info 115 | "--color=yes", # Colorized output 116 | "tests/integration/", 117 | ] 118 | verbose: true 119 | -------------------------------------------------------------------------------- /tests/unit/infra/arq/test_arq_settings.py: -------------------------------------------------------------------------------- 1 | from arq.connections import RedisSettings 2 | 3 | from src.infra.arq.arq_settings import ArqSettings, get_arq_settings 4 | from src.infra.arq.serializer import deserialize, serialize 5 | from src.infra.settings import get_settings 6 | 7 | settings = get_settings() 8 | 9 | 10 | def test_arq_settings_initialization(): 11 | """Test basic initialization of ARQ settings.""" 12 | arq_settings = ArqSettings() 13 | 14 | # Test default values 15 | assert arq_settings.job_retries == 3 16 | assert arq_settings.health_check_interval == 60 17 | assert arq_settings.max_jobs == 1000 18 | assert arq_settings.connection_retries == 5 19 | 20 | 21 | def test_redis_settings_property(): 22 | """Test Redis settings property creates correct RedisSettings instance.""" 23 | arq_settings = ArqSettings() 24 | 25 | # First call - should create new settings 26 | redis_settings = arq_settings.redis_settings 27 | assert isinstance(redis_settings, RedisSettings) 28 | assert redis_settings.host == settings.redis_host 29 | assert redis_settings.port == settings.redis_port 30 | assert redis_settings.username == settings.redis_user 31 | assert redis_settings.password == settings.redis_password 32 | 33 | # Second call - should return cached settings 34 | cached_settings = arq_settings.redis_settings 35 | assert cached_settings is redis_settings # Should be same instance 36 | 37 | 38 | def test_job_serializer_property(): 39 | """Test job serializer property returns correct serializer function.""" 40 | arq_settings = ArqSettings() 41 | 42 | # First call - should set and return serializer 43 | serializer = arq_settings.job_serializer 44 | assert serializer == serialize 45 | assert callable(serializer) 46 | 47 | # Second call - should return cached serializer 48 | cached_serializer = arq_settings.job_serializer 49 | assert cached_serializer is serializer # Should be same instance 50 | 51 | 52 | def test_job_deserializer_property(): 53 | """Test job deserializer property returns correct deserializer function.""" 54 | arq_settings = ArqSettings() 55 | 56 | # First call - should set and return deserializer 57 | deserializer = arq_settings.job_deserializer 58 | assert deserializer == deserialize 59 | assert callable(deserializer) 60 | 61 | # Second call - should return cached deserializer 62 | cached_deserializer = arq_settings.job_deserializer 63 | assert cached_deserializer is deserializer # Should be same instance 64 | 65 | 66 | def test_get_arq_settings_caching(): 67 | """Test that get_arq_settings caches and returns the same instance.""" 68 | # First call 69 | settings1 = get_arq_settings() 70 | assert isinstance(settings1, ArqSettings) 71 | 72 | # Second call - should return same instance 73 | settings2 = get_arq_settings() 74 | assert settings2 is settings1 # Should be same instance 75 | 76 | 77 | def test_arq_settings_with_custom_values(): 78 | """Test ArqSettings respects custom values from environment.""" 79 | custom_settings = ArqSettings(job_retries=5, health_check_interval=120, max_jobs=2000, connection_retries=10) 80 | 81 | # Test the actual configurable values 82 | assert custom_settings.job_retries == 5 83 | assert custom_settings.health_check_interval == 120 84 | assert custom_settings.max_jobs == 2000 85 | assert custom_settings.connection_retries == 10 86 | 87 | # Redis settings should come from settings.redis_url, not direct host/port 88 | redis_settings = custom_settings.redis_settings 89 | assert isinstance(redis_settings, RedisSettings) 90 | # We should test that it uses the URL from settings 91 | assert redis_settings == RedisSettings.from_dsn(settings.redis_url) 92 | -------------------------------------------------------------------------------- /src/models/job_models.py: -------------------------------------------------------------------------------- 1 | # TODO: Transition to redis for job management. There is no reason to use custom implementation. 2 | from __future__ import annotations 3 | 4 | from datetime import UTC, datetime 5 | from enum import Enum 6 | from typing import Any, ClassVar 7 | from uuid import UUID, uuid4 8 | 9 | from pydantic import BaseModel, Field, PrivateAttr 10 | 11 | from src.models.base_models import SupabaseModel 12 | 13 | 14 | class JobStatus(str, Enum): 15 | """Job status enum.""" 16 | 17 | PENDING = "pending" 18 | IN_PROGRESS = "in_progress" 19 | COMPLETED = "completed" 20 | FAILED = "failed" 21 | CANCELLED = "cancelled" 22 | 23 | 24 | class JobType(str, Enum): 25 | """Represents the type of a job.""" 26 | 27 | CRAWL = "crawl" 28 | PROCESSING = "processing" 29 | 30 | 31 | class CrawlJobDetails(BaseModel): 32 | """Detailed information about a crawl job.""" 33 | 34 | source_id: UUID = Field( 35 | ..., 36 | description="Maps each crawl job to the Source object.", 37 | ) 38 | firecrawl_id: str | None = Field( 39 | default=None, description="Job id returned by FireCrawl. Added only if a jobs starts successfully" 40 | ) 41 | pages_crawled: int = Field(default=0, description="Number of pages crawled") 42 | url: str = Field(..., description="URL that was crawled") 43 | 44 | 45 | class ProcessingJobDetails(BaseModel): 46 | """Detailed information about a processing job - chunking and vector storage.""" 47 | 48 | source_id: UUID = Field(..., description="Maps each processing job to the Source object.") 49 | document_ids: list[UUID] = Field(..., description="List of document ids to be processed.") 50 | 51 | 52 | class Job(SupabaseModel): 53 | """Track crawl job status and progress""" 54 | 55 | # General job info 56 | job_id: UUID = Field(default_factory=uuid4, description="Internal job id in the system") 57 | status: JobStatus = Field(default=JobStatus.PENDING, description="Crawl job status in the system.") 58 | job_type: JobType = Field(..., description="Type of the job.") 59 | 60 | # Job details 61 | details: CrawlJobDetails | ProcessingJobDetails = Field(..., description="Detailed information about the job.") 62 | 63 | # Timing 64 | completed_at: datetime | None = Field(None, description="Completion timestamp") 65 | 66 | # Results 67 | error: str | None = None 68 | 69 | _db_config: ClassVar[dict] = {"schema": "infra", "table": "jobs", "primary_key": "job_id"} 70 | _protected_fields: set[str] = PrivateAttr(default={"job_id", "job_type", "created_at"}) 71 | 72 | def update(self, **kwargs: Any) -> Job: 73 | """Update job fields while preserving protected fields.""" 74 | if "details" in kwargs and isinstance(kwargs["details"], dict): 75 | # Get current details as dict 76 | current_details = self.details.model_dump() 77 | # Update with new values 78 | current_details.update(kwargs["details"]) 79 | # Replace details update with merged version 80 | kwargs["details"] = current_details 81 | 82 | return super().update(**kwargs) 83 | 84 | def complete(self) -> None: 85 | """Mark job as completed.""" 86 | if self.status == JobStatus.COMPLETED: 87 | return 88 | self.status = JobStatus.COMPLETED 89 | self.completed_at = datetime.now(UTC) 90 | 91 | def fail(self, error: str) -> None: 92 | """Mark job as failed.""" 93 | if self.status == JobStatus.FAILED: 94 | return 95 | self.status = JobStatus.FAILED 96 | self.error = error 97 | self.completed_at = datetime.now(UTC) 98 | 99 | class Config: 100 | """Configuration class for CrawlJob model.""" 101 | 102 | json_encoders = {datetime: lambda v: v.isoformat()} 103 | -------------------------------------------------------------------------------- /src/api/v0/schemas/webhook_schemas.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from enum import Enum 5 | from typing import Any, Generic, TypeVar 6 | from uuid import UUID, uuid4 7 | 8 | from pydantic import BaseModel, Field 9 | from pytz import UTC 10 | 11 | T_WebhookData = TypeVar("T_WebhookData", bound=BaseModel) 12 | 13 | 14 | class WebhookProvider(str, Enum): 15 | """Supported webhook providers.""" 16 | 17 | FIRECRAWL = "firecrawl" 18 | 19 | 20 | class WebhookEvent(BaseModel, Generic[T_WebhookData]): 21 | """Base model for all webhook events. 22 | 23 | Provides common tracking fields needed across all webhook types. 24 | 25 | """ 26 | 27 | event_id: UUID = Field(default_factory=lambda: uuid4(), description="Internal event tracking ID") 28 | timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), description="When the event was received") 29 | raw_payload: dict = Field(..., description="Original webhook payload") 30 | data: T_WebhookData = Field(..., description="Processed event data. Must be defined for each provider") 31 | provider: WebhookProvider = Field(..., description="The webhook provider") 32 | 33 | class Config: 34 | """Pydantic model configuration.""" 35 | 36 | json_encoders = {datetime: lambda v: v.isoformat()} 37 | 38 | 39 | class FireCrawlEventType(str, Enum): 40 | """Base webhook event types.""" 41 | 42 | # FireCrawl events - from docs 43 | CRAWL_STARTED = "crawl.started" 44 | CRAWL_PAGE = "crawl.page" 45 | CRAWL_COMPLETED = "crawl.completed" 46 | CRAWL_FAILED = "crawl.failed" 47 | 48 | 49 | class FireCrawlWebhookResponse(BaseModel): 50 | """FireCrawl specific webhook response model. 51 | 52 | From FireCrawl docs: 53 | - success: If the webhook was successful 54 | - event_type: The type of event that occurred 55 | - firecrawl_id: The ID of the crawl (mapped from 'id' in the webhook) 56 | - data: The data that was scraped (Array). Only non-empty on crawl.page 57 | - error: If the webhook failed, this will contain the error message 58 | """ 59 | 60 | # Aligned with FireCrawl Webhook Response https://docs.firecrawl.dev/features/crawl#webhook-events 61 | success: bool = Field(..., description="If the webhook was successful in crawling the page correctly.") 62 | event_type: FireCrawlEventType = Field(..., alias="type", description="The type of event that occurred") 63 | firecrawl_id: str = Field(..., alias="id", description="The ID of the crawl") 64 | data: list[dict[str, Any]] = Field( 65 | default_factory=list, 66 | description="The data that was scraped (Array). This will only be non empty on crawl.page and will contain 1 " 67 | " item if the page was scraped successfully. The response is the same as the /scrape endpoint.", 68 | ) 69 | error: str | None = Field(None, description="If the webhook failed, this will contain the error message.") 70 | 71 | class Config: 72 | """Pydantic model configuration.""" 73 | 74 | populate_by_name = True 75 | 76 | 77 | class WebhookResponse(BaseModel): 78 | """Standard response model for webhook endpoints. 79 | 80 | Provides consistent response format across all webhook types. 81 | """ 82 | 83 | status: str = Field("success", description="Status of webhook processing") 84 | message: str = Field(..., description="Human-readable processing result") 85 | event_id: UUID = Field(..., description="ID of the processed event") 86 | provider: WebhookProvider = Field(..., description="Provider that sent the webhook") 87 | timestamp: datetime = Field(default_factory=datetime.utcnow, description="When the webhook was processed") 88 | 89 | class Config: 90 | """Pydantic model configuration.""" 91 | 92 | json_encoders = {datetime: lambda v: v.isoformat()} 93 | 94 | 95 | class FireCrawlWebhookEvent(WebhookEvent[FireCrawlWebhookResponse]): 96 | """Concrete webhook event type for FireCrawl.""" 97 | 98 | pass 99 | -------------------------------------------------------------------------------- /src/infra/external/redis_manager.py: -------------------------------------------------------------------------------- 1 | from redis import Redis as SyncRedis 2 | from redis.asyncio import Redis as AsyncRedis 3 | from redis.exceptions import ConnectionError, TimeoutError 4 | 5 | from src.infra.decorators import tenacity_retry_wrapper 6 | from src.infra.logger import get_logger 7 | from src.infra.settings import settings 8 | 9 | logger = get_logger() 10 | 11 | 12 | class RedisManager: 13 | """Redis client that handles both sync and async connections.""" 14 | 15 | def __init__(self, decode_responses: bool = True) -> None: 16 | """Initialize Redis clients using settings configuration. 17 | 18 | Args: 19 | decode_responses: Whether to decode byte responses to strings. 20 | Note: RQ requires decode_responses=False 21 | """ 22 | self._decode_responses = decode_responses 23 | self._sync_client: SyncRedis | None = None 24 | self._async_client: AsyncRedis | None = None 25 | 26 | def _create_sync_client(self, decode_responses: bool) -> SyncRedis: 27 | """Create sync Redis client.""" 28 | client = SyncRedis.from_url( 29 | url=settings.redis_url, 30 | decode_responses=decode_responses, 31 | ) 32 | return client 33 | 34 | def _create_async_client(self, decode_responses: bool) -> AsyncRedis: 35 | """Create async Redis client.""" 36 | client = AsyncRedis.from_url( 37 | url=settings.redis_url, 38 | decode_responses=decode_responses, 39 | ) 40 | return client 41 | 42 | @tenacity_retry_wrapper(exceptions=(ConnectionError, TimeoutError)) 43 | def _connect_sync(self) -> None: 44 | """Connect to the sync redis client and handle connection errors.""" 45 | if self._sync_client is None: 46 | try: 47 | client = self._create_sync_client(decode_responses=self._decode_responses) 48 | client.ping() 49 | self._sync_client = client 50 | logger.info("✓ Initialized sync Redis client successfully") 51 | except (ConnectionError, TimeoutError) as e: 52 | logger.exception(f"Failed to initialize sync Redis client: {e}") 53 | raise 54 | 55 | @tenacity_retry_wrapper(exceptions=(ConnectionError, TimeoutError)) 56 | async def _connect_async(self) -> None: 57 | """Connect to the async redis client and handle connection errors.""" 58 | if self._async_client is None: 59 | try: 60 | client = self._create_async_client(decode_responses=self._decode_responses) 61 | await client.ping() 62 | self._async_client = client 63 | logger.info("✓ Initialized async Redis client successfully") 64 | except (ConnectionError, TimeoutError) as e: 65 | logger.exception(f"Failed to initialize async Redis client: {e}") 66 | raise 67 | 68 | @classmethod 69 | def create(cls, decode_responses: bool = True) -> "RedisManager": 70 | """Creates a new sync redis client""" 71 | instance = cls(decode_responses=decode_responses) 72 | instance._connect_sync() 73 | return instance 74 | 75 | @classmethod 76 | async def create_async(cls, decode_responses: bool = True) -> "RedisManager": 77 | """Creates a new async redis client""" 78 | instance = cls(decode_responses=decode_responses) 79 | await instance._connect_async() 80 | return instance 81 | 82 | def get_sync_client(self) -> SyncRedis: 83 | """Get the sync redis client""" 84 | self._connect_sync() 85 | if self._sync_client is None: 86 | raise RuntimeError("Sync Redis client not initialized") 87 | return self._sync_client 88 | 89 | async def get_async_client(self) -> AsyncRedis: 90 | """Get the async redis client""" 91 | await self._connect_async() 92 | if self._async_client is None: 93 | raise RuntimeError("Async Redis client not initialized") 94 | return self._async_client 95 | -------------------------------------------------------------------------------- /src/core/chat/tools/tools.yaml: -------------------------------------------------------------------------------- 1 | rag_search: 2 | name: rag_search 3 | description: | 4 | Retrieves relevant information from a local document database using RAG (Retrieval Augmented Generation) 5 | technology. This tool performs semantic search on a vector database containing pre-processed and embedded 6 | documents, returning the most relevant content based on the input query. 7 | 8 | What this tool does: 9 | 1. Takes a direct search query and performs vector similarity search 10 | 2. Retrieves and ranks the most relevant document chunks 11 | 3. Returns formatted text passages with relevance scores 12 | 13 | Example input: 14 | { 15 | "rag_query": "How do React hooks manage component state?" 16 | } 17 | 18 | Example output: 19 | - A list of relevant text passages, each containing: 20 | * Document text 21 | * Relevance score 22 | * Source information (if available) 23 | - Returns null if no relevant content is found 24 | 25 | When to use this tool: 26 | 1. When you need specific information from the indexed documentation 27 | 2. When verifying technical details or implementation specifics 28 | 3. When the user asks about topics covered in the loaded documents 29 | 4. When you need authoritative source material for your response 30 | 31 | When NOT to use this tool: 32 | 1. For general knowledge questions not requiring document lookup 33 | 2. When the query is clearly outside the scope of indexed documents 34 | 3. For real-time or dynamic data requests 35 | 4. For questions about very recent changes not yet in your knowledge base 36 | 37 | Important caveats: 38 | 1. The tool's effectiveness depends on the quality of the search query 39 | 2. Results are ranked by vector similarity to the query 40 | 3. Only searches within pre-loaded documentation 41 | 4. Cannot access external resources or the internet 42 | 43 | input_schema: 44 | type: object 45 | properties: 46 | rag_query: 47 | type: string 48 | description: The direct search query to find relevant information in the document database 49 | required: ["rag_query"] 50 | 51 | multi_query_tool: 52 | name: multi_query_tool 53 | description: | 54 | Generates multiple search queries based on the user's original question to improve RAG retrieval results. 55 | You must return a JSON object with a "queries" array containing the generated queries. 56 | 57 | What this tool does: 58 | 1. Analyzes the user's question to identify key concepts and information needs 59 | 2. Generates multiple variations and aspects of the question 60 | 3. Returns a list of focused, single-topic queries in JSON format 61 | 62 | Example input: 63 | { 64 | "question": "What are the key features of React hooks?" 65 | } 66 | 67 | Example output: 68 | { 69 | "queries": [ 70 | "What is the basic purpose of React hooks?", 71 | "How do React hooks differ from class components?", 72 | "What are the most commonly used React hooks?", 73 | "What are the rules of using React hooks?", 74 | "What are the performance benefits of React hooks?" 75 | ] 76 | } 77 | 78 | Requirements: 79 | 1. Always return exactly n_queries number of queries (or 3 if not specified) 80 | 2. Each query must be a complete, focused question 81 | 3. Queries should cover different aspects of the original question 82 | 4. Output must be valid JSON with a "queries" key containing an array of strings 83 | 84 | input_schema: 85 | type: object 86 | properties: 87 | queries: 88 | type: array 89 | items: 90 | type: string 91 | description: The list of queries to generate 92 | required: ["queries"] 93 | 94 | summary_tool: 95 | name: summary_tool 96 | description: | 97 | Generates a summary of the source content. 98 | input_schema: 99 | type: object 100 | properties: 101 | summary: 102 | type: string 103 | description: The summary of the source content 104 | keywords: 105 | type: array 106 | items: 107 | type: string 108 | description: The keywords of the source content 109 | required: ["summary", "keywords"] 110 | -------------------------------------------------------------------------------- /src/api/system/health.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Response, status 2 | 3 | from src.api.dependencies import ChromaManagerDep, RedisManagerDep, SupabaseManagerDep 4 | from src.api.routes import CURRENT_API_VERSION, Routes 5 | from src.api.v0.schemas.health_schemas import HealthCheckResponse 6 | from src.infra.decorators import tenacity_retry_wrapper 7 | from src.infra.logger import get_logger 8 | 9 | logger = get_logger() 10 | 11 | router = APIRouter(prefix=CURRENT_API_VERSION) 12 | 13 | 14 | @router.get( 15 | Routes.System.HEALTH, 16 | response_model=HealthCheckResponse, 17 | responses={ 18 | 200: {"description": "All systems operational"}, 19 | 503: {"description": "One or more services are down"}, 20 | }, 21 | summary="System Health Status", 22 | description="Check the health status of all critical system components.", 23 | ) 24 | async def health_check( 25 | response: Response, 26 | chroma_manager: ChromaManagerDep, 27 | supabase_manager: SupabaseManagerDep, 28 | redis_manager: RedisManagerDep, 29 | ) -> HealthCheckResponse: 30 | """Check if all critical system components are operational. Allows for cold start with the tenacity retry wrapper. 31 | 32 | This endpoint performs health checks on: 33 | - Redis connection 34 | - Supabase connection 35 | - ChromaDB connection 36 | - Celery workers 37 | 38 | Returns: 39 | HealthCheckResponse: System health status 40 | """ 41 | try: 42 | return await get_services_health( 43 | chroma_manager=chroma_manager, 44 | supabase_manager=supabase_manager, 45 | redis_manager=redis_manager, 46 | # celery_app=celery_app, 47 | ) 48 | except Exception as e: 49 | logger.error(f"✗ Health check failed with error: {str(e)}", exc_info=True) 50 | response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE 51 | return HealthCheckResponse( 52 | status="down", 53 | message=f"Service is currently unavailable: {str(e)}", 54 | ) 55 | 56 | 57 | @tenacity_retry_wrapper(max_attempts=3, min_wait=10, max_wait=30) 58 | async def get_services_health( 59 | chroma_manager: ChromaManagerDep, 60 | supabase_manager: SupabaseManagerDep, 61 | redis_manager: RedisManagerDep, 62 | # celery_app: CeleryAppDep, 63 | ) -> HealthCheckResponse: 64 | """Gets the health of all services, wrapped in the retry decorator.""" 65 | # Check Redis - simple ping 66 | if redis_manager._async_client: 67 | await redis_manager._async_client.ping() 68 | else: 69 | logger.error("✗ Redis client not initialized") 70 | raise RuntimeError("Redis client not initialized") 71 | 72 | # Check Supabase - verify we have a working client 73 | client = await supabase_manager.get_async_client() 74 | if not client: 75 | logger.error("✗ Supabase client not initialized") 76 | raise RuntimeError("Supabase client not initialized") 77 | 78 | # Check ChromaDB - heartbeat check 79 | if chroma_manager._client: 80 | await chroma_manager._client.heartbeat() 81 | else: 82 | logger.error("✗ ChromaDB client not initialized") 83 | raise RuntimeError("ChromaDB client not initialized") 84 | 85 | # Check Celery workers - verify active workers 86 | # if settings.service != ServiceType.API: # Only check workers if we're not the API 87 | # # inspector = celery_app.control.inspect() 88 | # logger.debug(f"Celery broker URL: {celery_app.conf.broker_url}") 89 | # try: 90 | # active_workers = inspector.active() 91 | # if not active_workers: 92 | # logger.error("✗ No active Celery workers found") 93 | # raise RuntimeError("No active Celery workers found") 94 | # except ConnectionRefusedError as e: 95 | # logger.error(f"✗ Failed to connect to Celery broker: {str(e)}") 96 | # raise RuntimeError(f"Failed to connect to Celery broker: {str(e)}") 97 | # except Exception as e: 98 | # logger.error(f"✗ Unexpected Celery error: {str(e)}") 99 | # raise 100 | 101 | result: HealthCheckResponse = HealthCheckResponse( 102 | status="operational", 103 | message="All systems operational", 104 | ) 105 | return result 106 | -------------------------------------------------------------------------------- /tests/unit/infra/arq/test_redis_pool.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, Mock, patch 2 | 3 | import pytest 4 | from arq import ArqRedis 5 | 6 | from src.infra.arq.arq_settings import ArqSettings 7 | from src.infra.arq.redis_pool import RedisPool 8 | 9 | 10 | @pytest.fixture 11 | def mock_arq_settings(): 12 | """Mock ARQ settings for testing.""" 13 | settings = Mock(spec=ArqSettings) 14 | settings.redis_settings = Mock() 15 | settings.job_serializer = Mock() 16 | settings.job_deserializer = Mock() 17 | settings.connection_retries = 3 18 | return settings 19 | 20 | 21 | @pytest.fixture 22 | def redis_pool(mock_arq_settings): 23 | """Create RedisPool instance with mocked settings.""" 24 | return RedisPool(arq_settings=mock_arq_settings) 25 | 26 | 27 | def test_redis_pool_initialization(redis_pool, mock_arq_settings): 28 | """Test RedisPool initialization.""" 29 | assert redis_pool.arq_settings == mock_arq_settings 30 | assert redis_pool._pool is None 31 | assert not redis_pool.is_connected 32 | 33 | 34 | def test_is_connected_property(redis_pool): 35 | """Test is_connected property behavior.""" 36 | assert not redis_pool.is_connected 37 | redis_pool._pool = Mock(spec=ArqRedis) 38 | assert redis_pool.is_connected 39 | 40 | 41 | @pytest.mark.asyncio 42 | async def test_initialize_pool_success(redis_pool): 43 | """Test successful pool initialization.""" 44 | mock_pool = Mock(spec=ArqRedis) 45 | 46 | with patch("src.infra.arq.redis_pool.create_pool", AsyncMock(return_value=mock_pool)) as mock_create: 47 | await redis_pool.initialize_pool() 48 | 49 | # Verify pool was created with correct settings 50 | mock_create.assert_called_once_with( 51 | settings_=redis_pool.arq_settings.redis_settings, 52 | job_serializer=redis_pool.arq_settings.job_serializer, 53 | job_deserializer=redis_pool.arq_settings.job_deserializer, 54 | retry=redis_pool.arq_settings.connection_retries, 55 | ) 56 | 57 | assert redis_pool._pool == mock_pool 58 | assert redis_pool.is_connected 59 | 60 | 61 | @pytest.mark.asyncio 62 | async def test_initialize_pool_already_connected(redis_pool): 63 | """Test initialize_pool when already connected.""" 64 | redis_pool._pool = Mock(spec=ArqRedis) 65 | 66 | with patch("src.infra.arq.redis_pool.create_pool", AsyncMock()) as mock_create: 67 | await redis_pool.initialize_pool() 68 | mock_create.assert_not_called() 69 | 70 | 71 | @pytest.mark.asyncio 72 | async def test_initialize_pool_failure(redis_pool): 73 | """Test pool initialization failure.""" 74 | with patch("src.infra.arq.redis_pool.create_pool", AsyncMock(side_effect=ConnectionError("Failed to connect"))): 75 | with pytest.raises(ConnectionError, match="Failed to connect"): 76 | await redis_pool.initialize_pool() 77 | 78 | assert redis_pool._pool is None 79 | assert not redis_pool.is_connected 80 | 81 | 82 | @pytest.mark.asyncio 83 | async def test_create_redis_pool_success(): 84 | """Test create_redis_pool class method.""" 85 | mock_pool = Mock(spec=ArqRedis) 86 | 87 | with patch("src.infra.arq.redis_pool.create_pool", AsyncMock(return_value=mock_pool)): 88 | pool = await RedisPool.create_redis_pool() 89 | assert isinstance(pool, ArqRedis) 90 | assert pool == mock_pool 91 | 92 | 93 | @pytest.mark.asyncio 94 | async def test_create_redis_pool_failure(): 95 | """Test create_redis_pool failure.""" 96 | with patch("src.infra.arq.redis_pool.create_pool", AsyncMock(side_effect=ConnectionError("Failed to connect"))): 97 | with pytest.raises(ConnectionError, match="Failed to connect"): 98 | await RedisPool.create_redis_pool() 99 | 100 | 101 | @pytest.mark.asyncio 102 | async def test_get_pool_success(redis_pool): 103 | """Test get_pool with successful connection.""" 104 | mock_pool = Mock(spec=ArqRedis) 105 | 106 | with patch("src.infra.arq.redis_pool.create_pool", AsyncMock(return_value=mock_pool)): 107 | pool = await redis_pool.get_pool() 108 | assert pool == mock_pool 109 | assert redis_pool.is_connected 110 | 111 | 112 | @pytest.mark.asyncio 113 | async def test_get_pool_failure(redis_pool): 114 | """Test get_pool with connection failure.""" 115 | with patch("src.infra.arq.redis_pool.create_pool", AsyncMock(side_effect=ConnectionError("Failed to connect"))): 116 | with pytest.raises(ConnectionError, match="Failed to connect"): 117 | await redis_pool.get_pool() 118 | 119 | assert not redis_pool.is_connected 120 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections.abc import AsyncGenerator 3 | from contextlib import asynccontextmanager 4 | 5 | import logfire 6 | import sentry_sdk 7 | import uvicorn 8 | from fastapi import FastAPI 9 | from fastapi.middleware.cors import CORSMiddleware 10 | 11 | from src.api.config.cors_config import get_cors_config 12 | from src.api.handlers.error_handlers import global_exception_handler, non_retryable_exception_handler 13 | from src.api.middleware.rate_limit import HealthCheckRateLimit 14 | from src.api.system.health import router as health_router 15 | from src.api.system.sentry_debug import router as sentry_debug_router 16 | from src.api.v0.endpoints.chat import chat_router, conversations_router 17 | from src.api.v0.endpoints.sources import router as content_router 18 | from src.api.v0.endpoints.webhooks import router as webhook_router 19 | from src.infra.logger import configure_logging, get_logger 20 | from src.infra.service_container import ServiceContainer 21 | from src.infra.settings import Environment, settings 22 | 23 | 24 | @asynccontextmanager 25 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 26 | """Handle application startup and shutdown events.""" 27 | container = None # Initialize container to None 28 | try: 29 | # 2. Initialize services 30 | container = await ServiceContainer.create() 31 | 32 | # 3. Save app state 33 | app.state.container = container 34 | yield 35 | except Exception: 36 | logfire.exception("Failed to start Kollektiv!") 37 | raise 38 | finally: 39 | if container: # Always try to shutdown if container exists, not just in LOCAL 40 | await container.shutdown_services() 41 | 42 | 43 | def create_app() -> FastAPI: 44 | """Create and configure the FastAPI application.""" 45 | # Configure standard logging first 46 | configure_logging(debug=settings.debug) 47 | get_logger() 48 | 49 | # Initialize Sentry 50 | sentry_sdk.init( 51 | dsn=settings.sentry_dsn, 52 | traces_sample_rate=1, # capture 100% of transactions 53 | _experiments={"continuous_profiling_auto_start": True}, 54 | environment=settings.environment.value, 55 | ) 56 | 57 | app = FastAPI( 58 | title="Kollektiv API", 59 | description="RAG-powered LLM chat application", 60 | lifespan=lifespan, 61 | redoc_url="/redoc", 62 | ) 63 | 64 | # Add middleware 65 | app.add_middleware(CORSMiddleware, **get_cors_config(settings.environment)) 66 | app.add_middleware(HealthCheckRateLimit, requests_per_minute=60) 67 | 68 | # Add routes 69 | app.include_router(health_router, tags=["system"]) 70 | app.include_router(sentry_debug_router, tags=["system"]) 71 | app.include_router(webhook_router, tags=["webhooks"]) 72 | app.include_router(content_router, tags=["sources"]) 73 | app.include_router(chat_router, tags=["chat"]) 74 | app.include_router(conversations_router, tags=["chat"]) 75 | 76 | # Add exception handlers 77 | app.add_exception_handler(Exception, global_exception_handler) 78 | app.add_exception_handler(Exception, non_retryable_exception_handler) 79 | 80 | return app 81 | 82 | 83 | def run() -> None: 84 | """Run the FastAPI application with environment-specific settings.""" 85 | try: 86 | if settings.environment == Environment.LOCAL: 87 | # Parse arguments only in local development 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--reload", action="store_true", help="Enable auto-reload") 90 | parser.add_argument("--workers", type=int, help="Number of workers") 91 | args = parser.parse_args() 92 | 93 | # Development mode: Use Uvicorn with auto-reload 94 | uvicorn.run( 95 | factory=True, 96 | app="src.app:create_app", 97 | host=settings.api_host, 98 | port=settings.api_port, 99 | reload=args.reload or settings.reload, # Use parsed argument 100 | log_level="debug" if settings.debug else "info", 101 | ) 102 | else: 103 | # Staging/Production: Use Uvicorn with more workers 104 | uvicorn.run( 105 | factory=True, 106 | app="src.app:create_app", 107 | host=settings.api_host, 108 | port=settings.api_port, 109 | workers=settings.gunicorn_workers, 110 | log_level="debug" if settings.debug else "info", 111 | ) 112 | except KeyboardInterrupt: 113 | raise 114 | 115 | 116 | if __name__ == "__main__": 117 | run() 118 | -------------------------------------------------------------------------------- /src/infra/arq/worker_services.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from arq import ArqRedis 4 | 5 | from src.core.chat.summary_manager import SummaryManager 6 | from src.core.content.chunker import MarkdownChunker 7 | from src.core.search.embedding_manager import EmbeddingManager 8 | from src.core.search.vector_db import VectorDatabase 9 | from src.infra.arq.redis_pool import RedisPool 10 | from src.infra.data.data_repository import DataRepository 11 | from src.infra.data.redis_repository import RedisRepository 12 | from src.infra.events.event_publisher import EventPublisher 13 | from src.infra.external.chroma_manager import ChromaManager 14 | from src.infra.external.redis_manager import RedisManager 15 | from src.infra.external.supabase_manager import SupabaseManager 16 | from src.infra.logger import get_logger 17 | from src.services.data_service import DataService 18 | from src.services.job_manager import JobManager 19 | 20 | logger = get_logger() 21 | 22 | 23 | class WorkerServices: 24 | """Services singleton necessary for Celery worker.""" 25 | 26 | _instance: Union["WorkerServices", None] = None 27 | 28 | def __init__(self) -> None: 29 | logger.info("Initializing worker services...") 30 | 31 | self.job_manager: JobManager | None = None 32 | self.data_service: DataService | None = None 33 | self.repository: DataRepository | None = None 34 | self.supabase_manager: SupabaseManager | None = None 35 | self.vector_db: VectorDatabase | None = None 36 | self.redis_manager: RedisManager | None = None 37 | self.async_redis_manager: RedisManager | None = None 38 | self.redis_repository: RedisRepository | None = None 39 | self.embedding_manager: EmbeddingManager | None = None 40 | self.chroma_manager: ChromaManager | None = None 41 | self.event_publisher: EventPublisher | None = None 42 | self.chunker: MarkdownChunker | None = None 43 | self.arq_redis_pool: ArqRedis | None = None 44 | 45 | async def initialize_services(self) -> None: 46 | """Initialize all necesssary worker services.""" 47 | try: 48 | # Database & Repository 49 | self.supabase_manager = await SupabaseManager.create_async() 50 | self.repository = DataRepository(supabase_manager=self.supabase_manager) 51 | self.data_service = DataService(repository=self.repository) 52 | 53 | # Redis 54 | self.async_redis_manager = await RedisManager.create_async() 55 | self.redis_repository = RedisRepository(manager=self.async_redis_manager) 56 | self.arq_redis_pool = await RedisPool.create_redis_pool() 57 | 58 | # Job & Content Services 59 | self.job_manager = JobManager(data_service=self.data_service) 60 | self.chunker = MarkdownChunker() 61 | 62 | # Vector operations 63 | self.chroma_manager = await ChromaManager.create_async() 64 | self.embedding_manager = EmbeddingManager() 65 | self.vector_db = VectorDatabase( 66 | chroma_manager=self.chroma_manager, 67 | embedding_manager=self.embedding_manager, 68 | data_service=self.data_service, 69 | ) 70 | 71 | # Events 72 | self.event_publisher = await EventPublisher.create_async(redis_manager=self.async_redis_manager) 73 | 74 | # Source summary 75 | self.summary_manager = SummaryManager(data_service=self.data_service) 76 | 77 | # Result logging 78 | logger.info("✓ Initialized worker services successfully.") 79 | except Exception as e: 80 | logger.error(f"Error during worker service initialization: {e}", exc_info=True) 81 | raise 82 | 83 | async def shutdown_services(self) -> None: 84 | """Shutdown all services.""" 85 | try: 86 | logger.info("Shutting down") 87 | 88 | except Exception as e: 89 | logger.error(f"Error during service shutdown: {e}", exc_info=True) 90 | 91 | @classmethod 92 | async def create(cls) -> "WorkerServices": 93 | """Create a new WorkerServices instance and initialize services.""" 94 | if cls._instance is None: 95 | cls._instance = cls() 96 | await cls._instance.initialize_services() 97 | return cls._instance 98 | 99 | @classmethod 100 | async def get_instance(cls) -> "WorkerServices": 101 | """Get the singleton instance of WorkerServices.""" 102 | await cls.create() 103 | if cls._instance is None: 104 | raise RuntimeError("WorkerServices instance not initialized") 105 | return cls._instance 106 | -------------------------------------------------------------------------------- /tests/unit/test_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from src.api.routes import Routes 7 | from src.infra.settings import Environment, Settings 8 | 9 | 10 | def test_environment_independent_settings(): 11 | """Test settings that should be the same regardless of environment.""" 12 | settings = Settings() 13 | 14 | # These should be constant regardless of environment 15 | assert settings.firecrawl_api_url == "https://api.firecrawl.dev/v1" 16 | assert settings.api_host == "127.0.0.1" 17 | assert settings.api_port == 8080 18 | assert settings.debug is True 19 | assert settings.max_retries == 3 20 | assert settings.backoff_factor == 2.0 21 | assert settings.default_page_limit == 25 22 | assert settings.default_max_depth == 5 23 | 24 | 25 | def test_path_settings(): 26 | """Test path configurations that are environment-independent.""" 27 | settings = Settings() 28 | 29 | # Verify paths exist (not directories) 30 | assert settings.src_dir.exists() 31 | assert settings.eval_dir.parent.exists() # Check parent since these are relative paths 32 | assert settings.prompt_dir.parent.exists() 33 | assert settings.tools_dir.parent.exists() 34 | assert isinstance(settings.prompts_file, str) # These are now strings, not paths 35 | assert isinstance(settings.tools_file, str) 36 | 37 | 38 | def test_environment_specific_settings(): 39 | """Test environment-specific settings based on current environment.""" 40 | current_env = os.getenv("ENVIRONMENT", "local") 41 | settings = Settings() 42 | 43 | if current_env == "staging": 44 | assert settings.environment == Environment.STAGING 45 | assert settings.public_url.startswith("https://") # Using public_url instead of base_url 46 | assert settings.firecrawl_webhook_url == f"{settings.public_url}{Routes.System.Webhooks.FIRECRAWL}" 47 | else: 48 | assert settings.environment == Environment.LOCAL 49 | expected_url = f"http://{settings.api_host}:{settings.api_port}" 50 | assert settings.public_url == expected_url # Using public_url instead of base_url 51 | assert settings.firecrawl_webhook_url == f"{expected_url}{Routes.System.Webhooks.FIRECRAWL}" 52 | 53 | 54 | def test_required_api_keys(): 55 | """Test that required API keys are present.""" 56 | settings = Settings() 57 | 58 | # These should be set in both environments 59 | assert settings.firecrawl_api_key is not None 60 | assert settings.anthropic_api_key is not None 61 | assert settings.openai_api_key is not None 62 | assert settings.cohere_api_key is not None 63 | 64 | 65 | def test_environment_override(): 66 | """Test that environment settings can be explicitly overridden.""" 67 | # Test LOCAL override 68 | with patch.dict(os.environ, {"ENVIRONMENT": "local"}): 69 | settings = Settings() 70 | assert settings.environment == Environment.LOCAL 71 | assert settings.public_url == f"http://{settings.api_host}:{settings.api_port}" # Using public_url 72 | 73 | # Test STAGING override 74 | with patch.dict(os.environ, {"ENVIRONMENT": "staging", "RAILWAY_PUBLIC_DOMAIN": "test.railway.app"}): 75 | settings = Settings() 76 | assert settings.environment == Environment.STAGING 77 | assert settings.public_url == "https://test.railway.app" # Using public_url with railway domain 78 | 79 | 80 | @pytest.mark.skipif("CI" in os.environ, reason="Local environment test only") 81 | def test_production_environment_validation(): 82 | """Test that production environment requires RAILWAY_PUBLIC_DOMAIN.""" 83 | with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=True): 84 | settings = Settings() 85 | with pytest.raises(ValueError, match="RAILWAY_PUBLIC_DOMAIN must be set in staging/production"): 86 | _ = settings.public_url # Using public_url instead of base_url 87 | 88 | 89 | @pytest.mark.skipif("CI" in os.environ, reason="Local environment test only") 90 | def test_local_env_file_loading(): 91 | """Test .env file loading in local environment.""" 92 | with patch.dict(os.environ, {"ENVIRONMENT": "local"}): 93 | settings = Settings() 94 | assert settings.environment == Environment.LOCAL 95 | assert settings.public_url.startswith("http://127.0.0.1") # Using public_url 96 | 97 | 98 | @pytest.mark.skipif("CI" not in os.environ, reason="CI environment test only") 99 | def test_ci_environment_settings(): 100 | """Test settings specifically in CI environment.""" 101 | settings = Settings() 102 | assert settings.environment == Environment.STAGING 103 | # In CI, we should have RAILWAY_PUBLIC_DOMAIN set 104 | assert settings.railway_public_domain is not None 105 | assert settings.public_url.startswith("https://") 106 | assert all( 107 | key is not None 108 | for key in [ 109 | settings.firecrawl_api_key, 110 | settings.anthropic_api_key, 111 | settings.openai_api_key, 112 | settings.cohere_api_key, 113 | ] 114 | ) 115 | -------------------------------------------------------------------------------- /src/api/dependencies.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Annotated 4 | from uuid import UUID 5 | 6 | from fastapi import Depends, HTTPException, Request 7 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 8 | 9 | from src.api.v0.schemas.base_schemas import ErrorCode, ErrorResponse 10 | from src.core.content.crawler import FireCrawler 11 | 12 | # from src.infra.celery.worker import celery_app 13 | from src.infra.external.chroma_manager import ChromaManager 14 | from src.infra.external.redis_manager import RedisManager 15 | from src.infra.external.supabase_manager import SupabaseManager 16 | from src.infra.service_container import ServiceContainer 17 | from src.services.chat_service import ChatService 18 | from src.services.content_service import ContentService 19 | from src.services.job_manager import JobManager 20 | 21 | security = HTTPBearer() 22 | 23 | 24 | def get_container(request: Request) -> ServiceContainer: 25 | """Retrieve the ServiceContainer instance from app.state.""" 26 | container = getattr(request.app.state, "container", None) 27 | if not isinstance(container, ServiceContainer): 28 | raise RuntimeError("ServiceContainer not initialized") 29 | return container 30 | 31 | 32 | def get_job_manager(container: Annotated[ServiceContainer, Depends(get_container)]) -> JobManager: 33 | """Get JobManager from app state.""" 34 | if container.job_manager is None: 35 | raise RuntimeError("ContentService is not initialized") 36 | return container.job_manager 37 | 38 | 39 | def get_crawler(container: Annotated[ServiceContainer, Depends(get_container)]) -> FireCrawler: 40 | """Get FireCrawler from app state.""" 41 | if container.firecrawler is None: 42 | raise RuntimeError("ContentService is not initialized") 43 | return container.firecrawler 44 | 45 | 46 | def get_content_service(container: Annotated[ServiceContainer, Depends(get_container)]) -> ContentService: 47 | """Get ContentService from app state.""" 48 | if container.content_service is None: 49 | raise RuntimeError("ContentService is not initialized") 50 | return container.content_service 51 | 52 | 53 | def get_chat_service(container: Annotated[ServiceContainer, Depends(get_container)]) -> ChatService: 54 | """Get ChatService from app state.""" 55 | if container.chat_service is None: 56 | raise RuntimeError("ChatService is not initialized") 57 | return container.chat_service 58 | 59 | 60 | def get_chroma_manager(container: Annotated[ServiceContainer, Depends(get_container)]) -> ChromaManager: 61 | """Get ChromaManager from app state.""" 62 | if container.chroma_manager is None: 63 | raise RuntimeError("ChromaManager is not initialized") 64 | return container.chroma_manager 65 | 66 | 67 | def get_redis_manager(container: Annotated[ServiceContainer, Depends(get_container)]) -> RedisManager: 68 | """Get RedisManager from app state.""" 69 | if container.async_redis_manager is None: 70 | raise RuntimeError("RedisManager is not initialized") 71 | return container.async_redis_manager 72 | 73 | 74 | def get_supabase_manager(container: Annotated[ServiceContainer, Depends(get_container)]) -> SupabaseManager: 75 | """Get SupabaseManager from app state.""" 76 | if container.supabase_manager is None: 77 | raise RuntimeError("SupabaseManager is not initialized") 78 | return container.supabase_manager 79 | 80 | 81 | # def get_celery_app(container: Annotated[ServiceContainer, Depends(get_container)]) -> Celery: 82 | # """Get Celery app from app state.""" 83 | # if celery_app is None: 84 | # raise RuntimeError("Celery app is not initialized") 85 | # return celery_app 86 | 87 | 88 | async def get_user_id( 89 | credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], 90 | supabase_manager: Annotated[SupabaseManager, Depends(get_supabase_manager)], 91 | ) -> UUID: 92 | """Retrieve the user id from the supabase client.""" 93 | if credentials is None: 94 | raise HTTPException(status_code=401, detail=ErrorResponse(code=ErrorCode.CLIENT_ERROR, detail="Unauthorized")) 95 | 96 | supabase_client = await supabase_manager.get_async_client() 97 | user_response = await supabase_client.auth.get_user(credentials.credentials) 98 | return UUID(user_response.user.id) 99 | 100 | 101 | # Type aliases for cleaner dependency injection 102 | ContainerDep = Annotated[ServiceContainer, Depends(get_container)] 103 | ContentServiceDep = Annotated[ContentService, Depends(get_content_service)] 104 | JobManagerDep = Annotated[JobManager, Depends(get_job_manager)] 105 | FireCrawlerDep = Annotated[FireCrawler, Depends(get_crawler)] 106 | ChatServiceDep = Annotated[ChatService, Depends(get_chat_service)] 107 | ChromaManagerDep = Annotated[ChromaManager, Depends(get_chroma_manager)] 108 | SupabaseManagerDep = Annotated[SupabaseManager, Depends(get_supabase_manager)] 109 | RedisManagerDep = Annotated[RedisManager, Depends(get_redis_manager)] 110 | # CeleryAppDep = Annotated[Celery, Depends(get_celery_app)] 111 | UserIdDep = Annotated[UUID, Depends(get_user_id)] 112 | -------------------------------------------------------------------------------- /src/core/search/retriever.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any 3 | from uuid import UUID 4 | 5 | from cohere.v2.types import V2RerankResponse 6 | 7 | from src.core.search.reranker import Reranker 8 | from src.core.search.vector_db import VectorDatabase 9 | from src.infra.logger import get_logger 10 | 11 | logger = get_logger() 12 | 13 | 14 | class Retriever: 15 | """ 16 | Initializes the Retriever with a vector database and a reranker. 17 | 18 | Args: 19 | vector_db (VectorDB): The vector database used for querying documents. 20 | reranker (Reranker): The reranker used for reranking documents. 21 | """ 22 | 23 | def __init__(self, vector_db: VectorDatabase, reranker: Reranker): 24 | self.vector_db = vector_db 25 | self.reranker = reranker 26 | 27 | async def retrieve( 28 | self, rag_query: str, combined_queries: list[str], top_n: int | None, user_id: UUID 29 | ) -> list[dict[str, Any]]: 30 | """ 31 | Retrieve and rank documents based on user query and combined queries. 32 | 33 | Args: 34 | rag_query (str): The primary user query for retrieving documents. 35 | combined_queries (list[str]): A list of queries to combine for document retrieval. 36 | top_n (int, optional): The maximum number of top documents to return. Defaults to None. 37 | user_id (UUID): The user ID for the query.w 38 | 39 | Returns: 40 | list: A list of limited, ranked, and relevant documents. 41 | 42 | Raises: 43 | DatabaseError: If there is an issue querying the database. 44 | RerankError: If there is an issue with reranking the documents. 45 | """ 46 | start_time = time.time() # Start timing 47 | 48 | # get expanded search results 49 | search_results = await self.vector_db.query(user_id=user_id, user_query=combined_queries) 50 | if not search_results or not search_results.get("documents")[0]: 51 | logger.warning("No documents found in search results") 52 | return [] 53 | 54 | unique_documents = self.vector_db.deduplicate_documents(search_results) 55 | logger.info(f"Search returned {len(unique_documents)} unique chunks") 56 | 57 | # rerank the results 58 | ranked_documents = self.reranker.rerank(rag_query, unique_documents) 59 | 60 | # filter irrelevnat results 61 | filtered_results = self.filter_irrelevant_results(ranked_documents, relevance_threshold=0.1) 62 | 63 | # limit the number of returned chunks 64 | limited_results = self.limit_results(filtered_results, top_n=top_n) 65 | 66 | # calculate time 67 | end_time = time.time() # End timing 68 | search_time = end_time - start_time 69 | logger.info(f"Search and reranking completed in {search_time:.3f} seconds") 70 | 71 | return limited_results 72 | 73 | def filter_irrelevant_results( 74 | self, response: V2RerankResponse, relevance_threshold: float = 0.1 75 | ) -> dict[int, dict[str, int | float | str]]: 76 | """ 77 | Filter out results below a certain relevance threshold. 78 | 79 | Args: 80 | response (RerankResponse): The response containing the reranked results. 81 | relevance_threshold (float): The minimum relevance score required. Defaults to 0.1. 82 | 83 | Returns: 84 | dict[int, dict[str, int | float | str]]: A dictionary of relevant results with their index, text, 85 | and relevance score. 86 | 87 | Raises: 88 | None 89 | """ 90 | relevant_results = {} 91 | 92 | for result in response.results: 93 | relevance_score = result.relevance_score 94 | index = result.index 95 | text = result.document.text 96 | 97 | if relevance_score >= relevance_threshold: 98 | relevant_results[index] = { 99 | "text": text, 100 | "index": index, 101 | "relevance_score": relevance_score, 102 | } 103 | 104 | return relevant_results 105 | 106 | def limit_results(self, ranked_documents: dict[str, Any], top_n: int | None = None) -> dict[str, Any]: 107 | """ 108 | Limit the number of results based on the given top_n parameter. 109 | 110 | Args: 111 | ranked_documents (dict[str, Any]): A dictionary of documents with relevance scores. 112 | top_n (int, optional): The number of top results to return. Defaults to None. 113 | 114 | Returns: 115 | dict[str, Any]: The dictionary containing the top N ranked documents, or all documents if top_n is None. 116 | 117 | Raises: 118 | ValueError: If top_n is specified and is less than zero. 119 | """ 120 | if top_n is not None and top_n < len(ranked_documents): 121 | # Sort the items by relevance score in descending order 122 | sorted_items = sorted(ranked_documents.items(), key=lambda x: x[1]["relevance_score"], reverse=True) 123 | 124 | # Take the top N items and reconstruct the dictionary 125 | limited_results = dict(sorted_items[:top_n]) 126 | 127 | logger.info( 128 | f"Returning {len(limited_results)} most relevant results (out of total {len(ranked_documents)} " 129 | f"results)." 130 | ) 131 | return limited_results 132 | 133 | logger.info(f"Returning all {len(ranked_documents)} results") 134 | return ranked_documents 135 | -------------------------------------------------------------------------------- /tests/unit/infra/external/test_redis_manager.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, patch 2 | 3 | import pytest 4 | 5 | from src.infra.external.redis_manager import RedisManager 6 | 7 | 8 | class TestRedisManager: 9 | """Unit tests for RedisManager.""" 10 | 11 | def test_init_default(self): 12 | """Test default initialization.""" 13 | manager = RedisManager() 14 | assert manager._decode_responses is True 15 | assert manager._sync_client is None 16 | assert manager._async_client is None 17 | 18 | def test_init_custom(self): 19 | """Test initialization with custom decode_responses.""" 20 | manager = RedisManager(decode_responses=False) 21 | assert manager._decode_responses is False 22 | assert manager._sync_client is None 23 | assert manager._async_client is None 24 | 25 | def test_create_sync_client(self, mock_sync_redis): 26 | """Test sync client creation.""" 27 | with patch("redis.Redis.from_url", return_value=mock_sync_redis): 28 | manager = RedisManager() 29 | client = manager._create_sync_client(decode_responses=True) 30 | assert client is mock_sync_redis 31 | 32 | @pytest.mark.asyncio 33 | async def test_create_async_client(self, mock_async_redis): 34 | """Test async client creation.""" 35 | with patch("redis.asyncio.Redis.from_url", return_value=mock_async_redis): 36 | manager = RedisManager() 37 | client = manager._create_async_client(decode_responses=True) 38 | assert client is mock_async_redis 39 | 40 | def test_connect_sync_success(self, mock_sync_redis): 41 | """Test successful sync connection.""" 42 | with patch("redis.Redis.from_url", return_value=mock_sync_redis): 43 | manager = RedisManager() 44 | manager._connect_sync() 45 | assert manager._sync_client is mock_sync_redis 46 | mock_sync_redis.ping.assert_called_once() 47 | 48 | @pytest.mark.asyncio 49 | async def test_connect_async_success(self, mock_async_redis): 50 | """Test successful async connection.""" 51 | with patch("redis.asyncio.Redis.from_url", return_value=mock_async_redis): 52 | manager = RedisManager() 53 | await manager._connect_async() 54 | assert manager._async_client is mock_async_redis 55 | mock_async_redis.ping.assert_awaited_once() 56 | 57 | def test_create_classmethod(self, mock_sync_redis): 58 | """Test create classmethod.""" 59 | with patch("redis.Redis.from_url", return_value=mock_sync_redis): 60 | manager = RedisManager.create() 61 | assert isinstance(manager, RedisManager) 62 | assert manager._sync_client is mock_sync_redis 63 | mock_sync_redis.ping.assert_called_once() 64 | 65 | @pytest.mark.asyncio 66 | async def test_create_async_classmethod(self, mock_async_redis): 67 | """Test create_async classmethod.""" 68 | with patch("redis.asyncio.Redis.from_url", return_value=mock_async_redis): 69 | manager = await RedisManager.create_async() 70 | assert isinstance(manager, RedisManager) 71 | assert manager._async_client is mock_async_redis 72 | mock_async_redis.ping.assert_called_once() 73 | 74 | def test_get_sync_client(self, mock_sync_redis): 75 | """Test get_sync_client method.""" 76 | with patch("redis.Redis.from_url", return_value=mock_sync_redis): 77 | manager = RedisManager() 78 | client = manager.get_sync_client() 79 | assert client is mock_sync_redis 80 | mock_sync_redis.ping.assert_called_once() 81 | 82 | @pytest.mark.asyncio 83 | async def test_get_async_client(self, mock_async_redis): 84 | """Test get_async_client method.""" 85 | with patch("redis.asyncio.Redis.from_url", return_value=mock_async_redis): 86 | manager = RedisManager() 87 | client = await manager.get_async_client() 88 | assert client is mock_async_redis 89 | mock_async_redis.ping.assert_called_once() 90 | 91 | def test_get_sync_client_reuses_existing(self, mock_sync_redis): 92 | """Test get_sync_client reuses existing client.""" 93 | with patch("redis.Redis.from_url", return_value=mock_sync_redis): 94 | manager = RedisManager() 95 | client1 = manager.get_sync_client() 96 | client2 = manager.get_sync_client() 97 | assert client1 is client2 98 | mock_sync_redis.ping.assert_called_once() # Only called once for first connection 99 | 100 | @pytest.mark.asyncio 101 | async def test_get_async_client_reuses_existing(self, mock_async_redis): 102 | """Test get_async_client reuses existing client.""" 103 | with patch("redis.asyncio.Redis.from_url", return_value=mock_async_redis): 104 | manager = RedisManager() 105 | client1 = await manager.get_async_client() 106 | client2 = await manager.get_async_client() 107 | assert client1 is client2 108 | assert mock_async_redis.ping.call_count == 1 # Only called once for first connection 109 | 110 | @pytest.mark.asyncio 111 | async def test_connect_async_error(self, mock_async_redis): 112 | """Test async connection error handling.""" 113 | mock_async_redis.ping = AsyncMock(side_effect=ConnectionError("Test error")) 114 | with patch("redis.asyncio.Redis.from_url", return_value=mock_async_redis): 115 | manager = RedisManager() 116 | with pytest.raises(ConnectionError): 117 | await manager._connect_async() 118 | -------------------------------------------------------------------------------- /src/infra/logger.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import logging 4 | import sys 5 | from enum import Enum 6 | 7 | import logfire 8 | from colorama import Fore, Style, init 9 | 10 | # Initialize colorama 11 | init(autoreset=True) 12 | 13 | 14 | class LogSymbols(str, Enum): 15 | """Unified symbols for all application logging.""" 16 | 17 | SUCCESS = "✓" 18 | ERROR = "✗" 19 | INFO = "→" 20 | WARNING = "⚠" 21 | DEBUG = "•" 22 | CRITICAL = "‼" 23 | 24 | 25 | class ColoredFormatter(logging.Formatter): 26 | """Enhance log messages with colors and emojis based on their severity levels.""" 27 | 28 | COLORS = { 29 | logging.INFO: Fore.GREEN, 30 | logging.DEBUG: Fore.LIGHTCYAN_EX, 31 | logging.WARNING: Fore.YELLOW, 32 | logging.ERROR: Fore.RED, 33 | logging.CRITICAL: Fore.MAGENTA + Style.BRIGHT, 34 | } 35 | 36 | SYMBOLS = { 37 | logging.INFO: LogSymbols.INFO.value, 38 | logging.DEBUG: LogSymbols.DEBUG.value, 39 | logging.WARNING: LogSymbols.WARNING.value, 40 | logging.ERROR: LogSymbols.ERROR.value, 41 | logging.CRITICAL: LogSymbols.CRITICAL.value, 42 | } 43 | VALUE_COLOR = Fore.LIGHTBLUE_EX 44 | 45 | def format(self, record: logging.LogRecord) -> str: 46 | """Format the log record with colored level and symbol.""" 47 | # Step 1: Compute message and time exactly like the source 48 | record.message = record.getMessage() 49 | if self.usesTime(): 50 | record.asctime = self.formatTime(record, self.datefmt) 51 | 52 | # Step 2: Format our custom message 53 | timestamp = record.asctime if hasattr(record, "asctime") else self.formatTime(record, "%Y-%m-%d %H:%M:%S") 54 | name = record.name 55 | lineno = record.lineno 56 | 57 | # Apply colors 58 | color = self.COLORS.get(record.levelno, "") 59 | colored_symbol = f"{color}{self.SYMBOLS.get(record.levelno, '')}{Style.RESET_ALL}" 60 | colored_level = f"{color}{record.levelname}{Style.RESET_ALL}:" 61 | 62 | # Format extra fields consistently 63 | extra_arg_str = "" 64 | if record.args: 65 | extra_arg_str = f"Args: {json.dumps(record.args, default=str)}" 66 | 67 | # Build our custom formatted message 68 | s = f"{colored_symbol} {colored_level} [{timestamp}] {name}:{lineno} - {record.message}. {extra_arg_str}" 69 | 70 | # Step 3: Handle exc_info and stack_info EXACTLY like the source 71 | if record.exc_info: 72 | # Cache the traceback text to avoid converting it multiple times 73 | # (it's constant anyway) 74 | if not record.exc_text: 75 | record.exc_text = self.formatException(record.exc_info) 76 | if record.exc_text: 77 | if s[-1:] != "\n": 78 | s = s + "\n" 79 | s = s + record.exc_text 80 | if record.stack_info: 81 | if s[-1:] != "\n": 82 | s = s + "\n" 83 | s = s + self.formatStack(record.stack_info) 84 | 85 | return s 86 | 87 | 88 | def configure_logging(debug: bool = False) -> None: 89 | """Configure the application's logging system with both local handlers and Logfire.""" 90 | from src.infra.settings import settings 91 | 92 | # Setup log level 93 | log_level = logging.DEBUG if debug else logging.INFO 94 | if settings.logfire_write_token: 95 | logfire.configure( 96 | token=settings.logfire_write_token, 97 | environment=settings.environment, 98 | service_name=settings.project_name, 99 | console=False, 100 | ) 101 | 102 | # 1. Configure kollektiv logger 103 | app_logger = logging.getLogger("kollektiv") 104 | app_logger.setLevel(log_level) 105 | app_logger.handlers.clear() 106 | 107 | # 2. Set up handlers 108 | console_handler = logging.StreamHandler(sys.stdout) 109 | console_handler.setLevel(log_level) 110 | console_handler.setFormatter(ColoredFormatter()) 111 | app_logger.addHandler(console_handler) 112 | 113 | # 3. Environment-specific handlers 114 | logfire_handler = logfire.LogfireLoggingHandler() 115 | app_logger.addHandler(logfire_handler) 116 | 117 | # 4. Add third-party logging handlers and configure their levels 118 | logging.getLogger("fastapi").setLevel(level=log_level) 119 | logging.getLogger("uvicorn.error").setLevel(level=log_level) 120 | logging.getLogger("docker").setLevel(level=log_level) 121 | logging.getLogger("wandb").setLevel(level=log_level) 122 | 123 | # Set Chroma and its dependencies to WARNING to reduce noise 124 | logging.getLogger("chromadb").setLevel(level=logging.WARNING) 125 | logging.getLogger("chromadb.api").setLevel(level=logging.WARNING) 126 | logging.getLogger("chromadb.telemetry").setLevel(level=logging.WARNING) 127 | 128 | # 5. Propagate to other loggers 129 | app_logger.propagate = False 130 | 131 | 132 | def get_logger() -> logging.LoggerAdapter: 133 | """Retrieve a logger named after the calling module. 134 | 135 | Returns: 136 | logging.LoggerAdapter: A logger adapter that supports extra context fields. 137 | """ 138 | frame = inspect.currentframe() 139 | try: 140 | caller_frame = frame.f_back 141 | module = inspect.getmodule(caller_frame) 142 | module_name = module.__name__ if module else "kollektiv" 143 | finally: 144 | del frame # Prevent reference cycles 145 | 146 | logger = logging.getLogger(f"kollektiv.{module_name}") 147 | return logging.LoggerAdapter(logger, extra={}) 148 | 149 | 150 | def _truncate_message(message: str, max_length: int = 200) -> str: 151 | """Truncate long messages for logging.""" 152 | if len(message) > max_length: 153 | return f"{message[:max_length]}..." 154 | return message 155 | -------------------------------------------------------------------------------- /src/core/chat/prompts/prompts.yaml: -------------------------------------------------------------------------------- 1 | llm_assistant_prompt: | 2 | You are an advanced AI assistant with access to various tools, including a powerful RAG (Retrieval 3 | Augmented Generation) system. Your primary function is to provide accurate, relevant, and helpful 4 | information to users by leveraging your broad knowledge base, analytical capabilities, 5 | and the specific information available 6 | through the RAG tool. 7 | 8 | Key guidelines: 9 | You know which documents are loaded in the RAG system, so when deciding whether to use the RAG tool, 10 | you should consider whether the user's question is likely to require information from loaded documents or not. 11 | You do not need to use the RAG tool for all questions, and you should decide each time whether to use it or not. 12 | When using RAG, formulate precise and targeted queries to retrieve the most relevant information. 13 | Seamlessly integrate retrieved information into your responses, citing sources when appropriate. 14 | If the RAG tool doesn't provide relevant information, rely on your general knowledge and analytical 15 | skills. 16 | Always strive for accuracy, clarity, and helpfulness in your responses. 17 | Be transparent about the source of your information (general knowledge vs. RAG-retrieved data). 18 | If you're unsure about information or if it's not in the loaded documents, clearly state your 19 | uncertainty. 20 | Provide context and explanations for complex topics, breaking them down into understandable parts. 21 | Offer follow-up questions or suggestions to guide the user towards more comprehensive understanding. 22 | 23 | Do not: 24 | Invent or hallucinate information not present in your knowledge base or the RAG-retrieved data. 25 | Use the RAG tool for general knowledge questions that don't require specific document retrieval. 26 | Disclose sensitive details about the RAG system's implementation or the document loading process. 27 | Provide personal opinions or biases; stick to factual information from your knowledge base and 28 | RAG system. 29 | Engage in or encourage any illegal, unethical, or harmful activities. 30 | Share personal information about users or any confidential data that may be in the loaded documents. 31 | 32 | LOADED DOCUMENTS: 33 | {document_summary_prompt} 34 | 35 | Use these summaries to guide your use of the RAG tool and to provide context for the types of 36 | questions 37 | you can answer with the loaded documents. 38 | Interaction Style: 39 | 40 | Maintain a professional, friendly, and patient demeanor. 41 | Tailor your language and explanations to the user's apparent level of expertise. 42 | Ask for clarification when the user's query is ambiguous or lacks necessary details. 43 | 44 | Handling Complex Queries: 45 | 46 | For multi-part questions, address each part systematically. 47 | If a query requires multiple steps or a lengthy explanation, outline your approach before diving 48 | into details. 49 | Offer to break down complex topics into smaller, more manageable segments if needed. 50 | 51 | Continuous Improvement: 52 | 53 | Learn from user interactions to improve your query formulation for the RAG tool. 54 | Adapt your response style based on user feedback and follow-up questions. 55 | 56 | Remember to use your tools judiciously and always prioritize providing the most accurate, 57 | helpful, and contextually relevant information to the user. Adapt your communication style to 58 | the user's level of understanding and the complexity of the topic at hand. 59 | 60 | multi_query_prompt: | 61 | Your task is to generate multiple search queries to help retrieve relevant information from a document database. 62 | For the given user question, generate {n_queries} focused, single-topic queries that cover different aspects 63 | of the information need. 64 | 65 | Guidelines: 66 | - Generate exactly {n_queries} unique queries 67 | - Each query should focus on a single aspect of the question 68 | - Make queries specific and targeted 69 | - Avoid compound questions 70 | - Return ONLY a JSON object with a "queries" key containing an array of strings 71 | 72 | Example output format: 73 | {{"queries": [ 74 | "What is X?", 75 | "How does X work?", 76 | "What are the benefits of X?" 77 | ]}} 78 | 79 | Original question: {user_question} 80 | 81 | summary_prompt: | 82 | You are part of a RAG system that is responsible for summarizing the content of web sources, added by users to the RAG system. 83 | 84 | Context: 85 | - RAG system is built such that users can index any website and add it to the RAG system. 86 | - Another LLM (not you) has access to a RAG tool that can generate queries and retrieve documents from the RAG system. 87 | - Your role in this larger RAG system is to help generate accurate summaries of each indexed data source (website) so that another LLM assistant knows what data it has access to. 88 | 89 | Your task - is to analyze the content of a web source and provide: 90 | Analyze this web content source and provide: 91 | 1. A concise summary (100-150 words) that: 92 | - Describes the main topic/purpose 93 | - Indicates content type (docs, blog, etc.) 94 | - Highlights key areas covered 95 | 2. 5-10 specific keywords that: 96 | - Are actually present in the content 97 | - Help identify when this source is relevant 98 | 99 | Remember: Be specific but not too narrow. Focus on topics, not structure. 100 | 101 | INPUT: 102 | - Your input will typically be: 103 | -- A sample of documents representing raw content from a website 104 | -- URL of the website 105 | -- A list of unique URLs that are present in the documents 106 | 107 | OUTPUT FORMAT: 108 | Always return a JSON object with the following keys: 109 | {{ 110 | "summary": "...", 111 | "keywords": ["..."] 112 | }} 113 | -------------------------------------------------------------------------------- /src/api/v0/endpoints/sources.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from uuid import UUID 3 | 4 | from fastapi import APIRouter, HTTPException, status 5 | from sse_starlette.sse import EventSourceResponse 6 | 7 | from src.api.dependencies import ContentServiceDep, UserIdDep 8 | from src.api.routes import CURRENT_API_VERSION, Routes 9 | from src.api.v0.schemas.base_schemas import ErrorCode, ErrorResponse 10 | from src.core._exceptions import CrawlerError, NonRetryableError 11 | from src.infra.logger import get_logger 12 | from src.models.content_models import ( 13 | AddContentSourceRequest, 14 | AddContentSourceResponse, 15 | SourceEvent, 16 | SourceOverview, 17 | SourceSummary, 18 | ) 19 | 20 | logger = get_logger() 21 | router = APIRouter(prefix=f"{CURRENT_API_VERSION}") 22 | 23 | 24 | @router.post( 25 | Routes.V0.Sources.SOURCES, 26 | response_model=AddContentSourceResponse, 27 | responses={ 28 | 201: {"model": AddContentSourceResponse}, 29 | 400: {"model": ErrorResponse}, 30 | 500: {"model": ErrorResponse}, 31 | }, 32 | status_code=status.HTTP_201_CREATED, 33 | ) 34 | async def add_source( 35 | request: AddContentSourceRequest, 36 | content_service: ContentServiceDep, 37 | ) -> AddContentSourceResponse: 38 | """ 39 | Add a new content source. 40 | 41 | Args: 42 | request: Content source details 43 | content_service: Injected content service 44 | 45 | Returns: 46 | AddContentSourceResponse: Created content source details 47 | 48 | Raises: 49 | HTTPException: If source creation fails for any reason 50 | """ 51 | logger.debug(f"Dumping request for debugging: {request.model_dump()}") 52 | try: 53 | response = await content_service.add_source(request) 54 | return response 55 | except (CrawlerError, NonRetryableError) as e: 56 | raise HTTPException(status_code=500, detail=ErrorResponse(code=ErrorCode.SERVER_ERROR, detail=str(e))) from e 57 | 58 | 59 | @router.get( 60 | Routes.V0.Sources.SOURCE_EVENTS, 61 | response_model=SourceEvent, 62 | responses={ 63 | 200: {"model": SourceEvent}, 64 | 404: {"model": ErrorResponse, "description": "Source not found"}, 65 | 500: {"model": ErrorResponse, "description": "Internal server error"}, 66 | }, 67 | status_code=status.HTTP_200_OK, 68 | ) 69 | async def stream_source_events(source_id: UUID, content_service: ContentServiceDep) -> EventSourceResponse: 70 | """Returns a stream of events for a source.""" 71 | try: 72 | 73 | async def event_stream() -> AsyncGenerator[str, None]: 74 | async for event in content_service.stream_source_events(source_id=source_id): 75 | event_json = event.model_dump_json() 76 | logger.debug(f"Printing event for debugging: {event_json}") 77 | yield event_json 78 | 79 | return EventSourceResponse(event_stream(), media_type="text/event-stream") 80 | except ValueError as e: 81 | raise HTTPException(status_code=404, detail=ErrorResponse(code=ErrorCode.CLIENT_ERROR, detail=str(e))) from e 82 | 83 | 84 | @router.get( 85 | Routes.V0.Sources.SOURCES, 86 | response_model=list[SourceOverview], 87 | responses={ 88 | 200: {"model": list[SourceOverview]}, 89 | 500: {"model": ErrorResponse, "description": "Internal server error"}, 90 | }, 91 | status_code=status.HTTP_200_OK, 92 | ) 93 | async def get_sources(content_service: ContentServiceDep, user_id: UserIdDep) -> list[SourceSummary]: 94 | """Returns a list of all sources that a user has.""" 95 | try: 96 | return await content_service.get_sources(user_id=user_id) 97 | except Exception as e: 98 | raise HTTPException( 99 | status_code=500, 100 | detail=ErrorResponse( 101 | code=ErrorCode.SERVER_ERROR, 102 | detail="An error occured while trying to get the list of sources. We are working on it already.", 103 | ), 104 | ) from e 105 | 106 | 107 | # @router.patch( 108 | # Routes.V0.Sources.SOURCES, 109 | # response_model=UpdateSourcesResponse, 110 | # responses={ 111 | # 200: {"model": UpdateSourcesResponse}, 112 | # 400: {"model": ErrorResponse}, 113 | # 404: {"model": ErrorResponse}, 114 | # 500: {"model": ErrorResponse}, 115 | # }, 116 | # status_code=status.HTTP_200_OK, 117 | # ) 118 | # async def update_sources(request: UpdateSourcesRequest, content_service: ContentServiceDep) -> UpdateSourcesResponse: 119 | # """Updates a source.""" 120 | # try: 121 | # return await content_service.update_sources(request) 122 | # # How do we handle 400 and 404? What are those? 123 | # except Exception as e: 124 | # raise HTTPException( 125 | # status_code=500, 126 | # detail="An error occured while trying to update the source. We are working on it already.", 127 | # ) from e 128 | 129 | 130 | # @router.delete( 131 | # Routes.V0.Sources.SOURCES, 132 | # response_model=DeleteSourcesResponse, 133 | # responses={ 134 | # 200: {"model": DeleteSourcesResponse}, # successfully deleted 135 | # 404: {"model": ErrorResponse}, # Source not found 136 | # 500: {"model": ErrorResponse}, # Internal error 137 | # }, 138 | # status_code=status.HTTP_200_OK, 139 | # ) 140 | # async def delete_sources(request: DeleteSourcesRequest, content_service: ContentServiceDep) -> DeleteSourcesResponse: 141 | # """Deletes a source.""" 142 | # try: 143 | # return await content_service.delete_sources(request) 144 | # # How do we handle 404? 145 | # except Exception as e: 146 | # raise HTTPException( 147 | # status_code=500, 148 | # detail="An error occured while trying to delete the source. We are working on it already.", 149 | # ) from e 150 | -------------------------------------------------------------------------------- /docs/user-guide.md: -------------------------------------------------------------------------------- 1 | # Kollektiv User Guide 2 | 3 | Kollektiv is a powerful Retrieval-Augmented Generation (RAG) system that allows you to chat with up-to-date library 4 | documentation. This guide will walk you through the process of setting up, running, and using the Kollektiv system. 5 | 6 | ## Table of Contents 7 | 8 | 1. [System Overview](#system-overview) 9 | 2. [Installation](#installation) 10 | 3. [Usage Scenarios](#usage-scenarios) 11 | - [First-Time Setup](#first-time-setup) 12 | - [Adding New Documentation](#adding-new-documentation) 13 | - [Chatting with Existing Documentation](#chatting-with-existing-documentation) 14 | 4. [Step-by-Step Guide](#step-by-step-guide) 15 | - [Crawling Documentation](#crawling-documentation) 16 | - [Chunking Documents](#chunking-documents) 17 | - [Embedding and Storing](#embedding-and-storing) 18 | - [Running the Chat Interface](#running-the-chat-interface) 19 | 5. [Advanced Usage](#advanced-usage) 20 | 6. [Troubleshooting](#troubleshooting) 21 | 22 | ## System Overview 23 | 24 | Kollektiv consists of several components that work together to provide an interactive chat experience with 25 | documentation: 26 | 27 | 1. Web Crawler: Uses the FireCrawl API to fetch documentation from specified websites. 28 | 2. Document Processor: Chunks the crawled documents into manageable pieces. 29 | 3. Vector Database: Stores document chunks and their embeddings for efficient retrieval. 30 | 4. Embedding Generator: Creates vector representations of document chunks. 31 | 5. Query Expander: Generates multiple relevant queries to improve search results. 32 | 6. Re-ranker: Improves the relevance of retrieved documents. 33 | 7. AI Assistant: Interacts with users and generates responses based on retrieved information. 34 | 35 | ## Installation 36 | 37 | 1. Clone the repository: 38 | ```bash 39 | git clone https://github.com/Twist333d/kollektiv.git 40 | cd kollektiv 41 | ``` 42 | 2. Install dependencies: 43 | ```bash 44 | poetry install 45 | ``` 46 | 3. Set up environment variables: 47 | Create a `.env` file in the project root with the following: 48 | ```bash 49 | FIRECRAWL_API_KEY="your_firecrawl_api_key" 50 | OPENAI_API_KEY="your_openai_api_key" 51 | ANTHROPIC_API_KEY="your_anthropic_api_key" 52 | COHERE_API_KEY="your_cohere_api_key" 53 | ``` 54 | 55 | 4. Start the application: 56 | ```bash 57 | poetry run kollektiv 58 | ``` 59 | 60 | This will start both the API server and the Chainlit UI. You can access: 61 | - API at http://localhost:8000 62 | - Web UI at http://localhost:8001 63 | 64 | ## Usage Scenarios 65 | 66 | ### First-Time Setup 67 | 68 | When using Kollektiv for the first time: 69 | 70 | 1. Configure environment variables (see [Installation](#installation)) 71 | 2. Start the web interface: 72 | 73 | ## Step-by-Step Guide 74 | 75 | ### Crawling Documentation 76 | 77 | To crawl documentation using FireCrawl: 78 | 79 | 1. Open `crawler.py` 80 | 2. Modify the `urls_to_crawl` list with the URLs you want to crawl 81 | 3. Run the crawler: 82 | ```bash 83 | python src/crawling/crawler.py 84 | ``` 85 | 86 | Example: 87 | 88 | ```python 89 | urls_to_crawl = [ 90 | "https://docs.yourlibrary.com", 91 | "https://api.anotherlibrary.com" 92 | ] 93 | crawler.async_crawl(urls_to_crawl, page_limit=100) 94 | ``` 95 | This will save the crawled data in the src/data/raw directory. 96 | 97 | ### Chunking Documents 98 | 99 | After crawling, you need to chunk the documents: 100 | 101 | 1. Open chunking.py 102 | 2. Update the input_filename with the name of your crawled file 103 | 3. Run the chunker: 104 | ```bash 105 | python src/chunking/chunking.py 106 | ``` 107 | Example 108 | 109 | ```python 110 | markdown_chunker = MarkdownChunker(input_filename="cra_docs_yourlibrary_com_20240526_123456.json") 111 | result = markdown_chunker.get_documents() 112 | chunks = markdown_chunker.process_pages(result) 113 | markdown_chunker.save_chunks(chunks) 114 | ``` 115 | 116 | This will save the chunked data in the src/data/chunks directory. 117 | 118 | ### Embedding and Storing 119 | 120 | To embed and store the chunks: 121 | 122 | 1. Open `app.py` 123 | 2. Update the `file_names` list with your chunked document files 124 | 3. Run the embedding and storing process: 125 | ```bash 126 | python app.py 127 | ``` 128 | Example 129 | ```python 130 | file_names = [ 131 | "cra_docs_yourlibrary_com_20240526_123456-chunked.json", 132 | "cra_docs_anotherlibrary_com_20240526_123457-chunked.json", 133 | ] 134 | for file_name in file_names: 135 | document_loader = DocumentProcessor(file_name) 136 | json_data = document_loader.load_json() 137 | await vector_db.add_documents(json_data, claude_assistant) 138 | ``` 139 | This will embed the chunks and store them in the vector database. 140 | 141 | ### Running the Chat Interface 142 | To start chatting with the documentation: 143 | 144 | 1. Ensure all previous steps are completed 145 | 2. Run the main application: 146 | ```bash 147 | python app.py 148 | ``` 149 | 3. Start asking questions in the terminal interface 150 | 151 | ## Advanced Usage 152 | ### Customizing Chunking Parameters 153 | You can customize the chunking process by modifying parameters in the MarkdownChunker class: 154 | ```python 155 | markdown_chunker = MarkdownChunker( 156 | input_filename="your_file.json", 157 | max_tokens=1000, 158 | soft_token_limit=800, 159 | min_chunk_size=100, 160 | overlap_percentage=0.05 161 | ) 162 | ``` 163 | ### Modifying the AI Assistant 164 | To change the behavior of the AI assistant, you can update the system prompt in claude_assistant.py: 165 | ```python 166 | self.base_system_prompt = """ 167 | Your custom instructions here... 168 | """ 169 | ``` 170 | ## Troubleshooting 171 | 172 | - **Crawling Issues:** Ensure your FireCrawl API key is correct and you have sufficient credits. 173 | - **Chunking Errors:** Check the input JSON file format and ensure it matches the expected structure. 174 | - **Embedding Failures:** Verify your OpenAI API key and check for rate limiting issues. 175 | - **Chat Interface Not Responding:** Make sure all components are initialized correctly and the vector database is 176 | populated. 177 | 178 | For any other issues, check the log files in the `logs` directory for detailed error messages. 179 | *** 180 | This user guide provides a comprehensive overview of the Kollektiv system. For further assistance or to report issues, 181 | open an issue on the project's GitHub repository. 182 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "kollektiv" 3 | version = "0.1.6" 4 | description = "" 5 | authors = ["AZ "] 6 | readme = "README.md" 7 | package-mode = false 8 | packages = [{ include = "src" }] 9 | 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.12,<3.13" 13 | openai = "^1.42.0" 14 | colorama = "^0.4.6" 15 | requests = "^2.32.3" 16 | tiktoken = "^0.7.0" 17 | markdown = "^3.7" 18 | cohere = "^5.9.0" 19 | firecrawl-py = "^1.2.3" 20 | jq = "^1.8.0" 21 | weave = "^0.51.12" 22 | wandb = "^0.18.3" 23 | anthropic = "^0.37.1" 24 | ragas = "^0.2.3" 25 | pydantic-settings = "^2.6.1" 26 | supabase = "^2.10.0" 27 | postgrest = "^0.18.0" 28 | libmagic = "^1.0" 29 | sentry-sdk = { extras = ["fastapi"], version = "^2.19.0" } 30 | sentry-cli = "^2.39.1" 31 | sse-starlette = "^2.1.3" 32 | redis = { version = "^5.2.1", extras = ["hiredis"] } 33 | chromadb = "^0.5.23" 34 | types-pyyaml = "^6.0.12.20241221" 35 | logfire = { extras = [ 36 | "asyncpg", 37 | "fastapi", 38 | "httpx", 39 | "redis", 40 | "system-metrics", 41 | ], version = "^2.11.0" } 42 | asyncpg = "^0.30.0" 43 | ngrok = "^1.4.0" 44 | celery = "^5.4.0" 45 | fastapi = { extras = ["standard"], version = "^0.115.6" } 46 | uvicorn = { extras = ["standard"], version = "^0.34.0" } 47 | flower = "^2.0.1" 48 | git-filter-repo = "^2.47.0" 49 | arq = "^0.26.3" 50 | msgpack = "^1.1.0" 51 | idna = "^3.10" 52 | 53 | 54 | [tool.poetry.group.dev.dependencies] 55 | pytest = "^8.3.3" 56 | pre-commit = "^4.0.1" 57 | ruff = "^0.7.3" 58 | mypy = "^1.13.0" 59 | pytest-cov = "^6.0.0" 60 | pytest-asyncio = "^0.24.0" 61 | types-requests = "^2.32.0.20241016" 62 | types-pytz = "^2024.2.0.20241003" 63 | types-aiofiles = "^24.1.0.20240626" 64 | types-colorama = "^0.4.15.20240311" 65 | types-markdown = "^3.7.0.20240822" 66 | fakeredis = "^2.26.2" 67 | 68 | [build-system] 69 | requires = ["poetry-core"] 70 | build-backend = "poetry.core.masonry.api" 71 | 72 | 73 | [tool.poetry.scripts] 74 | crawl = "src.core.content.crawler.crawler:run_crawler" 75 | api = "src.app:run" 76 | worker = "src.infra.arq.worker:run_worker" 77 | 78 | 79 | [tool.pytest.ini_options] 80 | minversion = "6.0" 81 | addopts = "-ra -q" 82 | testpaths = ["tests"] 83 | pythonpath = [".", "src"] 84 | asyncio_mode = "auto" 85 | filterwarnings = ["ignore::DeprecationWarning", "ignore::UserWarning"] 86 | markers = [ 87 | "integration: mark test as integration test", 88 | "e2e: mark test as end-to-end test", 89 | "slow: mark test as slow running", 90 | ] 91 | 92 | [tool.coverage.run] 93 | source = ["src"] 94 | 95 | [tool.coverage.report] 96 | exclude_lines = [ 97 | "pragma: no cover", 98 | "def __repr__", 99 | "if self.debug:", 100 | "if __name__ == .__main__.:", 101 | "raise NotImplementedError", 102 | "pass", 103 | "except ImportError:", 104 | ] 105 | show_missing = true 106 | skip_covered = true 107 | skip_empty = true 108 | precision = 2 109 | 110 | 111 | [tool.mypy] 112 | python_version = "3.12" 113 | strict = true 114 | 115 | # handle __init__ specifically 116 | disallow_incomplete_defs = true 117 | 118 | # For cleaner output 119 | disallow_subclassing_any = false 120 | disallow_any_explicit = false 121 | 122 | # For faster performance 123 | follow_imports = "normal" 124 | 125 | plugins = ["pydantic.mypy"] 126 | 127 | # Ignore noise 128 | disable_error_code = [ 129 | "misc", # Ignore misc issues like untyped decorators 130 | "union-attr", # Ignore union attribute access 131 | ] 132 | 133 | 134 | # More relaxed settings for tests 135 | [[tool.mypy.overrides]] 136 | module = ["tests.*"] 137 | disallow_untyped_defs = false # Allow untyped definitions 138 | check_untyped_defs = false # Don't check untyped definitions 139 | warn_return_any = false # Don't warn about implicit Any returns 140 | warn_unused_ignores = true 141 | no_implicit_optional = false 142 | disallow_incomplete_defs = false 143 | disallow_untyped_decorators = false 144 | 145 | 146 | [tool.ruff] 147 | line-length = 120 148 | target-version = "py312" 149 | indent-width = 4 150 | exclude = [ 151 | ".bzr", 152 | ".direnv", 153 | ".eggs", 154 | ".git", 155 | ".git-rewrite", 156 | ".hg", 157 | ".mypy_cache", 158 | ".pytype", 159 | ".ruff_cache", 160 | ".ipynb_checkpoints", 161 | "__pypackages__", 162 | ] 163 | 164 | [tool.ruff.lint] 165 | select = [ 166 | "E", # pycodestyle errors 167 | "D", # pydocstyle 168 | "W", # pycodestyle warnings 169 | "F", # pyflakes 170 | "I", # isort 171 | "B", # flake8-bugbear 172 | "C4", # flake8-comprehensions 173 | "UP", # pyupgrade 174 | "PT", # flake8-pytest-style 175 | "S", # Bandit security rules 176 | "N", # pep8-naming 177 | "TCH", # typechecking 178 | "PYI", # pyi 179 | "ANN", # type annotation checks 180 | ] 181 | fixable = ["ALL"] 182 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 183 | ignore = [ 184 | "D100", # Missing docstring in public module 185 | "D212", # Multi-line docstring should start at the first line 186 | "D107", # Missing docstring in __init__ 187 | "D415", # First line should end with period 188 | "ANN101", # Missing type annotation for self in method 189 | "ANN102", # Missing type annotation for cls in classmethod 190 | "ANN002", # Missing type annotation for *args 191 | "ANN003", # Missing type annotation for **kwargs 192 | "ANN401", # Dynamically typed expressions (typing.Any) are disallowed 193 | "ANN204", # Missing return type annotation for special method __init__ 194 | "E203", # Whitespace before ':' 195 | "E266", # Too many leading '#' for block comment 196 | ] 197 | 198 | [tool.ruff.lint.per-file-ignores] 199 | "__init__.py" = ["F401", "D100", "D104", "D107", "D212"] 200 | "tests/*" = [ 201 | "D100", 202 | "D101", 203 | "D102", 204 | "D103", 205 | "S101", 206 | "ANN001", 207 | "ANN201", 208 | "ANN101", # Type annotation rules 209 | ] 210 | 211 | [tool.ruff.lint.pydocstyle] 212 | convention = "google" 213 | ignore-decorators = ["property", "classmethod", "staticmethod"] 214 | 215 | [tool.ruff.lint.isort] 216 | known-third-party = ["anthropic", "openai", "pydantic"] 217 | known-first-party = ["src"] 218 | 219 | [tool.ruff.lint.pyupgrade] 220 | keep-runtime-typing = true 221 | 222 | [tool.ruff.format] 223 | quote-style = "double" 224 | indent-style = "space" 225 | skip-magic-trailing-comma = false 226 | line-ending = "auto" 227 | docstring-code-format = true 228 | 229 | [tool.ruff.lint.pycodestyle] 230 | ignore-overlong-task-comments = true 231 | -------------------------------------------------------------------------------- /docs/feature-spec/arq-migration.md: -------------------------------------------------------------------------------- 1 | # Arq Migration 2 | 3 | ## Context 4 | Celery does not natively support async operations. 5 | 6 | 7 | 8 | ## Problem Statement 9 | Using Celery forces us to wrap async tasks inside synchronous functions (using `asyncio.run` etc.). This adds complexity, prevents genuine parallelism, and creates cumbersome orchestration (notably when using groups or chords for parallelism). 10 | 11 | ## Proposed Solution 12 | Transition to [Arq](https://arq-docs.helpmanual.io/) because: 13 | - Kollektiv is early-stage. 14 | - Arq is actively maintained with native async support. 15 | - It's built by the Pydantic founder 16 | - Native async design simplifies orchestration and error handling. 17 | 18 | ## Worker Setup Checklist 19 | 1. **Redis Connection** ✅ 20 | - Define `RedisSettings` for connecting to Redis since Arq uses Redis as its queue backend. 21 | 2. **Lifecycle Hooks** ✅ 22 | - **on_startup**: Initialize worker services (e.g. create a singleton for `WorkerServices`). 23 | - **on_shutdown**: Shutdown or clean up any resources. ✅ 24 | 3. **Worker Configuration** ✅ 25 | - Create a main function that: 26 | - Establishes the Redis pool. 27 | - Provides the list of async task functions. 28 | - Optionally initializes a thread/process pool for blocking (sync) jobs. 29 | 30 | ## Key Migration Topics 31 | 32 | ### Local development 33 | - Define changes necessary to compose.yaml 34 | - Define changes to Dockerfile or settings 35 | 36 | ### Concurrency 37 | - **How many workers?** 38 | Arq uses async workers—start with a small pool and scale based on resource utilization. Configure worker concurrency via command-line or settings. 39 | - **Task parallelism:** 40 | Instead of Celery's `group`, simply enqueue jobs concurrently (or use `asyncio.gather` in a parent task if orchestration is needed). 41 | 42 | ### Startup & Shutdown 43 | - **Worker Services:** 44 | Initialize worker services (from `worker_services.py`) in the `on_startup` coroutine. 45 | - **Logging:** 46 | Set up logging in the on_startup hook so that each worker has its configuration upon boot. 47 | 48 | ### Idempotency 49 | - **Pessimistic Execution:** 50 | ARQ may run jobs more than once if a worker shuts down abruptly. 51 | - **Design Considerations:** 52 | Design jobs to be idempotent (use transactions, idempotency keys, or set Redis flags to mark completed API calls). 53 | 54 | ### Healthcheck 55 | - ARQ updates a Redis key every `health_check_interval` seconds. Use this key to confirm that the worker is live. 56 | - You can check health via the CLI: 57 | ``` 58 | arq --check YourWorkerSettingsClass 59 | ``` 60 | 61 | ### Serialization 62 | - **Default Serializer:** 63 | ARQ uses MessagePack by default, which might differ slightly from your JSON-based Celery serialization. 64 | - **Custom Handling:** 65 | Assess if additional serializer customization is needed for your Pydantic models. A helper for converting models to their dict (or JSON) representations may be useful. 66 | 67 | ### Task Queue Orchestration 68 | - **Replacing `group`:** 69 | Enqueue multiple tasks concurrently. In an async context, use `await asyncio.gather(*tasks)` rather than a Celery group. 70 | - **Replacing `chord`:** 71 | Instead of using a chord, chain tasks manually. For example, enqueue all subtasks and then enqueue a final "notification" task that polls or waits for completion. 72 | - **Retry Policy:** 73 | Define retries within your task definitions or via your worker settings (e.g., using a `max_retries` parameter). ARQ supports retry parameters that can be set per job. 74 | 75 | ### Sync Jobs 76 | - **Blocked Operations:** 77 | For CPU-bound tasks (like chunking), use an executor: 78 | ```python 79 | import asyncio 80 | from concurrent.futures import ProcessPoolExecutor 81 | 82 | async def run_sync_task(ctx, t): 83 | loop = asyncio.get_running_loop() 84 | return await loop.run_in_executor(ctx["pool"], sync_task, t) 85 | ``` 86 | Initialize the executor in `on_startup` and shut it down in `on_shutdown`. 87 | 88 | ### Job Results & Enqueueing Tasks 89 | - **Job Handling:** 90 | ARQ's `enqueue_job` returns a `Job` instance which can be used to query status and results. 91 | - **From ContentService:** 92 | Instead of `celery_app.delay(...)`, use: 93 | ```python 94 | job = await redis.enqueue_job("your_task_name", arg1, arg2) 95 | ``` 96 | This method allows you to chain or await completion if needed. 97 | 98 | ### Defining Tasks 99 | Plan to define these tasks as async functions: 100 | - `process_documents` 101 | - `chunk_documents_batch` 102 | - `persist_chunks` 103 | - `check_chunking_complete` 104 | - `generate_summary` 105 | - `publish_event` 106 | 107 | 108 | 109 | Each task should: 110 | - Be written as an `async def` function. 111 | - Use `await` instead of `asyncio.run`. 112 | - Incorporate a retry mechanism via ARQ settings if necessary. 113 | 114 | #### Simplified workflow 115 | - `process_documents` -> accepts list of documents, user_id, source_id and fires up processing 116 | 117 | It should: 118 | - break down documents into batches (should be fast) 119 | - schedule processing of these batches 120 | - schedule check complition task with job ids 121 | - it doesn't need to await the completion 122 | 123 | - `chunk_documents_batch` 124 | - accepts list of documents 125 | - awaits processing of each batch 126 | - schedules storage of all batches AND awaits results? 127 | 128 | - `persist_chunks` 129 | - accepts list of chunks 130 | - sends them to vector db 131 | - sends them to supabase 132 | 133 | - `check_chunking_complete` 134 | - accepts list of job ids 135 | - awaits completion of all jobs 136 | - if all jobs finished successfully -> generates a summary and schedules event publishing 137 | - if some jobs failed -> schedules a notification to user with the error 138 | 139 | - `generate_summary` 140 | - accepts list of documents 141 | - generates LLM summary 142 | - schedules event publishing 143 | 144 | - `publish_event` 145 | - accepts event 146 | - publishes it to the event bus 147 | 148 | --- 149 | 150 | ## Summary 151 | - **Migrate to async-first tasks:** Rewrite tasks (from `tasks.py`) as async functions. 152 | - **Use lifecycle hooks:** Replace Celery's process init logic (in `worker.py`) with ARQ's `on_startup` and `on_shutdown`. 153 | - **Simplify orchestration:** Replace Celery chords/groups with async concurrency (`asyncio.gather`) or task chaining. 154 | - **Plan for idempotency & retry:** Since ARQ may re-run jobs, ensure each job is designed to be safely repeatable. 155 | - **Check serialization needs:** Decide if you need custom serialization beyond ARQ's default MessagePack. 156 | 157 | This document should guide you to account for the differences between Celery and ARQ and help you design a cleaner, native async task queue. 158 | 159 | Happy coding! 160 | -------------------------------------------------------------------------------- /src/infra/service_container.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from redis.asyncio import Redis 4 | 5 | from src.core.chat.conversation_manager import ConversationManager 6 | from src.core.chat.llm_assistant import ClaudeAssistant 7 | from src.core.chat.summary_manager import SummaryManager 8 | from src.core.content.crawler import FireCrawler 9 | from src.core.search.embedding_manager import EmbeddingManager 10 | from src.core.search.reranker import Reranker 11 | from src.core.search.retriever import Retriever 12 | from src.core.search.vector_db import VectorDatabase 13 | from src.infra.arq.redis_pool import RedisPool 14 | from src.infra.data.data_repository import DataRepository 15 | from src.infra.data.redis_repository import RedisRepository 16 | from src.infra.events.event_consumer import EventConsumer 17 | from src.infra.events.event_publisher import EventPublisher 18 | from src.infra.external.chroma_manager import ChromaManager 19 | from src.infra.external.redis_manager import RedisManager 20 | from src.infra.external.supabase_manager import SupabaseManager 21 | from src.infra.logger import get_logger 22 | from src.infra.misc.ngrok_service import NgrokService 23 | from src.services.chat_service import ChatService 24 | from src.services.content_service import ContentService 25 | from src.services.data_service import DataService 26 | from src.services.job_manager import JobManager 27 | 28 | logger = get_logger() 29 | 30 | 31 | class ServiceContainer: 32 | """Container object for all services that are initialized in the application.""" 33 | 34 | def __init__(self) -> None: 35 | """Initialize Kollektiv container attributes.""" 36 | self.job_manager: JobManager | None = None 37 | self.firecrawler: FireCrawler | None = None 38 | self.data_service: DataService | None = None 39 | self.content_service: ContentService | None = None 40 | self.repository: DataRepository | None = None 41 | self.supabase_manager: SupabaseManager | None = None 42 | self.llm_assistant: ClaudeAssistant | None = None 43 | self.vector_db: VectorDatabase | None = None 44 | self.chat_service: ChatService | None = None 45 | self.conversation_manager: ConversationManager | None = None 46 | self.retriever: Retriever | None = None 47 | self.reranker: Reranker | None = None 48 | self.async_redis_client: Redis | None = None 49 | self.redis_repository: RedisRepository | None = None 50 | self.embedding_manager: EmbeddingManager | None = None 51 | self.ngrok_service: NgrokService | None = None 52 | self.chroma_manager: ChromaManager | None = None 53 | self.event_publisher: EventPublisher | None = None 54 | self.event_consumer: EventConsumer | None = None 55 | self.summary_manager: SummaryManager | None = None 56 | 57 | async def initialize_services(self) -> None: 58 | """Initialize all services.""" 59 | try: 60 | # Database & Repository 61 | self.supabase_manager = await SupabaseManager.create_async() 62 | self.repository = DataRepository(supabase_manager=self.supabase_manager) 63 | self.data_service = DataService(repository=self.repository) 64 | 65 | # Redis 66 | self.async_redis_manager = await RedisManager.create_async() 67 | self.redis_repository = RedisRepository(manager=self.async_redis_manager) 68 | self.event_publisher = await EventPublisher.create_async(redis_manager=self.async_redis_manager) 69 | self.arq_redis_pool = await RedisPool.create_redis_pool() 70 | 71 | # Job & Content Services 72 | self.job_manager = JobManager(data_service=self.data_service) 73 | self.firecrawler = FireCrawler() 74 | self.content_service = ContentService( 75 | crawler=self.firecrawler, 76 | job_manager=self.job_manager, 77 | data_service=self.data_service, 78 | redis_manager=self.async_redis_manager, 79 | event_publisher=self.event_publisher, 80 | arq_redis_pool=self.arq_redis_pool, 81 | ) 82 | 83 | # Vector operations 84 | self.chroma_manager = await ChromaManager.create_async() 85 | self.embedding_manager = EmbeddingManager() 86 | self.vector_db = VectorDatabase( 87 | chroma_manager=self.chroma_manager, 88 | embedding_manager=self.embedding_manager, 89 | data_service=self.data_service, 90 | ) 91 | self.reranker = Reranker() 92 | self.retriever = Retriever(vector_db=self.vector_db, reranker=self.reranker) 93 | 94 | # Chat Services 95 | self.claude_assistant = ClaudeAssistant(retriever=self.retriever) 96 | self.conversation_manager = ConversationManager( 97 | redis_repository=self.redis_repository, data_service=self.data_service 98 | ) 99 | self.chat_service = ChatService( 100 | claude_assistant=self.claude_assistant, 101 | data_service=self.data_service, 102 | conversation_manager=self.conversation_manager, 103 | ) 104 | self.ngrok_service = await NgrokService.create() 105 | 106 | # Events 107 | self.event_consumer = await EventConsumer.create_async( 108 | redis_manager=self.async_redis_manager, content_service=self.content_service 109 | ) 110 | await self.event_consumer.start() 111 | 112 | # Source summary 113 | self.summary_manager = SummaryManager( 114 | data_service=self.data_service, 115 | ) 116 | 117 | # Log the successful initialization 118 | logger.info("✓ Initialized services successfully.") 119 | except Exception as e: 120 | logger.error(f"Error during service initialization: {e}", exc_info=True) 121 | raise 122 | 123 | @classmethod 124 | async def create(cls) -> ServiceContainer: 125 | """Create a new ServiceContainer instance and initialize services.""" 126 | container = cls() 127 | await container.initialize_services() 128 | return container 129 | 130 | # TODO: Implement proper service shutdown logic 131 | async def shutdown_services(self) -> None: 132 | """Shutdown all services.""" 133 | # First app layer (chat, content, job manager) 134 | # Event layer 135 | # Base layer - dbs, redis, chroms 136 | try: 137 | if self.ngrok_service is not None: 138 | await self.ngrok_service.stop_tunnel() 139 | 140 | if self.event_consumer is not None: 141 | await self.event_consumer.stop() 142 | 143 | except Exception as e: 144 | logger.error(f"Error during service shutdown: {e}", exc_info=True) 145 | -------------------------------------------------------------------------------- /tests/unit/infra/arq/test_task_definitions.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, Mock, patch 2 | from uuid import uuid4 3 | 4 | import pytest 5 | from arq.jobs import Job 6 | 7 | from src.infra.arq.serializer import deserialize 8 | from src.infra.arq.task_definitions import ( 9 | KollektivTaskResult, 10 | KollektivTaskStatus, 11 | _create_job_reference, 12 | _gather_job_results, 13 | publish_event, 14 | ) 15 | from src.infra.events.channels import Channels 16 | from src.models.content_models import ContentProcessingEvent, SourceStage 17 | from src.models.pubsub_models import EventType 18 | 19 | 20 | @pytest.fixture 21 | def mock_context(): 22 | """Create a mock context with required components.""" 23 | return {"arq_redis": Mock(), "worker_services": Mock()} 24 | 25 | 26 | @pytest.fixture 27 | def mock_job(): 28 | """Create a mock ARQ job.""" 29 | job = Mock(spec=Job) 30 | job.job_id = str(uuid4()) 31 | return job 32 | 33 | 34 | @pytest.fixture 35 | def success_result(): 36 | """Create a successful task result.""" 37 | return KollektivTaskResult( 38 | status=KollektivTaskStatus.SUCCESS, message="Operation completed successfully", data={"key": "value"} 39 | ) 40 | 41 | 42 | @pytest.fixture 43 | def failure_result(): 44 | """Create a failed task result.""" 45 | return KollektivTaskResult( 46 | status=KollektivTaskStatus.FAILED, message="Operation failed", data={"error": "test error"} 47 | ) 48 | 49 | 50 | def test_create_job_reference_success(mock_context): 51 | """Test successful job reference creation.""" 52 | job_id = str(uuid4()) 53 | 54 | with patch("src.infra.arq.task_definitions.Job") as mock_job_class: 55 | _create_job_reference(mock_context, job_id) 56 | 57 | mock_job_class.assert_called_once_with( 58 | job_id=job_id, 59 | redis=mock_context["arq_redis"], 60 | _deserializer=deserialize, # Use the actual deserializer 61 | ) 62 | 63 | 64 | def test_create_job_reference_invalid_id(mock_context): 65 | """Test job reference creation with invalid job ID.""" 66 | with pytest.raises(ValueError, match="Invalid job ID"): # Changed to match the actual error 67 | with patch("src.infra.arq.task_definitions.Job", side_effect=ValueError("Invalid job ID")): 68 | _create_job_reference(mock_context, "invalid-id") 69 | 70 | 71 | @pytest.mark.asyncio 72 | async def test_gather_job_results_all_success(mock_context): 73 | """Test gathering results when all jobs succeed.""" 74 | job_ids = [str(uuid4()) for _ in range(3)] 75 | success_results = [ 76 | KollektivTaskResult(status=KollektivTaskStatus.SUCCESS, message=f"Success {i}") for i in range(3) 77 | ] 78 | 79 | mock_jobs = [] 80 | for job_id, result in zip(job_ids, success_results, strict=False): 81 | mock_job = Mock(spec=Job) 82 | mock_job.job_id = job_id 83 | mock_job.result = AsyncMock(return_value=result) 84 | mock_jobs.append(mock_job) 85 | 86 | with patch("src.infra.arq.task_definitions._create_job_reference", side_effect=mock_jobs): 87 | results = await _gather_job_results(mock_context, job_ids, "test_operation") 88 | 89 | assert len(results) == 3 90 | assert all(r.status == KollektivTaskStatus.SUCCESS for r in results) 91 | 92 | 93 | @pytest.mark.asyncio 94 | async def test_gather_job_results_with_failures(mock_context): 95 | """Test gathering results when some jobs fail.""" 96 | job_ids = [str(uuid4()) for _ in range(3)] 97 | mixed_results = [ 98 | KollektivTaskResult(status=KollektivTaskStatus.SUCCESS, message="Success"), 99 | KollektivTaskResult(status=KollektivTaskStatus.FAILED, message="Failed 1"), 100 | KollektivTaskResult(status=KollektivTaskStatus.FAILED, message="Failed 2"), 101 | ] 102 | 103 | mock_jobs = [] 104 | for job_id, result in zip(job_ids, mixed_results, strict=False): 105 | mock_job = Mock(spec=Job) 106 | mock_job.job_id = job_id 107 | mock_job.result = AsyncMock(return_value=result) 108 | mock_jobs.append(mock_job) 109 | 110 | with patch("src.infra.arq.task_definitions._create_job_reference", side_effect=mock_jobs): 111 | with pytest.raises(Exception, match="test_operation failed: 2 out of 3 jobs failed"): 112 | await _gather_job_results(mock_context, job_ids, "test_operation") 113 | 114 | 115 | @pytest.mark.asyncio 116 | async def test_gather_job_results_execution_error(mock_context): 117 | """Test gathering results when job execution fails.""" 118 | job_ids = [str(uuid4())] 119 | mock_job = Mock(spec=Job) 120 | mock_job.job_id = job_ids[0] 121 | mock_job.result = AsyncMock(side_effect=Exception("Execution failed")) 122 | 123 | with patch("src.infra.arq.task_definitions._create_job_reference", return_value=mock_job): 124 | with pytest.raises(Exception, match="test_operation failed: Execution failed"): 125 | await _gather_job_results(mock_context, job_ids, "test_operation") 126 | 127 | 128 | @pytest.mark.asyncio 129 | async def test_publish_event_success(mock_context): 130 | """Test successful event publishing.""" 131 | source_id = uuid4() 132 | event = ContentProcessingEvent( 133 | source_id=source_id, 134 | event_type=EventType.CONTENT_PROCESSING, 135 | stage=SourceStage.CREATED, # Fixed: Add required stage field 136 | ) 137 | 138 | mock_context["worker_services"].event_publisher.publish_event = AsyncMock() 139 | 140 | result = await publish_event(mock_context, event) 141 | 142 | mock_context["worker_services"].event_publisher.publish_event.assert_called_once_with( 143 | channel=Channels.content_processing_channel(source_id), message=event 144 | ) 145 | 146 | assert result.status == KollektivTaskStatus.SUCCESS 147 | assert "Event published successfully" in result.message 148 | 149 | 150 | @pytest.mark.asyncio 151 | async def test_publish_event_connection_error(mock_context): 152 | """Test event publishing with connection error.""" 153 | source_id = uuid4() 154 | event = ContentProcessingEvent( 155 | source_id=source_id, event_type=EventType.CONTENT_PROCESSING, stage=SourceStage.CREATED 156 | ) 157 | 158 | mock_context["worker_services"].event_publisher.publish_event = AsyncMock( 159 | side_effect=ConnectionError("Redis connection failed") 160 | ) 161 | 162 | result = await publish_event(mock_context, event) 163 | 164 | assert result.status == KollektivTaskStatus.FAILED 165 | assert "connection error" in result.message 166 | 167 | 168 | @pytest.mark.asyncio 169 | async def test_publish_event_unexpected_error(mock_context): 170 | """Test event publishing with unexpected error.""" 171 | source_id = uuid4() 172 | event = ContentProcessingEvent( 173 | source_id=source_id, event_type=EventType.CONTENT_PROCESSING, stage=SourceStage.CREATED 174 | ) 175 | 176 | mock_context["worker_services"].event_publisher.publish_event = AsyncMock(side_effect=Exception("Unexpected error")) 177 | 178 | result = await publish_event(mock_context, event) 179 | 180 | assert result.status == KollektivTaskStatus.FAILED 181 | assert "unexpected error" in result.message 182 | -------------------------------------------------------------------------------- /src/api/v0/endpoints/chat.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncIterator 2 | from typing import Annotated 3 | from uuid import UUID 4 | 5 | from fastapi import APIRouter, Depends, HTTPException, status 6 | from fastapi.exceptions import RequestValidationError 7 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 8 | from sse_starlette.sse import EventSourceResponse 9 | 10 | from src.api.dependencies import ChatServiceDep, SupabaseManagerDep 11 | from src.api.routes import CURRENT_API_VERSION, Routes 12 | from src.api.v0.schemas.base_schemas import ErrorCode, ErrorResponse 13 | from src.core._exceptions import DatabaseError, EntityNotFoundError, NonRetryableLLMError, RetryableLLMError 14 | from src.infra.logger import get_logger 15 | from src.models.chat_models import ( 16 | ConversationHistoryResponse, 17 | ConversationListResponse, 18 | FrontendChatEvent, 19 | UserMessage, 20 | ) 21 | 22 | # Define routers with base prefix only 23 | chat_router = APIRouter(prefix=CURRENT_API_VERSION) 24 | conversations_router = APIRouter(prefix=CURRENT_API_VERSION) 25 | 26 | logger = get_logger() 27 | 28 | security = HTTPBearer() 29 | 30 | 31 | @chat_router.post( 32 | Routes.V0.Chat.CHAT, 33 | response_model=FrontendChatEvent, 34 | responses={ 35 | 200: {"model": FrontendChatEvent}, 36 | 400: {"model": ErrorResponse}, 37 | 500: {"model": ErrorResponse}, 38 | }, 39 | ) 40 | async def chat(request: UserMessage, chat_service: ChatServiceDep) -> EventSourceResponse: 41 | """ 42 | Sends a user message and gets a streaming response. 43 | 44 | Returns Server-Sent Events with tokens. 45 | """ 46 | try: 47 | logger.debug(f"POST /chat request for debugging: {request.model_dump(serialize_as_any=True)}") 48 | 49 | async def event_stream() -> AsyncIterator[str]: 50 | async for event in chat_service.get_response(user_message=request): 51 | yield event.model_dump_json(serialize_as_any=True) 52 | 53 | return EventSourceResponse(event_stream(), media_type="text/event-stream") 54 | 55 | except NonRetryableLLMError as e: 56 | raise HTTPException( 57 | status_code=500, 58 | detail=ErrorResponse( 59 | code=ErrorCode.SERVER_ERROR, 60 | detail=f"A non-retryable error occurred in the system:: {str(e)}. We are on it.", 61 | ), 62 | ) from e 63 | except RetryableLLMError as e: 64 | raise HTTPException( 65 | status_code=500, 66 | detail=ErrorResponse( 67 | code=ErrorCode.SERVER_ERROR, 68 | detail=f"An error occurred in the system:: {str(e)}. Can you please try again?", 69 | ), 70 | ) from e 71 | 72 | 73 | # Get all conversations 74 | @conversations_router.get( 75 | Routes.V0.Conversations.LIST, 76 | response_model=ConversationListResponse, 77 | responses={ 78 | 200: {"model": ConversationListResponse}, 79 | 400: {"model": ErrorResponse}, 80 | 500: {"model": ErrorResponse}, 81 | }, 82 | ) 83 | async def list_conversations(user_id: UUID, chat_service: ChatServiceDep) -> ConversationListResponse: 84 | """Get grouped list of conversations.""" 85 | try: 86 | return await chat_service.get_conversations(user_id) 87 | except DatabaseError as e: 88 | logger.error(f"Database error while getting conversations for user {user_id}: {e}") 89 | raise HTTPException( 90 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 91 | detail=ErrorResponse(code=ErrorCode.SERVER_ERROR, detail="Failed to retrieve conversations."), 92 | ) from e 93 | except RequestValidationError as e: 94 | logger.error(f"Validation error: {e}") 95 | raise HTTPException( 96 | status_code=status.HTTP_400_BAD_REQUEST, 97 | detail=ErrorResponse(code=ErrorCode.CLIENT_ERROR, detail="Invalid request data."), 98 | ) from e 99 | except Exception as e: 100 | logger.error(f"Unexpected error while getting conversations for user {user_id}: {e}", exc_info=True) 101 | raise HTTPException( 102 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 103 | detail=ErrorResponse(code=ErrorCode.SERVER_ERROR, detail="An unexpected error occurred."), 104 | ) from e 105 | 106 | 107 | # TODO: Refactor user id into a UserContext service that would be accesssible by any service / endpoint 108 | # TODO: API layer would set the user id in the request context 109 | # TODO: Chat service would get the user id from the user context service 110 | 111 | 112 | # Get messages in a conversation 113 | @conversations_router.get( 114 | Routes.V0.Conversations.GET, 115 | response_model=ConversationHistoryResponse, 116 | responses={ 117 | 200: {"model": ConversationHistoryResponse}, 118 | 400: {"model": ErrorResponse}, 119 | 404: {"model": ErrorResponse}, 120 | 500: {"model": ErrorResponse}, 121 | }, 122 | ) 123 | async def get_conversation( 124 | credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], 125 | conversation_id: UUID, 126 | chat_service: ChatServiceDep, 127 | supabase: SupabaseManagerDep, 128 | ) -> ConversationHistoryResponse: 129 | """Get all messages in a conversation.""" 130 | try: 131 | # Get the client 132 | supabase_client = await supabase.get_async_client() 133 | 134 | # Get the user 135 | user_response = await supabase_client.auth.get_user(credentials.credentials) 136 | user_id = UUID(user_response.user.id) 137 | logger.debug(f"User ID: {user_id}") 138 | 139 | # Get the conversation 140 | return await chat_service.get_conversation(conversation_id, user_id) 141 | # Handle case where conversation is not found 142 | except EntityNotFoundError as e: 143 | logger.warning(f"Conversation not found: {conversation_id}") 144 | raise HTTPException( 145 | status_code=status.HTTP_404_NOT_FOUND, 146 | detail=ErrorResponse(code=ErrorCode.CLIENT_ERROR, detail="Conversation not found."), 147 | ) from e 148 | # Handle all other database errors 149 | except DatabaseError as e: 150 | logger.error(f"Database error while getting conversation {conversation_id}: {e}") 151 | raise HTTPException( 152 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 153 | detail=ErrorResponse(code=ErrorCode.SERVER_ERROR, detail="Failed to retrieve conversation."), 154 | ) from e 155 | # Handle case when the client sends invalid request 156 | except RequestValidationError as e: 157 | logger.error(f"Validation error: {e}") 158 | raise HTTPException( 159 | status_code=status.HTTP_400_BAD_REQUEST, 160 | detail=ErrorResponse(code=ErrorCode.CLIENT_ERROR, detail="Invalid request data."), 161 | ) from e 162 | except Exception as e: 163 | logger.error(f"Unexpected error while getting conversation {conversation_id}: {e}", exc_info=True) 164 | raise HTTPException( 165 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 166 | detail=ErrorResponse(code=ErrorCode.SERVER_ERROR, detail="An unexpected error occurred."), 167 | ) from e 168 | --------------------------------------------------------------------------------