├── .copier-answers.yml ├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .github ├── dependabot.yml └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yml ├── pyproject.toml ├── src └── raglite │ ├── __init__.py │ ├── _bench.py │ ├── _chainlit.py │ ├── _chatml_function_calling.py │ ├── _cli.py │ ├── _config.py │ ├── _database.py │ ├── _embed.py │ ├── _eval.py │ ├── _extract.py │ ├── _insert.py │ ├── _lazy_llama.py │ ├── _litellm.py │ ├── _markdown.py │ ├── _mcp.py │ ├── _query_adapter.py │ ├── _rag.py │ ├── _search.py │ ├── _split_chunklets.py │ ├── _split_chunks.py │ ├── _split_sentences.py │ ├── _typing.py │ └── py.typed └── tests ├── __init__.py ├── conftest.py ├── specrel.pdf ├── test_chatml_function_calling.py ├── test_database.py ├── test_embed.py ├── test_extract.py ├── test_import.py ├── test_insert.py ├── test_lazy_llama.py ├── test_markdown.py ├── test_query_adapter.py ├── test_rag.py ├── test_rerank.py ├── test_search.py ├── test_split_chunklets.py ├── test_split_chunks.py └── test_split_sentences.py /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | _commit: v1.5.1 2 | _src_path: gh:superlinear-ai/substrate 3 | author_email: laurent@superlinear.eu 4 | author_name: Laurent Sorber 5 | project_description: A Python toolkit for Retrieval-Augmented Generation (RAG) with 6 | DuckDB or PostgreSQL. 7 | project_name: raglite 8 | project_type: package 9 | project_url: https://github.com/superlinear-ai/raglite 10 | python_version: '3.10' 11 | typing: strict 12 | with_conventional_commits: true 13 | with_typer_cli: true 14 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "raglite", 3 | "dockerComposeFile": "../docker-compose.yml", 4 | "service": "devcontainer", 5 | "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}/", 6 | "features": { 7 | "ghcr.io/devcontainers-extra/features/starship:1": {} 8 | }, 9 | "overrideCommand": true, 10 | "remoteUser": "user", 11 | "postStartCommand": "sudo chown -R user:user /opt/ && uv sync --python ${localEnv:PYTHON_VERSION:3.10} --resolution ${localEnv:RESOLUTION_STRATEGY:highest} --all-extras && pre-commit install --install-hooks", 12 | "customizations": { 13 | "jetbrains": { 14 | "backend": "PyCharm", 15 | "plugins": [ 16 | "com.github.copilot" 17 | ] 18 | }, 19 | "vscode": { 20 | "extensions": [ 21 | "charliermarsh.ruff", 22 | "GitHub.copilot", 23 | "GitHub.copilot-chat", 24 | "GitHub.vscode-github-actions", 25 | "GitHub.vscode-pull-request-github", 26 | "ms-azuretools.vscode-docker", 27 | "ms-python.mypy-type-checker", 28 | "ms-python.python", 29 | "ms-toolsai.jupyter", 30 | "ryanluker.vscode-coverage-gutters", 31 | "tamasfe.even-better-toml", 32 | "visualstudioexptteam.vscodeintellicode" 33 | ], 34 | "settings": { 35 | "coverage-gutters.coverageFileNames": [ 36 | "reports/coverage.xml" 37 | ], 38 | "editor.codeActionsOnSave": { 39 | "source.fixAll": "explicit", 40 | "source.organizeImports": "explicit" 41 | }, 42 | "editor.formatOnSave": true, 43 | "[python]": { 44 | "editor.defaultFormatter": "charliermarsh.ruff" 45 | }, 46 | "[toml]": { 47 | "editor.formatOnSave": false 48 | }, 49 | "editor.rulers": [ 50 | 100 51 | ], 52 | "files.autoSave": "onFocusChange", 53 | "github.copilot.chat.agent.enabled": true, 54 | "github.copilot.chat.codesearch.enabled": true, 55 | "github.copilot.chat.edits.enabled": true, 56 | "github.copilot.nextEditSuggestions.enabled": true, 57 | "jupyter.kernels.excludePythonEnvironments": [ 58 | "/usr/local/bin/python" 59 | ], 60 | "mypy-type-checker.importStrategy": "fromEnvironment", 61 | "mypy-type-checker.preferDaemon": true, 62 | "notebook.codeActionsOnSave": { 63 | "notebook.source.fixAll": "explicit", 64 | "notebook.source.organizeImports": "explicit" 65 | }, 66 | "notebook.formatOnSave.enabled": true, 67 | "python.defaultInterpreterPath": "/opt/venv/bin/python", 68 | "python.terminal.activateEnvironment": false, 69 | "python.testing.pytestEnabled": true, 70 | "ruff.importStrategy": "fromEnvironment", 71 | "ruff.logLevel": "warning", 72 | "terminal.integrated.env.linux": { 73 | "GIT_EDITOR": "code --wait" 74 | }, 75 | "terminal.integrated.env.mac": { 76 | "GIT_EDITOR": "code --wait" 77 | } 78 | } 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Caches 2 | .*_cache/ 3 | 4 | # Git 5 | .git/ 6 | 7 | # Python 8 | .venv/ 9 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: github-actions 5 | directory: / 6 | schedule: 7 | interval: monthly 8 | commit-message: 9 | prefix: "ci" 10 | prefix-development: "ci" 11 | include: scope 12 | groups: 13 | ci-dependencies: 14 | patterns: 15 | - "*" 16 | - package-ecosystem: pip 17 | directory: / 18 | schedule: 19 | interval: monthly 20 | commit-message: 21 | prefix: "chore" 22 | prefix-development: "build" 23 | include: scope 24 | allow: 25 | - dependency-type: development 26 | versioning-strategy: increase 27 | groups: 28 | development-dependencies: 29 | dependency-type: development 30 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | environment: pypi 12 | permissions: 13 | id-token: write 14 | 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v5 21 | 22 | - name: Publish package 23 | run: | 24 | uv build 25 | uv publish 26 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | 10 | jobs: 11 | test: 12 | runs-on: 13 | group: raglite 14 | 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.10", "3.12"] 19 | resolution-strategy: ["highest", "lowest-direct"] 20 | 21 | name: Python ${{ matrix.python-version }} (resolution=${{ matrix.resolution-strategy }}) 22 | 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | 27 | - name: Set up Node.js 28 | uses: actions/setup-node@v4 29 | with: 30 | node-version: 23 31 | 32 | - name: Install @devcontainers/cli 33 | run: npm install --location=global @devcontainers/cli@0.73.0 34 | 35 | - name: Start Dev Container 36 | run: | 37 | git config --global init.defaultBranch main 38 | devcontainer up --workspace-folder . 39 | env: 40 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 41 | PYTHON_VERSION: ${{ matrix.python-version }} 42 | RESOLUTION_STRATEGY: ${{ matrix.resolution-strategy }} 43 | 44 | - name: Lint package 45 | run: devcontainer exec --workspace-folder . poe lint 46 | 47 | - name: Test package 48 | run: devcontainer exec --workspace-folder . poe test 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Chainlit 2 | .chainlit/ 3 | .files/ 4 | chainlit.md 5 | 6 | # Coverage.py 7 | htmlcov/ 8 | reports/ 9 | 10 | # Copier 11 | *.rej 12 | 13 | # Data 14 | *.csv* 15 | *.dat* 16 | *.pickle* 17 | *.xls* 18 | *.zip* 19 | data/ 20 | 21 | # direnv 22 | .envrc 23 | 24 | # dotenv 25 | .env 26 | 27 | # rerankers 28 | .*_cache/ 29 | 30 | # Hypothesis 31 | .hypothesis/ 32 | 33 | # Jupyter 34 | *.ipynb 35 | .ipynb_checkpoints/ 36 | notebooks/ 37 | 38 | # macOS 39 | .DS_Store 40 | 41 | # mise 42 | mise.local.toml 43 | 44 | # mypy 45 | .dmypy.json 46 | .mypy_cache/ 47 | 48 | # Node.js 49 | node_modules/ 50 | 51 | # PyCharm 52 | .idea/ 53 | 54 | # pyenv 55 | .python-version 56 | 57 | # pytest 58 | .pytest_cache/ 59 | 60 | # Python 61 | __pycache__/ 62 | *.egg-info/ 63 | *.py[cdo] 64 | .venv/ 65 | dist/ 66 | 67 | # RAGLite 68 | *.db 69 | 70 | # Ruff 71 | .ruff_cache/ 72 | 73 | # Terraform 74 | .terraform/ 75 | 76 | # uv 77 | uv.lock 78 | 79 | # VS Code 80 | .vscode/ 81 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | default_install_hook_types: [commit-msg, pre-commit] 3 | default_stages: [pre-commit, manual] 4 | fail_fast: true 5 | repos: 6 | - repo: meta 7 | hooks: 8 | - id: check-useless-excludes 9 | - repo: https://github.com/pre-commit/pygrep-hooks 10 | rev: v1.10.0 11 | hooks: 12 | - id: python-check-mock-methods 13 | - id: python-use-type-annotations 14 | - id: rst-backticks 15 | - id: rst-directive-colons 16 | - id: rst-inline-touching-normal 17 | - id: text-unicode-replacement-char 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v5.0.0 20 | hooks: 21 | - id: check-added-large-files 22 | - id: check-ast 23 | - id: check-builtin-literals 24 | - id: check-case-conflict 25 | - id: check-docstring-first 26 | - id: check-illegal-windows-names 27 | - id: check-json 28 | - id: check-merge-conflict 29 | - id: check-shebang-scripts-are-executable 30 | - id: check-symlinks 31 | - id: check-toml 32 | - id: check-vcs-permalinks 33 | - id: check-xml 34 | - id: check-yaml 35 | - id: debug-statements 36 | - id: destroyed-symlinks 37 | - id: detect-private-key 38 | - id: end-of-file-fixer 39 | types: [python] 40 | - id: fix-byte-order-marker 41 | - id: mixed-line-ending 42 | - id: name-tests-test 43 | args: [--pytest-test-first] 44 | - id: trailing-whitespace 45 | types: [python] 46 | - repo: local 47 | hooks: 48 | - id: commitizen 49 | name: commitizen 50 | entry: cz check 51 | args: [--commit-msg-file] 52 | require_serial: true 53 | language: system 54 | stages: [commit-msg] 55 | - id: ruff-check 56 | name: ruff check 57 | entry: ruff check 58 | args: ["--force-exclude", "--extend-fixable=ERA001,F401,F841,T201,T203"] 59 | require_serial: true 60 | language: system 61 | types_or: [python, pyi] 62 | - id: ruff-format 63 | name: ruff format 64 | entry: ruff format 65 | args: [--force-exclude] 66 | require_serial: true 67 | language: system 68 | types_or: [python, pyi] 69 | - id: mypy 70 | name: mypy 71 | entry: mypy 72 | language: system 73 | types: [python] 74 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.7.0 (2025-03-17) 2 | 3 | ### Feat 4 | 5 | - replace post-processing with declarative optimization (#112) 6 | - compute optimal sentence boundaries (#110) 7 | - migrate from poetry-cookiecutter to substrate (#98) 8 | - make llama-cpp-python an optional dependency (#97) 9 | - add ability to directly insert Markdown content into the database (#96) 10 | - make importing faster (#86) 11 | 12 | ### Fix 13 | 14 | - fix CLI entrypoint regression (#111) 15 | - lazily raise module not found for optional deps (#109) 16 | - revert pandoc extra name (#106) 17 | - avoid conflicting chunk ids (#93) 18 | 19 | ## v0.6.2 (2025-01-06) 20 | 21 | ### Fix 22 | 23 | - remove unnecessary stop sequence (#84) 24 | 25 | ## v0.6.1 (2025-01-06) 26 | 27 | ### Fix 28 | 29 | - fix Markdown heading boundary probas (#81) 30 | - improve (re)insertion speed (#80) 31 | - **deps**: exclude litellm versions that break get_model_info (#78) 32 | - conditionally enable `LlamaRAMCache` (#83) 33 | 34 | ## v0.6.0 (2025-01-05) 35 | 36 | ### Feat 37 | 38 | - add support for Python 3.12 (#69) 39 | - upgrade from xx_sent_ud_sm to SaT (#74) 40 | - add streaming tool use to llama-cpp-python (#71) 41 | - improve sentence splitting (#72) 42 | 43 | ## v0.5.1 (2024-12-18) 44 | 45 | ### Fix 46 | 47 | - improve output for empty databases (#68) 48 | 49 | ## v0.5.0 (2024-12-17) 50 | 51 | ### Feat 52 | 53 | - add MCP server (#67) 54 | - let LLM choose whether to retrieve context (#62) 55 | 56 | ### Fix 57 | 58 | - support pgvector v0.7.0+ (#63) 59 | 60 | ## v0.4.1 (2024-12-05) 61 | 62 | ### Fix 63 | 64 | - add and enable OpenAI strict mode (#55) 65 | - support embedding with LiteLLM for Ragas (#56) 66 | 67 | ## v0.4.0 (2024-12-04) 68 | 69 | ### Feat 70 | 71 | - improve late chunking and optimize pgvector settings (#51) 72 | 73 | ## v0.3.0 (2024-12-03) 74 | 75 | ### Feat 76 | 77 | - support prompt caching and apply Anthropic's long-context prompt format (#52) 78 | 79 | ## v0.2.1 (2024-11-22) 80 | 81 | ### Fix 82 | 83 | - improve structured output extraction and query adapter updates (#34) 84 | - upgrade rerankers and remove flashrank patch (#47) 85 | - improve unpacking of keyword search results (#46) 86 | - add fallbacks for model info (#44) 87 | 88 | ## v0.2.0 (2024-10-21) 89 | 90 | ### Feat 91 | 92 | - add Chainlit frontend (#33) 93 | 94 | ## v0.1.4 (2024-10-15) 95 | 96 | ### Fix 97 | 98 | - fix optimal chunking edge cases (#32) 99 | 100 | ## v0.1.3 (2024-10-13) 101 | 102 | ### Fix 103 | 104 | - upgrade pdftext (#30) 105 | - improve chunk and segment ordering (#29) 106 | 107 | ## v0.1.2 (2024-10-08) 108 | 109 | ### Fix 110 | 111 | - avoid pdftext v0.3.11 (#27) 112 | 113 | ## v0.1.1 (2024-10-07) 114 | 115 | ### Fix 116 | 117 | - patch rerankers flashrank issue (#22) 118 | 119 | ## v0.1.0 (2024-10-07) 120 | 121 | ### Feat 122 | 123 | - add reranking (#20) 124 | - add LiteLLM and late chunking (#19) 125 | - add PostgreSQL support (#18) 126 | - make query adapter minimally invasive (#16) 127 | - upgrade default CPU model to Phi-3.5-mini (#15) 128 | - add evaluation (#14) 129 | - infer missing font sizes (#12) 130 | - automatically adjust number of RAG contexts (#10) 131 | - improve exception feedback for extraction (#9) 132 | - optimize config for CPU and GPU (#7) 133 | - simplify document insertion (#6) 134 | - implement basic features (#2) 135 | - initial commit 136 | 137 | ### Fix 138 | 139 | - lazily import optional dependencies (#11) 140 | - improve indexing of multiple documents (#8) 141 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM ghcr.io/astral-sh/uv:python3.10-bookworm AS dev 3 | 4 | # Create and activate a virtual environment [1]. 5 | # [1] https://docs.astral.sh/uv/concepts/projects/config/#project-environment-path 6 | ENV VIRTUAL_ENV=/opt/venv 7 | ENV PATH=$VIRTUAL_ENV/bin:$PATH 8 | ENV UV_PROJECT_ENVIRONMENT=$VIRTUAL_ENV 9 | 10 | # Tell Git that the workspace is safe to avoid 'detected dubious ownership in repository' warnings. 11 | RUN git config --system --add safe.directory '*' 12 | 13 | # Create a non-root user and give it passwordless sudo access [1]. 14 | # [1] https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user 15 | RUN --mount=type=cache,target=/var/cache/apt/ \ 16 | --mount=type=cache,target=/var/lib/apt/ \ 17 | groupadd --gid 1000 user && \ 18 | useradd --create-home --no-log-init --gid 1000 --uid 1000 --shell /usr/bin/bash user && \ 19 | chown user:user /opt/ && \ 20 | apt-get update && apt-get install --no-install-recommends --yes sudo clang libomp-dev && \ 21 | echo 'user ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/user && chmod 0440 /etc/sudoers.d/user 22 | USER user 23 | 24 | # Configure the non-root user's shell. 25 | RUN mkdir ~/.history/ && \ 26 | echo 'HISTFILE=~/.history/.bash_history' >> ~/.bashrc && \ 27 | echo 'bind "\"\e[A\": history-search-backward"' >> ~/.bashrc && \ 28 | echo 'bind "\"\e[B\": history-search-forward"' >> ~/.bashrc && \ 29 | echo 'eval "$(starship init bash)"' >> ~/.bashrc 30 | 31 | # Explicitly configure compilers for llama-cpp-python. 32 | ENV CC=clang 33 | ENV CXX=clang++ 34 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | 3 | devcontainer: 4 | build: 5 | target: dev 6 | environment: 7 | - CI 8 | - OPENAI_API_KEY 9 | depends_on: 10 | - postgres 11 | volumes: 12 | - ..:/workspaces 13 | - command-history-volume:/home/user/.history/ 14 | 15 | postgres: 16 | image: pgvector/pgvector:pg17 17 | environment: 18 | POSTGRES_USER: raglite_user 19 | POSTGRES_PASSWORD: raglite_password 20 | tmpfs: 21 | - /var/lib/postgresql/data 22 | 23 | volumes: 24 | command-history-volume: 25 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] # https://docs.astral.sh/uv/concepts/projects/config/#build-systems 2 | requires = ["hatchling>=1.27.0"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] # https://packaging.python.org/en/latest/specifications/pyproject-toml/ 6 | name = "raglite" 7 | version = "0.7.0" 8 | description = "A Python toolkit for Retrieval-Augmented Generation (RAG) with DuckDB or PostgreSQL." 9 | readme = "README.md" 10 | authors = [ 11 | { name = "Laurent Sorber", email = "laurent@superlinear.eu" }, 12 | ] 13 | requires-python = ">=3.10,<4.0" 14 | dependencies = [ 15 | # Configuration: 16 | "platformdirs (>=4.0.0)", 17 | # Markdown conversion: 18 | "pdftext (>=0.4.1)", 19 | "scikit-learn (>=1.4.2)", 20 | # Sentence and chunk splitting: 21 | "markdown-it-py (>=3.0.0)", 22 | "numpy (>=1.26.4,<2.0.0)", 23 | "scipy (>=1.11.2,!=1.15.0.*,!=1.15.1,!=1.15.2)", 24 | "wtpsplit-lite (>=0.1.0)", 25 | # Large Language Models: 26 | "huggingface-hub (>=0.31.2)", 27 | "litellm (>=1.60.2)", 28 | "pydantic (>=2.7.0)", 29 | # Reranking: 30 | "langdetect (>=1.0.9)", 31 | "rerankers[api,flashrank] (>=0.10.0)", 32 | # Storage: 33 | "duckdb (>=1.1.3)", 34 | "duckdb-engine (>=0.16.0)", 35 | "pg8000 (>=1.31.2)", 36 | "sqlmodel-slim (>=0.0.21)", 37 | # Progress: 38 | "tqdm (>=4.66.0)", 39 | # CLI: 40 | "typer (>=0.15.1)", 41 | # Model Context Protocol: 42 | "fastmcp (>=2.0.0)", 43 | # Utilities: 44 | "packaging (>=23.0)", 45 | ] 46 | 47 | [project.scripts] # https://docs.astral.sh/uv/concepts/projects/config/#command-line-interfaces 48 | raglite = "raglite._cli:cli" 49 | 50 | [project.urls] # https://packaging.python.org/en/latest/specifications/well-known-project-urls/#well-known-labels 51 | homepage = "https://github.com/superlinear-ai/raglite" 52 | source = "https://github.com/superlinear-ai/raglite" 53 | changelog = "https://github.com/superlinear-ai/raglite/blob/main/CHANGELOG.md" 54 | releasenotes = "https://github.com/superlinear-ai/raglite/releases" 55 | documentation = "https://github.com/superlinear-ai/raglite" 56 | issues = "https://github.com/superlinear-ai/raglite/issues" 57 | 58 | [dependency-groups] # https://docs.astral.sh/uv/concepts/projects/dependencies/#development-dependencies 59 | dev = [ 60 | "commitizen (>=4.3.0)", 61 | "coverage[toml] (>=7.6.10)", 62 | "ipykernel (>=6.29.4)", 63 | "ipython (>=8.18.0)", 64 | "ipywidgets (>=8.1.2)", 65 | "mypy (>=1.14.1)", 66 | "pdoc (>=15.0.1)", 67 | "poethepoet (>=0.32.1)", 68 | "pre-commit (>=4.0.1)", 69 | "pytest (>=8.3.4)", 70 | "pytest-mock (>=3.14.0)", 71 | "pytest-xdist (>=3.6.1)", 72 | "ruff (>=0.10.0)", 73 | "typeguard (>=4.4.1)", 74 | ] 75 | 76 | [project.optional-dependencies] # https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#dependencies-optional-dependencies 77 | # Frontend: 78 | chainlit = ["chainlit (>=2.0.0)"] 79 | # Large Language Models: 80 | llama-cpp-python = ["llama-cpp-python (>=0.3.9)"] 81 | # Markdown conversion: 82 | pandoc = ["pypandoc-binary (>=1.13)"] 83 | # Evaluation: 84 | ragas = ["pandas (>=2.1.1)", "ragas (>=0.1.12)"] 85 | # Benchmarking: 86 | bench = [ 87 | "faiss-cpu (>=1.11.0)", 88 | "ir_datasets (>=0.5.10)", 89 | "ir_measures (>=0.3.7)", 90 | "llama-index (>=0.12.39)", 91 | "llama-index-vector-stores-faiss (>=0.4.0)", 92 | "openai (>=1.75.0)", 93 | "pandas (>=2.1.1)", 94 | "python-slugify (>=8.0.4)", 95 | ] 96 | 97 | [tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/ 98 | bump_message = "bump: v$current_version → v$new_version" 99 | tag_format = "v$version" 100 | update_changelog_on_bump = true 101 | version_provider = "uv" 102 | 103 | [tool.coverage.report] # https://coverage.readthedocs.io/en/latest/config.html#report 104 | fail_under = 50 105 | precision = 1 106 | show_missing = true 107 | skip_covered = true 108 | 109 | [tool.coverage.run] # https://coverage.readthedocs.io/en/latest/config.html#run 110 | branch = true 111 | command_line = "--module pytest" 112 | data_file = "reports/.coverage" 113 | source = ["src"] 114 | 115 | [tool.coverage.xml] # https://coverage.readthedocs.io/en/latest/config.html#xml 116 | output = "reports/coverage.xml" 117 | 118 | [tool.mypy] # https://mypy.readthedocs.io/en/latest/config_file.html 119 | junit_xml = "reports/mypy.xml" 120 | strict = true 121 | disallow_subclassing_any = false 122 | disallow_untyped_decorators = false 123 | ignore_missing_imports = true 124 | pretty = true 125 | show_column_numbers = true 126 | show_error_codes = true 127 | show_error_context = true 128 | warn_unreachable = true 129 | 130 | [tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref 131 | addopts = "--color=yes --doctest-modules --ignore=src/raglite/_chainlit.py --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml" 132 | filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::pytest.PytestUnraisableExceptionWarning"] 133 | markers = ["slow: mark test as slow"] 134 | testpaths = ["src", "tests"] 135 | xfail_strict = true 136 | 137 | [tool.ruff] # https://docs.astral.sh/ruff/settings/ 138 | fix = true 139 | line-length = 100 140 | src = ["src", "tests"] 141 | target-version = "py310" 142 | 143 | [tool.ruff.format] 144 | docstring-code-format = true 145 | 146 | [tool.ruff.lint] 147 | select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"] 148 | ignore = ["D203", "D213", "E501", "RET504", "RUF002", "RUF022", "S101", "S307", "TC004"] 149 | unfixable = ["ERA001", "F401", "F841", "T201", "T203"] 150 | 151 | [tool.ruff.lint.flake8-tidy-imports] 152 | ban-relative-imports = "all" 153 | 154 | [tool.ruff.lint.pycodestyle] 155 | max-doc-length = 100 156 | 157 | [tool.ruff.lint.pydocstyle] 158 | convention = "numpy" 159 | 160 | [tool.poe.executor] # https://github.com/nat-n/poethepoet 161 | type = "simple" 162 | 163 | [tool.poe.tasks] 164 | 165 | [tool.poe.tasks.docs] 166 | help = "Generate this package's docs" 167 | cmd = """ 168 | pdoc 169 | --docformat $docformat 170 | --output-directory $outputdirectory 171 | raglite 172 | """ 173 | 174 | [[tool.poe.tasks.docs.args]] 175 | help = "The docstring style (default: numpy)" 176 | name = "docformat" 177 | options = ["--docformat"] 178 | default = "numpy" 179 | 180 | [[tool.poe.tasks.docs.args]] 181 | help = "The output directory (default: docs)" 182 | name = "outputdirectory" 183 | options = ["--output-directory"] 184 | default = "docs" 185 | 186 | [tool.poe.tasks.lint] 187 | help = "Lint this package" 188 | cmd = """ 189 | pre-commit run 190 | --all-files 191 | --color always 192 | """ 193 | 194 | [tool.poe.tasks.test] 195 | help = "Test this package" 196 | 197 | [[tool.poe.tasks.test.sequence]] 198 | cmd = "coverage run" 199 | 200 | [[tool.poe.tasks.test.sequence]] 201 | cmd = "coverage report" 202 | 203 | [[tool.poe.tasks.test.sequence]] 204 | cmd = "coverage xml" 205 | -------------------------------------------------------------------------------- /src/raglite/__init__.py: -------------------------------------------------------------------------------- 1 | """RAGLite.""" 2 | 3 | from raglite._config import RAGLiteConfig 4 | from raglite._database import Document 5 | from raglite._eval import answer_evals, evaluate, insert_evals 6 | from raglite._insert import insert_documents 7 | from raglite._query_adapter import update_query_adapter 8 | from raglite._rag import add_context, async_rag, rag, retrieve_context 9 | from raglite._search import ( 10 | hybrid_search, 11 | keyword_search, 12 | rerank_chunks, 13 | retrieve_chunk_spans, 14 | retrieve_chunks, 15 | search_and_rerank_chunk_spans, 16 | search_and_rerank_chunks, 17 | vector_search, 18 | ) 19 | 20 | __all__ = [ 21 | # Config 22 | "RAGLiteConfig", 23 | # Insert 24 | "Document", 25 | "insert_documents", 26 | # Search 27 | "hybrid_search", 28 | "keyword_search", 29 | "vector_search", 30 | "retrieve_chunks", 31 | "retrieve_chunk_spans", 32 | "rerank_chunks", 33 | "search_and_rerank_chunks", 34 | "search_and_rerank_chunk_spans", 35 | # RAG 36 | "retrieve_context", 37 | "add_context", 38 | "async_rag", 39 | "rag", 40 | # Query adapter 41 | "update_query_adapter", 42 | # Evaluate 43 | "answer_evals", 44 | "insert_evals", 45 | "evaluate", 46 | ] 47 | -------------------------------------------------------------------------------- /src/raglite/_bench.py: -------------------------------------------------------------------------------- 1 | """Benchmarking with TREC runs.""" 2 | 3 | import warnings 4 | from abc import ABC, abstractmethod 5 | from collections.abc import Generator 6 | from dataclasses import replace 7 | from functools import cached_property 8 | from pathlib import Path 9 | from typing import Any 10 | 11 | from ir_datasets.datasets.base import Dataset 12 | from ir_measures import ScoredDoc, read_trec_run 13 | from platformdirs import user_data_dir 14 | from slugify import slugify 15 | from tqdm.auto import tqdm 16 | 17 | from raglite._config import RAGLiteConfig 18 | 19 | 20 | class IREvaluator(ABC): 21 | def __init__( 22 | self, 23 | dataset: Dataset, 24 | *, 25 | num_results: int = 10, 26 | insert_variant: str | None = None, 27 | search_variant: str | None = None, 28 | ) -> None: 29 | self.dataset = dataset 30 | self.num_results = num_results 31 | self.insert_variant = insert_variant 32 | self.search_variant = search_variant 33 | self.insert_id = ( 34 | slugify(self.__class__.__name__.lower().replace("evaluator", "")) 35 | + (f"_{slugify(insert_variant)}" if insert_variant else "") 36 | + f"_{slugify(dataset.docs_namespace())}" 37 | ) 38 | self.search_id = ( 39 | self.insert_id 40 | + f"@{num_results}" 41 | + (f"_{slugify(search_variant)}" if search_variant else "") 42 | ) 43 | self.cwd = Path(user_data_dir("raglite", ensure_exists=True)) 44 | 45 | @abstractmethod 46 | def insert_documents(self, max_workers: int | None = None) -> None: 47 | """Insert all of the dataset's documents into the search index.""" 48 | raise NotImplementedError 49 | 50 | @abstractmethod 51 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]: 52 | """Search for documents given a query.""" 53 | raise NotImplementedError 54 | 55 | @property 56 | def trec_run_filename(self) -> str: 57 | return f"{self.search_id}.trec" 58 | 59 | @property 60 | def trec_run_filepath(self) -> Path: 61 | return self.cwd / self.trec_run_filename 62 | 63 | def score(self) -> Generator[ScoredDoc, None, None]: 64 | """Read or compute a TREC run.""" 65 | if self.trec_run_filepath.exists(): 66 | yield from read_trec_run(self.trec_run_filepath.as_posix()) # type: ignore[no-untyped-call] 67 | return 68 | if not self.search("q0", next(self.dataset.queries_iter()).text): 69 | self.insert_documents() 70 | with self.trec_run_filepath.open(mode="w") as trec_run_file: 71 | for query in tqdm( 72 | self.dataset.queries_iter(), 73 | total=self.dataset.queries_count(), 74 | desc="Running queries", 75 | unit="query", 76 | dynamic_ncols=True, 77 | ): 78 | results = self.search(query.query_id, query.text, num_results=self.num_results) 79 | unique_results = {doc.doc_id: doc for doc in sorted(results, key=lambda d: d.score)} 80 | top_results = sorted(unique_results.values(), key=lambda d: d.score, reverse=True) 81 | top_results = top_results[: self.num_results] 82 | for rank, scored_doc in enumerate(top_results): 83 | trec_line = f"{query.query_id} 0 {scored_doc.doc_id} {rank} {scored_doc.score} {self.trec_run_filename}\n" 84 | trec_run_file.write(trec_line) 85 | yield scored_doc 86 | 87 | 88 | class RAGLiteEvaluator(IREvaluator): 89 | def __init__( 90 | self, 91 | dataset: Dataset, 92 | *, 93 | num_results: int = 10, 94 | insert_variant: str | None = None, 95 | search_variant: str | None = None, 96 | config: RAGLiteConfig | None = None, 97 | ): 98 | super().__init__( 99 | dataset, 100 | num_results=num_results, 101 | insert_variant=insert_variant, 102 | search_variant=search_variant, 103 | ) 104 | self.db_filepath = self.cwd / f"{self.insert_id}.db" 105 | db_url = f"duckdb:///{self.db_filepath.as_posix()}" 106 | self.config = replace(config or RAGLiteConfig(), db_url=db_url) 107 | 108 | def insert_documents(self, max_workers: int | None = None) -> None: 109 | from raglite import Document, insert_documents 110 | 111 | documents = [ 112 | Document.from_text(doc.text, id=doc.doc_id) for doc in self.dataset.docs_iter() 113 | ] 114 | insert_documents(documents, max_workers=max_workers, config=self.config) 115 | 116 | def update_query_adapter(self, num_evals: int = 1024) -> None: 117 | from raglite import insert_evals, update_query_adapter 118 | from raglite._database import IndexMetadata 119 | 120 | if ( 121 | self.config.vector_search_query_adapter 122 | and IndexMetadata.get(config=self.config).get("query_adapter") is None 123 | ): 124 | insert_evals(num_evals=num_evals, config=self.config) 125 | update_query_adapter(config=self.config) 126 | 127 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]: 128 | from raglite import retrieve_chunks, vector_search 129 | 130 | self.update_query_adapter() 131 | chunk_ids, scores = vector_search(query, num_results=2 * num_results, config=self.config) 132 | chunks = retrieve_chunks(chunk_ids, config=self.config) 133 | scored_docs = [ 134 | ScoredDoc(query_id=query_id, doc_id=chunk.document.id, score=score) 135 | for chunk, score in zip(chunks, scores, strict=True) 136 | ] 137 | return scored_docs 138 | 139 | 140 | class LlamaIndexEvaluator(IREvaluator): 141 | def __init__( 142 | self, 143 | dataset: Dataset, 144 | *, 145 | num_results: int = 10, 146 | insert_variant: str | None = None, 147 | search_variant: str | None = None, 148 | ): 149 | super().__init__( 150 | dataset, 151 | num_results=num_results, 152 | insert_variant=insert_variant, 153 | search_variant=search_variant, 154 | ) 155 | self.embedder = "text-embedding-3-large" 156 | self.embedder_dim = 3072 157 | self.persist_path = self.cwd / self.insert_id 158 | 159 | def insert_documents(self, max_workers: int | None = None) -> None: 160 | # Adapted from https://docs.llamaindex.ai/en/stable/examples/vector_stores/FaissIndexDemo/. 161 | import faiss 162 | from llama_index.core import Document, StorageContext, VectorStoreIndex 163 | from llama_index.embeddings.openai import OpenAIEmbedding 164 | from llama_index.vector_stores.faiss import FaissVectorStore 165 | 166 | self.persist_path.mkdir(parents=True, exist_ok=True) 167 | faiss_index = faiss.IndexHNSWFlat(self.embedder_dim, 32, faiss.METRIC_INNER_PRODUCT) 168 | vector_store = FaissVectorStore(faiss_index=faiss_index) 169 | index = VectorStoreIndex.from_documents( 170 | [ 171 | Document(id_=doc.doc_id, text=doc.text, metadata={"filename": doc.doc_id}) 172 | for doc in self.dataset.docs_iter() 173 | ], 174 | storage_context=StorageContext.from_defaults(vector_store=vector_store), 175 | embed_model=OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim), 176 | show_progress=True, 177 | ) 178 | index.storage_context.persist(persist_dir=self.persist_path) 179 | 180 | @cached_property 181 | def index(self) -> Any: 182 | from llama_index.core import StorageContext, load_index_from_storage 183 | from llama_index.embeddings.openai import OpenAIEmbedding 184 | from llama_index.vector_stores.faiss import FaissVectorStore 185 | 186 | vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.persist_path.as_posix()) 187 | storage_context = StorageContext.from_defaults( 188 | vector_store=vector_store, persist_dir=self.persist_path.as_posix() 189 | ) 190 | embed_model = OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim) 191 | index = load_index_from_storage(storage_context, embed_model=embed_model) 192 | return index 193 | 194 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]: 195 | if not self.persist_path.exists(): 196 | self.insert_documents() 197 | retriever = self.index.as_retriever(similarity_top_k=2 * num_results) 198 | nodes = retriever.retrieve(query) 199 | scored_docs = [ 200 | ScoredDoc( 201 | query_id=query_id, 202 | doc_id=node.metadata.get("filename", node.id_), 203 | score=node.score if node.score is not None else 1.0, 204 | ) 205 | for node in nodes 206 | ] 207 | return scored_docs 208 | 209 | 210 | class OpenAIVectorStoreEvaluator(IREvaluator): 211 | def __init__( 212 | self, 213 | dataset: Dataset, 214 | *, 215 | num_results: int = 10, 216 | insert_variant: str | None = None, 217 | search_variant: str | None = None, 218 | ): 219 | super().__init__( 220 | dataset, 221 | num_results=num_results, 222 | insert_variant=insert_variant, 223 | search_variant=search_variant, 224 | ) 225 | self.vector_store_name = dataset.docs_namespace() + ( 226 | f"_{slugify(insert_variant)}" if insert_variant else "" 227 | ) 228 | 229 | @cached_property 230 | def client(self) -> Any: 231 | import openai 232 | 233 | return openai.OpenAI() 234 | 235 | @property 236 | def vector_store_id(self) -> str | None: 237 | vector_stores = self.client.vector_stores.list() 238 | vector_store = next((vs for vs in vector_stores if vs.name == self.vector_store_name), None) 239 | if vector_store is None: 240 | return None 241 | if vector_store.file_counts.failed > 0: 242 | warnings.warn( 243 | f"Vector store {vector_store.name} has {vector_store.file_counts.failed} failed files.", 244 | stacklevel=2, 245 | ) 246 | if vector_store.file_counts.in_progress > 0: 247 | error_message = f"Vector store {vector_store.name} has {vector_store.file_counts.in_progress} files in progress." 248 | raise RuntimeError(error_message) 249 | return vector_store.id # type: ignore[no-any-return] 250 | 251 | def insert_documents(self, max_workers: int | None = None) -> None: 252 | import tempfile 253 | from pathlib import Path 254 | 255 | vector_store = self.client.vector_stores.create(name=self.vector_store_name) 256 | files, max_files_per_batch = [], 32 257 | with tempfile.TemporaryDirectory() as temp_dir: 258 | for i, doc in tqdm( 259 | enumerate(self.dataset.docs_iter()), 260 | total=self.dataset.docs_count(), 261 | desc="Inserting documents", 262 | unit="document", 263 | dynamic_ncols=True, 264 | ): 265 | if not doc.text.strip(): 266 | continue 267 | temp_file = Path(temp_dir) / f"{slugify(doc.doc_id)}.txt" 268 | temp_file.write_text(doc.text) 269 | files.append(temp_file.open("rb")) 270 | if len(files) == max_files_per_batch or (i == self.dataset.docs_count() - 1): 271 | self.client.vector_stores.file_batches.upload_and_poll( 272 | vector_store_id=vector_store.id, files=files, max_concurrency=max_workers 273 | ) 274 | for f in files: 275 | f.close() 276 | files = [] 277 | 278 | @cached_property 279 | def filename_to_doc_id(self) -> dict[str, str]: 280 | return {f"{slugify(doc.doc_id)}.txt": doc.doc_id for doc in self.dataset.docs_iter()} 281 | 282 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]: 283 | if not self.vector_store_id: 284 | return [] 285 | response = self.client.vector_stores.search( 286 | vector_store_id=self.vector_store_id, query=query, max_num_results=2 * num_results 287 | ) 288 | scored_docs = [ 289 | ScoredDoc( 290 | query_id=query_id, 291 | doc_id=self.filename_to_doc_id[result.filename], 292 | score=result.score, 293 | ) 294 | for result in response 295 | ] 296 | return scored_docs 297 | -------------------------------------------------------------------------------- /src/raglite/_chainlit.py: -------------------------------------------------------------------------------- 1 | """Chainlit frontend for RAGLite.""" 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import chainlit as cl 7 | from chainlit.input_widget import Switch, TextInput 8 | 9 | from raglite import ( 10 | Document, 11 | RAGLiteConfig, 12 | async_rag, 13 | hybrid_search, 14 | insert_documents, 15 | rerank_chunks, 16 | ) 17 | from raglite._markdown import document_to_markdown 18 | 19 | async_insert_documents = cl.make_async(insert_documents) 20 | async_hybrid_search = cl.make_async(hybrid_search) 21 | async_rerank_chunks = cl.make_async(rerank_chunks) 22 | 23 | 24 | @cl.on_chat_start 25 | async def start_chat() -> None: 26 | """Initialize the chat.""" 27 | # Set tokenizers parallelism to avoid a deadlock warning. 28 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 29 | # Add Chainlit settings with which the user can configure the RAGLite config. 30 | default_config = RAGLiteConfig() 31 | config = RAGLiteConfig( 32 | db_url=os.environ.get("RAGLITE_DB_URL", default_config.db_url), 33 | llm=os.environ.get("RAGLITE_LLM", default_config.llm), 34 | embedder=os.environ.get("RAGLITE_EMBEDDER", default_config.embedder), 35 | ) 36 | settings = await cl.ChatSettings( # type: ignore[no-untyped-call] 37 | [ 38 | TextInput(id="db_url", label="Database URL", initial=str(config.db_url)), 39 | TextInput(id="llm", label="LLM", initial=config.llm), 40 | TextInput(id="embedder", label="Embedder", initial=config.embedder), 41 | Switch(id="vector_search_query_adapter", label="Query adapter", initial=True), 42 | ] 43 | ).send() 44 | await update_config(settings) 45 | 46 | 47 | @cl.on_settings_update # type: ignore[arg-type] 48 | async def update_config(settings: cl.ChatSettings) -> None: 49 | """Update the RAGLite config.""" 50 | # Update the RAGLite config given the Chainlit settings. 51 | config = RAGLiteConfig( 52 | db_url=settings["db_url"], # type: ignore[index] 53 | llm=settings["llm"], # type: ignore[index] 54 | embedder=settings["embedder"], # type: ignore[index] 55 | vector_search_query_adapter=settings["vector_search_query_adapter"], # type: ignore[index] 56 | ) 57 | cl.user_session.set("config", config) # type: ignore[no-untyped-call] 58 | # Run a search to prime the pipeline if it's a local pipeline. 59 | if config.embedder.startswith("llama-cpp-python"): 60 | query = "Hello world" 61 | chunk_ids, _ = await async_hybrid_search(query=query, config=config) 62 | _ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config) 63 | 64 | 65 | @cl.on_message 66 | async def handle_message(user_message: cl.Message) -> None: 67 | """Respond to a user message.""" 68 | # Get the config and message history from the user session. 69 | config: RAGLiteConfig = cl.user_session.get("config") # type: ignore[no-untyped-call] 70 | # Determine what to do with the attachments. 71 | inline_attachments = [] 72 | for file in user_message.elements: 73 | if file.path: 74 | doc_md = document_to_markdown(Path(file.path)) 75 | if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3): 76 | # Document is small enough to attach to the context. 77 | inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}") 78 | else: 79 | # Document is too large and must be inserted into the database. 80 | async with cl.Step(name="insert", type="run") as step: 81 | step.input = Path(file.path).name 82 | document = Document.from_path(Path(file.path)) 83 | await async_insert_documents([document], config=config) 84 | # Append any inline attachments to the user prompt. 85 | user_prompt = ( 86 | "\n\n".join( 87 | f'\n{attachment.strip()}\n' 88 | for i, attachment in enumerate(inline_attachments) 89 | ) 90 | + f"\n\n{user_message.content}" 91 | ).strip() 92 | # Stream the LLM response. 93 | assistant_message = cl.Message(content="") 94 | chunk_spans = [] 95 | messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] 96 | messages.append({"role": "user", "content": user_prompt}) 97 | async for token in async_rag( 98 | messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config 99 | ): 100 | await assistant_message.stream_token(token) 101 | # Append RAG sources, if any. 102 | if chunk_spans: 103 | rag_sources: dict[str, list[str]] = {} 104 | for chunk_span in chunk_spans: 105 | rag_sources.setdefault(chunk_span.document.id, []) 106 | rag_sources[chunk_span.document.id].append(str(chunk_span)) 107 | assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks. 108 | f"[{i + 1}]" for i in range(len(rag_sources)) 109 | ) 110 | assistant_message.elements = [ # Markdown content is rendered in sidebar. 111 | cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc] 112 | for i, (_, content) in enumerate(rag_sources.items()) 113 | ] 114 | await assistant_message.update() # type: ignore[no-untyped-call] 115 | -------------------------------------------------------------------------------- /src/raglite/_cli.py: -------------------------------------------------------------------------------- 1 | """RAGLite CLI.""" 2 | 3 | import json 4 | import os 5 | from typing import ClassVar 6 | 7 | import typer 8 | from pydantic_settings import BaseSettings, SettingsConfigDict 9 | 10 | from raglite._config import RAGLiteConfig 11 | 12 | 13 | class RAGLiteCLIConfig(BaseSettings): 14 | """RAGLite CLI config.""" 15 | 16 | model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict( 17 | env_prefix="RAGLITE_", env_file=".env", extra="allow" 18 | ) 19 | 20 | mcp_server_name: str = "RAGLite" 21 | db_url: str = str(RAGLiteConfig().db_url) 22 | llm: str = RAGLiteConfig().llm 23 | embedder: str = RAGLiteConfig().embedder 24 | 25 | 26 | cli = typer.Typer() 27 | cli.add_typer(mcp_cli := typer.Typer(), name="mcp") 28 | 29 | 30 | @cli.callback() 31 | def main( 32 | ctx: typer.Context, 33 | db_url: str = typer.Option(RAGLiteCLIConfig().db_url, help="Database URL"), 34 | llm: str = typer.Option(RAGLiteCLIConfig().llm, help="LiteLLM LLM"), 35 | embedder: str = typer.Option(RAGLiteCLIConfig().embedder, help="LiteLLM embedder"), 36 | ) -> None: 37 | """RAGLite CLI.""" 38 | ctx.obj = {"db_url": db_url, "llm": llm, "embedder": embedder} 39 | 40 | 41 | @cli.command() 42 | def chainlit(ctx: typer.Context) -> None: 43 | """Serve a Chainlit frontend.""" 44 | # Set the environment variables for the Chainlit frontend. 45 | os.environ["RAGLITE_DB_URL"] = ctx.obj["db_url"] 46 | os.environ["RAGLITE_LLM"] = ctx.obj["llm"] 47 | os.environ["RAGLITE_EMBEDDER"] = ctx.obj["embedder"] 48 | # Import Chainlit here as it's an optional dependency. 49 | try: 50 | from chainlit.cli import run_chainlit 51 | except ModuleNotFoundError as error: 52 | error_message = "To serve a Chainlit frontend, please install the `chainlit` extra." 53 | raise ModuleNotFoundError(error_message) from error 54 | # Serve the frontend. 55 | run_chainlit(__file__.replace("_cli.py", "_chainlit.py")) 56 | 57 | 58 | @mcp_cli.command("install") 59 | def install_mcp_server( 60 | ctx: typer.Context, 61 | server_name: str = typer.Option(RAGLiteCLIConfig().mcp_server_name, help="MCP server name"), 62 | ) -> None: 63 | """Install MCP server in the Claude desktop app.""" 64 | from fastmcp.cli.claude import get_claude_config_path 65 | 66 | # Get the Claude config path. 67 | claude_config_path = get_claude_config_path() 68 | if not claude_config_path: 69 | typer.echo( 70 | "Please download the Claude desktop app from https://claude.ai/download before installing an MCP server." 71 | ) 72 | return 73 | claude_config_filepath = claude_config_path / "claude_desktop_config.json" 74 | # Parse the Claude config. 75 | claude_config = ( 76 | json.loads(claude_config_filepath.read_text()) if claude_config_filepath.exists() else {} 77 | ) 78 | # Update the Claude config with the MCP server. 79 | mcp_config = RAGLiteCLIConfig( 80 | mcp_server_name=server_name, 81 | db_url=ctx.obj["db_url"], 82 | llm=ctx.obj["llm"], 83 | embedder=ctx.obj["embedder"], 84 | ) 85 | claude_config["mcpServers"][server_name] = { 86 | "command": "uvx", 87 | "args": [ 88 | "--python", 89 | "3.11", 90 | "--with", 91 | "numpy<2.0.0", # TODO: Remove this constraint when uv no longer needs it to solve the environment. 92 | "raglite", 93 | "mcp", 94 | "run", 95 | ], 96 | "env": { 97 | f"RAGLITE_{key.upper()}" if key in RAGLiteCLIConfig.model_fields else key.upper(): value 98 | for key, value in mcp_config.model_dump().items() 99 | if value 100 | }, 101 | } 102 | # Write the updated Claude config to disk. 103 | claude_config_filepath.write_text(json.dumps(claude_config, indent=2)) 104 | 105 | 106 | @mcp_cli.command("run") 107 | def run_mcp_server( 108 | ctx: typer.Context, 109 | server_name: str = typer.Option(RAGLiteCLIConfig().mcp_server_name, help="MCP server name"), 110 | ) -> None: 111 | """Run MCP server.""" 112 | from raglite._mcp import create_mcp_server 113 | 114 | config = RAGLiteConfig( 115 | db_url=ctx.obj["db_url"], llm=ctx.obj["llm"], embedder=ctx.obj["embedder"] 116 | ) 117 | mcp = create_mcp_server(server_name, config=config) 118 | mcp.run() 119 | 120 | 121 | @cli.command() 122 | def bench( 123 | ctx: typer.Context, 124 | dataset_name: str = typer.Option( 125 | "nano-beir/hotpotqa", "--dataset", "-d", help="Dataset to use from https://ir-datasets.com/" 126 | ), 127 | measure: str = typer.Option( 128 | "AP@10", 129 | "--measure", 130 | "-m", 131 | help="Evaluation measure from https://ir-measur.es/en/latest/measures.html", 132 | ), 133 | ) -> None: 134 | """Run benchmark.""" 135 | import ir_datasets 136 | import ir_measures 137 | import pandas as pd 138 | 139 | from raglite._bench import ( 140 | IREvaluator, 141 | LlamaIndexEvaluator, 142 | OpenAIVectorStoreEvaluator, 143 | RAGLiteEvaluator, 144 | ) 145 | 146 | # Initialise the benchmark. 147 | evaluator: IREvaluator 148 | measures = [ir_measures.parse_measure(measure)] 149 | index, results = [], [] 150 | # Evaluate RAGLite (single-vector) + DuckDB HNSW + text-embedding-3-large. 151 | chunk_max_size = 2048 152 | config = RAGLiteConfig( 153 | embedder="text-embedding-3-large", 154 | chunk_max_size=chunk_max_size, 155 | vector_search_multivector=False, 156 | vector_search_query_adapter=False, 157 | ) 158 | dataset = ir_datasets.load(dataset_name) 159 | evaluator = RAGLiteEvaluator( 160 | dataset, insert_variant=f"single-vector-{chunk_max_size // 4}t", config=config 161 | ) 162 | index.append("RAGLite (single-vector)") 163 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) 164 | # Evaluate RAGLite (multi-vector) + DuckDB HNSW + text-embedding-3-large. 165 | config = RAGLiteConfig( 166 | embedder="text-embedding-3-large", 167 | chunk_max_size=chunk_max_size, 168 | vector_search_multivector=True, 169 | vector_search_query_adapter=False, 170 | ) 171 | dataset = ir_datasets.load(dataset_name) 172 | evaluator = RAGLiteEvaluator( 173 | dataset, insert_variant=f"multi-vector-{chunk_max_size // 4}t", config=config 174 | ) 175 | index.append("RAGLite (multi-vector)") 176 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) 177 | # Evaluate RAGLite (query adapter) + DuckDB HNSW + text-embedding-3-large. 178 | config = RAGLiteConfig( 179 | llm=(llm := "gpt-4.1"), 180 | embedder="text-embedding-3-large", 181 | chunk_max_size=chunk_max_size, 182 | vector_search_multivector=True, 183 | vector_search_query_adapter=True, 184 | ) 185 | dataset = ir_datasets.load(dataset_name) 186 | evaluator = RAGLiteEvaluator( 187 | dataset, 188 | insert_variant=f"multi-vector-{chunk_max_size // 4}t", 189 | search_variant=f"query-adapter-{llm}", 190 | config=config, 191 | ) 192 | index.append("RAGLite (query adapter)") 193 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) 194 | # Evaluate LLamaIndex + FAISS HNSW + text-embedding-3-large. 195 | dataset = ir_datasets.load(dataset_name) 196 | evaluator = LlamaIndexEvaluator(dataset) 197 | index.append("LlamaIndex") 198 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) 199 | # Evaluate OpenAI Vector Store. 200 | dataset = ir_datasets.load(dataset_name) 201 | evaluator = OpenAIVectorStoreEvaluator(dataset) 202 | index.append("OpenAI Vector Store") 203 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) 204 | # Print the results. 205 | results_df = pd.DataFrame.from_records(results, index=index) 206 | typer.echo(results_df) 207 | 208 | 209 | if __name__ == "__main__": 210 | cli() 211 | -------------------------------------------------------------------------------- /src/raglite/_config.py: -------------------------------------------------------------------------------- 1 | """RAGLite config.""" 2 | 3 | import contextlib 4 | import os 5 | from dataclasses import dataclass, field 6 | from io import StringIO 7 | from pathlib import Path 8 | from typing import Literal 9 | 10 | from platformdirs import user_data_dir 11 | from sqlalchemy.engine import URL 12 | 13 | from raglite._lazy_llama import llama_supports_gpu_offload 14 | from raglite._typing import ChunkId, SearchMethod 15 | 16 | # Suppress rerankers output on import until [1] is fixed. 17 | # [1] https://github.com/AnswerDotAI/rerankers/issues/36 18 | with contextlib.redirect_stdout(StringIO()): 19 | from rerankers.models.flashrank_ranker import FlashRankRanker 20 | from rerankers.models.ranker import BaseRanker 21 | 22 | 23 | cache_path = Path(user_data_dir("raglite", ensure_exists=True)) 24 | 25 | 26 | # Lazily load the default search method to avoid circular imports. 27 | # TODO: Replace with search_and_rerank_chunk_spans after benchmarking. 28 | def _vector_search( 29 | query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None 30 | ) -> tuple[list[ChunkId], list[float]]: 31 | from raglite._search import vector_search 32 | 33 | return vector_search(query, num_results=num_results, config=config) 34 | 35 | 36 | @dataclass(frozen=True) 37 | class RAGLiteConfig: 38 | """RAGLite config.""" 39 | 40 | # Database config. 41 | db_url: str | URL = f"duckdb:///{(cache_path / 'raglite.db').as_posix()}" 42 | # LLM config used for generation. 43 | llm: str = field( 44 | default_factory=lambda: ( 45 | "llama-cpp-python/unsloth/Qwen3-8B-GGUF/*Q4_K_M.gguf@8192" 46 | if llama_supports_gpu_offload() 47 | else "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192" 48 | ) 49 | ) 50 | llm_max_tries: int = 4 51 | # Embedder config used for indexing. 52 | embedder: str = field( 53 | default_factory=lambda: ( # Nomic-embed may be better if only English is used. 54 | "llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@512" 55 | if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004 56 | else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512" 57 | ) 58 | ) 59 | embedder_normalize: bool = True 60 | # Chunk config used to partition documents into chunks. 61 | chunk_max_size: int = 2048 # Max number of characters per chunk. 62 | # Vector search config. 63 | vector_search_distance_metric: Literal["cosine", "dot", "l2"] = "cosine" 64 | vector_search_multivector: bool = True 65 | vector_search_query_adapter: bool = True # Only supported for "cosine" and "dot" metrics. 66 | # Reranking config. 67 | reranker: BaseRanker | dict[str, BaseRanker] | None = field( 68 | default_factory=lambda: { 69 | "en": FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0, cache_dir=cache_path), 70 | "other": FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0, cache_dir=cache_path), 71 | }, 72 | compare=False, # Exclude the reranker from comparison to avoid lru_cache misses. 73 | ) 74 | # Search config: you can pick any search method that returns (list[ChunkId], list[float]), 75 | # list[Chunk], or list[ChunkSpan]. 76 | search_method: SearchMethod = field(default=_vector_search, compare=False) 77 | -------------------------------------------------------------------------------- /src/raglite/_embed.py: -------------------------------------------------------------------------------- 1 | """String embedder.""" 2 | 3 | from functools import partial 4 | from typing import Literal 5 | 6 | import numpy as np 7 | from litellm import embedding 8 | from tqdm.auto import tqdm, trange 9 | 10 | from raglite._config import RAGLiteConfig 11 | from raglite._lazy_llama import LLAMA_POOLING_TYPE_NONE, Llama 12 | from raglite._litellm import LlamaCppPythonLLM 13 | from raglite._typing import FloatMatrix, IntVector 14 | 15 | 16 | def embed_strings_with_late_chunking( # noqa: C901,PLR0915 17 | sentences: list[str], *, config: RAGLiteConfig | None = None 18 | ) -> FloatMatrix: 19 | """Embed a document's sentences with late chunking.""" 20 | 21 | def _count_tokens( 22 | sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int] 23 | ) -> list[int]: 24 | # Join the sentences with the sentinel token and tokenise the result. 25 | sentences_tokens = np.asarray( 26 | embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp 27 | ) 28 | # Map all sentinel token variants to the first one. 29 | for sentinel_token in sentinel_tokens[1:]: 30 | sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0] 31 | # Count how many tokens there are in between sentinel tokens to recover the token counts. 32 | sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0] 33 | num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens)) 34 | assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document" 35 | num_tokens_list: list[int] = num_tokens.tolist() 36 | return num_tokens_list 37 | 38 | def _create_segment( 39 | content_start_index: int, 40 | max_tokens_preamble: int, 41 | max_tokens_content: int, 42 | num_tokens: IntVector, 43 | ) -> tuple[int, int]: 44 | # Compute the segment sentence start index so that the segment preamble has no more than 45 | # max_tokens_preamble tokens between [segment_start_index, content_start_index). 46 | cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1]) 47 | offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right") 48 | segment_start_index = content_start_index - int(offset_preamble) 49 | # Allow a larger segment content if we didn't use all of the allowed preamble tokens. 50 | max_tokens_content = max_tokens_content + ( 51 | max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index]) 52 | ) 53 | # Compute the segment sentence end index so that the segment content has no more than 54 | # max_tokens_content tokens between [content_start_index, segment_end_index). 55 | cumsum_forwards = np.cumsum(num_tokens[content_start_index:]) 56 | offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right") 57 | segment_end_index = content_start_index + int(offset_segment) 58 | return segment_start_index, segment_end_index 59 | 60 | # Assert that we're using a llama-cpp-python model, since API-based embedding models don't 61 | # support outputting token-level embeddings. 62 | config = config or RAGLiteConfig() 63 | assert config.embedder.startswith("llama-cpp-python") 64 | embedder = LlamaCppPythonLLM.llm( 65 | config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE 66 | ) 67 | n_ctx = embedder.n_ctx() 68 | n_batch = embedder.n_batch 69 | # Identify the tokens corresponding to a sentinel character. 70 | sentinel_char = "⊕" 71 | sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D" 72 | sentinel_tokens = [ 73 | token 74 | for token in embedder.tokenize(sentinel_test.encode(), add_bos=False) 75 | if sentinel_char in embedder.detokenize([token]).decode() 76 | ] 77 | assert sentinel_tokens, f"Sentinel `{sentinel_char}` not supported by embedder" 78 | # Compute the number of tokens per sentence. We use a method based on a sentinel token to 79 | # minimise the number of calls to embedder.tokenize, which incurs a significant overhead 80 | # (presumably to load the tokenizer) [1]. 81 | # TODO: Make token counting faster and more robust once [1] is fixed. 82 | # [1] https://github.com/abetlen/llama-cpp-python/issues/1763 83 | num_tokens_list: list[int] = [] 84 | sentence_batch, sentence_batch_len = [], 0 85 | for i, sentence in enumerate(sentences): 86 | sentence_batch.append(sentence) 87 | sentence_batch_len += len(sentence) 88 | if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2): 89 | num_tokens_list.extend( 90 | _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens) 91 | ) 92 | sentence_batch, sentence_batch_len = [], 0 93 | num_tokens = np.asarray(num_tokens_list, dtype=np.intp) 94 | # Compute the maximum number of tokens for each segment's preamble and content. 95 | # Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try 96 | # to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch. 97 | # TODO: Improve the context window size once [1] is fixed. 98 | # [1] https://github.com/abetlen/llama-cpp-python/issues/1762 99 | max_tokens = min(n_ctx, n_batch) - 16 100 | max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio. 101 | max_tokens_content = max_tokens - max_tokens_preamble 102 | # Compute a list of segments, each consisting of a preamble and content. 103 | segments = [] 104 | content_start_index = 0 105 | while content_start_index < len(sentences): 106 | segment_start_index, segment_end_index = _create_segment( 107 | content_start_index, max_tokens_preamble, max_tokens_content, num_tokens 108 | ) 109 | segments.append((segment_start_index, content_start_index, segment_end_index)) 110 | content_start_index = segment_end_index 111 | # Embed the segments and apply late chunking. 112 | sentence_embeddings_list: list[FloatMatrix] = [] 113 | if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004 114 | segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True, leave=False) 115 | for segment in segments: 116 | # Get the token embeddings of the entire segment, including preamble and content. 117 | segment_start_index, content_start_index, segment_end_index = segment 118 | segment_sentences = sentences[segment_start_index:segment_end_index] 119 | segment_embedding = np.asarray(embedder.embed("".join(segment_sentences))) 120 | # Split the segment embeddings into embedding matrices per sentence using the largest 121 | # remainder method. 122 | segment_tokens = num_tokens[segment_start_index:segment_end_index] 123 | sentence_size_frac = len(segment_embedding) * (segment_tokens / np.sum(segment_tokens)) 124 | sentence_size = np.floor(sentence_size_frac).astype(np.intp) 125 | remainder = len(segment_embedding) - np.sum(sentence_size) 126 | if remainder > 0: # Assign the remaining tokens to sentences with largest fractional parts. 127 | top_remainders = np.argsort(sentence_size_frac - sentence_size)[-remainder:] 128 | sentence_size[top_remainders] += 1 129 | sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1]) 130 | # Compute the segment sentence embeddings by averaging the token embeddings. 131 | content_sentence_embeddings = [ 132 | np.mean(sentence_matrix, axis=0, keepdims=True) 133 | for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :] 134 | ] 135 | sentence_embeddings_list.append(np.vstack(content_sentence_embeddings)) 136 | sentence_embeddings = np.vstack(sentence_embeddings_list) 137 | # Normalise the sentence embeddings to unit norm and cast to half precision. 138 | if config.embedder_normalize: 139 | sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True) 140 | sentence_embeddings = sentence_embeddings.astype(np.float16) 141 | return sentence_embeddings 142 | 143 | 144 | def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix: 145 | """Embed a batch of text strings.""" 146 | if config.embedder.startswith("llama-cpp-python"): 147 | # LiteLLM doesn't yet support registering a custom embedder, so we handle it here. 148 | # Additionally, we explicitly manually pool the token embeddings to obtain sentence 149 | # embeddings because token embeddings are universally supported, while sequence 150 | # embeddings are only supported by some models. 151 | embedder = LlamaCppPythonLLM.llm( 152 | config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE 153 | ) 154 | embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)]) 155 | else: 156 | # Use LiteLLM's API to embed the batch of strings. 157 | response = embedding(config.embedder, string_batch) 158 | embeddings = np.asarray([item["embedding"] for item in response["data"]]) 159 | # Normalise the embeddings to unit norm and cast to half precision. 160 | if config.embedder_normalize: 161 | eps = np.finfo(embeddings.dtype).eps 162 | norm = np.linalg.norm(embeddings, axis=1, keepdims=True) 163 | embeddings /= np.maximum(norm, eps) 164 | embeddings = embeddings.astype(np.float16) 165 | return embeddings 166 | 167 | 168 | def embed_strings_without_late_chunking( 169 | strings: list[str], *, config: RAGLiteConfig | None = None 170 | ) -> FloatMatrix: 171 | """Embed a list of text strings in batches.""" 172 | config = config or RAGLiteConfig() 173 | batch_size = 96 174 | batch_range = ( 175 | partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True) 176 | if len(strings) > batch_size 177 | else range 178 | ) 179 | batch_embeddings = [ 180 | _embed_string_batch(strings[i : i + batch_size], config=config) 181 | for i in batch_range(0, len(strings), batch_size) 182 | ] 183 | string_embeddings = np.vstack(batch_embeddings) 184 | return string_embeddings 185 | 186 | 187 | def embedding_type(*, config: RAGLiteConfig | None = None) -> Literal["late_chunking", "standard"]: 188 | """Return the type of sentence embeddings.""" 189 | config = config or RAGLiteConfig() 190 | return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "standard" 191 | 192 | 193 | def embed_strings(strings: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix: 194 | """Embed the chunklets of a document as a NumPy matrix with one row per chunklet.""" 195 | config = config or RAGLiteConfig() 196 | if embedding_type(config=config) == "late_chunking": 197 | string_embeddings = embed_strings_with_late_chunking(strings, config=config) 198 | else: 199 | string_embeddings = embed_strings_without_late_chunking(strings, config=config) 200 | return string_embeddings 201 | -------------------------------------------------------------------------------- /src/raglite/_eval.py: -------------------------------------------------------------------------------- 1 | """Generation and evaluation of evals.""" 2 | 3 | import contextlib 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from functools import partial 6 | from random import randint 7 | from typing import TYPE_CHECKING, ClassVar 8 | 9 | import numpy as np 10 | from pydantic import BaseModel, ConfigDict, Field, field_validator 11 | from sqlalchemy import text 12 | from sqlmodel import Session, func, select 13 | from tqdm.auto import tqdm 14 | 15 | if TYPE_CHECKING: 16 | import pandas as pd 17 | 18 | from raglite._config import RAGLiteConfig 19 | from raglite._database import Chunk, Document, Eval, create_database_engine 20 | from raglite._extract import extract_with_llm 21 | from raglite._rag import add_context, rag, retrieve_context 22 | from raglite._search import retrieve_chunk_spans, vector_search 23 | 24 | 25 | def generate_eval(*, max_chunks: int = 20, config: RAGLiteConfig | None = None) -> Eval: 26 | """Generate an eval.""" 27 | 28 | class QuestionResponse(BaseModel): 29 | """A specific question about the content of a set of document contexts.""" 30 | 31 | model_config = ConfigDict( 32 | extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. 33 | ) 34 | question: str = Field( 35 | ..., description="A specific question about the content of a set of document contexts." 36 | ) 37 | system_prompt: ClassVar[str] = """ 38 | You are given a set of contexts extracted from a document. 39 | You are a subject matter expert on the document's topic. 40 | Your task is to generate a question to quiz other subject matter experts on the information in the provided context. 41 | The question MUST satisfy ALL of the following criteria: 42 | - The question SHOULD integrate as much of the provided context as possible. 43 | - The question MUST NOT be a general or open question, but MUST instead be as specific to the provided context as possible. 44 | - The question MUST be completely answerable using ONLY the information in the provided context, without depending on any background information. 45 | - The question MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context. 46 | - The question MUST NOT reference the existence of the context, directly or indirectly. 47 | - The question MUST treat the context as if its contents are entirely part of your working memory. 48 | """.strip() 49 | 50 | @field_validator("question") 51 | @classmethod 52 | def validate_question(cls, value: str) -> str: 53 | """Validate the question.""" 54 | question = value.strip().lower() 55 | if "context" in question or "document" in question or "question" in question: 56 | raise ValueError 57 | if not question.endswith("?"): 58 | raise ValueError 59 | return value 60 | 61 | with Session(create_database_engine(config := config or RAGLiteConfig())) as session: 62 | # Sample a random document from the database. 63 | seed_document = session.exec(select(Document).order_by(func.random()).limit(1)).first() 64 | if seed_document is None: 65 | error_message = "First run `insert_documents()` before generating evals." 66 | raise ValueError(error_message) 67 | # Sample a random chunk from that document. 68 | seed_chunk = session.exec( 69 | select(Chunk) 70 | .where(Chunk.document_id == seed_document.id) 71 | .order_by(func.random()) 72 | .limit(1) 73 | ).first() 74 | assert isinstance(seed_chunk, Chunk) 75 | # Expand the seed chunk into a set of related chunks. 76 | related_chunk_ids, _ = vector_search( 77 | query=np.mean(seed_chunk.embedding_matrix, axis=0), 78 | num_results=randint(1, max_chunks), # noqa: S311 79 | config=config, 80 | ) 81 | related_chunks = [ 82 | str(chunk_spans) 83 | for chunk_spans in retrieve_chunk_spans(related_chunk_ids, config=config) 84 | ] 85 | # Extract a question from the seed chunk's related chunks. 86 | question = extract_with_llm( 87 | QuestionResponse, related_chunks, strict=True, config=config 88 | ).question 89 | # Search for candidate chunks to answer the generated question. 90 | candidate_chunk_ids, _ = vector_search( 91 | query=question, num_results=2 * max_chunks, config=config 92 | ) 93 | candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids] 94 | 95 | # Determine which candidate chunks are relevant to answer the generated question. 96 | class ContextEvalResponse(BaseModel): 97 | """Indicate whether the provided context can be used to answer a given question.""" 98 | 99 | model_config = ConfigDict( 100 | extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. 101 | ) 102 | hit: bool = Field( 103 | ..., 104 | description="True if the provided context contains (a part of) the answer to the given question, false otherwise.", 105 | ) 106 | system_prompt: ClassVar[str] = f""" 107 | You are given a context extracted from a document. 108 | You are a subject matter expert on the document's topic. 109 | Your task is to determine whether the provided context contains (a part of) the answer to this question: "{question}" 110 | An example of a context that does NOT contain (a part of) the answer is a table of contents. 111 | """.strip() 112 | 113 | relevant_chunks = [] 114 | for candidate_chunk in tqdm( 115 | candidate_chunks, 116 | desc="Evaluating chunks", 117 | unit="chunk", 118 | dynamic_ncols=True, 119 | leave=False, 120 | ): 121 | try: 122 | context_eval_response = extract_with_llm( 123 | ContextEvalResponse, str(candidate_chunk), strict=True, config=config 124 | ) 125 | except ValueError: # noqa: PERF203 126 | pass 127 | else: 128 | if context_eval_response.hit: 129 | relevant_chunks.append(candidate_chunk) 130 | if not relevant_chunks: 131 | error_message = "No relevant chunks found to answer the question." 132 | raise ValueError(error_message) 133 | 134 | # Answer the question using the relevant chunks. 135 | class AnswerResponse(BaseModel): 136 | """Answer a question using the provided context.""" 137 | 138 | model_config = ConfigDict( 139 | extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. 140 | ) 141 | answer: str = Field( 142 | ..., 143 | description="A complete answer to the given question using the provided context.", 144 | ) 145 | system_prompt: ClassVar[str] = f""" 146 | You are given a set of contexts extracted from a document. 147 | You are a subject matter expert on the document's topic. 148 | Your task is to generate a complete answer to the following question using the provided context: "{question}" 149 | The answer MUST satisfy ALL of the following criteria: 150 | - The answer MUST integrate as much of the provided context as possible. 151 | - The answer MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context. 152 | - The answer MUST NOT reference the existence of the context, directly or indirectly. 153 | - The answer MUST treat the context as if its contents are entirely part of your working memory. 154 | """.strip() 155 | 156 | answer = extract_with_llm( 157 | AnswerResponse, 158 | [str(relevant_chunk) for relevant_chunk in relevant_chunks], 159 | strict=True, 160 | config=config, 161 | ).answer 162 | # Construct the eval. 163 | eval_ = Eval.from_chunks(question=question, contexts=relevant_chunks, ground_truth=answer) 164 | return eval_ 165 | 166 | 167 | def insert_evals( 168 | *, 169 | num_evals: int = 100, 170 | max_chunks_per_eval: int = 20, 171 | max_workers: int | None = None, 172 | config: RAGLiteConfig | None = None, 173 | ) -> None: 174 | """Generate and insert evals into the database.""" 175 | with ( 176 | Session(engine := create_database_engine(config := config or RAGLiteConfig())) as session, 177 | ThreadPoolExecutor(max_workers=max_workers) as executor, 178 | tqdm(total=num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True) as pbar, 179 | ): 180 | futures = [ 181 | executor.submit(partial(generate_eval, max_chunks=max_chunks_per_eval, config=config)) 182 | for _ in range(num_evals) 183 | ] 184 | for future in as_completed(futures): 185 | with contextlib.suppress(Exception): # Eval generation may fail for various reasons. 186 | eval_ = future.result() 187 | session.add(eval_) 188 | pbar.update() 189 | session.commit() 190 | if engine.dialect.name == "duckdb": 191 | session.execute(text("CHECKPOINT;")) 192 | 193 | 194 | def answer_evals( 195 | num_evals: int = 100, 196 | *, 197 | config: RAGLiteConfig | None = None, 198 | ) -> "pd.DataFrame": 199 | """Read evals from the database and answer them with RAG.""" 200 | try: 201 | import pandas as pd 202 | except ModuleNotFoundError as import_error: 203 | error_message = "To use the `answer_evals` function, please install the `ragas` extra." 204 | raise ModuleNotFoundError(error_message) from import_error 205 | 206 | # Read evals from the database. 207 | with Session(create_database_engine(config := config or RAGLiteConfig())) as session: 208 | evals = session.exec(select(Eval).limit(num_evals)).all() 209 | # Answer evals with RAG. 210 | answers: list[str] = [] 211 | contexts: list[list[str]] = [] 212 | for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): 213 | chunk_spans = retrieve_context(query=eval_.question, config=config) 214 | messages = [add_context(user_prompt=eval_.question, context=chunk_spans)] 215 | response = rag(messages, config=config) 216 | answer = "".join(response) 217 | answers.append(answer) 218 | contexts.append([str(chunk_span) for chunk_span in chunk_spans]) 219 | # Collect the answered evals. 220 | answered_evals: dict[str, list[str] | list[list[str]]] = { 221 | "question": [eval_.question for eval_ in evals], 222 | "answer": answers, 223 | "contexts": contexts, 224 | "ground_truth": [eval_.ground_truth for eval_ in evals], 225 | "ground_truth_contexts": [eval_.contexts for eval_ in evals], 226 | } 227 | answered_evals_df = pd.DataFrame.from_dict(answered_evals) 228 | return answered_evals_df 229 | 230 | 231 | def evaluate( 232 | answered_evals: "pd.DataFrame | int" = 100, config: RAGLiteConfig | None = None 233 | ) -> "pd.DataFrame": 234 | """Evaluate the performance of a set of answered evals with Ragas.""" 235 | try: 236 | import pandas as pd 237 | from datasets import Dataset 238 | from langchain_community.chat_models import ChatLiteLLM 239 | from langchain_community.llms import LlamaCpp 240 | from ragas import RunConfig 241 | from ragas import evaluate as ragas_evaluate 242 | from ragas.embeddings import BaseRagasEmbeddings 243 | 244 | from raglite._config import RAGLiteConfig 245 | from raglite._embed import embed_strings 246 | from raglite._litellm import LlamaCppPythonLLM 247 | except ModuleNotFoundError as import_error: 248 | error_message = "To use the `evaluate` function, please install the `ragas` extra." 249 | raise ModuleNotFoundError(error_message) from import_error 250 | 251 | class RAGLiteRagasEmbeddings(BaseRagasEmbeddings): 252 | """A RAGLite embedder for Ragas.""" 253 | 254 | def __init__(self, config: RAGLiteConfig | None = None): 255 | self.config = config or RAGLiteConfig() 256 | 257 | def embed_query(self, text: str) -> list[float]: 258 | # Embed the input text with RAGLite's embedding function. 259 | embeddings = embed_strings([text], config=self.config) 260 | return embeddings[0].tolist() # type: ignore[no-any-return] 261 | 262 | def embed_documents(self, texts: list[str]) -> list[list[float]]: 263 | # Embed a list of documents with RAGLite's embedding function. 264 | embeddings = embed_strings(texts, config=self.config) 265 | return embeddings.tolist() # type: ignore[no-any-return] 266 | 267 | # Create a set of answered evals if not provided. 268 | config = config or RAGLiteConfig() 269 | answered_evals_df = ( 270 | answered_evals 271 | if isinstance(answered_evals, pd.DataFrame) 272 | else answer_evals(num_evals=answered_evals, config=config) 273 | ) 274 | # Load the LLM. 275 | if config.llm.startswith("llama-cpp-python"): 276 | llm = LlamaCppPythonLLM().llm(model=config.llm) 277 | lc_llm = LlamaCpp( 278 | model_path=llm.model_path, 279 | n_batch=llm.n_batch, 280 | n_ctx=llm.n_ctx(), 281 | n_gpu_layers=-1, 282 | verbose=llm.verbose, 283 | ) 284 | else: 285 | lc_llm = ChatLiteLLM(model=config.llm) 286 | embedder = RAGLiteRagasEmbeddings(config=config) 287 | # Evaluate the answered evals with Ragas. 288 | evaluation_df = ragas_evaluate( 289 | dataset=Dataset.from_pandas(answered_evals_df), 290 | llm=lc_llm, 291 | embeddings=embedder, 292 | run_config=RunConfig(max_workers=1), 293 | ).to_pandas() 294 | return evaluation_df 295 | -------------------------------------------------------------------------------- /src/raglite/_extract.py: -------------------------------------------------------------------------------- 1 | """Extract structured data from unstructured text with an LLM.""" 2 | 3 | from typing import Any, TypeVar 4 | 5 | from litellm import completion, get_supported_openai_params # type: ignore[attr-defined] 6 | from pydantic import BaseModel, ValidationError 7 | 8 | from raglite._config import RAGLiteConfig 9 | 10 | T = TypeVar("T", bound=BaseModel) 11 | 12 | 13 | def extract_with_llm( 14 | return_type: type[T], 15 | user_prompt: str | list[str], 16 | strict: bool = False, # noqa: FBT001, FBT002 17 | config: RAGLiteConfig | None = None, 18 | **kwargs: Any, 19 | ) -> T: 20 | """Extract structured data from unstructured text with an LLM. 21 | 22 | This function expects a `return_type.system_prompt: ClassVar[str]` that contains the system 23 | prompt to use. Example: 24 | 25 | from typing import ClassVar 26 | from pydantic import BaseModel, Field 27 | 28 | class MyNameResponse(BaseModel): 29 | my_name: str = Field(..., description="The user's name.") 30 | system_prompt: ClassVar[str] = "The system prompt to use (excluded from JSON schema)." 31 | 32 | my_name_response = extract_with_llm(MyNameResponse, "My name is Thomas A. Anderson.") 33 | """ 34 | # Load the default config if not provided. 35 | config = config or RAGLiteConfig() 36 | # Check if the LLM supports the response format. 37 | llm_supports_response_format = "response_format" in ( 38 | get_supported_openai_params(model=config.llm) or [] 39 | ) 40 | # Update the system prompt with the JSON schema of the return type to help the LLM. 41 | system_prompt = getattr(return_type, "system_prompt", "").strip() 42 | if not llm_supports_response_format or config.llm.startswith("llama-cpp-python"): 43 | system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}" 44 | # Constrain the response format to the JSON schema if it's supported by the LLM [1]. Strict mode 45 | # is disabled by default because it only supports a subset of JSON schema features [2]. 46 | # [1] https://docs.litellm.ai/docs/completion/json_mode 47 | # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported 48 | # TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM. 49 | response_format: dict[str, Any] | None = ( 50 | { 51 | "type": "json_schema", 52 | "json_schema": { 53 | "name": return_type.__name__, 54 | "description": return_type.__doc__ or "", 55 | "schema": return_type.model_json_schema(), 56 | "strict": strict, 57 | }, 58 | } 59 | if llm_supports_response_format 60 | else None 61 | ) 62 | # Concatenate the user prompt if it is a list of strings. 63 | if isinstance(user_prompt, list): 64 | user_prompt = "\n\n".join( 65 | f'\n{chunk.strip()}\n' 66 | for i, chunk in enumerate(user_prompt) 67 | ) 68 | # Extract structured data from the unstructured input. 69 | for _ in range(config.llm_max_tries): 70 | response = completion( 71 | model=config.llm, 72 | messages=[ 73 | {"role": "system", "content": system_prompt}, 74 | {"role": "user", "content": user_prompt}, 75 | ], 76 | response_format=response_format, 77 | **kwargs, 78 | ) 79 | try: 80 | instance = return_type.model_validate_json(response["choices"][0]["message"]["content"]) 81 | except (KeyError, ValueError, ValidationError) as e: 82 | # Malformed response, not a JSON string, or not a valid instance of the return type. 83 | last_exception = e 84 | continue 85 | else: 86 | break 87 | else: 88 | error_message = f"Failed to extract {return_type} from input {user_prompt}." 89 | raise ValueError(error_message) from last_exception 90 | return instance 91 | -------------------------------------------------------------------------------- /src/raglite/_insert.py: -------------------------------------------------------------------------------- 1 | """Index documents.""" 2 | 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | from contextlib import nullcontext 5 | from functools import partial 6 | from pathlib import Path 7 | 8 | from filelock import FileLock 9 | from sqlalchemy import text 10 | from sqlalchemy.engine import make_url 11 | from sqlmodel import Session, col, select 12 | from tqdm.auto import tqdm 13 | 14 | from raglite._config import RAGLiteConfig 15 | from raglite._database import Chunk, ChunkEmbedding, Document, create_database_engine 16 | from raglite._embed import embed_strings, embed_strings_without_late_chunking, embedding_type 17 | from raglite._split_chunklets import split_chunklets 18 | from raglite._split_chunks import split_chunks 19 | from raglite._split_sentences import split_sentences 20 | 21 | 22 | def _create_chunk_records( 23 | document: Document, config: RAGLiteConfig 24 | ) -> tuple[Document, list[Chunk], list[list[ChunkEmbedding]]]: 25 | """Process chunks into chunk and chunk embedding records.""" 26 | # Partition the document into chunks. 27 | assert document.content is not None 28 | sentences = split_sentences(document.content, max_len=config.chunk_max_size) 29 | chunklets = split_chunklets(sentences, max_size=config.chunk_max_size) 30 | chunklet_embeddings = embed_strings(chunklets, config=config) 31 | chunks, chunk_embeddings = split_chunks( 32 | chunklets=chunklets, 33 | chunklet_embeddings=chunklet_embeddings, 34 | max_size=config.chunk_max_size, 35 | ) 36 | # Create the chunk records. 37 | chunk_records, headings = [], "" 38 | for i, chunk in enumerate(chunks): 39 | # Create and append the chunk record. 40 | record = Chunk.from_body( 41 | document=document, index=i, body=chunk, headings=headings, **document.metadata_ 42 | ) 43 | chunk_records.append(record) 44 | # Update the Markdown headings with those of this chunk. 45 | headings = record.extract_headings() 46 | # Create the chunk embedding records. 47 | chunk_embedding_records_list = [] 48 | if embedding_type(config=config) == "late_chunking": 49 | # Every chunk record is associated with a list of chunk embedding records, one for each of 50 | # the chunklets in the chunk. 51 | for chunk_record, chunk_embedding in zip(chunk_records, chunk_embeddings, strict=True): 52 | chunk_embedding_records_list.append( 53 | [ 54 | ChunkEmbedding(chunk_id=chunk_record.id, embedding=chunklet_embedding) 55 | for chunklet_embedding in chunk_embedding 56 | ] 57 | ) 58 | else: 59 | # Embed the full chunks, including the current Markdown headings. 60 | full_chunk_embeddings = embed_strings_without_late_chunking( 61 | [chunk_record.content for chunk_record in chunk_records], config=config 62 | ) 63 | # Every chunk record is associated with a list of chunk embedding records. The chunk 64 | # embedding records each correspond to a linear combination of a chunklet embedding and an 65 | # embedding of the full chunk with Markdown headings. 66 | α = 0.15 # Benchmark-optimised value. # noqa: PLC2401 67 | for chunk_record, chunk_embedding, full_chunk_embedding in zip( 68 | chunk_records, chunk_embeddings, full_chunk_embeddings, strict=True 69 | ): 70 | if config.vector_search_multivector: 71 | chunk_embedding_records_list.append( 72 | [ 73 | ChunkEmbedding( 74 | chunk_id=chunk_record.id, 75 | embedding=α * chunklet_embedding + (1 - α) * full_chunk_embedding, 76 | ) 77 | for chunklet_embedding in chunk_embedding 78 | ] 79 | ) 80 | else: 81 | chunk_embedding_records_list.append( 82 | [ 83 | ChunkEmbedding( 84 | chunk_id=chunk_record.id, 85 | embedding=full_chunk_embedding, 86 | ) 87 | ] 88 | ) 89 | return document, chunk_records, chunk_embedding_records_list 90 | 91 | 92 | def insert_documents( # noqa: C901 93 | documents: list[Document], 94 | *, 95 | max_workers: int | None = None, 96 | config: RAGLiteConfig | None = None, 97 | ) -> None: 98 | """Insert documents into the database and update the index. 99 | 100 | Parameters 101 | ---------- 102 | documents 103 | A list of documents to insert into the database. 104 | max_workers 105 | The maximum number of worker threads to use. 106 | config 107 | The RAGLite config to use to insert the documents into the database. 108 | 109 | Returns 110 | ------- 111 | None 112 | """ 113 | # Verify that all documents have content. 114 | if not all(isinstance(doc.content, str) for doc in documents): 115 | error_message = "Some or all documents have missing `document.content`." 116 | raise ValueError(error_message) 117 | # Early exit if no documents are provided. 118 | documents = [doc for doc in documents if doc.content.strip()] # type: ignore[union-attr] 119 | if not documents: 120 | return 121 | # Skip documents that are already in the database. 122 | batch_size = 128 123 | with Session(engine := create_database_engine(config := config or RAGLiteConfig())) as session: 124 | existing_doc_ids: set[str] = set() 125 | for i in range(0, len(documents), batch_size): 126 | doc_id_batch = [doc.id for doc in documents[i : i + batch_size]] 127 | existing_doc_ids.update( 128 | session.exec(select(Document.id).where(col(Document.id).in_(doc_id_batch))).all() 129 | ) 130 | documents = [doc for doc in documents if doc.id not in existing_doc_ids] 131 | if not documents: 132 | return 133 | # For DuckDB databases, acquire a lock on the database. 134 | if engine.dialect.name == "duckdb": 135 | db_url = make_url(config.db_url) if isinstance(config.db_url, str) else config.db_url 136 | db_lock = ( 137 | FileLock(Path(db_url.database).with_suffix(".lock")) 138 | if db_url.database 139 | else nullcontext() 140 | ) 141 | else: 142 | db_lock = nullcontext() 143 | # Create and insert the document, chunk, and chunk embedding records. 144 | with ( 145 | db_lock, 146 | Session(engine) as session, 147 | ThreadPoolExecutor(max_workers=max_workers) as executor, 148 | tqdm( 149 | total=len(documents), desc="Inserting documents", unit="document", dynamic_ncols=True 150 | ) as pbar, 151 | ): 152 | futures = [ 153 | executor.submit(partial(_create_chunk_records, config=config), doc) for doc in documents 154 | ] 155 | num_unflushed_embeddings = 0 156 | for future in as_completed(futures): 157 | try: 158 | document_record, chunk_records, chunk_embedding_records_list = future.result() 159 | except Exception as e: 160 | executor.shutdown(cancel_futures=True) # Cancel remaining work. 161 | session.rollback() # Cancel uncommitted changes. 162 | error_message = f"Error processing document: {e}" 163 | raise ValueError(error_message) from e 164 | session.add(document_record) 165 | session.add_all(chunk_records) 166 | for chunk_embedding_records in chunk_embedding_records_list: 167 | session.add_all(chunk_embedding_records) 168 | num_unflushed_embeddings += len(chunk_embedding_records) 169 | if num_unflushed_embeddings >= batch_size: 170 | session.flush() # Flush changes to the database. 171 | session.expunge_all() # Release memory of flushed changes. 172 | num_unflushed_embeddings = 0 173 | pbar.update() 174 | session.commit() 175 | if engine.dialect.name == "duckdb": 176 | # DuckDB does not automatically update its keyword search index [1], so we do it 177 | # manually after insertion. Additionally, we re-compact the HNSW index [2]. Finally, we 178 | # synchronize data in the write-ahead log (WAL) to the database data file with the 179 | # CHECKPOINT statement [3]. 180 | # [1] https://duckdb.org/docs/stable/extensions/full_text_search 181 | # [2] https://duckdb.org/docs/stable/core_extensions/vss.html#inserts-updates-deletes-and-re-compaction 182 | # [3] https://duckdb.org/docs/stable/sql/statements/checkpoint.html 183 | session.execute(text("PRAGMA create_fts_index('chunk', 'id', 'body', overwrite = 1);")) 184 | if len(documents) >= 8: # noqa: PLR2004 185 | session.execute(text("PRAGMA hnsw_compact_index('vector_search_chunk_index');")) 186 | session.commit() 187 | session.execute(text("CHECKPOINT;")) 188 | -------------------------------------------------------------------------------- /src/raglite/_lazy_llama.py: -------------------------------------------------------------------------------- 1 | """Import from llama-cpp-python with a lazy ModuleNotFoundError if it's not installed.""" 2 | 3 | from importlib import import_module 4 | from typing import TYPE_CHECKING, Any, NoReturn 5 | 6 | # When type checking, import everything normally. 7 | if TYPE_CHECKING: 8 | from llama_cpp import ( # type: ignore[attr-defined] 9 | LLAMA_POOLING_TYPE_NONE, 10 | Llama, 11 | LlamaRAMCache, 12 | llama, 13 | llama_chat_format, 14 | llama_grammar, 15 | llama_supports_gpu_offload, 16 | llama_types, 17 | ) 18 | 19 | # Explicitly export these names for static analysis. 20 | __all__ = [ 21 | "LLAMA_POOLING_TYPE_NONE", 22 | "Llama", 23 | "LlamaRAMCache", 24 | "llama", 25 | "llama_chat_format", 26 | "llama_grammar", 27 | "llama_supports_gpu_offload", 28 | "llama_types", 29 | ] 30 | 31 | 32 | def __getattr__(name: str) -> object: 33 | """Import from llama-cpp-python with a lazy ModuleNotFoundError if it's not installed.""" 34 | 35 | # Create a mock attribute and submodule that lazily raises an ModuleNotFoundError when accessed. 36 | class LazyAttributeError: 37 | error_message = "To use llama.cpp models, please install `llama-cpp-python`." 38 | 39 | def __init__(self, error: ModuleNotFoundError | None = None): 40 | self.error = error 41 | 42 | def __getattr__(self, name: str) -> NoReturn: 43 | raise ModuleNotFoundError(self.error_message) from self.error 44 | 45 | def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: 46 | raise ModuleNotFoundError(self.error_message) from self.error 47 | 48 | class LazySubmoduleError: 49 | def __init__(self, error: ModuleNotFoundError): 50 | self.error = error 51 | 52 | def __getattr__(self, name: str) -> LazyAttributeError | type[LazyAttributeError]: 53 | return LazyAttributeError(self.error) if name == name.lower() else LazyAttributeError 54 | 55 | # Check if the attribute is a submodule. 56 | llama_cpp_submodules = ["llama", "llama_chat_format", "llama_grammar", "llama_types"] 57 | attr_is_submodule = name in llama_cpp_submodules 58 | try: 59 | # Import and return the requested submodule or attribute. 60 | module = import_module(f"llama_cpp.{name}" if attr_is_submodule else "llama_cpp") 61 | return module if attr_is_submodule else getattr(module, name) 62 | except ModuleNotFoundError as import_error: 63 | # Return a mock submodule or attribute that lazily raises an ModuleNotFoundError. 64 | return ( 65 | LazySubmoduleError(import_error) 66 | if attr_is_submodule 67 | else LazyAttributeError(import_error) 68 | ) 69 | -------------------------------------------------------------------------------- /src/raglite/_markdown.py: -------------------------------------------------------------------------------- 1 | """Convert any document to Markdown.""" 2 | 3 | import re 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import numpy as np 9 | from pdftext.extraction import dictionary_output 10 | from sklearn.cluster import KMeans 11 | 12 | 13 | def parsed_pdf_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0915 14 | """Convert a PDF parsed with pdftext to Markdown.""" 15 | 16 | def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: # noqa: C901 17 | """Add heading level metadata to a PDF parsed with pdftext.""" 18 | 19 | def extract_font_size(span: dict[str, Any]) -> float: 20 | """Extract the font size from a text span.""" 21 | font_size: float = 1.0 22 | if span["font"]["size"] > 1: # A value of 1 appears to mean "unknown" in pdftext. 23 | font_size = span["font"]["size"] 24 | elif digit_sequences := re.findall(r"\d+", span["font"]["name"] or ""): 25 | font_size = float(digit_sequences[-1]) 26 | elif "\n" not in span["text"]: # Occasionally a span can contain a newline character. 27 | if round(span["rotation"]) in (0.0, 180.0, -180.0): 28 | font_size = span["bbox"][3] - span["bbox"][1] 29 | elif round(span["rotation"]) in (90.0, -90.0, 270.0, -270.0): 30 | font_size = span["bbox"][2] - span["bbox"][0] 31 | return font_size 32 | 33 | # Copy the pages. 34 | pages = deepcopy(pages) 35 | # Extract an array of all font sizes used by the text spans. 36 | font_sizes = np.asarray( 37 | [ 38 | extract_font_size(span) 39 | for page in pages 40 | for block in page["blocks"] 41 | for line in block["lines"] 42 | for span in line["spans"] 43 | ] 44 | ) 45 | font_sizes = np.round(font_sizes * 2) / 2 46 | unique_font_sizes, counts = np.unique(font_sizes, return_counts=True) 47 | # Determine the paragraph font size as the mode font size. 48 | tiny = unique_font_sizes < min(5, np.max(unique_font_sizes)) 49 | counts[tiny] = -counts[tiny] 50 | mode = np.argmax(counts) 51 | counts[tiny] = -counts[tiny] 52 | mode_font_size = unique_font_sizes[mode] 53 | # Determine (at most) 6 heading font sizes by clustering font sizes larger than the mode. 54 | heading_font_sizes = unique_font_sizes[mode + 1 :] 55 | if len(heading_font_sizes) > 0: 56 | heading_counts = counts[mode + 1 :] 57 | kmeans = KMeans(n_clusters=min(6, len(heading_font_sizes)), random_state=42) 58 | kmeans.fit(heading_font_sizes[:, np.newaxis], sample_weight=heading_counts) 59 | heading_font_sizes = np.sort(np.ravel(kmeans.cluster_centers_))[::-1] 60 | # Add heading level information to the text spans and lines. 61 | for page in pages: 62 | for block in page["blocks"]: 63 | for line in block["lines"]: 64 | if "md" not in line: 65 | line["md"] = {} 66 | heading_level = np.zeros(8) # 0-5:

-

, 6:

, 7: 67 | for span in line["spans"]: 68 | if "md" not in span: 69 | span["md"] = {} 70 | span_font_size = extract_font_size(span) 71 | if span_font_size < mode_font_size: 72 | idx = 7 73 | elif span_font_size == mode_font_size: 74 | idx = 6 75 | else: 76 | idx = np.argmin(np.abs(heading_font_sizes - span_font_size)) # type: ignore[assignment] 77 | span["md"]["heading_level"] = idx + 1 78 | heading_level[idx] += len(span["text"]) 79 | line["md"]["heading_level"] = np.argmax(heading_level) + 1 80 | return pages 81 | 82 | def add_emphasis_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: 83 | """Add emphasis metadata such as bold and italic to a PDF parsed with pdftext.""" 84 | # Copy the pages. 85 | pages = deepcopy(pages) 86 | # Add emphasis metadata to the text spans. 87 | for page in pages: 88 | for block in page["blocks"]: 89 | for line in block["lines"]: 90 | if "md" not in line: 91 | line["md"] = {} 92 | for span in line["spans"]: 93 | if "md" not in span: 94 | span["md"] = {} 95 | span["md"]["bold"] = span["font"]["weight"] > 500 # noqa: PLR2004 96 | span["md"]["italic"] = "ital" in (span["font"]["name"] or "").lower() 97 | line["md"]["bold"] = all( 98 | span["md"]["bold"] for span in line["spans"] if span["text"].strip() 99 | ) 100 | line["md"]["italic"] = all( 101 | span["md"]["italic"] for span in line["spans"] if span["text"].strip() 102 | ) 103 | return pages 104 | 105 | def strip_page_numbers(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: 106 | """Strip page numbers from a PDF parsed with pdftext.""" 107 | # Copy the pages. 108 | pages = deepcopy(pages) 109 | # Remove lines that only contain a page number. 110 | for page in pages: 111 | for block in page["blocks"]: 112 | block["lines"] = [ 113 | line 114 | for line in block["lines"] 115 | if not re.match( 116 | r"^\s*[#0]*\d+\s*$", "".join(span["text"] for span in line["spans"]) 117 | ) 118 | ] 119 | return pages 120 | 121 | def convert_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0912 122 | """Convert a list of pages to Markdown.""" 123 | pages_md = [] 124 | for page in pages: 125 | page_md = "" 126 | for block in page["blocks"]: 127 | block_text = "" 128 | for line in block["lines"]: 129 | # Build the line text and style the spans. 130 | line_text = "" 131 | for span in line["spans"]: 132 | if ( 133 | not line["md"]["bold"] 134 | and not line["md"]["italic"] 135 | and span["md"]["bold"] 136 | and span["md"]["italic"] 137 | ): 138 | line_text += f"***{span['text']}***" 139 | elif not line["md"]["bold"] and span["md"]["bold"]: 140 | line_text += f"**{span['text']}**" 141 | elif not line["md"]["italic"] and span["md"]["italic"]: 142 | line_text += f"*{span['text']}*" 143 | else: 144 | line_text += span["text"] 145 | # Add emphasis to the line (if it's not a heading or whitespace). 146 | line_text = line_text.rstrip() 147 | line_is_whitespace = not line_text.strip() 148 | line_is_heading = line["md"]["heading_level"] <= 6 # noqa: PLR2004 149 | if not line_is_heading and not line_is_whitespace: 150 | if line["md"]["bold"] and line["md"]["italic"]: 151 | line_text = f"***{line_text}***" 152 | elif line["md"]["bold"]: 153 | line_text = f"**{line_text}**" 154 | elif line["md"]["italic"]: 155 | line_text = f"*{line_text}*" 156 | # Set the heading level. 157 | if line_is_heading and not line_is_whitespace: 158 | line_text = f"{'#' * line['md']['heading_level']} {line_text}" 159 | line_text += "\n" 160 | block_text += line_text 161 | block_text = block_text.rstrip() + "\n\n" 162 | page_md += block_text 163 | pages_md.append(page_md.strip()) 164 | return pages_md 165 | 166 | def merge_split_headings(pages: list[str]) -> list[str]: 167 | """Merge headings that are split across lines.""" 168 | 169 | def _merge_split_headings(match: re.Match[str]) -> str: 170 | atx_headings = [line.strip("# ").strip() for line in match.group().splitlines()] 171 | return f"{match.group(1)} {' '.join(atx_headings)}\n\n" 172 | 173 | pages_md = [ 174 | re.sub( 175 | r"^(#+)[ \t]+[^\n]+\n+(?:^\1[ \t]+[^\n]+\n+)+", 176 | _merge_split_headings, 177 | page, 178 | flags=re.MULTILINE, 179 | ) 180 | for page in pages 181 | ] 182 | return pages_md 183 | 184 | # Add heading level metadata. 185 | pages = add_heading_level_metadata(pages) 186 | # Add emphasis metadata. 187 | pages = add_emphasis_metadata(pages) 188 | # Strip page numbers. 189 | pages = strip_page_numbers(pages) 190 | # Convert the pages to Markdown. 191 | pages_md = convert_to_markdown(pages) 192 | # Merge headings that are split across lines. 193 | pages_md = merge_split_headings(pages_md) 194 | return pages_md 195 | 196 | 197 | def document_to_markdown(doc_path: Path) -> str: 198 | """Convert any document to GitHub Flavored Markdown.""" 199 | # Convert the file's content to GitHub Flavored Markdown. 200 | if doc_path.suffix == ".pdf": 201 | # Parse the PDF with pdftext and convert it to Markdown. 202 | pages = dictionary_output(doc_path, sort=True, keep_chars=False) 203 | doc = "\n\n".join(parsed_pdf_to_markdown(pages)) 204 | elif doc_path.suffix in (".md", ".txt"): 205 | # Read the Markdown file. 206 | doc = doc_path.read_text() 207 | else: 208 | try: 209 | # Use pandoc for everything else. 210 | import pypandoc 211 | 212 | doc = pypandoc.convert_file(doc_path, to="gfm") 213 | except ModuleNotFoundError as error: 214 | error_message = ( 215 | "To convert files to Markdown with pandoc, please install the `pandoc` extra." 216 | ) 217 | raise ModuleNotFoundError(error_message) from error 218 | except RuntimeError: 219 | # File format not supported, fall back to reading the text. 220 | doc = doc_path.read_text() 221 | return doc 222 | -------------------------------------------------------------------------------- /src/raglite/_mcp.py: -------------------------------------------------------------------------------- 1 | """MCP server for RAGLite.""" 2 | 3 | from typing import Annotated, Any 4 | 5 | from fastmcp import FastMCP 6 | from pydantic import Field 7 | 8 | from raglite._config import RAGLiteConfig 9 | from raglite._rag import add_context, retrieve_context 10 | 11 | Query = Annotated[ 12 | str, 13 | Field( 14 | description=( 15 | "The `query` string MUST be a precise single-faceted question in the user's language.\n" 16 | "The `query` string MUST resolve all pronouns to explicit nouns." 17 | ) 18 | ), 19 | ] 20 | 21 | 22 | def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP[Any]: 23 | """Create a RAGLite MCP server.""" 24 | mcp: FastMCP[Any] = FastMCP(server_name) 25 | 26 | @mcp.prompt() 27 | def kb(query: Query) -> str: 28 | """Answer a question with information from the knowledge base.""" 29 | chunk_spans = retrieve_context(query, config=config) 30 | rag_instruction = add_context(query, chunk_spans) 31 | return rag_instruction["content"] 32 | 33 | @mcp.tool() 34 | def search_knowledge_base(query: Query) -> str: 35 | """Search the knowledge base. 36 | 37 | IMPORTANT: You MAY NOT use this function if the question can be answered with common 38 | knowledge or straightforward reasoning. For multi-faceted questions, call this function once 39 | for each facet. 40 | """ 41 | chunk_spans = retrieve_context(query, config=config) 42 | rag_context = '{{"documents": [{elements}]}}'.format( 43 | elements=", ".join( 44 | chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) 45 | ) 46 | ) 47 | return rag_context 48 | 49 | # Warm up the querying pipeline. 50 | if config.embedder.startswith("llama-cpp-python"): 51 | _ = retrieve_context("Hello world", config=config) 52 | 53 | return mcp 54 | -------------------------------------------------------------------------------- /src/raglite/_query_adapter.py: -------------------------------------------------------------------------------- 1 | """Compute and update an optimal query adapter.""" 2 | 3 | # ruff: noqa: N806 4 | 5 | from dataclasses import replace 6 | 7 | import numpy as np 8 | from scipy.optimize import lsq_linear 9 | from sqlalchemy import text 10 | from sqlalchemy.orm.attributes import flag_modified 11 | from sqlmodel import Session, col, select 12 | from tqdm.auto import tqdm 13 | 14 | from raglite._config import RAGLiteConfig 15 | from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine 16 | from raglite._embed import embed_strings 17 | from raglite._search import vector_search 18 | from raglite._typing import FloatMatrix, FloatVector 19 | 20 | 21 | def _optimize_query_target( 22 | q: FloatVector, 23 | P: FloatMatrix, # noqa: N803, 24 | N: FloatMatrix, # noqa: N803, 25 | *, 26 | α: float = 0.05, # noqa: PLC2401 27 | ) -> FloatVector: 28 | # Convert to double precision for the optimizer. 29 | q_dtype = q.dtype 30 | q, P, N = q.astype(np.float64), P.astype(np.float64), N.astype(np.float64) 31 | # Construct the constraint matrix D := P - (1 + α) * N. # noqa: RUF003 32 | D = np.reshape(P[:, np.newaxis, :] - (1.0 + α) * N[np.newaxis, :, :], (-1, P.shape[1])) 33 | # Solve the dual problem min_μ ½ ‖q + Dᵀ μ‖² s.t. μ ≥ 0. 34 | A, b = D.T, -q 35 | μ_star = lsq_linear(A, b, bounds=(0.0, np.inf), tol=np.finfo(A.dtype).eps).x # noqa: PLC2401 36 | # Recover the primal solution q* = q + Dᵀ μ*. 37 | q_star: FloatVector = (q + D.T @ μ_star).astype(q_dtype) 38 | return q_star 39 | 40 | 41 | def update_query_adapter( 42 | *, 43 | max_evals: int = 4096, 44 | optimize_top_k: int = 40, 45 | optimize_gap: float = 0.05, 46 | config: RAGLiteConfig | None = None, 47 | ) -> FloatMatrix: 48 | """Compute an optimal query adapter and update the database with it. 49 | 50 | This function computes an optimal linear transform A, called a 'query adapter', that is used to 51 | transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in 52 | order to improve the quality of the search results. 53 | 54 | Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the 55 | score pᵢᵀqᵢ of the positive chunk pᵢ and decreases the score nᵢᵀqᵢ of the negative chunk nᵢ. 56 | 57 | If the nearest neighbour search uses the dot product as its relevance score, we can find the 58 | optimal query adapter by solving the following relaxed Procrustes optimisation problem with a 59 | bound on the Frobenius norm of A: 60 | 61 | A* := argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ) 62 | Σᵢ (pᵢ - nᵢ)ᵀ A qᵢ 63 | trace[ (P - N) A Qᵀ ] where Q := [q₁ᵀ; ...; qₖᵀ] 64 | P := [p₁ᵀ; ...; pₖᵀ] 65 | N := [n₁ᵀ; ...; nₖᵀ] 66 | trace[ Qᵀ (P - N) A ] 67 | trace[ Mᵀ A ] where M := (P - N)ᵀ Q 68 | s.t. ||A||_F == 1 69 | = M / ||M||_F 70 | 71 | If the nearest neighbour search uses the cosine similarity as its relevance score, we can find 72 | the optimal query adapter by solving the following orthogonal Procrustes optimisation problem 73 | [1] with an orthogonality constraint on A: 74 | 75 | A* := argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ) 76 | Σᵢ (pᵢ - nᵢ)ᵀ A qᵢ 77 | trace[ (P - N) A Qᵀ ] 78 | trace[ Qᵀ (P - N) A ] 79 | trace[ Mᵀ A ] 80 | trace[ (U Σ V)ᵀ A ] where U Σ Vᵀ := M is the SVD of M 81 | trace[ Σ V A Uᵀ ] 82 | s.t. AᵀA == 𝕀 83 | = U Vᵀ 84 | 85 | The action of A* is to map a query embedding qᵢ to a target vector t := (pᵢ - nᵢ) that maximally 86 | separates the positive and negative chunks. For a given query embedding qᵢ, a retrieval method 87 | will yield a result set containing both positive and negative chunks. Instead of extracting 88 | multiple triplets (qᵢ, pᵢ, nᵢ) from each such result set, we can compute a single optimal target 89 | vector t* for the query embedding qᵢ as follows: 90 | 91 | t* := argmax ½ ||t - qᵢ||² 92 | s.t. Dᵢ t >= 0 93 | 94 | where the constraint matrix Dᵢ := [pₘᵀ - (1 + α) nₙᵀ]ₘₙ comprises all pairs of positive and 95 | negative chunk embeddings in the result set corresponding to the query embedding qᵢ. This 96 | optimisation problem expresses the idea that the target vector t* should be as close as 97 | possible to the query embedding qᵢ, while separating all positive and negative chunk embeddings 98 | in the result set by a margin of at least α. To solve this problem, we'll first introduce 99 | a Lagrangian with Lagrange multipliers μ: 100 | 101 | L(t, μ) := ½ ||t - qᵢ||² + μᵀ (-Dᵢ t) 102 | 103 | Now we can set the gradient of the Lagrangian to zero to find the optimal target vector t*: 104 | 105 | ∇ₜL = t - qᵢ - Dᵢᵀ μ = 0 106 | t* = qᵢ + Dᵢᵀ μ* 107 | 108 | where μ* is the solution to the dual nonnegative least squares problem 109 | 110 | μ* := argmin ½ ||qᵢ + Dᵢᵀ μ||² 111 | s.t. μ >= 0 112 | 113 | Parameters 114 | ---------- 115 | max_evals 116 | The maximum number of evals to use to compute the query adapter. Each eval corresponds to a 117 | rank-one update of the query adapter A. 118 | optimize_top_k 119 | The number of search results per eval to optimize. 120 | optimize_gap 121 | The strength of the query adapter, expressed as a nonnegative number. Should be large enough 122 | to correct incorrectly ranked results, but small enough to not affect correctly ranked 123 | results. 124 | config 125 | The RAGLite config to use to construct and store the query adapter. 126 | 127 | Raises 128 | ------ 129 | ValueError 130 | If no documents have been inserted into the database yet. 131 | ValueError 132 | If no evals have been inserted into the database yet. 133 | ValueError 134 | If the `config.vector_search_distance_metric` is not supported. 135 | 136 | Returns 137 | ------- 138 | FloatMatrix 139 | The query adapter. 140 | """ 141 | config = config or RAGLiteConfig() 142 | config_no_query_adapter = replace(config, vector_search_query_adapter=False) 143 | with Session(engine := create_database_engine(config)) as session: 144 | # Get random evals from the database. 145 | chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first() 146 | if chunk_embedding is None: 147 | error_message = "First run `insert_documents()` to insert documents." 148 | raise ValueError(error_message) 149 | evals = session.exec(select(Eval).order_by(Eval.id).limit(max_evals)).all() 150 | if len(evals) == 0: 151 | error_message = "First run `insert_evals()` to generate evals." 152 | raise ValueError(error_message) 153 | # Construct the query and target matrices. 154 | Q = np.zeros((0, len(chunk_embedding.embedding))) 155 | T = np.zeros_like(Q) 156 | for eval_ in tqdm( 157 | evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False 158 | ): 159 | # Embed the question. 160 | q = embed_strings([eval_.question], config=config)[0] 161 | # Retrieve chunks that would be used to answer the question. 162 | chunk_ids, _ = vector_search( 163 | q, num_results=optimize_top_k, config=config_no_query_adapter 164 | ) 165 | retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() 166 | retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) 167 | # Skip this eval if it doesn't contain both relevant and irrelevant chunks. 168 | is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks]) 169 | if not np.any(is_relevant) or not np.any(~is_relevant): 170 | continue 171 | # Extract the positive and negative chunk embeddings. 172 | P = np.vstack( 173 | [ 174 | chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] 175 | for chunk in np.array(retrieved_chunks)[is_relevant] 176 | ] 177 | ) 178 | N = np.vstack( 179 | [ 180 | chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] 181 | for chunk in np.array(retrieved_chunks)[~is_relevant] 182 | ] 183 | ) 184 | # Compute the optimal target vector t for this query embedding q. 185 | t = _optimize_query_target(q, P, N, α=optimize_gap) 186 | Q = np.vstack([Q, q[np.newaxis, :]]) 187 | T = np.vstack([T, t[np.newaxis, :]]) 188 | # Normalise the rows of Q and T. 189 | Q /= np.linalg.norm(Q, axis=1, keepdims=True) 190 | if config.vector_search_distance_metric == "cosine": 191 | T /= np.linalg.norm(T, axis=1, keepdims=True) 192 | # Compute the optimal unconstrained query adapter M. 193 | n, d = Q.shape 194 | M = (1 / n) * T.T @ Q 195 | if n < d or np.linalg.matrix_rank(Q) < d: 196 | M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q 197 | # Compute the optimal constrained query adapter A* from M, given the distance metric. 198 | A_star: FloatMatrix 199 | if config.vector_search_distance_metric == "dot": 200 | # Use the relaxed Procrustes solution. 201 | A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d) 202 | elif config.vector_search_distance_metric == "cosine": 203 | # Use the orthogonal Procrustes solution. 204 | U, _, VT = np.linalg.svd(M, full_matrices=False) 205 | A_star = U @ VT 206 | else: 207 | error_message = f"Unsupported metric: {config.vector_search_distance_metric}" 208 | raise ValueError(error_message) 209 | # Store the optimal query adapter in the database. 210 | index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default") 211 | index_metadata.metadata_["query_adapter"] = A_star 212 | flag_modified(index_metadata, "metadata_") 213 | session.add(index_metadata) 214 | session.commit() 215 | if engine.dialect.name == "duckdb": 216 | session.execute(text("CHECKPOINT;")) 217 | # Clear the index metadata cache to allow the new query adapter to be used. 218 | IndexMetadata._get.cache_clear() # noqa: SLF001 219 | return A_star 220 | -------------------------------------------------------------------------------- /src/raglite/_rag.py: -------------------------------------------------------------------------------- 1 | """Retrieval-augmented generation.""" 2 | 3 | import json 4 | from collections.abc import AsyncIterator, Callable, Iterator 5 | from typing import Any 6 | 7 | import numpy as np 8 | from litellm import ( # type: ignore[attr-defined] 9 | ChatCompletionMessageToolCall, 10 | acompletion, 11 | completion, 12 | stream_chunk_builder, 13 | supports_function_calling, 14 | ) 15 | 16 | from raglite._config import RAGLiteConfig 17 | from raglite._database import Chunk, ChunkSpan 18 | from raglite._litellm import get_context_size 19 | from raglite._search import retrieve_chunk_spans 20 | 21 | # The default RAG instruction template follows Anthropic's best practices [1]. 22 | # [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips 23 | RAG_INSTRUCTION_TEMPLATE = """ 24 | --- 25 | You are a friendly and knowledgeable assistant that provides complete and insightful answers. 26 | Whenever possible, use only the provided context to respond to the question at the end. 27 | When responding, you MUST NOT reference the existence of the context, directly or indirectly. 28 | Instead, you MUST treat the context as if its contents are entirely part of your working memory. 29 | --- 30 | 31 | 32 | {context} 33 | 34 | 35 | {user_prompt} 36 | """.strip() 37 | 38 | 39 | def retrieve_context( 40 | query: str, *, num_chunks: int = 10, config: RAGLiteConfig | None = None 41 | ) -> list[ChunkSpan]: 42 | """Retrieve context for RAG.""" 43 | # Call the search method. 44 | config = config or RAGLiteConfig() 45 | results = config.search_method(query, num_results=num_chunks, config=config) 46 | # Convert results to chunk spans. 47 | chunk_spans = [] 48 | if isinstance(results, tuple): 49 | chunk_spans = retrieve_chunk_spans(results[0], config=config) 50 | elif all(isinstance(result, Chunk) for result in results): 51 | chunk_spans = retrieve_chunk_spans(results, config=config) # type: ignore[arg-type] 52 | elif all(isinstance(result, ChunkSpan) for result in results): 53 | chunk_spans = results # type: ignore[assignment] 54 | return chunk_spans 55 | 56 | 57 | def add_context( 58 | user_prompt: str, 59 | context: list[ChunkSpan], 60 | *, 61 | rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE, 62 | ) -> dict[str, str]: 63 | """Convert a user prompt to a RAG instruction. 64 | 65 | The RAG instruction's format follows Anthropic's best practices [1]. 66 | 67 | [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips 68 | """ 69 | message = { 70 | "role": "user", 71 | "content": rag_instruction_template.format( 72 | context="\n".join( 73 | chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context) 74 | ), 75 | user_prompt=user_prompt.strip(), 76 | ), 77 | } 78 | return message 79 | 80 | 81 | def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]: 82 | """Left clip a messages array to avoid hitting the context limit.""" 83 | cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1]) 84 | first_message = -np.searchsorted(cum_tokens, max_tokens) 85 | return messages[first_message:] 86 | 87 | 88 | def _get_tools( 89 | messages: list[dict[str, str]], config: RAGLiteConfig 90 | ) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]: 91 | """Get tools to search the knowledge base if no RAG context is provided in the messages.""" 92 | # Check if messages already contain RAG context or if the LLM supports tool use. 93 | final_message = messages[-1].get("content", "") 94 | messages_contain_rag_context = any( 95 | s in final_message for s in ("", "", "from_chunk_id") 96 | ) 97 | llm_supports_function_calling = supports_function_calling(config.llm) 98 | if not messages_contain_rag_context and not llm_supports_function_calling: 99 | error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling." 100 | raise ValueError(error_message) 101 | # Return a single tool to search the knowledge base if no RAG context is provided. 102 | tools: list[dict[str, Any]] | None = ( 103 | [ 104 | { 105 | "type": "function", 106 | "function": { 107 | "name": "search_knowledge_base", 108 | "description": ( 109 | "Search the knowledge base.\n" 110 | "IMPORTANT: You MAY NOT use this function if the question can be answered with common knowledge or straightforward reasoning.\n" 111 | "For multi-faceted questions, call this function once for each facet." 112 | ), 113 | "parameters": { 114 | "type": "object", 115 | "properties": { 116 | "query": { 117 | "type": "string", 118 | "description": ( 119 | "The `query` string MUST be a precise single-faceted question in the user's language.\n" 120 | "The `query` string MUST resolve all pronouns to explicit nouns." 121 | ), 122 | }, 123 | }, 124 | "required": ["query"], 125 | "additionalProperties": False, 126 | }, 127 | }, 128 | } 129 | ] 130 | if not messages_contain_rag_context 131 | else None 132 | ) 133 | tool_choice: dict[str, Any] | str | None = "auto" if tools else None 134 | return tools, tool_choice 135 | 136 | 137 | def _run_tools( 138 | tool_calls: list[ChatCompletionMessageToolCall], 139 | on_retrieval: Callable[[list[ChunkSpan]], None] | None, 140 | config: RAGLiteConfig, 141 | ) -> list[dict[str, Any]]: 142 | """Run tools to search the knowledge base for RAG context.""" 143 | tool_messages: list[dict[str, Any]] = [] 144 | for tool_call in tool_calls: 145 | if tool_call.function.name == "search_knowledge_base": 146 | kwargs = json.loads(tool_call.function.arguments) 147 | kwargs["config"] = config 148 | chunk_spans = retrieve_context(**kwargs) 149 | tool_messages.append( 150 | { 151 | "role": "tool", 152 | "content": '{{"documents": [{elements}]}}'.format( 153 | elements=", ".join( 154 | chunk_span.to_json(index=i + 1) 155 | for i, chunk_span in enumerate(chunk_spans) 156 | ) 157 | ), 158 | "tool_call_id": tool_call.id, 159 | } 160 | ) 161 | if chunk_spans and callable(on_retrieval): 162 | on_retrieval(chunk_spans) 163 | else: 164 | error_message = f"Unknown function `{tool_call.function.name}`." 165 | raise ValueError(error_message) 166 | return tool_messages 167 | 168 | 169 | def rag( 170 | messages: list[dict[str, str]], 171 | *, 172 | on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, 173 | config: RAGLiteConfig, 174 | ) -> Iterator[str]: 175 | # If the final message does not contain RAG context, get a tool to search the knowledge base. 176 | max_tokens = get_context_size(config) 177 | tools, tool_choice = _get_tools(messages, config) 178 | # Stream the LLM response, which is either a tool call request or an assistant response. 179 | stream = completion( 180 | model=config.llm, 181 | messages=_clip(messages, max_tokens), 182 | tools=tools, 183 | tool_choice=tool_choice, 184 | stream=True, 185 | ) 186 | chunks = [] 187 | for chunk in stream: 188 | chunks.append(chunk) 189 | if isinstance(token := chunk.choices[0].delta.content, str): 190 | yield token 191 | # Check if there are tools to be called. 192 | response = stream_chunk_builder(chunks, messages) 193 | tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] 194 | if tool_calls: 195 | # Add the tool call request to the message array. 196 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] 197 | # Run the tool calls to retrieve the RAG context and append the output to the message array. 198 | messages.extend(_run_tools(tool_calls, on_retrieval, config)) 199 | # Stream the assistant response. 200 | chunks = [] 201 | stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True) 202 | for chunk in stream: 203 | chunks.append(chunk) 204 | if isinstance(token := chunk.choices[0].delta.content, str): 205 | yield token 206 | # Append the assistant response to the message array. 207 | response = stream_chunk_builder(chunks, messages) 208 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] 209 | 210 | 211 | async def async_rag( 212 | messages: list[dict[str, str]], 213 | *, 214 | on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, 215 | config: RAGLiteConfig, 216 | ) -> AsyncIterator[str]: 217 | # If the final message does not contain RAG context, get a tool to search the knowledge base. 218 | max_tokens = get_context_size(config) 219 | tools, tool_choice = _get_tools(messages, config) 220 | # Asynchronously stream the LLM response, which is either a tool call or an assistant response. 221 | async_stream = await acompletion( 222 | model=config.llm, 223 | messages=_clip(messages, max_tokens), 224 | tools=tools, 225 | tool_choice=tool_choice, 226 | stream=True, 227 | ) 228 | chunks = [] 229 | async for chunk in async_stream: 230 | chunks.append(chunk) 231 | if isinstance(token := chunk.choices[0].delta.content, str): 232 | yield token 233 | # Check if there are tools to be called. 234 | response = stream_chunk_builder(chunks, messages) 235 | tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] 236 | if tool_calls: 237 | # Add the tool call requests to the message array. 238 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] 239 | # Run the tool calls to retrieve the RAG context and append the output to the message array. 240 | # TODO: Make this async. 241 | messages.extend(_run_tools(tool_calls, on_retrieval, config)) 242 | # Asynchronously stream the assistant response. 243 | chunks = [] 244 | async_stream = await acompletion( 245 | model=config.llm, messages=_clip(messages, max_tokens), stream=True 246 | ) 247 | async for chunk in async_stream: 248 | chunks.append(chunk) 249 | if isinstance(token := chunk.choices[0].delta.content, str): 250 | yield token 251 | # Append the assistant response to the message array. 252 | response = stream_chunk_builder(chunks, messages) 253 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] 254 | -------------------------------------------------------------------------------- /src/raglite/_search.py: -------------------------------------------------------------------------------- 1 | """Search and retrieve chunks.""" 2 | 3 | import contextlib 4 | import re 5 | import string 6 | from collections import defaultdict 7 | from itertools import groupby 8 | 9 | import numpy as np 10 | from langdetect import LangDetectException, detect 11 | from sqlalchemy.orm import joinedload 12 | from sqlmodel import Session, and_, col, func, or_, select, text 13 | 14 | from raglite._config import RAGLiteConfig 15 | from raglite._database import ( 16 | Chunk, 17 | ChunkEmbedding, 18 | ChunkSpan, 19 | IndexMetadata, 20 | create_database_engine, 21 | ) 22 | from raglite._embed import embed_strings 23 | from raglite._typing import BasicSearchMethod, ChunkId, FloatVector 24 | 25 | 26 | def vector_search( 27 | query: str | FloatVector, 28 | *, 29 | num_results: int = 3, 30 | oversample: int = 4, 31 | config: RAGLiteConfig | None = None, 32 | ) -> tuple[list[ChunkId], list[float]]: 33 | """Search chunks using ANN vector search.""" 34 | # Read the config. 35 | config = config or RAGLiteConfig() 36 | # Embed the query. 37 | query_embedding = ( 38 | embed_strings([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query) 39 | ) 40 | # Apply the query adapter to the query embedding. 41 | if ( 42 | config.vector_search_query_adapter 43 | and (Q := IndexMetadata.get("default", config=config).get("query_adapter")) is not None # noqa: N806 44 | ): 45 | query_embedding = (Q @ query_embedding).astype(query_embedding.dtype) 46 | # Rank the chunks by relevance according to the L∞ norm of the similarities of the multi-vector 47 | # chunk embeddings to the query embedding with a single query. 48 | with Session(create_database_engine(config)) as session: 49 | corrected_oversample = oversample * config.chunk_max_size / RAGLiteConfig.chunk_max_size 50 | num_hits = round(corrected_oversample) * max(num_results, 10) 51 | dist = ChunkEmbedding.embedding.distance( # type: ignore[attr-defined] 52 | query_embedding, metric=config.vector_search_distance_metric 53 | ).label("dist") 54 | sim = (1.0 - dist).label("sim") 55 | top_vectors = select(ChunkEmbedding.chunk_id, sim).order_by(dist).limit(num_hits).subquery() 56 | sim_norm = func.max(top_vectors.c.sim).label("sim_norm") 57 | statement = ( 58 | select(top_vectors.c.chunk_id, sim_norm) 59 | .group_by(top_vectors.c.chunk_id) 60 | .order_by(sim_norm.desc()) 61 | .limit(num_results) 62 | ) 63 | rows = session.exec(statement).all() 64 | chunk_ids = [row[0] for row in rows] 65 | similarity = [float(row[1]) for row in rows] 66 | return chunk_ids, similarity 67 | 68 | 69 | def keyword_search( 70 | query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None 71 | ) -> tuple[list[ChunkId], list[float]]: 72 | """Search chunks using BM25 keyword search.""" 73 | # Read the config. 74 | config = config or RAGLiteConfig() 75 | # Connect to the database. 76 | with Session(create_database_engine(config)) as session: 77 | dialect = session.get_bind().dialect.name 78 | if dialect == "postgresql": 79 | # Convert the query to a tsquery [1]. 80 | # [1] https://www.postgresql.org/docs/current/textsearch-controls.html 81 | query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", " ", query) 82 | tsv_query = " | ".join(query_escaped.split()) 83 | # Perform keyword search with tsvector. 84 | statement = text(""" 85 | SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score 86 | FROM chunk 87 | WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query) 88 | ORDER BY score DESC 89 | LIMIT :limit; 90 | """) 91 | results = session.execute(statement, params={"query": tsv_query, "limit": num_results}) 92 | elif dialect == "duckdb": 93 | statement = text( 94 | """ 95 | SELECT chunk_id, score 96 | FROM ( 97 | SELECT id AS chunk_id, fts_main_chunk.match_bm25(id, :query) AS score 98 | FROM chunk 99 | ) sq 100 | WHERE score IS NOT NULL 101 | ORDER BY score DESC 102 | LIMIT :limit; 103 | """ 104 | ) 105 | results = session.execute(statement, params={"query": query, "limit": num_results}) 106 | # Unpack the results. 107 | results = list(results) # type: ignore[assignment] 108 | chunk_ids = [result.chunk_id for result in results] 109 | keyword_score = [result.score for result in results] 110 | return chunk_ids, keyword_score 111 | 112 | 113 | def reciprocal_rank_fusion( 114 | rankings: list[list[ChunkId]], *, k: int = 60, weights: list[float] | None = None 115 | ) -> tuple[list[ChunkId], list[float]]: 116 | """Reciprocal Rank Fusion.""" 117 | if weights is None: 118 | weights = [1.0] * len(rankings) 119 | if len(weights) != len(rankings): 120 | error = "The number of weights must match the number of rankings." 121 | raise ValueError(error) 122 | # Compute the RRF score. 123 | chunk_id_score: defaultdict[str, float] = defaultdict(float) 124 | for ranking, weight in zip(rankings, weights, strict=True): 125 | for i, chunk_id in enumerate(ranking): 126 | chunk_id_score[chunk_id] += weight / (k + i) 127 | # Exit early if there are no results to fuse. 128 | if not chunk_id_score: 129 | return [], [] 130 | # Rank RRF results according to descending RRF score. 131 | rrf_chunk_ids, rrf_score = zip( 132 | *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True 133 | ) 134 | return list(rrf_chunk_ids), list(rrf_score) 135 | 136 | 137 | def hybrid_search( # noqa: PLR0913 138 | query: str, 139 | *, 140 | num_results: int = 3, 141 | oversample: int = 2, 142 | vector_search_weight: float = 0.75, 143 | keyword_search_weight: float = 0.25, 144 | config: RAGLiteConfig | None = None, 145 | ) -> tuple[list[ChunkId], list[float]]: 146 | """Search chunks by combining ANN vector search with BM25 keyword search.""" 147 | # Run both searches. 148 | vs_chunk_ids, _ = vector_search(query, num_results=oversample * num_results, config=config) 149 | ks_chunk_ids, _ = keyword_search(query, num_results=oversample * num_results, config=config) 150 | # Combine the results with Reciprocal Rank Fusion (RRF). 151 | chunk_ids, hybrid_score = reciprocal_rank_fusion( 152 | [vs_chunk_ids, ks_chunk_ids], weights=[vector_search_weight, keyword_search_weight] 153 | ) 154 | chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results] 155 | return chunk_ids, hybrid_score 156 | 157 | 158 | def retrieve_chunks( 159 | chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None 160 | ) -> list[Chunk]: 161 | """Retrieve chunks by their ids.""" 162 | if not chunk_ids: 163 | return [] 164 | with Session(create_database_engine(config := config or RAGLiteConfig())) as session: 165 | chunks = list( 166 | session.exec( 167 | select(Chunk) 168 | .where(col(Chunk.id).in_(chunk_ids)) 169 | # Eagerly load chunk.document. 170 | .options(joinedload(Chunk.document)) # type: ignore[arg-type] 171 | ).all() 172 | ) 173 | chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id)) 174 | return chunks 175 | 176 | 177 | def retrieve_chunk_spans( 178 | chunk_ids: list[ChunkId] | list[Chunk], 179 | *, 180 | neighbors: tuple[int, ...] | None = (-1, 1), 181 | config: RAGLiteConfig | None = None, 182 | ) -> list[ChunkSpan]: 183 | """Group chunks into contiguous chunk spans and retrieve them. 184 | 185 | Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as 186 | determined by the order in which they are provided to this function. 187 | """ 188 | # Exit early if the input is empty. 189 | if not chunk_ids: 190 | return [] 191 | # Retrieve the chunks. 192 | config = config or RAGLiteConfig() 193 | chunks: list[Chunk] = ( 194 | retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] 195 | if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids) 196 | else chunk_ids 197 | ) 198 | # Assign a reciprocal ranking score to each chunk based on its position in the original list. 199 | chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)} 200 | # Extend the chunks with their neighbouring chunks. 201 | with Session(create_database_engine(config)) as session: 202 | if neighbors: 203 | neighbor_conditions = [ 204 | and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset) 205 | for chunk in chunks 206 | for offset in neighbors 207 | ] 208 | chunks += list( 209 | session.exec( 210 | select(Chunk) 211 | .where(or_(*neighbor_conditions)) 212 | # Eagerly load chunk.document. 213 | .options(joinedload(Chunk.document)) # type: ignore[arg-type] 214 | ).all() 215 | ) 216 | # Deduplicate and sort the chunks by document_id and index (needed for groupby). 217 | unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index)) 218 | # Group the chunks into contiguous segments. 219 | chunk_spans: list[ChunkSpan] = [] 220 | for _, group in groupby(unique_chunks, key=lambda chunk: chunk.document_id): 221 | chunk_sequence: list[Chunk] = [] 222 | for chunk in group: 223 | if not chunk_sequence or chunk.index == chunk_sequence[-1].index + 1: 224 | chunk_sequence.append(chunk) 225 | else: 226 | chunk_spans.append(ChunkSpan(chunks=chunk_sequence)) 227 | chunk_sequence = [chunk] 228 | chunk_spans.append(ChunkSpan(chunks=chunk_sequence)) 229 | # Rank segments according to the aggregate relevance of their chunks. 230 | chunk_spans.sort( 231 | key=lambda chunk_span: sum( 232 | chunk_id_to_score.get(chunk.id, 0.0) for chunk in chunk_span.chunks 233 | ), 234 | reverse=True, 235 | ) 236 | return chunk_spans 237 | 238 | 239 | def rerank_chunks( 240 | query: str, chunk_ids: list[ChunkId] | list[Chunk], *, config: RAGLiteConfig | None = None 241 | ) -> list[Chunk]: 242 | """Rerank chunks according to their relevance to a given query.""" 243 | # Retrieve the chunks. 244 | config = config or RAGLiteConfig() 245 | chunks: list[Chunk] = ( 246 | retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] 247 | if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids) 248 | else chunk_ids 249 | ) 250 | # Exit early if no reranker is configured or if the input is empty. 251 | if not config.reranker or not chunks: 252 | return chunks 253 | # Select the reranker. 254 | if isinstance(config.reranker, dict): 255 | # Detect the languages of the chunks and queries. 256 | with contextlib.suppress(LangDetectException): 257 | langs = {detect(str(chunk)) for chunk in chunks} 258 | langs.add(detect(query)) 259 | # If all chunks and the query are in the same language, use a language-specific reranker. 260 | rerankers = config.reranker 261 | if len(langs) == 1 and (lang := next(iter(langs))) in rerankers: 262 | reranker = rerankers[lang] 263 | else: 264 | reranker = rerankers.get("other") 265 | else: 266 | # A specific reranker was configured. 267 | reranker = config.reranker 268 | # Rerank the chunks. 269 | if reranker: 270 | results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks]) 271 | chunks = [chunks[result.doc_id] for result in results.results] 272 | return chunks 273 | 274 | 275 | def search_and_rerank_chunks( 276 | query: str, 277 | *, 278 | num_results: int = 8, 279 | oversample: int = 4, 280 | search: BasicSearchMethod = hybrid_search, 281 | config: RAGLiteConfig | None = None, 282 | ) -> list[Chunk]: 283 | """Search and rerank chunks.""" 284 | chunk_ids, _ = search(query, num_results=oversample * num_results, config=config) 285 | chunks = rerank_chunks(query, chunk_ids, config=config)[:num_results] 286 | return chunks 287 | 288 | 289 | def search_and_rerank_chunk_spans( # noqa: PLR0913 290 | query: str, 291 | *, 292 | num_results: int = 8, 293 | oversample: int = 4, 294 | neighbors: tuple[int, ...] | None = (-1, 1), 295 | search: BasicSearchMethod = hybrid_search, 296 | config: RAGLiteConfig | None = None, 297 | ) -> list[ChunkSpan]: 298 | """Search and rerank chunks, and then collate into chunk spans.""" 299 | chunk_ids, _ = search(query, num_results=oversample * num_results, config=config) 300 | chunks = rerank_chunks(query, chunk_ids, config=config)[:num_results] 301 | chunk_spans = retrieve_chunk_spans(chunks, neighbors=neighbors, config=config) 302 | return chunk_spans 303 | -------------------------------------------------------------------------------- /src/raglite/_split_chunklets.py: -------------------------------------------------------------------------------- 1 | """Split a document into chunklets.""" 2 | 3 | from collections.abc import Callable 4 | 5 | import numpy as np 6 | from markdown_it import MarkdownIt 7 | 8 | from raglite._typing import FloatVector 9 | 10 | 11 | def markdown_chunklet_boundaries(sentences: list[str]) -> FloatVector: 12 | """Estimate chunklet boundary probabilities given a Markdown document.""" 13 | # Parse the document. 14 | doc = "".join(sentences) 15 | md = MarkdownIt() 16 | tokens = md.parse(doc) 17 | # Identify the character index of each line in the document. 18 | lines = doc.splitlines(keepends=True) 19 | line_start_char = [0] 20 | for line in lines[:-1]: 21 | line_start_char.append(line_start_char[-1] + len(line)) 22 | # Identify the character index of each sentence in the document. 23 | sentence_start_char = [0] 24 | for sentence in sentences: 25 | sentence_start_char.append(sentence_start_char[-1] + len(sentence)) 26 | # Map each line index to a corresponding sentence index. 27 | line_to_sentence = np.searchsorted(sentence_start_char, line_start_char, side="right") - 1 28 | # Configure probabilities for token types to be chunklet boundaries. 29 | token_type_to_proba = { 30 | "blockquote_open": 0.75, 31 | "bullet_list_open": 0.25, 32 | "heading_open": 1.0, 33 | "paragraph_open": 0.5, 34 | "ordered_list_open": 0.25, 35 | } 36 | # Compute the boundary probabilities for each sentence. 37 | last_sentence = -1 38 | boundary_probas = np.zeros(len(sentences)) 39 | for token in tokens: 40 | if token.type in token_type_to_proba: 41 | start_line, _ = token.map # type: ignore[misc] 42 | if (i := line_to_sentence[start_line]) != last_sentence: 43 | boundary_probas[i] = token_type_to_proba[token.type] 44 | last_sentence = i # type: ignore[assignment] 45 | # For segments of consecutive boundaries, encourage splitting on the largest boundary in the 46 | # segment by setting the other boundary probabilities in the segment to zero. 47 | mask = boundary_probas != 0.0 48 | split_indices = np.flatnonzero(mask[1:] != mask[:-1]) + 1 49 | segments = np.split(boundary_probas, split_indices) 50 | for segment in segments: 51 | max_idx, max_proba = np.argmax(segment), np.max(segment) 52 | segment[:] = 0.0 53 | segment[max_idx] = max_proba 54 | boundary_probas = np.concatenate(segments) 55 | return boundary_probas 56 | 57 | 58 | def compute_num_statements(sentences: list[str]) -> FloatVector: 59 | """Compute the approximate number of statements of each sentence in a list of sentences.""" 60 | sentence_word_length = np.asarray( 61 | [len(sentence.split()) for sentence in sentences], dtype=np.float64 62 | ) 63 | q25, q75 = np.quantile(sentence_word_length, [0.25, 0.75]) 64 | q25 = max(q25, np.sqrt(np.finfo(np.float64).eps)) 65 | q75 = max(q75, q25 + np.sqrt(np.finfo(np.float64).eps)) 66 | num_statements = np.piecewise( 67 | sentence_word_length, 68 | [sentence_word_length <= q25, sentence_word_length > q25], 69 | [lambda n: 0.75 * n / q25, lambda n: 0.75 + 0.5 * (n - q25) / (q75 - q25)], 70 | ) 71 | return num_statements 72 | 73 | 74 | def split_chunklets( 75 | sentences: list[str], 76 | boundary_cost: Callable[[FloatVector], float] = lambda p: (1.0 - p[0]) + np.sum(p[1:]), 77 | statement_cost: Callable[[float], float] = lambda s: ((s - 3) ** 2 / np.sqrt(max(s, 1e-6)) / 2), 78 | max_size: int = 2048, 79 | ) -> list[str]: 80 | """Split sentences into optimal chunklets. 81 | 82 | A chunklet is a concatenated contiguous list of sentences from a document. This function 83 | optimally partitions a document into chunklets using dynamic programming. 84 | 85 | A chunklet is considered optimal when it contains as close to 3 statements as possible, when the 86 | first sentence in the chunklet is a Markdown boundary such as the start of a heading or 87 | paragraph, and when the remaining sentences in the chunklet are not Markdown boundaries. 88 | 89 | Here, we define the number of statements in a sentence as a measure of the sentence's 90 | information content. A sentence is said to contain 1 statement if it contains the median number 91 | of words per sentence, across sentences in the document. 92 | 93 | The given document of sentences is optimally partitioned into chunklets by solving a dynamic 94 | programming problem that assigns a cost to each chunklet given the 95 | `boundary_cost(boundary_probas)` function and the `statement_cost(num_statements)` function. The 96 | former outputs the cost associated with the boundaries of the chunklet's sentences given the 97 | sentences' Markdown boundary probabilities, while the latter outputs the cost of the total 98 | number of statements in the chunklet. 99 | 100 | Parameters 101 | ---------- 102 | sentences 103 | The input document as a list of sentences. 104 | boundary_cost 105 | A function that computes the boundary cost of a chunklet given the boundary probabilities of 106 | its sentences. The total cost of a chunklet is the sum of its boundary and statement cost. 107 | statement_cost 108 | A function that computes the statement cost of a chunklet given its number of statements. 109 | The total cost of a chunklet is the sum of its boundary and statement cost. 110 | max_size 111 | The maximum size of a chunklet in characters. 112 | 113 | Returns 114 | ------- 115 | list[str] 116 | The document optimally partitioned into chunklets. 117 | """ 118 | # Precompute chunklet boundary probabilities and each sentence's number of statements. 119 | boundary_probas = markdown_chunklet_boundaries(sentences) 120 | num_statements = compute_num_statements(sentences) 121 | # Initialize a dynamic programming table and backpointers. The dynamic programming table dp[i] 122 | # is defined as the minimum cost to segment the first i sentences (i.e., sentences[:i]). 123 | num_sentences = len(sentences) 124 | dp = np.full(num_sentences + 1, np.inf) 125 | dp[0] = 0.0 126 | back = -np.ones(num_sentences + 1, dtype=np.intp) 127 | # Compute the cost of partitioning sentences into chunklets. 128 | for i in range(1, num_sentences + 1): 129 | for j in range(i): 130 | # Limit the chunklets to a maximum size. 131 | if sum(len(s) for s in sentences[j:i]) > max_size: 132 | continue 133 | # Compute the cost of partitioning sentences[j:i] into a single chunklet. 134 | cost_ji = boundary_cost(boundary_probas[j:i]) 135 | cost_ji += statement_cost(np.sum(num_statements[j:i])) 136 | # Compute the cost of partitioning sentences[:i] if we were to split at j. 137 | cost_0i = dp[j] + cost_ji 138 | # If the cost is less than the current minimum, update the DP table and backpointer. 139 | if cost_0i < dp[i]: 140 | dp[i] = cost_0i 141 | back[i] = j 142 | # Recover the optimal partitioning. 143 | partition_indices: list[int] = [] 144 | i = back[num_sentences] 145 | while i > 0: 146 | partition_indices.append(i) 147 | i = back[i] 148 | partition_indices.reverse() 149 | # Split the sentences into optimal chunklets. 150 | chunklets = [ 151 | "".join(sentences[i:j]) 152 | for i, j in zip([0, *partition_indices], [*partition_indices, len(sentences)], strict=True) 153 | ] 154 | return chunklets 155 | -------------------------------------------------------------------------------- /src/raglite/_split_chunks.py: -------------------------------------------------------------------------------- 1 | """Split a document into semantic chunks.""" 2 | 3 | import re 4 | 5 | import numpy as np 6 | from scipy.optimize import linprog 7 | from scipy.sparse import coo_matrix 8 | 9 | from raglite._typing import FloatMatrix 10 | 11 | 12 | def split_chunks( # noqa: C901, PLR0915 13 | chunklets: list[str], 14 | chunklet_embeddings: FloatMatrix, 15 | max_size: int = 2048, 16 | ) -> tuple[list[str], list[FloatMatrix]]: 17 | """Split chunklets into optimal semantic chunks with corresponding chunklet embeddings. 18 | 19 | A chunk is a concatenated contiguous list of chunklets from a document. This function 20 | optimally partitions a document into chunks using binary integer programming. 21 | 22 | A partioning of a document into chunks is considered optimal if the total cost of partitioning 23 | the document into chunks is minimized. The cost of adding a partition point is given by the 24 | cosine similarity of the chunklet embedding before and after the partition point, corrected by 25 | the discourse vector of the chunklet embeddings across the document. 26 | 27 | Parameters 28 | ---------- 29 | chunklets 30 | The input document as a list of chunklets. 31 | chunklet_embeddings 32 | A NumPy array wherein the i'th row is an embedding vector corresponding to the i'th 33 | chunklet. Embedding vectors are expected to have nonzero length. 34 | max_size 35 | The maximum size of a chunk in characters. 36 | 37 | Returns 38 | ------- 39 | tuple[list[str], list[FloatMatrix]] 40 | The document and chunklet embeddings optimally partitioned into chunks. 41 | """ 42 | # Validate the input. 43 | chunklet_size = np.asarray([len(chunklet) for chunklet in chunklets]) 44 | if not np.all(chunklet_size <= max_size): 45 | error_message = "Chunklet larger than chunk max_size detected." 46 | raise ValueError(error_message) 47 | if not np.all(np.linalg.norm(chunklet_embeddings, axis=1) > 0.0): 48 | error_message = "Chunklet embeddings with zero norm detected." 49 | raise ValueError(error_message) 50 | # Exit early if there is only one chunk to return. 51 | if len(chunklets) <= 1 or sum(chunklet_size) <= max_size: 52 | return ["".join(chunklets)] if chunklets else chunklets, [chunklet_embeddings] 53 | # Normalise the chunklet embeddings to unit norm. 54 | X = chunklet_embeddings.astype(np.float32) # noqa: N806 55 | X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806 56 | # Select nonoutlying chunklets and remove the discourse vector. 57 | q15, q85 = np.quantile(chunklet_size, [0.15, 0.85]) 58 | nonoutlying_chunklets = (q15 <= chunklet_size) & (chunklet_size <= q85) 59 | if np.any(nonoutlying_chunklets): 60 | discourse = np.mean(X[nonoutlying_chunklets, :], axis=0) 61 | discourse = discourse / np.linalg.norm(discourse) 62 | X_modulo = X - np.outer(X @ discourse, discourse) # noqa: N806 63 | if not np.any(np.linalg.norm(X_modulo, axis=1) <= np.finfo(X.dtype).eps): 64 | X = X_modulo # noqa: N806 65 | X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806 66 | # For each partition point in the list of chunklets, compute the similarity of chunklet before 67 | # and after the partition point. 68 | partition_similarity = np.sum(X[:-1] * X[1:], axis=1) 69 | # Make partition similarity nonnegative before modification and optimisation. 70 | partition_similarity = np.maximum( 71 | (partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps) 72 | ) 73 | # Modify the partition similarity to encourage splitting on Markdown headings. 74 | prev_chunklet_is_heading = True 75 | for i, chunklet in enumerate(chunklets[:-1]): 76 | is_heading = bool(re.match(r"^#+\s", chunklet.replace("\n", "").strip())) 77 | if is_heading: 78 | # Encourage splitting before a heading. 79 | if not prev_chunklet_is_heading: 80 | partition_similarity[i - 1] = partition_similarity[i - 1] / 4 81 | # Don't split immediately after a heading. 82 | partition_similarity[i] = 1.0 83 | prev_chunklet_is_heading = is_heading 84 | # Solve an optimisation problem to find the best partition points. 85 | chunklet_size_cumsum = np.cumsum(chunklet_size) 86 | row_indices = [] 87 | col_indices = [] 88 | data = [] 89 | for i in range(len(chunklets) - 1): 90 | r = chunklet_size_cumsum[i - 1] if i > 0 else 0 91 | idx = np.searchsorted(chunklet_size_cumsum - r, max_size, side="right") 92 | assert idx > i 93 | if idx == len(chunklet_size_cumsum): 94 | break 95 | cols = list(range(i, idx)) 96 | col_indices.extend(cols) 97 | row_indices.extend([i] * len(cols)) 98 | data.extend([1] * len(cols)) 99 | A = coo_matrix( # noqa: N806 100 | (data, (row_indices, col_indices)), 101 | shape=(max(row_indices) + 1, len(chunklets) - 1), 102 | dtype=np.float32, 103 | ) 104 | b_ub = np.ones(A.shape[0], dtype=np.float32) 105 | res = linprog( 106 | partition_similarity, 107 | A_ub=-A, 108 | b_ub=-b_ub, 109 | bounds=(0, 1), 110 | integrality=[1] * A.shape[1], 111 | ) 112 | if not res.success: 113 | error_message = "Optimization of chunk partitions failed." 114 | raise ValueError(error_message) 115 | # Split the chunklets and their window embeddings into optimal chunks. 116 | partition_indices = (np.where(res.x)[0] + 1).tolist() 117 | chunks = [ 118 | "".join(chunklets[i:j]) 119 | for i, j in zip([0, *partition_indices], [*partition_indices, len(chunklets)], strict=True) 120 | ] 121 | chunk_embeddings = np.split(chunklet_embeddings, partition_indices) 122 | return chunks, chunk_embeddings 123 | -------------------------------------------------------------------------------- /src/raglite/_split_sentences.py: -------------------------------------------------------------------------------- 1 | """Split a document into sentences.""" 2 | 3 | import warnings 4 | from collections.abc import Callable 5 | from functools import cache 6 | from typing import Any 7 | 8 | import numpy as np 9 | from markdown_it import MarkdownIt 10 | from scipy import sparse 11 | from scipy.optimize import OptimizeWarning, linprog 12 | from wtpsplit_lite import SaT 13 | 14 | from raglite._typing import FloatVector 15 | 16 | 17 | @cache 18 | def _load_sat() -> tuple[SaT, dict[str, Any]]: 19 | """Load a Segment any Text (SaT) model.""" 20 | sat = SaT("sat-3l-sm") # This model makes the best trade-off between speed and accuracy. 21 | sat_kwargs = {"stride": 128, "block_size": 256, "weighting": "hat"} 22 | return sat, sat_kwargs 23 | 24 | 25 | def markdown_sentence_boundaries(doc: str) -> FloatVector: 26 | """Determine known sentence boundaries from a Markdown document.""" 27 | 28 | def get_markdown_heading_indexes(doc: str) -> list[tuple[int, int]]: 29 | """Get the indexes of the headings in a Markdown document.""" 30 | md = MarkdownIt() 31 | tokens = md.parse(doc) 32 | headings = [] 33 | lines = doc.splitlines(keepends=True) 34 | line_start_char = [0] 35 | for line in lines: 36 | line_start_char.append(line_start_char[-1] + len(line)) 37 | for token in tokens: 38 | if token.type == "heading_open": 39 | start_line, end_line = token.map # type: ignore[misc] 40 | heading_start = line_start_char[start_line] 41 | heading_end = line_start_char[end_line] 42 | headings.append((heading_start, heading_end + 1)) 43 | return headings 44 | 45 | # Get the start and end character indexes of the headings in the document. 46 | headings = get_markdown_heading_indexes(doc) 47 | # Indicate that each heading is a contiguous sentence by setting the boundary probabilities. 48 | boundary_probas = np.full(len(doc), np.nan) 49 | for heading_start, heading_end in headings: 50 | if 0 <= heading_start - 1 < len(boundary_probas): 51 | boundary_probas[heading_start - 1] = 1 # First heading character starts a sentence. 52 | boundary_probas[heading_start : heading_end - 1] = 0 # Body does not contain boundaries. 53 | if 0 <= heading_end - 1 < len(boundary_probas): 54 | boundary_probas[heading_end - 1] = 1 # Last heading character is the end of a sentence. 55 | return boundary_probas 56 | 57 | 58 | def _split_sentences( 59 | doc: str, probas: FloatVector, *, min_len: int, max_len: int | None = None 60 | ) -> list[str]: 61 | # Solve an optimisation problem to find the best sentence boundaries given the predicted 62 | # boundary probabilities. The objective is to select boundaries that maximise the sum of the 63 | # boundary probabilities above a given threshold, subject to the resulting sentences not being 64 | # shorter or longer than the given minimum or maximum length, respectively. 65 | sentence_threshold = 0.25 # Default threshold for -sm models. 66 | c = probas - sentence_threshold 67 | N = len(probas) # noqa: N806 68 | M = N - min_len + 1 # noqa: N806 69 | diagonals = [np.ones(M, dtype=np.float32) for _ in range(min_len)] 70 | offsets = list(range(min_len)) 71 | A_min = sparse.diags(diagonals, offsets, shape=(M, N), format="csr") # noqa: N806 72 | b_min = np.ones(M, dtype=np.float32) 73 | bounds = [(0, 1)] * N 74 | bounds[: min_len - 1] = [(0, 0)] * (min_len - 1) # Prevent short leading sentences. 75 | bounds[-min_len:] = [(0, 0)] * min_len # Prevent short trailing sentences. 76 | if max_len is not None and (M := N - max_len + 1) > 0: # noqa: N806 77 | diagonals = [np.ones(M, dtype=np.float32) for _ in range(max_len)] 78 | offsets = list(range(max_len)) 79 | A_max = sparse.diags(diagonals, offsets, shape=(M, N), format="csr") # noqa: N806 80 | b_max = np.ones(M, dtype=np.float32) 81 | A_ub = sparse.vstack([A_min, -A_max], format="csr") # noqa: N806 82 | b_ub = np.hstack([b_min, -b_max]) 83 | else: 84 | A_ub = A_min # noqa: N806 85 | b_ub = b_min 86 | x0 = (probas >= sentence_threshold).astype(np.float32) 87 | with warnings.catch_warnings(): 88 | warnings.filterwarnings("ignore", category=OptimizeWarning) # Ignore x0 not being used. 89 | res = linprog(-c, A_ub=A_ub, b_ub=b_ub, bounds=bounds, x0=x0, integrality=[1] * N) 90 | if not res.success: 91 | error_message = "Optimization of sentence partitions failed." 92 | raise ValueError(error_message) 93 | # Split the document into sentences where the boundary probability exceeds a threshold. 94 | partition_indices = np.where(res.x > 0.5)[0] + 1 # noqa: PLR2004 95 | sentences = [ 96 | doc[i:j] for i, j in zip([0, *partition_indices], [*partition_indices, None], strict=True) 97 | ] 98 | return sentences 99 | 100 | 101 | def split_sentences( 102 | doc: str, 103 | min_len: int = 4, 104 | max_len: int | None = None, 105 | boundary_probas: FloatVector | Callable[[str], FloatVector] = markdown_sentence_boundaries, 106 | ) -> list[str]: 107 | """Split a document into sentences. 108 | 109 | Parameters 110 | ---------- 111 | doc 112 | The document to split into sentences. 113 | min_len 114 | The minimum number of characters in a sentence. 115 | max_len 116 | The maximum number of characters in a sentence, with no maximum by default. 117 | boundary_probas 118 | Any known sentence boundary probabilities to override the model's predicted sentence 119 | boundary probabilities. If an element of the probability vector with index k is 1 (0), then 120 | the character at index k + 1 is (not) the start of a sentence. Elements set to NaN will not 121 | override the predicted probabilities. By default, the known sentence boundary probabilities 122 | are determined from the document's Markdown headings. 123 | 124 | Returns 125 | ------- 126 | list[str] 127 | The input document partitioned into sentences. All sentences are constructed to contain at 128 | least one non-whitespace character, not have any leading whitespace (except for the first 129 | sentence if the document itself has leading whitespace), and respect the minimum and maximum 130 | sentence lengths. 131 | """ 132 | # Exit early if there is only one sentence to return. 133 | if len(doc) <= min_len: 134 | return [doc] 135 | # Compute the sentence boundary probabilities with a wtpsplit Segment any Text (SaT) model. 136 | sat, sat_kwargs = _load_sat() 137 | predicted_probas = sat.predict_proba(doc, **sat_kwargs) 138 | # Override the predicted boundary probabilities with the known boundary probabilities. 139 | known_probas = boundary_probas(doc) if callable(boundary_probas) else boundary_probas 140 | probas = predicted_probas.copy() 141 | probas[np.isfinite(known_probas)] = known_probas[np.isfinite(known_probas)] 142 | # Propagate the boundary probabilities so that whitespace is always trailing and never leading. 143 | is_space = np.array([char.isspace() for char in doc], dtype=np.bool_) 144 | start = np.where(np.insert(~is_space[:-1] & is_space[1:], len(is_space) - 1, False))[0] 145 | end = np.where(np.insert(~is_space[1:] & is_space[:-1], 0, False))[0] 146 | start = start[start < np.max(end, initial=-1)] 147 | end = end[end > np.min(start, initial=len(is_space))] 148 | for i, j in zip(start, end, strict=True): 149 | min_proba, max_proba = np.min(probas[i:j]), np.max(probas[i:j]) 150 | probas[i : j - 1] = min_proba # From the non-whitespace to the penultimate whitespace char. 151 | probas[j - 1] = max_proba # The last whitespace char. 152 | # Solve an optimization problem to find optimal sentences with no maximum length. We delay the 153 | # maximum length constraint to a subsequent step to avoid blowing up memory usage. 154 | sentences = _split_sentences(doc, probas, min_len=min_len, max_len=None) 155 | # For each sentence that exceeds the maximum length, solve the optimization problem again with 156 | # a maximum length constraint. 157 | if max_len is not None: 158 | sentences = [ 159 | subsentence 160 | for sentence in sentences 161 | for subsentence in ( 162 | [sentence] 163 | if len(sentence) <= max_len 164 | else _split_sentences( 165 | sentence, 166 | probas[doc.index(sentence) : doc.index(sentence) + len(sentence)], 167 | min_len=min_len, 168 | max_len=max_len, 169 | ) 170 | ) 171 | ] 172 | return sentences 173 | -------------------------------------------------------------------------------- /src/raglite/_typing.py: -------------------------------------------------------------------------------- 1 | """RAGLite typing.""" 2 | 3 | import io 4 | import pickle 5 | from collections.abc import Callable 6 | from typing import TYPE_CHECKING, Any, Literal, Protocol 7 | 8 | import numpy as np 9 | from sqlalchemy import literal 10 | from sqlalchemy.engine import Dialect 11 | from sqlalchemy.ext.compiler import compiles 12 | from sqlalchemy.sql.functions import FunctionElement 13 | from sqlalchemy.sql.operators import Operators 14 | from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType 15 | 16 | if TYPE_CHECKING: 17 | from raglite._config import RAGLiteConfig 18 | from raglite._database import Chunk, ChunkSpan 19 | 20 | ChunkId = str 21 | DocumentId = str 22 | EvalId = str 23 | IndexId = str 24 | 25 | DistanceMetric = Literal["cosine", "dot", "l1", "l2"] 26 | 27 | FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] 28 | FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]] 29 | IntVector = np.ndarray[tuple[int], np.dtype[np.intp]] 30 | 31 | 32 | class BasicSearchMethod(Protocol): 33 | def __call__( 34 | self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None 35 | ) -> tuple[list[ChunkId], list[float]]: ... 36 | 37 | 38 | class SearchMethod(Protocol): 39 | def __call__( 40 | self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None 41 | ) -> tuple[list[ChunkId], list[float]] | list["Chunk"] | list["ChunkSpan"]: ... 42 | 43 | 44 | class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): 45 | """A NumPy array column type for SQLAlchemy.""" 46 | 47 | impl = LargeBinary 48 | 49 | def process_bind_param( 50 | self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect 51 | ) -> bytes | None: 52 | """Convert a NumPy array to bytes.""" 53 | if value is None: 54 | return None 55 | buffer = io.BytesIO() 56 | np.save(buffer, value, allow_pickle=False, fix_imports=False) 57 | return buffer.getvalue() 58 | 59 | def process_result_value( 60 | self, value: bytes | None, dialect: Dialect 61 | ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: 62 | """Convert bytes to a NumPy array.""" 63 | if value is None: 64 | return None 65 | return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return] 66 | 67 | 68 | class PickledObject(TypeDecorator[object]): 69 | """A pickled object column type for SQLAlchemy.""" 70 | 71 | impl = LargeBinary 72 | 73 | def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None: 74 | """Convert a Python object to bytes.""" 75 | if value is None: 76 | return None 77 | return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False) 78 | 79 | def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None: 80 | """Convert bytes to a Python object.""" 81 | if value is None: 82 | return None 83 | return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301 84 | 85 | 86 | class EmbeddingDistance(FunctionElement[float]): 87 | """SQL expression that renders a distance operator per dialect.""" 88 | 89 | inherit_cache = True 90 | type = Float() # The result is always a scalar float. 91 | 92 | def __init__(self, left: Any, right: Any, metric: DistanceMetric) -> None: 93 | self.metric = metric 94 | super().__init__(left, right) 95 | 96 | 97 | @compiles(EmbeddingDistance, "postgresql") 98 | def _embedding_distance_postgresql(element: EmbeddingDistance, compiler: Any, **kwargs: Any) -> str: 99 | op_map: dict[DistanceMetric, str] = { 100 | "cosine": "<=>", 101 | "dot": "<#>", 102 | "l1": "<+>", 103 | "l2": "<->", 104 | } 105 | left, right = list(element.clauses) 106 | operator = op_map[element.metric] 107 | return f"({compiler.process(left)} {operator} {compiler.process(right)})" 108 | 109 | 110 | @compiles(EmbeddingDistance, "duckdb") 111 | def _embedding_distance_duckdb(element: EmbeddingDistance, compiler: Any, **kwargs: Any) -> str: 112 | func_map: dict[DistanceMetric, str] = { 113 | "cosine": "array_cosine_distance", 114 | "dot": "array_negative_inner_product", 115 | "l2": "array_distance", 116 | } 117 | left, right = list(element.clauses) 118 | dim = left.type.dim # type: ignore[attr-defined] 119 | func_name = func_map[element.metric] 120 | right_cast = f"{compiler.process(right)}::FLOAT[{dim}]" 121 | return f"{func_name}({compiler.process(left)}, {right_cast})" 122 | 123 | 124 | class EmbeddingComparator(UserDefinedType.Comparator[FloatVector]): 125 | """An embedding distance comparator.""" 126 | 127 | def distance(self, other: FloatVector, *, metric: DistanceMetric) -> Operators: 128 | rhs = literal(other, type_=self.expr.type) 129 | return EmbeddingDistance(self.expr, rhs, metric) 130 | 131 | 132 | class PostgresHalfVec(UserDefinedType[FloatVector]): 133 | """A PostgreSQL half-precision vector column type for SQLAlchemy.""" 134 | 135 | cache_ok = True 136 | 137 | def __init__(self, dim: int | None = None) -> None: 138 | super().__init__() 139 | self.dim = dim 140 | 141 | def get_col_spec(self, **kwargs: Any) -> str: 142 | return f"halfvec({self.dim})" 143 | 144 | def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]: 145 | """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters.""" 146 | 147 | def process(value: FloatVector | None) -> str | None: 148 | return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None 149 | 150 | return process 151 | 152 | def result_processor( 153 | self, dialect: Dialect, coltype: Any 154 | ) -> Callable[[str | None], FloatVector | None]: 155 | """Process PostgreSQL halfvec format to NumPy ndarray.""" 156 | 157 | def process(value: str | None) -> FloatVector | None: 158 | if value is None: 159 | return None 160 | return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16) 161 | 162 | return process 163 | 164 | 165 | class DuckDBSingleVec(UserDefinedType[FloatVector]): 166 | """A DuckDB single precision vector column type for SQLAlchemy.""" 167 | 168 | cache_ok = True 169 | 170 | def __init__(self, dim: int | None = None) -> None: 171 | super().__init__() 172 | self.dim = dim 173 | 174 | def get_col_spec(self, **kwargs: Any) -> str: 175 | return f"FLOAT[{self.dim}]" if self.dim is not None else "FLOAT[]" 176 | 177 | def bind_processor( 178 | self, dialect: Dialect 179 | ) -> Callable[[FloatVector | None], list[float] | None]: 180 | """Process NumPy ndarray to DuckDB single precision vector format for bound parameters.""" 181 | 182 | def process(value: FloatVector | None) -> list[float] | None: 183 | return np.ravel(value).tolist() if value is not None else None 184 | 185 | return process 186 | 187 | def result_processor( 188 | self, dialect: Dialect, coltype: Any 189 | ) -> Callable[[list[float] | None], FloatVector | None]: 190 | """Process DuckDB single precision vector format to NumPy ndarray.""" 191 | 192 | def process(value: list[float] | None) -> FloatVector | None: 193 | return np.asarray(value, dtype=np.float32) if value is not None else None 194 | 195 | return process 196 | 197 | 198 | class Embedding(TypeDecorator[FloatVector]): 199 | """An embedding column type for SQLAlchemy.""" 200 | 201 | cache_ok = True 202 | impl = NumpyArray 203 | comparator_factory: type[EmbeddingComparator] = EmbeddingComparator 204 | 205 | def __init__(self, dim: int = -1): 206 | super().__init__() 207 | self.dim = dim 208 | 209 | def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]: 210 | if dialect.name == "postgresql": 211 | return dialect.type_descriptor(PostgresHalfVec(self.dim)) 212 | if dialect.name == "duckdb": 213 | return dialect.type_descriptor(DuckDBSingleVec(self.dim)) 214 | return dialect.type_descriptor(NumpyArray()) 215 | -------------------------------------------------------------------------------- /src/raglite/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlinear-ai/raglite/17cf038a597c23911412d6e3f34a686c0eca933a/src/raglite/py.typed -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """RAGLite test suite.""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures for the tests.""" 2 | 3 | import os 4 | import socket 5 | import tempfile 6 | from collections.abc import Generator 7 | from pathlib import Path 8 | 9 | import pytest 10 | from sqlalchemy import create_engine, text 11 | 12 | from raglite import Document, RAGLiteConfig, insert_documents 13 | 14 | POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres" 15 | 16 | 17 | def pytest_configure(config: pytest.Config) -> None: 18 | """Configure pytest to skip slow tests in CI.""" 19 | if os.environ.get("CI"): 20 | markexpr = "not slow" 21 | if config.option.markexpr: 22 | markexpr = f"({config.option.markexpr}) and ({markexpr})" 23 | config.option.markexpr = markexpr 24 | 25 | 26 | def is_postgres_running() -> bool: 27 | """Check if PostgreSQL is running.""" 28 | try: 29 | with socket.create_connection(("postgres", 5432), timeout=1): 30 | return True 31 | except OSError: 32 | return False 33 | 34 | 35 | def is_openai_available() -> bool: 36 | """Check if an OpenAI API key is set.""" 37 | return bool(os.environ.get("OPENAI_API_KEY")) 38 | 39 | 40 | def pytest_sessionstart(session: pytest.Session) -> None: 41 | """Reset the PostgreSQL database.""" 42 | if is_postgres_running(): 43 | engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT") 44 | with engine.connect() as conn: 45 | for variant in ["local", "remote"]: 46 | conn.execute(text(f"DROP DATABASE IF EXISTS raglite_test_{variant}")) 47 | conn.execute(text(f"CREATE DATABASE raglite_test_{variant}")) 48 | 49 | 50 | @pytest.fixture(scope="session") 51 | def duckdb_url() -> Generator[str, None, None]: 52 | """Create a temporary DuckDB database file and return the database URL.""" 53 | with tempfile.TemporaryDirectory() as temp_dir: 54 | db_file = Path(temp_dir) / "raglite_test.db" 55 | yield f"duckdb:///{db_file}" 56 | 57 | 58 | @pytest.fixture( 59 | scope="session", 60 | params=[ 61 | pytest.param("duckdb", id="duckdb"), 62 | pytest.param( 63 | POSTGRES_URL, 64 | id="postgres", 65 | marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"), 66 | ), 67 | ], 68 | ) 69 | def database(request: pytest.FixtureRequest) -> str: 70 | """Get a database URL to test RAGLite with.""" 71 | db_url: str = ( 72 | request.getfixturevalue("duckdb_url") if request.param == "duckdb" else request.param 73 | ) 74 | return db_url 75 | 76 | 77 | @pytest.fixture( 78 | scope="session", 79 | params=[ 80 | pytest.param( 81 | ( 82 | "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192", 83 | "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512", # More context degrades performance. 84 | ), 85 | id="qwen3_4B-bge_m3", 86 | ), 87 | pytest.param( 88 | ("gpt-4o-mini", "text-embedding-3-small"), 89 | id="gpt_4o_mini-text_embedding_3_small", 90 | marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"), 91 | ), 92 | ], 93 | ) 94 | def llm_embedder(request: pytest.FixtureRequest) -> str: 95 | """Get an LLM and embedder pair to test RAGLite with.""" 96 | llm_embedder: str = request.param 97 | return llm_embedder 98 | 99 | 100 | @pytest.fixture(scope="session") 101 | def llm(llm_embedder: tuple[str, str]) -> str: 102 | """Get an LLM to test RAGLite with.""" 103 | llm, _ = llm_embedder 104 | return llm 105 | 106 | 107 | @pytest.fixture(scope="session") 108 | def embedder(llm_embedder: tuple[str, str]) -> str: 109 | """Get an embedder to test RAGLite with.""" 110 | _, embedder = llm_embedder 111 | return embedder 112 | 113 | 114 | @pytest.fixture(scope="session") 115 | def raglite_test_config(database: str, llm: str, embedder: str) -> RAGLiteConfig: 116 | """Create a lightweight in-memory config for testing DuckDB and PostgreSQL.""" 117 | # Select the database based on the embedder. 118 | variant = "local" if embedder.startswith("llama-cpp-python") else "remote" 119 | if "postgres" in database: 120 | database = database.replace("/postgres", f"/raglite_test_{variant}") 121 | elif "duckdb" in database: 122 | database = database.replace(".db", f"_{variant}.db") 123 | # Create a RAGLite config for the given database and embedder. 124 | db_config = RAGLiteConfig(db_url=database, llm=llm, embedder=embedder) 125 | # Insert a document and update the index. 126 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. 127 | insert_documents([Document.from_path(doc_path)], config=db_config) 128 | return db_config 129 | -------------------------------------------------------------------------------- /tests/specrel.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlinear-ai/raglite/17cf038a597c23911412d6e3f34a686c0eca933a/tests/specrel.pdf -------------------------------------------------------------------------------- /tests/test_chatml_function_calling.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's upgraded chatml-function-calling llama-cpp-python chat handler.""" 2 | 3 | import os 4 | from typing import TYPE_CHECKING, cast 5 | 6 | if TYPE_CHECKING: 7 | from collections.abc import Iterator 8 | 9 | import pytest 10 | from typeguard import ForwardRefPolicy, check_type 11 | 12 | from raglite._chatml_function_calling import chatml_function_calling_with_streaming 13 | from raglite._lazy_llama import ( 14 | Llama, 15 | llama_supports_gpu_offload, 16 | llama_types, 17 | ) 18 | 19 | 20 | def is_accelerator_available() -> bool: 21 | """Check if an accelerator is available.""" 22 | if llama_supports_gpu_offload(): 23 | return True 24 | if os.environ.get("CI"): 25 | return False 26 | return (os.cpu_count() or 1) >= 8 # noqa: PLR2004 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "stream", 31 | [ 32 | pytest.param(True, id="stream=True"), 33 | pytest.param(False, id="stream=False"), 34 | ], 35 | ) 36 | @pytest.mark.parametrize( 37 | "tool_choice", 38 | [ 39 | pytest.param("none", id="tool_choice=none"), 40 | pytest.param("auto", id="tool_choice=auto"), 41 | pytest.param( 42 | {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed" 43 | ), 44 | ], 45 | ) 46 | @pytest.mark.parametrize( 47 | "user_prompt_expected_tool_calls", 48 | [ 49 | pytest.param( 50 | ("Is 7 a prime number?", 0), 51 | id="expected_tool_calls=0", 52 | ), 53 | pytest.param( 54 | ("What's the weather like in Paris today?", 1), 55 | id="expected_tool_calls=1", 56 | ), 57 | pytest.param( 58 | ("What's the weather like in Paris today? What about New York?", 2), 59 | id="expected_tool_calls=2", 60 | ), 61 | ], 62 | ) 63 | @pytest.mark.parametrize( 64 | "llm_repo_id", 65 | [ 66 | pytest.param("unsloth/Qwen3-4B-GGUF", id="qwen3_4B"), 67 | pytest.param( 68 | "unsloth/Qwen3-8B-GGUF", 69 | id="qwen3_8B", 70 | marks=pytest.mark.skipif( 71 | not is_accelerator_available(), reason="Accelerator not available" 72 | ), 73 | ), 74 | ], 75 | ) 76 | @pytest.mark.slow 77 | def test_llama_cpp_python_tool_use( 78 | llm_repo_id: str, 79 | user_prompt_expected_tool_calls: tuple[str, int], 80 | tool_choice: llama_types.ChatCompletionToolChoiceOption, 81 | stream: bool, # noqa: FBT001 82 | ) -> None: 83 | """Test the upgraded chatml-function-calling llama-cpp-python chat handler.""" 84 | user_prompt, expected_tool_calls = user_prompt_expected_tool_calls 85 | if isinstance(tool_choice, dict) and expected_tool_calls == 0: 86 | pytest.skip("Nonsensical") 87 | llm = Llama.from_pretrained( 88 | repo_id=llm_repo_id, 89 | filename="*Q4_K_M.gguf", 90 | n_ctx=8192, 91 | n_gpu_layers=-1, 92 | verbose=False, 93 | chat_handler=chatml_function_calling_with_streaming, 94 | ) 95 | messages: list[llama_types.ChatCompletionRequestMessage] = [ 96 | {"role": "user", "content": user_prompt} 97 | ] 98 | tools: list[llama_types.ChatCompletionTool] = [ 99 | { 100 | "type": "function", 101 | "function": { 102 | "name": "get_weather", 103 | "description": "Get the weather for a location.", 104 | "parameters": { 105 | "type": "object", 106 | "properties": {"location": {"type": "string", "description": "A city name."}}, 107 | }, 108 | }, 109 | } 110 | ] 111 | response = llm.create_chat_completion( 112 | messages=messages, tools=tools, tool_choice=tool_choice, stream=stream 113 | ) 114 | if stream: 115 | response = cast("Iterator[llama_types.CreateChatCompletionStreamResponse]", response) 116 | num_tool_calls = 0 117 | for chunk in response: 118 | check_type(chunk, llama_types.CreateChatCompletionStreamResponse) 119 | tool_calls = chunk["choices"][0]["delta"].get("tool_calls") 120 | if isinstance(tool_calls, list): 121 | num_tool_calls = max(tool_call["index"] for tool_call in tool_calls) + 1 122 | assert num_tool_calls == (expected_tool_calls if tool_choice != "none" else 0) 123 | else: 124 | response = cast("llama_types.CreateChatCompletionResponse", response) 125 | check_type( 126 | response, 127 | llama_types.CreateChatCompletionResponse, 128 | forward_ref_policy=ForwardRefPolicy.IGNORE, 129 | ) 130 | if expected_tool_calls == 0 or tool_choice == "none": 131 | assert response["choices"][0]["message"].get("tool_calls") is None 132 | else: 133 | assert len(response["choices"][0]["message"]["tool_calls"]) == expected_tool_calls 134 | assert all( 135 | tool_call["function"]["name"] == tools[0]["function"]["name"] 136 | for tool_call in response["choices"][0]["message"]["tool_calls"] 137 | ) 138 | -------------------------------------------------------------------------------- /tests/test_database.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's database engine creation.""" 2 | 3 | from pathlib import Path 4 | 5 | from raglite import RAGLiteConfig 6 | from raglite._database import create_database_engine 7 | 8 | 9 | def test_in_memory_duckdb_creation(tmp_path: Path) -> None: 10 | """Test creating an in-memory DuckDB database.""" 11 | config = RAGLiteConfig(db_url="duckdb:///:memory:") 12 | create_database_engine(config) 13 | 14 | 15 | def test_repeated_duckdb_creation(tmp_path: Path) -> None: 16 | """Test creating the same DuckDB database engine twice.""" 17 | duckdb_filepath = tmp_path / "test.db" 18 | config = RAGLiteConfig(db_url=f"duckdb:///{duckdb_filepath.as_posix()}") 19 | create_database_engine(config) 20 | create_database_engine.cache_clear() 21 | create_database_engine(config) 22 | -------------------------------------------------------------------------------- /tests/test_embed.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's embedding functionality.""" 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | from raglite import RAGLiteConfig 8 | from raglite._embed import embed_strings 9 | from raglite._markdown import document_to_markdown 10 | from raglite._split_sentences import split_sentences 11 | 12 | 13 | def test_embed(embedder: str) -> None: 14 | """Test embedding a document.""" 15 | raglite_test_config = RAGLiteConfig(embedder=embedder, embedder_normalize=True) 16 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. 17 | doc = document_to_markdown(doc_path) 18 | sentences = split_sentences(doc, max_len=raglite_test_config.chunk_max_size) 19 | sentence_embeddings = embed_strings(sentences, config=raglite_test_config) 20 | assert isinstance(sentences, list) 21 | assert isinstance(sentence_embeddings, np.ndarray) 22 | assert len(sentences) == len(sentence_embeddings) 23 | assert sentence_embeddings.shape[1] >= 128 # noqa: PLR2004 24 | assert sentence_embeddings.dtype == np.float16 25 | assert np.all(np.isfinite(sentence_embeddings)) 26 | assert np.allclose(np.linalg.norm(sentence_embeddings, axis=1), 1.0, rtol=1e-3) 27 | -------------------------------------------------------------------------------- /tests/test_extract.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's structured output extraction.""" 2 | 3 | from typing import ClassVar 4 | 5 | import pytest 6 | from pydantic import BaseModel, ConfigDict, Field 7 | 8 | from raglite import RAGLiteConfig 9 | from raglite._extract import extract_with_llm 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")] 14 | ) 15 | def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001 16 | """Test extracting structured data.""" 17 | # Set the LLM. 18 | config = RAGLiteConfig(llm=llm) 19 | 20 | # Define the JSON schema of the response. 21 | class LoginResponse(BaseModel): 22 | """The response to a login request.""" 23 | 24 | model_config = ConfigDict(extra="forbid" if strict else "allow") 25 | username: str = Field(..., description="The username.") 26 | password: str = Field(..., description="The password.") 27 | system_prompt: ClassVar[str] = "Extract the username and password from the input." 28 | 29 | # Extract structured data. 30 | username, password = "cypher", "steak" 31 | login_response = extract_with_llm( 32 | LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config 33 | ) 34 | # Validate the response. 35 | assert isinstance(login_response, LoginResponse) 36 | assert login_response.username == username 37 | assert login_response.password == password 38 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite.""" 2 | 3 | import raglite 4 | 5 | 6 | def test_import() -> None: 7 | """Test that the package can be imported.""" 8 | assert isinstance(raglite.__name__, str) 9 | -------------------------------------------------------------------------------- /tests/test_insert.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's document insertion.""" 2 | 3 | from pathlib import Path 4 | 5 | from sqlmodel import Session, select 6 | from tqdm import tqdm 7 | 8 | from raglite._config import RAGLiteConfig 9 | from raglite._database import Chunk, Document, create_database_engine 10 | from raglite._markdown import document_to_markdown 11 | 12 | 13 | def test_insert(raglite_test_config: RAGLiteConfig) -> None: 14 | """Test document insertion.""" 15 | # Open a session to extract document and chunks from the existing database. 16 | with Session(create_database_engine(raglite_test_config)) as session: 17 | # Get the first document from the database (already inserted by the fixture). 18 | document = session.exec(select(Document)).first() 19 | assert document is not None, "No document found in the database" 20 | # Get the existing chunks for this document. 21 | chunks = session.exec( 22 | select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.index) # type: ignore[arg-type] 23 | ).all() 24 | assert len(chunks) > 0, "No chunks found for the document" 25 | restored_document = "" 26 | for chunk in tqdm(chunks, desc="Processing chunks", leave=False): 27 | # Body should not contain the heading string (except if heading is empty). 28 | if chunk.headings.strip() != "": 29 | assert chunk.headings.strip() not in chunk.body.strip(), ( 30 | f"Chunk body contains heading: '{chunk.headings.strip()}'\n" 31 | f"Chunk body: '{chunk.body.strip()}'" 32 | ) 33 | # Body that starts with a # should not have a heading. 34 | if chunk.body.strip().startswith("# "): 35 | assert chunk.headings.strip() == "", ( 36 | f"Chunk body starts with a heading: '{chunk.body.strip()}'\n" 37 | f"Chunk headings: '{chunk.headings.strip()}'" 38 | ) 39 | restored_document += chunk.body 40 | # Combining the chunks should yield the original document. 41 | restored_document = restored_document.replace("\n", "").strip() 42 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. 43 | doc = document_to_markdown(doc_path) 44 | doc = doc.replace("\n", "").strip() 45 | assert restored_document == doc, "Restored document does not match the original input." 46 | -------------------------------------------------------------------------------- /tests/test_lazy_llama.py: -------------------------------------------------------------------------------- 1 | """Test that llama-cpp-python package is an optional dependency for RAGLite.""" 2 | 3 | import builtins 4 | import sys 5 | from typing import Any 6 | 7 | import pytest 8 | 9 | 10 | def test_raglite_import_without_llama_cpp(monkeypatch: pytest.MonkeyPatch) -> None: 11 | """Test that RAGLite can be imported without llama-cpp-python being available.""" 12 | # Unimport raglite and llama_cpp. 13 | module_names = list(sys.modules) 14 | for module_name in module_names: 15 | if module_name.startswith(("llama_cpp", "raglite", "sqlmodel")): 16 | monkeypatch.delitem(sys.modules, module_name, raising=False) 17 | 18 | # Save the original __import__ function. 19 | original_import = builtins.__import__ 20 | 21 | # Define a fake import function that raises ModuleNotFoundError when trying to import llama_cpp. 22 | def fake_import(name: str, *args: Any) -> Any: 23 | if name.startswith("llama_cpp"): 24 | import_error = f"No module named '{name}'" 25 | raise ModuleNotFoundError(import_error) 26 | return original_import(name, *args) 27 | 28 | # Monkey patch __import__ with the fake import function. 29 | monkeypatch.setattr(builtins, "__import__", fake_import) 30 | 31 | # Verify that importing raglite does not raise an error. 32 | import raglite # noqa: F401 33 | from raglite._config import llama_supports_gpu_offload # type: ignore[attr-defined] 34 | 35 | # Verify that lazily using llama-cpp-python raises a ModuleNotFoundError. 36 | with pytest.raises(ModuleNotFoundError, match="llama.cpp models"): 37 | llama_supports_gpu_offload() 38 | -------------------------------------------------------------------------------- /tests/test_markdown.py: -------------------------------------------------------------------------------- 1 | """Test Markdown conversion.""" 2 | 3 | from pathlib import Path 4 | 5 | from raglite._markdown import document_to_markdown 6 | 7 | 8 | def test_pdf_with_missing_font_sizes() -> None: 9 | """Test conversion of a PDF with missing font sizes.""" 10 | # Convert a PDF whose parsed font sizes are all equal to 1. 11 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. 12 | doc = document_to_markdown(doc_path) 13 | # Verify that we can reconstruct the font sizes and heading levels regardless of the missing 14 | # font size data. 15 | expected_heading = "# ON THE ELECTRODYNAMICS OF MOVING BODIES \n\n## By A. EINSTEIN June 30, 1905 \n\nIt is known that Maxwell" 16 | assert doc.startswith(expected_heading) 17 | -------------------------------------------------------------------------------- /tests/test_query_adapter.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's query adapter.""" 2 | 3 | from dataclasses import replace 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from raglite import RAGLiteConfig, insert_evals, update_query_adapter, vector_search 9 | from raglite._database import IndexMetadata 10 | 11 | 12 | @pytest.mark.slow 13 | def test_query_adapter(raglite_test_config: RAGLiteConfig) -> None: 14 | """Test the query adapter update functionality.""" 15 | # Create a config with and without the query adapter enabled. 16 | config_with_query_adapter = replace(raglite_test_config, vector_search_query_adapter=True) 17 | config_without_query_adapter = replace(raglite_test_config, vector_search_query_adapter=False) 18 | # Verify that there is no query adapter in the database. 19 | Q = IndexMetadata.get("default", config=config_without_query_adapter).get("query_adapter") # noqa: N806 20 | assert Q is None 21 | # Insert evals. 22 | insert_evals(num_evals=2, max_chunks_per_eval=10, config=config_with_query_adapter) 23 | # Update the query adapter. 24 | A = update_query_adapter(config=config_with_query_adapter) # noqa: N806 25 | assert isinstance(A, np.ndarray) 26 | assert A.ndim == 2 # noqa: PLR2004 27 | assert A.shape[0] == A.shape[1] 28 | assert np.isfinite(A).all() 29 | # Verify that there is a query adapter in the database. 30 | Q = IndexMetadata.get("default", config=config_without_query_adapter).get("query_adapter") # noqa: N806 31 | assert isinstance(Q, np.ndarray) 32 | assert Q.ndim == 2 # noqa: PLR2004 33 | assert Q.shape[0] == Q.shape[1] 34 | assert np.isfinite(Q).all() 35 | assert np.all(A == Q) 36 | # Verify that the query adapter affects the results of vector search. 37 | query = "How does Einstein define 'simultaneous events' in his special relativity paper?" 38 | _, scores_qa = vector_search(query, config=config_with_query_adapter) 39 | _, scores_no_qa = vector_search(query, config=config_without_query_adapter) 40 | assert scores_qa != scores_no_qa 41 | -------------------------------------------------------------------------------- /tests/test_rag.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's RAG functionality.""" 2 | 3 | import json 4 | 5 | from raglite import ( 6 | RAGLiteConfig, 7 | add_context, 8 | retrieve_context, 9 | ) 10 | from raglite._database import ChunkSpan 11 | from raglite._rag import rag 12 | 13 | 14 | def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: 15 | """Test Retrieval-Augmented Generation with manual retrieval.""" 16 | # Answer a question with manual RAG. 17 | user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" 18 | chunk_spans = retrieve_context(query=user_prompt, config=raglite_test_config) 19 | messages = [add_context(user_prompt, context=chunk_spans)] 20 | stream = rag(messages, config=raglite_test_config) 21 | answer = "" 22 | for update in stream: 23 | assert isinstance(update, str) 24 | answer += update 25 | assert "event" in answer.lower() 26 | # Verify that no RAG context was retrieved through tool use. 27 | assert [message["role"] for message in messages] == ["user", "assistant"] 28 | 29 | 30 | def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: 31 | """Test Retrieval-Augmented Generation with automatic retrieval.""" 32 | # Answer a question that requires RAG. 33 | user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" 34 | messages = [{"role": "user", "content": user_prompt}] 35 | chunk_spans = [] 36 | stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config) 37 | answer = "" 38 | for update in stream: 39 | assert isinstance(update, str) 40 | answer += update 41 | assert "event" in answer.lower() 42 | # Verify that RAG context was retrieved automatically. 43 | assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] 44 | assert json.loads(messages[-2]["content"]) 45 | assert chunk_spans 46 | assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) 47 | 48 | 49 | def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: 50 | """Test Retrieval-Augmented Generation with automatic retrieval.""" 51 | # Answer a question that does not require RAG. 52 | user_prompt = "Is 7 a prime number?" 53 | messages = [{"role": "user", "content": user_prompt}] 54 | chunk_spans = [] 55 | stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config) 56 | answer = "" 57 | for update in stream: 58 | assert isinstance(update, str) 59 | answer += update 60 | # Verify that no RAG context was retrieved. 61 | assert [message["role"] for message in messages] == ["user", "assistant"] 62 | assert not chunk_spans 63 | -------------------------------------------------------------------------------- /tests/test_rerank.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's reranking functionality.""" 2 | 3 | import random 4 | from typing import TypeVar 5 | 6 | import pytest 7 | from rerankers.models.flashrank_ranker import FlashRankRanker 8 | from rerankers.models.ranker import BaseRanker 9 | from scipy.stats import kendalltau 10 | 11 | from raglite import RAGLiteConfig, rerank_chunks, retrieve_chunks, vector_search 12 | from raglite._database import Chunk 13 | 14 | T = TypeVar("T") 15 | 16 | 17 | def kendall_tau(a: list[T], b: list[T]) -> float: 18 | """Measure the Kendall rank correlation coefficient between two lists.""" 19 | τ: float = kendalltau(range(len(a)), [a.index(el) for el in b])[0] # noqa: PLC2401 20 | return τ 21 | 22 | 23 | @pytest.fixture( 24 | params=[ 25 | pytest.param(None, id="no_reranker"), 26 | pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"), 27 | pytest.param( 28 | { 29 | "en": FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), 30 | "other": FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0), 31 | }, 32 | id="flashrank_multilingual", 33 | ), 34 | ], 35 | ) 36 | def reranker( 37 | request: pytest.FixtureRequest, 38 | ) -> BaseRanker | dict[str, BaseRanker] | None: 39 | """Get a reranker to test RAGLite with.""" 40 | reranker: BaseRanker | dict[str, BaseRanker] | None = request.param 41 | return reranker 42 | 43 | 44 | def test_reranker( 45 | raglite_test_config: RAGLiteConfig, 46 | reranker: BaseRanker | dict[str, BaseRanker] | None, 47 | ) -> None: 48 | """Test inserting a document, updating the indexes, and searching for a query.""" 49 | # Update the config with the reranker. 50 | raglite_test_config = RAGLiteConfig( 51 | db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker 52 | ) 53 | # Search for a query. 54 | query = "What does it mean for two events to be simultaneous?" 55 | chunk_ids, _ = vector_search(query, num_results=40, config=raglite_test_config) 56 | # Retrieve the chunks. 57 | chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) 58 | assert all(isinstance(chunk, Chunk) for chunk in chunks) 59 | assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True)) 60 | # Randomly shuffle the chunks. 61 | random.seed(42) 62 | chunks_random = random.sample(chunks, len(chunks)) 63 | # Rerank the chunks starting from a pathological order and verify that it improves the ranking. 64 | for arg in (chunks[::-1], chunk_ids[::-1]): 65 | reranked_chunks = rerank_chunks(query, arg, config=raglite_test_config) 66 | if reranker: 67 | τ_search = kendall_tau(chunks, reranked_chunks) # noqa: PLC2401 68 | τ_inverse = kendall_tau(chunks[::-1], reranked_chunks) # noqa: PLC2401 69 | τ_random = kendall_tau(chunks_random, reranked_chunks) # noqa: PLC2401 70 | assert τ_search >= τ_random >= τ_inverse 71 | -------------------------------------------------------------------------------- /tests/test_search.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's search functionality.""" 2 | 3 | import pytest 4 | 5 | from raglite import ( 6 | RAGLiteConfig, 7 | hybrid_search, 8 | keyword_search, 9 | retrieve_chunk_spans, 10 | retrieve_chunks, 11 | vector_search, 12 | ) 13 | from raglite._database import Chunk, ChunkSpan, Document 14 | from raglite._typing import BasicSearchMethod 15 | 16 | 17 | @pytest.fixture( 18 | params=[ 19 | pytest.param(keyword_search, id="keyword_search"), 20 | pytest.param(vector_search, id="vector_search"), 21 | pytest.param(hybrid_search, id="hybrid_search"), 22 | ], 23 | ) 24 | def search_method( 25 | request: pytest.FixtureRequest, 26 | ) -> BasicSearchMethod: 27 | """Get a search method to test RAGLite with.""" 28 | search_method: BasicSearchMethod = request.param 29 | return search_method 30 | 31 | 32 | def test_search(raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod) -> None: 33 | """Test searching for a query.""" 34 | # Search for a query. 35 | query = "What does it mean for two events to be simultaneous?" 36 | num_results = 5 37 | chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config) 38 | assert len(chunk_ids) == len(scores) == num_results 39 | assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) 40 | assert all(isinstance(score, float) for score in scores) 41 | # Retrieve the chunks. 42 | chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) 43 | assert all(isinstance(chunk, Chunk) for chunk in chunks) 44 | assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True)) 45 | assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks), ( 46 | "Expected 'Definition of Simultaneity' in chunks but got:\n" 47 | + "\n".join(f"- Chunk {i + 1}:\n{chunk!s}\n{'-' * 80}" for i, chunk in enumerate(chunks)) 48 | ) 49 | assert all(isinstance(chunk.document, Document) for chunk in chunks) 50 | # Extend the chunks with their neighbours and group them into contiguous segments. 51 | chunk_spans = retrieve_chunk_spans(chunk_ids, neighbors=(-1, 1), config=raglite_test_config) 52 | assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) 53 | assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans) 54 | chunk_spans = retrieve_chunk_spans(chunks, neighbors=(-1, 1), config=raglite_test_config) 55 | assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) 56 | assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans) 57 | 58 | 59 | def test_search_no_results( 60 | raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod 61 | ) -> None: 62 | """Test searching for a query with no keyword search results.""" 63 | query = "supercalifragilisticexpialidocious" 64 | num_results = 5 65 | chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config) 66 | num_results_expected = 0 if search_method == keyword_search else num_results 67 | assert len(chunk_ids) == len(scores) == num_results_expected 68 | assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) 69 | assert all(isinstance(score, float) for score in scores) 70 | 71 | 72 | def test_search_empty_database(llm: str, embedder: str, search_method: BasicSearchMethod) -> None: 73 | """Test searching for a query with an empty database.""" 74 | raglite_test_config = RAGLiteConfig(db_url="duckdb:///:memory:", llm=llm, embedder=embedder) 75 | query = "supercalifragilisticexpialidocious" 76 | num_results = 5 77 | chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config) 78 | num_results_expected = 0 79 | assert len(chunk_ids) == len(scores) == num_results_expected 80 | assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) 81 | assert all(isinstance(score, float) for score in scores) 82 | -------------------------------------------------------------------------------- /tests/test_split_chunklets.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's chunk splitting functionality.""" 2 | 3 | import pytest 4 | 5 | from raglite._split_chunklets import split_chunklets 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "sentences_splits", 10 | [ 11 | pytest.param( 12 | ( 13 | [ 14 | # Sentence 1: 15 | "It is known that Maxwell’s electrodynamics—as usually understood at the\n" # noqa: RUF001 16 | "present time—when applied to moving bodies, leads to asymmetries which do\n\n" 17 | "not appear to be inherent in the phenomena. ", 18 | # Sentence 2: 19 | "Take, for example, the recipro-\ncal electrodynamic action of a magnet and a conductor. \n\n", 20 | # Sentence 3 (heading): 21 | "# ON THE ELECTRODYNAMICS OF MOVING BODIES\n\n", 22 | # Sentence 4 (heading): 23 | "## By A. EINSTEIN June 30, 1905\n\n", 24 | # Sentence 5 (paragraph boundary): 25 | "The observable phe-\n" 26 | "nomenon here depends only on the relative motion of the conductor and the\n" 27 | "magnet, whereas the customary view draws a sharp distinction between the two\n" 28 | "cases in which either the one or the other of these bodies is in motion. ", 29 | # Sentence 6: 30 | "For if the\n" 31 | "magnet is in motion and the conductor at rest, there arises in the neighbour-\n" 32 | "hood of the magnet an electric field with a certain definite energy, producing\n" 33 | "a current at the places where parts of the conductor are situated. ", 34 | ], 35 | [2], 36 | ), 37 | id="consecutive_boundaries", 38 | ), 39 | pytest.param( 40 | ( 41 | [ 42 | # Sentence 1: 43 | "The theory to be developed is based—like all electrodynamics—on the kine-\n" 44 | "matics of the rigid body, since the assertions of any such theory have to do\n" 45 | "with the relationships between rigid bodies (systems of co-ordinates), clocks,\n" 46 | "and electromagnetic processes. ", 47 | # Sentence 2: 48 | "Insufficient consideration of this circumstance\n" 49 | "lies at the root of the difficulties which the electrodynamics of moving bodies\n" 50 | "at present encounters.\n\n", 51 | # Sentence 3 (paragraph boundary): 52 | "The observable phe-\n" 53 | "nomenon here depends only on the relative motion of the conductor and the\n" 54 | "magnet, whereas the customary view draws a sharp distinction between the two\n" 55 | "cases in which either the one or the other of these bodies is in motion. ", 56 | # Sentence 4: 57 | "For if the\n" 58 | "magnet is in motion and the conductor at rest, there arises in the neighbour-\n" 59 | "hood of the magnet an electric field with a certain definite energy, producing\n" 60 | "a current at the places where parts of the conductor are situated. ", 61 | # Sentence 5: 62 | "But if the\n" 63 | "magnet is stationary and the conductor in motion, no electric field arises in the\n" 64 | "neighbourhood of the magnet. ", 65 | ], 66 | [2], 67 | ), 68 | id="paragraph_boundary", 69 | ), 70 | ], 71 | ) 72 | def test_split_chunklets(sentences_splits: tuple[list[str], list[int]]) -> None: 73 | """Test chunklet splitting.""" 74 | sentences, splits = sentences_splits 75 | chunklets = split_chunklets(sentences) 76 | expected_chunklets = [ 77 | "".join(sentences[i:j]) 78 | for i, j in zip([0, *splits], [*splits, len(sentences)], strict=True) 79 | ] 80 | assert isinstance(chunklets, list) 81 | assert all(isinstance(chunklet, str) for chunklet in chunklets) 82 | assert sum(len(chunklet) for chunklet in chunklets) == sum( 83 | len(sentence) for sentence in sentences 84 | ) 85 | assert all( 86 | chunklet == expected_chunklet 87 | for chunklet, expected_chunklet in zip(chunklets, expected_chunklets, strict=True) 88 | ) 89 | -------------------------------------------------------------------------------- /tests/test_split_chunks.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's chunk splitting functionality.""" 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from raglite._split_chunks import split_chunks 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "chunklets", 11 | [ 12 | pytest.param([], id="one_chunk:no_chunklets"), 13 | pytest.param(["Hello world"], id="one_chunk:one_chunklet"), 14 | pytest.param(["Hello world"] * 2, id="one_chunk:two_chunklets"), 15 | pytest.param(["Hello world"] * 3, id="one_chunk:three_chunklets"), 16 | pytest.param(["Hello world"] * 100, id="one_chunk:many_chunklets"), 17 | pytest.param(["Hello world", "X" * 1000], id="n_chunks:two_chunklets_a"), 18 | pytest.param(["X" * 1000, "Hello world"], id="n_chunks:two_chunklets_b"), 19 | pytest.param(["Hello world", "X" * 1000, "X" * 1000], id="n_chunks:three_chunklets_a"), 20 | pytest.param(["X" * 1000, "Hello world", "X" * 1000], id="n_chunks:three_chunklets_b"), 21 | pytest.param(["X" * 1000, "X" * 1000, "Hello world"], id="n_chunks:three_chunklets_c"), 22 | pytest.param(["X" * 1000] * 100, id="n_chunks:many_chunklets_a"), 23 | pytest.param(["X" * 100] * 1000, id="n_chunks:many_chunklets_b"), 24 | ], 25 | ) 26 | def test_edge_cases(chunklets: list[str]) -> None: 27 | """Test chunk splitting edge cases.""" 28 | chunklet_embeddings = np.ones((len(chunklets), 768)).astype(np.float16) 29 | chunks, chunk_embeddings = split_chunks(chunklets, chunklet_embeddings, max_size=1440) 30 | assert isinstance(chunks, list) 31 | assert isinstance(chunk_embeddings, list) 32 | assert len(chunk_embeddings) == (len(chunks) if chunklets else 1) 33 | assert all(isinstance(chunk, str) for chunk in chunks) 34 | assert all(isinstance(chunk_embedding, np.ndarray) for chunk_embedding in chunk_embeddings) 35 | assert all(ce.dtype == chunklet_embeddings.dtype for ce in chunk_embeddings) 36 | assert sum(ce.shape[0] for ce in chunk_embeddings) == chunklet_embeddings.shape[0] 37 | assert all(ce.shape[1] == chunklet_embeddings.shape[1] for ce in chunk_embeddings) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "chunklets", 42 | [ 43 | pytest.param(["Hello world" * 1000] + ["X"] * 100, id="first"), 44 | pytest.param(["X"] * 50 + ["Hello world" * 1000] + ["X"] * 50, id="middle"), 45 | pytest.param(["X"] * 100 + ["Hello world" * 1000], id="last"), 46 | ], 47 | ) 48 | def test_long_chunklet(chunklets: list[str]) -> None: 49 | """Test chunking on chunklets that are too long.""" 50 | chunklet_embeddings = np.ones((len(chunklets), 768)).astype(np.float16) 51 | with pytest.raises(ValueError, match="Chunklet larger than chunk max_size detected."): 52 | _ = split_chunks(chunklets, chunklet_embeddings, max_size=1440) 53 | -------------------------------------------------------------------------------- /tests/test_split_sentences.py: -------------------------------------------------------------------------------- 1 | """Test RAGLite's sentence splitting functionality.""" 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from raglite._markdown import document_to_markdown 8 | from raglite._split_sentences import split_sentences 9 | 10 | 11 | def test_split_sentences() -> None: 12 | """Test splitting a document into sentences.""" 13 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. 14 | doc = document_to_markdown(doc_path) 15 | sentences = split_sentences(doc) 16 | expected_sentences = [ 17 | "# ON THE ELECTRODYNAMICS OF MOVING BODIES \n\n", 18 | "## By A. EINSTEIN June 30, 1905 \n\n", 19 | "It is known that Maxwell’s electrodynamics—as usually understood at the\npresent time—when applied to moving bodies, leads to asymmetries which do\n\nnot appear to be inherent in the phenomena. ", # noqa: RUF001 20 | "Take, for example, the recipro-\ncal electrodynamic action of a magnet and a conductor. ", 21 | "The observable phe-\nnomenon here depends only on the relative motion of the conductor and the\nmagnet, whereas the customary view draws a sharp distinction between the two\ncases in which either the one or the other of these bodies is in motion. ", 22 | "For if the\nmagnet is in motion and the conductor at rest, there arises in the neighbour-\nhood of the magnet an electric field with a certain definite energy, producing\na current at the places where parts of the conductor are situated. ", 23 | "But if the\n\nmagnet is stationary and the conductor in motion, no electric field arises in the\nneighbourhood of the magnet. ", 24 | "In the conductor, however, we find an electro-\nmotive force, to which in itself there is no corresponding energy, but which gives\nrise—assuming equality of relative motion in the two cases discussed—to elec-\n\ntric currents of the same path and intensity as those produced by the electric\nforces in the former case.\n\n", 25 | "Examples of this sort, together with the unsuccessful attempts to discover\nany motion of the earth relatively to the “light medium,” suggest that the\n\nphenomena of electrodynamics as well as of mechanics possess no properties\ncorresponding to the idea of absolute rest. ", 26 | "They suggest rather that, as has\nalready been shown to the first order of small quantities, the same laws of\nelectrodynamics and optics will be valid for all frames of reference for which the\nequations of mechanics hold good.1 ", 27 | "We will raise this conjecture (the purport\nof which will hereafter be called the “Principle of Relativity”) to the status\n\nof a postulate, and also introduce another postulate, which is only apparently\nirreconcilable with the former, namely, that light is always propagated in empty\nspace with a definite velocity c which is independent of the state of motion of the\nemitting body. ", 28 | "These two postulates suffice for the attainment of a simple and\nconsistent theory of the electrodynamics of moving bodies based on Maxwell’s\ntheory for stationary bodies. ", # noqa: RUF001 29 | "The introduction of a “luminiferous ether” will\nprove to be superfluous inasmuch as the view here to be developed will not\nrequire an “absolutely stationary space” provided with special properties, nor\n1", 30 | "The preceding memoir by Lorentz was not at this time known to the author.\n\n", 31 | "assign a velocity-vector to a point of the empty space in which electromagnetic\nprocesses take place.\n\n", 32 | "The theory to be developed is based—like all electrodynamics—on the kine-\nmatics of the rigid body, since the assertions of any such theory have to do\nwith the relationships between rigid bodies (systems of co-ordinates), clocks,\nand electromagnetic processes. ", 33 | "Insufficient consideration of this circumstance\nlies at the root of the difficulties which the electrodynamics of moving bodies\nat present encounters.\n\n", 34 | "## I. KINEMATICAL PART § **1. Definition of Simultaneity** \n\n", 35 | "Let us take a system of co-ordinates in which the equations of Newtonian\nmechanics hold good.2 ", 36 | ] 37 | assert isinstance(sentences, list) 38 | assert all(not sentence.isspace() for sentence in sentences) 39 | assert all( 40 | sentence == expected_sentence 41 | for sentence, expected_sentence in zip( 42 | sentences[: len(expected_sentences)], expected_sentences, strict=True 43 | ) 44 | ) 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "case", 49 | [ 50 | pytest.param(("", [""], (4, None)), id="tiny-0"), 51 | pytest.param(("Hi!", ["Hi!"], (4, None)), id="tiny-1a"), 52 | pytest.param(("Yes? No!", ["Yes? No!"], (4, None)), id="tiny-1b"), 53 | pytest.param(("Yes? No!", ["Yes? ", "No!"], (3, None)), id="tiny-2a"), 54 | pytest.param(("\n\nYes?\n\nNo!\n\n", ["\n\nYes?\n\n", "No!\n\n"], (3, None)), id="tiny-2b"), 55 | pytest.param( 56 | ("X" * 768 + "\n\n" + "X" * 768, ["X" * 768 + "\n\n" + "X" * 768], (4, None)), 57 | id="huge-1", 58 | ), 59 | pytest.param( 60 | ("X" * 768 + "\n\n" + "X" * 768, ["X" * 768 + "\n\n", "X" * 768], (4, 1024)), 61 | id="huge-2a", 62 | ), 63 | pytest.param( 64 | ("X" * 768 + " " + "X" * 768, ["X" * 768 + " ", "X" * 768], (4, 1024)), 65 | id="huge-2b", 66 | ), 67 | ], 68 | ) 69 | def test_split_sentences_edge_cases(case: tuple[str, list[str], tuple[int, int | None]]) -> None: 70 | """Test edge cases of splitting a document into sentences.""" 71 | doc, expected_sentences, (min_len, max_len) = case 72 | sentences = split_sentences(doc, min_len=min_len, max_len=max_len) 73 | assert isinstance(sentences, list) 74 | assert all( 75 | sentence == expected_sentence 76 | for sentence, expected_sentence in zip( 77 | sentences[: len(expected_sentences)], expected_sentences, strict=True 78 | ) 79 | ) 80 | --------------------------------------------------------------------------------