├── .copier-answers.yml
├── .devcontainer
└── devcontainer.json
├── .dockerignore
├── .github
├── dependabot.yml
└── workflows
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── Dockerfile
├── LICENSE
├── README.md
├── docker-compose.yml
├── pyproject.toml
├── src
└── raglite
│ ├── __init__.py
│ ├── _bench.py
│ ├── _chainlit.py
│ ├── _chatml_function_calling.py
│ ├── _cli.py
│ ├── _config.py
│ ├── _database.py
│ ├── _embed.py
│ ├── _eval.py
│ ├── _extract.py
│ ├── _insert.py
│ ├── _lazy_llama.py
│ ├── _litellm.py
│ ├── _markdown.py
│ ├── _mcp.py
│ ├── _query_adapter.py
│ ├── _rag.py
│ ├── _search.py
│ ├── _split_chunklets.py
│ ├── _split_chunks.py
│ ├── _split_sentences.py
│ ├── _typing.py
│ └── py.typed
└── tests
├── __init__.py
├── conftest.py
├── specrel.pdf
├── test_chatml_function_calling.py
├── test_database.py
├── test_embed.py
├── test_extract.py
├── test_import.py
├── test_insert.py
├── test_lazy_llama.py
├── test_markdown.py
├── test_query_adapter.py
├── test_rag.py
├── test_rerank.py
├── test_search.py
├── test_split_chunklets.py
├── test_split_chunks.py
└── test_split_sentences.py
/.copier-answers.yml:
--------------------------------------------------------------------------------
1 | _commit: v1.5.1
2 | _src_path: gh:superlinear-ai/substrate
3 | author_email: laurent@superlinear.eu
4 | author_name: Laurent Sorber
5 | project_description: A Python toolkit for Retrieval-Augmented Generation (RAG) with
6 | DuckDB or PostgreSQL.
7 | project_name: raglite
8 | project_type: package
9 | project_url: https://github.com/superlinear-ai/raglite
10 | python_version: '3.10'
11 | typing: strict
12 | with_conventional_commits: true
13 | with_typer_cli: true
14 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "raglite",
3 | "dockerComposeFile": "../docker-compose.yml",
4 | "service": "devcontainer",
5 | "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}/",
6 | "features": {
7 | "ghcr.io/devcontainers-extra/features/starship:1": {}
8 | },
9 | "overrideCommand": true,
10 | "remoteUser": "user",
11 | "postStartCommand": "sudo chown -R user:user /opt/ && uv sync --python ${localEnv:PYTHON_VERSION:3.10} --resolution ${localEnv:RESOLUTION_STRATEGY:highest} --all-extras && pre-commit install --install-hooks",
12 | "customizations": {
13 | "jetbrains": {
14 | "backend": "PyCharm",
15 | "plugins": [
16 | "com.github.copilot"
17 | ]
18 | },
19 | "vscode": {
20 | "extensions": [
21 | "charliermarsh.ruff",
22 | "GitHub.copilot",
23 | "GitHub.copilot-chat",
24 | "GitHub.vscode-github-actions",
25 | "GitHub.vscode-pull-request-github",
26 | "ms-azuretools.vscode-docker",
27 | "ms-python.mypy-type-checker",
28 | "ms-python.python",
29 | "ms-toolsai.jupyter",
30 | "ryanluker.vscode-coverage-gutters",
31 | "tamasfe.even-better-toml",
32 | "visualstudioexptteam.vscodeintellicode"
33 | ],
34 | "settings": {
35 | "coverage-gutters.coverageFileNames": [
36 | "reports/coverage.xml"
37 | ],
38 | "editor.codeActionsOnSave": {
39 | "source.fixAll": "explicit",
40 | "source.organizeImports": "explicit"
41 | },
42 | "editor.formatOnSave": true,
43 | "[python]": {
44 | "editor.defaultFormatter": "charliermarsh.ruff"
45 | },
46 | "[toml]": {
47 | "editor.formatOnSave": false
48 | },
49 | "editor.rulers": [
50 | 100
51 | ],
52 | "files.autoSave": "onFocusChange",
53 | "github.copilot.chat.agent.enabled": true,
54 | "github.copilot.chat.codesearch.enabled": true,
55 | "github.copilot.chat.edits.enabled": true,
56 | "github.copilot.nextEditSuggestions.enabled": true,
57 | "jupyter.kernels.excludePythonEnvironments": [
58 | "/usr/local/bin/python"
59 | ],
60 | "mypy-type-checker.importStrategy": "fromEnvironment",
61 | "mypy-type-checker.preferDaemon": true,
62 | "notebook.codeActionsOnSave": {
63 | "notebook.source.fixAll": "explicit",
64 | "notebook.source.organizeImports": "explicit"
65 | },
66 | "notebook.formatOnSave.enabled": true,
67 | "python.defaultInterpreterPath": "/opt/venv/bin/python",
68 | "python.terminal.activateEnvironment": false,
69 | "python.testing.pytestEnabled": true,
70 | "ruff.importStrategy": "fromEnvironment",
71 | "ruff.logLevel": "warning",
72 | "terminal.integrated.env.linux": {
73 | "GIT_EDITOR": "code --wait"
74 | },
75 | "terminal.integrated.env.mac": {
76 | "GIT_EDITOR": "code --wait"
77 | }
78 | }
79 | }
80 | }
81 | }
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Caches
2 | .*_cache/
3 |
4 | # Git
5 | .git/
6 |
7 | # Python
8 | .venv/
9 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | updates:
4 | - package-ecosystem: github-actions
5 | directory: /
6 | schedule:
7 | interval: monthly
8 | commit-message:
9 | prefix: "ci"
10 | prefix-development: "ci"
11 | include: scope
12 | groups:
13 | ci-dependencies:
14 | patterns:
15 | - "*"
16 | - package-ecosystem: pip
17 | directory: /
18 | schedule:
19 | interval: monthly
20 | commit-message:
21 | prefix: "chore"
22 | prefix-development: "build"
23 | include: scope
24 | allow:
25 | - dependency-type: development
26 | versioning-strategy: increase
27 | groups:
28 | development-dependencies:
29 | dependency-type: development
30 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | release:
5 | types:
6 | - created
7 |
8 | jobs:
9 | publish:
10 | runs-on: ubuntu-latest
11 | environment: pypi
12 | permissions:
13 | id-token: write
14 |
15 | steps:
16 | - name: Checkout
17 | uses: actions/checkout@v4
18 |
19 | - name: Install uv
20 | uses: astral-sh/setup-uv@v5
21 |
22 | - name: Publish package
23 | run: |
24 | uv build
25 | uv publish
26 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | pull_request:
9 |
10 | jobs:
11 | test:
12 | runs-on:
13 | group: raglite
14 |
15 | strategy:
16 | fail-fast: false
17 | matrix:
18 | python-version: ["3.10", "3.12"]
19 | resolution-strategy: ["highest", "lowest-direct"]
20 |
21 | name: Python ${{ matrix.python-version }} (resolution=${{ matrix.resolution-strategy }})
22 |
23 | steps:
24 | - name: Checkout
25 | uses: actions/checkout@v4
26 |
27 | - name: Set up Node.js
28 | uses: actions/setup-node@v4
29 | with:
30 | node-version: 23
31 |
32 | - name: Install @devcontainers/cli
33 | run: npm install --location=global @devcontainers/cli@0.73.0
34 |
35 | - name: Start Dev Container
36 | run: |
37 | git config --global init.defaultBranch main
38 | devcontainer up --workspace-folder .
39 | env:
40 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
41 | PYTHON_VERSION: ${{ matrix.python-version }}
42 | RESOLUTION_STRATEGY: ${{ matrix.resolution-strategy }}
43 |
44 | - name: Lint package
45 | run: devcontainer exec --workspace-folder . poe lint
46 |
47 | - name: Test package
48 | run: devcontainer exec --workspace-folder . poe test
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Chainlit
2 | .chainlit/
3 | .files/
4 | chainlit.md
5 |
6 | # Coverage.py
7 | htmlcov/
8 | reports/
9 |
10 | # Copier
11 | *.rej
12 |
13 | # Data
14 | *.csv*
15 | *.dat*
16 | *.pickle*
17 | *.xls*
18 | *.zip*
19 | data/
20 |
21 | # direnv
22 | .envrc
23 |
24 | # dotenv
25 | .env
26 |
27 | # rerankers
28 | .*_cache/
29 |
30 | # Hypothesis
31 | .hypothesis/
32 |
33 | # Jupyter
34 | *.ipynb
35 | .ipynb_checkpoints/
36 | notebooks/
37 |
38 | # macOS
39 | .DS_Store
40 |
41 | # mise
42 | mise.local.toml
43 |
44 | # mypy
45 | .dmypy.json
46 | .mypy_cache/
47 |
48 | # Node.js
49 | node_modules/
50 |
51 | # PyCharm
52 | .idea/
53 |
54 | # pyenv
55 | .python-version
56 |
57 | # pytest
58 | .pytest_cache/
59 |
60 | # Python
61 | __pycache__/
62 | *.egg-info/
63 | *.py[cdo]
64 | .venv/
65 | dist/
66 |
67 | # RAGLite
68 | *.db
69 |
70 | # Ruff
71 | .ruff_cache/
72 |
73 | # Terraform
74 | .terraform/
75 |
76 | # uv
77 | uv.lock
78 |
79 | # VS Code
80 | .vscode/
81 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # https://pre-commit.com
2 | default_install_hook_types: [commit-msg, pre-commit]
3 | default_stages: [pre-commit, manual]
4 | fail_fast: true
5 | repos:
6 | - repo: meta
7 | hooks:
8 | - id: check-useless-excludes
9 | - repo: https://github.com/pre-commit/pygrep-hooks
10 | rev: v1.10.0
11 | hooks:
12 | - id: python-check-mock-methods
13 | - id: python-use-type-annotations
14 | - id: rst-backticks
15 | - id: rst-directive-colons
16 | - id: rst-inline-touching-normal
17 | - id: text-unicode-replacement-char
18 | - repo: https://github.com/pre-commit/pre-commit-hooks
19 | rev: v5.0.0
20 | hooks:
21 | - id: check-added-large-files
22 | - id: check-ast
23 | - id: check-builtin-literals
24 | - id: check-case-conflict
25 | - id: check-docstring-first
26 | - id: check-illegal-windows-names
27 | - id: check-json
28 | - id: check-merge-conflict
29 | - id: check-shebang-scripts-are-executable
30 | - id: check-symlinks
31 | - id: check-toml
32 | - id: check-vcs-permalinks
33 | - id: check-xml
34 | - id: check-yaml
35 | - id: debug-statements
36 | - id: destroyed-symlinks
37 | - id: detect-private-key
38 | - id: end-of-file-fixer
39 | types: [python]
40 | - id: fix-byte-order-marker
41 | - id: mixed-line-ending
42 | - id: name-tests-test
43 | args: [--pytest-test-first]
44 | - id: trailing-whitespace
45 | types: [python]
46 | - repo: local
47 | hooks:
48 | - id: commitizen
49 | name: commitizen
50 | entry: cz check
51 | args: [--commit-msg-file]
52 | require_serial: true
53 | language: system
54 | stages: [commit-msg]
55 | - id: ruff-check
56 | name: ruff check
57 | entry: ruff check
58 | args: ["--force-exclude", "--extend-fixable=ERA001,F401,F841,T201,T203"]
59 | require_serial: true
60 | language: system
61 | types_or: [python, pyi]
62 | - id: ruff-format
63 | name: ruff format
64 | entry: ruff format
65 | args: [--force-exclude]
66 | require_serial: true
67 | language: system
68 | types_or: [python, pyi]
69 | - id: mypy
70 | name: mypy
71 | entry: mypy
72 | language: system
73 | types: [python]
74 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## v0.7.0 (2025-03-17)
2 |
3 | ### Feat
4 |
5 | - replace post-processing with declarative optimization (#112)
6 | - compute optimal sentence boundaries (#110)
7 | - migrate from poetry-cookiecutter to substrate (#98)
8 | - make llama-cpp-python an optional dependency (#97)
9 | - add ability to directly insert Markdown content into the database (#96)
10 | - make importing faster (#86)
11 |
12 | ### Fix
13 |
14 | - fix CLI entrypoint regression (#111)
15 | - lazily raise module not found for optional deps (#109)
16 | - revert pandoc extra name (#106)
17 | - avoid conflicting chunk ids (#93)
18 |
19 | ## v0.6.2 (2025-01-06)
20 |
21 | ### Fix
22 |
23 | - remove unnecessary stop sequence (#84)
24 |
25 | ## v0.6.1 (2025-01-06)
26 |
27 | ### Fix
28 |
29 | - fix Markdown heading boundary probas (#81)
30 | - improve (re)insertion speed (#80)
31 | - **deps**: exclude litellm versions that break get_model_info (#78)
32 | - conditionally enable `LlamaRAMCache` (#83)
33 |
34 | ## v0.6.0 (2025-01-05)
35 |
36 | ### Feat
37 |
38 | - add support for Python 3.12 (#69)
39 | - upgrade from xx_sent_ud_sm to SaT (#74)
40 | - add streaming tool use to llama-cpp-python (#71)
41 | - improve sentence splitting (#72)
42 |
43 | ## v0.5.1 (2024-12-18)
44 |
45 | ### Fix
46 |
47 | - improve output for empty databases (#68)
48 |
49 | ## v0.5.0 (2024-12-17)
50 |
51 | ### Feat
52 |
53 | - add MCP server (#67)
54 | - let LLM choose whether to retrieve context (#62)
55 |
56 | ### Fix
57 |
58 | - support pgvector v0.7.0+ (#63)
59 |
60 | ## v0.4.1 (2024-12-05)
61 |
62 | ### Fix
63 |
64 | - add and enable OpenAI strict mode (#55)
65 | - support embedding with LiteLLM for Ragas (#56)
66 |
67 | ## v0.4.0 (2024-12-04)
68 |
69 | ### Feat
70 |
71 | - improve late chunking and optimize pgvector settings (#51)
72 |
73 | ## v0.3.0 (2024-12-03)
74 |
75 | ### Feat
76 |
77 | - support prompt caching and apply Anthropic's long-context prompt format (#52)
78 |
79 | ## v0.2.1 (2024-11-22)
80 |
81 | ### Fix
82 |
83 | - improve structured output extraction and query adapter updates (#34)
84 | - upgrade rerankers and remove flashrank patch (#47)
85 | - improve unpacking of keyword search results (#46)
86 | - add fallbacks for model info (#44)
87 |
88 | ## v0.2.0 (2024-10-21)
89 |
90 | ### Feat
91 |
92 | - add Chainlit frontend (#33)
93 |
94 | ## v0.1.4 (2024-10-15)
95 |
96 | ### Fix
97 |
98 | - fix optimal chunking edge cases (#32)
99 |
100 | ## v0.1.3 (2024-10-13)
101 |
102 | ### Fix
103 |
104 | - upgrade pdftext (#30)
105 | - improve chunk and segment ordering (#29)
106 |
107 | ## v0.1.2 (2024-10-08)
108 |
109 | ### Fix
110 |
111 | - avoid pdftext v0.3.11 (#27)
112 |
113 | ## v0.1.1 (2024-10-07)
114 |
115 | ### Fix
116 |
117 | - patch rerankers flashrank issue (#22)
118 |
119 | ## v0.1.0 (2024-10-07)
120 |
121 | ### Feat
122 |
123 | - add reranking (#20)
124 | - add LiteLLM and late chunking (#19)
125 | - add PostgreSQL support (#18)
126 | - make query adapter minimally invasive (#16)
127 | - upgrade default CPU model to Phi-3.5-mini (#15)
128 | - add evaluation (#14)
129 | - infer missing font sizes (#12)
130 | - automatically adjust number of RAG contexts (#10)
131 | - improve exception feedback for extraction (#9)
132 | - optimize config for CPU and GPU (#7)
133 | - simplify document insertion (#6)
134 | - implement basic features (#2)
135 | - initial commit
136 |
137 | ### Fix
138 |
139 | - lazily import optional dependencies (#11)
140 | - improve indexing of multiple documents (#8)
141 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:1
2 | FROM ghcr.io/astral-sh/uv:python3.10-bookworm AS dev
3 |
4 | # Create and activate a virtual environment [1].
5 | # [1] https://docs.astral.sh/uv/concepts/projects/config/#project-environment-path
6 | ENV VIRTUAL_ENV=/opt/venv
7 | ENV PATH=$VIRTUAL_ENV/bin:$PATH
8 | ENV UV_PROJECT_ENVIRONMENT=$VIRTUAL_ENV
9 |
10 | # Tell Git that the workspace is safe to avoid 'detected dubious ownership in repository' warnings.
11 | RUN git config --system --add safe.directory '*'
12 |
13 | # Create a non-root user and give it passwordless sudo access [1].
14 | # [1] https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user
15 | RUN --mount=type=cache,target=/var/cache/apt/ \
16 | --mount=type=cache,target=/var/lib/apt/ \
17 | groupadd --gid 1000 user && \
18 | useradd --create-home --no-log-init --gid 1000 --uid 1000 --shell /usr/bin/bash user && \
19 | chown user:user /opt/ && \
20 | apt-get update && apt-get install --no-install-recommends --yes sudo clang libomp-dev && \
21 | echo 'user ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/user && chmod 0440 /etc/sudoers.d/user
22 | USER user
23 |
24 | # Configure the non-root user's shell.
25 | RUN mkdir ~/.history/ && \
26 | echo 'HISTFILE=~/.history/.bash_history' >> ~/.bashrc && \
27 | echo 'bind "\"\e[A\": history-search-backward"' >> ~/.bashrc && \
28 | echo 'bind "\"\e[B\": history-search-forward"' >> ~/.bashrc && \
29 | echo 'eval "$(starship init bash)"' >> ~/.bashrc
30 |
31 | # Explicitly configure compilers for llama-cpp-python.
32 | ENV CC=clang
33 | ENV CXX=clang++
34 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 |
3 | devcontainer:
4 | build:
5 | target: dev
6 | environment:
7 | - CI
8 | - OPENAI_API_KEY
9 | depends_on:
10 | - postgres
11 | volumes:
12 | - ..:/workspaces
13 | - command-history-volume:/home/user/.history/
14 |
15 | postgres:
16 | image: pgvector/pgvector:pg17
17 | environment:
18 | POSTGRES_USER: raglite_user
19 | POSTGRES_PASSWORD: raglite_password
20 | tmpfs:
21 | - /var/lib/postgresql/data
22 |
23 | volumes:
24 | command-history-volume:
25 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system] # https://docs.astral.sh/uv/concepts/projects/config/#build-systems
2 | requires = ["hatchling>=1.27.0"]
3 | build-backend = "hatchling.build"
4 |
5 | [project] # https://packaging.python.org/en/latest/specifications/pyproject-toml/
6 | name = "raglite"
7 | version = "0.7.0"
8 | description = "A Python toolkit for Retrieval-Augmented Generation (RAG) with DuckDB or PostgreSQL."
9 | readme = "README.md"
10 | authors = [
11 | { name = "Laurent Sorber", email = "laurent@superlinear.eu" },
12 | ]
13 | requires-python = ">=3.10,<4.0"
14 | dependencies = [
15 | # Configuration:
16 | "platformdirs (>=4.0.0)",
17 | # Markdown conversion:
18 | "pdftext (>=0.4.1)",
19 | "scikit-learn (>=1.4.2)",
20 | # Sentence and chunk splitting:
21 | "markdown-it-py (>=3.0.0)",
22 | "numpy (>=1.26.4,<2.0.0)",
23 | "scipy (>=1.11.2,!=1.15.0.*,!=1.15.1,!=1.15.2)",
24 | "wtpsplit-lite (>=0.1.0)",
25 | # Large Language Models:
26 | "huggingface-hub (>=0.31.2)",
27 | "litellm (>=1.60.2)",
28 | "pydantic (>=2.7.0)",
29 | # Reranking:
30 | "langdetect (>=1.0.9)",
31 | "rerankers[api,flashrank] (>=0.10.0)",
32 | # Storage:
33 | "duckdb (>=1.1.3)",
34 | "duckdb-engine (>=0.16.0)",
35 | "pg8000 (>=1.31.2)",
36 | "sqlmodel-slim (>=0.0.21)",
37 | # Progress:
38 | "tqdm (>=4.66.0)",
39 | # CLI:
40 | "typer (>=0.15.1)",
41 | # Model Context Protocol:
42 | "fastmcp (>=2.0.0)",
43 | # Utilities:
44 | "packaging (>=23.0)",
45 | ]
46 |
47 | [project.scripts] # https://docs.astral.sh/uv/concepts/projects/config/#command-line-interfaces
48 | raglite = "raglite._cli:cli"
49 |
50 | [project.urls] # https://packaging.python.org/en/latest/specifications/well-known-project-urls/#well-known-labels
51 | homepage = "https://github.com/superlinear-ai/raglite"
52 | source = "https://github.com/superlinear-ai/raglite"
53 | changelog = "https://github.com/superlinear-ai/raglite/blob/main/CHANGELOG.md"
54 | releasenotes = "https://github.com/superlinear-ai/raglite/releases"
55 | documentation = "https://github.com/superlinear-ai/raglite"
56 | issues = "https://github.com/superlinear-ai/raglite/issues"
57 |
58 | [dependency-groups] # https://docs.astral.sh/uv/concepts/projects/dependencies/#development-dependencies
59 | dev = [
60 | "commitizen (>=4.3.0)",
61 | "coverage[toml] (>=7.6.10)",
62 | "ipykernel (>=6.29.4)",
63 | "ipython (>=8.18.0)",
64 | "ipywidgets (>=8.1.2)",
65 | "mypy (>=1.14.1)",
66 | "pdoc (>=15.0.1)",
67 | "poethepoet (>=0.32.1)",
68 | "pre-commit (>=4.0.1)",
69 | "pytest (>=8.3.4)",
70 | "pytest-mock (>=3.14.0)",
71 | "pytest-xdist (>=3.6.1)",
72 | "ruff (>=0.10.0)",
73 | "typeguard (>=4.4.1)",
74 | ]
75 |
76 | [project.optional-dependencies] # https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#dependencies-optional-dependencies
77 | # Frontend:
78 | chainlit = ["chainlit (>=2.0.0)"]
79 | # Large Language Models:
80 | llama-cpp-python = ["llama-cpp-python (>=0.3.9)"]
81 | # Markdown conversion:
82 | pandoc = ["pypandoc-binary (>=1.13)"]
83 | # Evaluation:
84 | ragas = ["pandas (>=2.1.1)", "ragas (>=0.1.12)"]
85 | # Benchmarking:
86 | bench = [
87 | "faiss-cpu (>=1.11.0)",
88 | "ir_datasets (>=0.5.10)",
89 | "ir_measures (>=0.3.7)",
90 | "llama-index (>=0.12.39)",
91 | "llama-index-vector-stores-faiss (>=0.4.0)",
92 | "openai (>=1.75.0)",
93 | "pandas (>=2.1.1)",
94 | "python-slugify (>=8.0.4)",
95 | ]
96 |
97 | [tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/
98 | bump_message = "bump: v$current_version → v$new_version"
99 | tag_format = "v$version"
100 | update_changelog_on_bump = true
101 | version_provider = "uv"
102 |
103 | [tool.coverage.report] # https://coverage.readthedocs.io/en/latest/config.html#report
104 | fail_under = 50
105 | precision = 1
106 | show_missing = true
107 | skip_covered = true
108 |
109 | [tool.coverage.run] # https://coverage.readthedocs.io/en/latest/config.html#run
110 | branch = true
111 | command_line = "--module pytest"
112 | data_file = "reports/.coverage"
113 | source = ["src"]
114 |
115 | [tool.coverage.xml] # https://coverage.readthedocs.io/en/latest/config.html#xml
116 | output = "reports/coverage.xml"
117 |
118 | [tool.mypy] # https://mypy.readthedocs.io/en/latest/config_file.html
119 | junit_xml = "reports/mypy.xml"
120 | strict = true
121 | disallow_subclassing_any = false
122 | disallow_untyped_decorators = false
123 | ignore_missing_imports = true
124 | pretty = true
125 | show_column_numbers = true
126 | show_error_codes = true
127 | show_error_context = true
128 | warn_unreachable = true
129 |
130 | [tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref
131 | addopts = "--color=yes --doctest-modules --ignore=src/raglite/_chainlit.py --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml"
132 | filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::pytest.PytestUnraisableExceptionWarning"]
133 | markers = ["slow: mark test as slow"]
134 | testpaths = ["src", "tests"]
135 | xfail_strict = true
136 |
137 | [tool.ruff] # https://docs.astral.sh/ruff/settings/
138 | fix = true
139 | line-length = 100
140 | src = ["src", "tests"]
141 | target-version = "py310"
142 |
143 | [tool.ruff.format]
144 | docstring-code-format = true
145 |
146 | [tool.ruff.lint]
147 | select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"]
148 | ignore = ["D203", "D213", "E501", "RET504", "RUF002", "RUF022", "S101", "S307", "TC004"]
149 | unfixable = ["ERA001", "F401", "F841", "T201", "T203"]
150 |
151 | [tool.ruff.lint.flake8-tidy-imports]
152 | ban-relative-imports = "all"
153 |
154 | [tool.ruff.lint.pycodestyle]
155 | max-doc-length = 100
156 |
157 | [tool.ruff.lint.pydocstyle]
158 | convention = "numpy"
159 |
160 | [tool.poe.executor] # https://github.com/nat-n/poethepoet
161 | type = "simple"
162 |
163 | [tool.poe.tasks]
164 |
165 | [tool.poe.tasks.docs]
166 | help = "Generate this package's docs"
167 | cmd = """
168 | pdoc
169 | --docformat $docformat
170 | --output-directory $outputdirectory
171 | raglite
172 | """
173 |
174 | [[tool.poe.tasks.docs.args]]
175 | help = "The docstring style (default: numpy)"
176 | name = "docformat"
177 | options = ["--docformat"]
178 | default = "numpy"
179 |
180 | [[tool.poe.tasks.docs.args]]
181 | help = "The output directory (default: docs)"
182 | name = "outputdirectory"
183 | options = ["--output-directory"]
184 | default = "docs"
185 |
186 | [tool.poe.tasks.lint]
187 | help = "Lint this package"
188 | cmd = """
189 | pre-commit run
190 | --all-files
191 | --color always
192 | """
193 |
194 | [tool.poe.tasks.test]
195 | help = "Test this package"
196 |
197 | [[tool.poe.tasks.test.sequence]]
198 | cmd = "coverage run"
199 |
200 | [[tool.poe.tasks.test.sequence]]
201 | cmd = "coverage report"
202 |
203 | [[tool.poe.tasks.test.sequence]]
204 | cmd = "coverage xml"
205 |
--------------------------------------------------------------------------------
/src/raglite/__init__.py:
--------------------------------------------------------------------------------
1 | """RAGLite."""
2 |
3 | from raglite._config import RAGLiteConfig
4 | from raglite._database import Document
5 | from raglite._eval import answer_evals, evaluate, insert_evals
6 | from raglite._insert import insert_documents
7 | from raglite._query_adapter import update_query_adapter
8 | from raglite._rag import add_context, async_rag, rag, retrieve_context
9 | from raglite._search import (
10 | hybrid_search,
11 | keyword_search,
12 | rerank_chunks,
13 | retrieve_chunk_spans,
14 | retrieve_chunks,
15 | search_and_rerank_chunk_spans,
16 | search_and_rerank_chunks,
17 | vector_search,
18 | )
19 |
20 | __all__ = [
21 | # Config
22 | "RAGLiteConfig",
23 | # Insert
24 | "Document",
25 | "insert_documents",
26 | # Search
27 | "hybrid_search",
28 | "keyword_search",
29 | "vector_search",
30 | "retrieve_chunks",
31 | "retrieve_chunk_spans",
32 | "rerank_chunks",
33 | "search_and_rerank_chunks",
34 | "search_and_rerank_chunk_spans",
35 | # RAG
36 | "retrieve_context",
37 | "add_context",
38 | "async_rag",
39 | "rag",
40 | # Query adapter
41 | "update_query_adapter",
42 | # Evaluate
43 | "answer_evals",
44 | "insert_evals",
45 | "evaluate",
46 | ]
47 |
--------------------------------------------------------------------------------
/src/raglite/_bench.py:
--------------------------------------------------------------------------------
1 | """Benchmarking with TREC runs."""
2 |
3 | import warnings
4 | from abc import ABC, abstractmethod
5 | from collections.abc import Generator
6 | from dataclasses import replace
7 | from functools import cached_property
8 | from pathlib import Path
9 | from typing import Any
10 |
11 | from ir_datasets.datasets.base import Dataset
12 | from ir_measures import ScoredDoc, read_trec_run
13 | from platformdirs import user_data_dir
14 | from slugify import slugify
15 | from tqdm.auto import tqdm
16 |
17 | from raglite._config import RAGLiteConfig
18 |
19 |
20 | class IREvaluator(ABC):
21 | def __init__(
22 | self,
23 | dataset: Dataset,
24 | *,
25 | num_results: int = 10,
26 | insert_variant: str | None = None,
27 | search_variant: str | None = None,
28 | ) -> None:
29 | self.dataset = dataset
30 | self.num_results = num_results
31 | self.insert_variant = insert_variant
32 | self.search_variant = search_variant
33 | self.insert_id = (
34 | slugify(self.__class__.__name__.lower().replace("evaluator", ""))
35 | + (f"_{slugify(insert_variant)}" if insert_variant else "")
36 | + f"_{slugify(dataset.docs_namespace())}"
37 | )
38 | self.search_id = (
39 | self.insert_id
40 | + f"@{num_results}"
41 | + (f"_{slugify(search_variant)}" if search_variant else "")
42 | )
43 | self.cwd = Path(user_data_dir("raglite", ensure_exists=True))
44 |
45 | @abstractmethod
46 | def insert_documents(self, max_workers: int | None = None) -> None:
47 | """Insert all of the dataset's documents into the search index."""
48 | raise NotImplementedError
49 |
50 | @abstractmethod
51 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
52 | """Search for documents given a query."""
53 | raise NotImplementedError
54 |
55 | @property
56 | def trec_run_filename(self) -> str:
57 | return f"{self.search_id}.trec"
58 |
59 | @property
60 | def trec_run_filepath(self) -> Path:
61 | return self.cwd / self.trec_run_filename
62 |
63 | def score(self) -> Generator[ScoredDoc, None, None]:
64 | """Read or compute a TREC run."""
65 | if self.trec_run_filepath.exists():
66 | yield from read_trec_run(self.trec_run_filepath.as_posix()) # type: ignore[no-untyped-call]
67 | return
68 | if not self.search("q0", next(self.dataset.queries_iter()).text):
69 | self.insert_documents()
70 | with self.trec_run_filepath.open(mode="w") as trec_run_file:
71 | for query in tqdm(
72 | self.dataset.queries_iter(),
73 | total=self.dataset.queries_count(),
74 | desc="Running queries",
75 | unit="query",
76 | dynamic_ncols=True,
77 | ):
78 | results = self.search(query.query_id, query.text, num_results=self.num_results)
79 | unique_results = {doc.doc_id: doc for doc in sorted(results, key=lambda d: d.score)}
80 | top_results = sorted(unique_results.values(), key=lambda d: d.score, reverse=True)
81 | top_results = top_results[: self.num_results]
82 | for rank, scored_doc in enumerate(top_results):
83 | trec_line = f"{query.query_id} 0 {scored_doc.doc_id} {rank} {scored_doc.score} {self.trec_run_filename}\n"
84 | trec_run_file.write(trec_line)
85 | yield scored_doc
86 |
87 |
88 | class RAGLiteEvaluator(IREvaluator):
89 | def __init__(
90 | self,
91 | dataset: Dataset,
92 | *,
93 | num_results: int = 10,
94 | insert_variant: str | None = None,
95 | search_variant: str | None = None,
96 | config: RAGLiteConfig | None = None,
97 | ):
98 | super().__init__(
99 | dataset,
100 | num_results=num_results,
101 | insert_variant=insert_variant,
102 | search_variant=search_variant,
103 | )
104 | self.db_filepath = self.cwd / f"{self.insert_id}.db"
105 | db_url = f"duckdb:///{self.db_filepath.as_posix()}"
106 | self.config = replace(config or RAGLiteConfig(), db_url=db_url)
107 |
108 | def insert_documents(self, max_workers: int | None = None) -> None:
109 | from raglite import Document, insert_documents
110 |
111 | documents = [
112 | Document.from_text(doc.text, id=doc.doc_id) for doc in self.dataset.docs_iter()
113 | ]
114 | insert_documents(documents, max_workers=max_workers, config=self.config)
115 |
116 | def update_query_adapter(self, num_evals: int = 1024) -> None:
117 | from raglite import insert_evals, update_query_adapter
118 | from raglite._database import IndexMetadata
119 |
120 | if (
121 | self.config.vector_search_query_adapter
122 | and IndexMetadata.get(config=self.config).get("query_adapter") is None
123 | ):
124 | insert_evals(num_evals=num_evals, config=self.config)
125 | update_query_adapter(config=self.config)
126 |
127 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
128 | from raglite import retrieve_chunks, vector_search
129 |
130 | self.update_query_adapter()
131 | chunk_ids, scores = vector_search(query, num_results=2 * num_results, config=self.config)
132 | chunks = retrieve_chunks(chunk_ids, config=self.config)
133 | scored_docs = [
134 | ScoredDoc(query_id=query_id, doc_id=chunk.document.id, score=score)
135 | for chunk, score in zip(chunks, scores, strict=True)
136 | ]
137 | return scored_docs
138 |
139 |
140 | class LlamaIndexEvaluator(IREvaluator):
141 | def __init__(
142 | self,
143 | dataset: Dataset,
144 | *,
145 | num_results: int = 10,
146 | insert_variant: str | None = None,
147 | search_variant: str | None = None,
148 | ):
149 | super().__init__(
150 | dataset,
151 | num_results=num_results,
152 | insert_variant=insert_variant,
153 | search_variant=search_variant,
154 | )
155 | self.embedder = "text-embedding-3-large"
156 | self.embedder_dim = 3072
157 | self.persist_path = self.cwd / self.insert_id
158 |
159 | def insert_documents(self, max_workers: int | None = None) -> None:
160 | # Adapted from https://docs.llamaindex.ai/en/stable/examples/vector_stores/FaissIndexDemo/.
161 | import faiss
162 | from llama_index.core import Document, StorageContext, VectorStoreIndex
163 | from llama_index.embeddings.openai import OpenAIEmbedding
164 | from llama_index.vector_stores.faiss import FaissVectorStore
165 |
166 | self.persist_path.mkdir(parents=True, exist_ok=True)
167 | faiss_index = faiss.IndexHNSWFlat(self.embedder_dim, 32, faiss.METRIC_INNER_PRODUCT)
168 | vector_store = FaissVectorStore(faiss_index=faiss_index)
169 | index = VectorStoreIndex.from_documents(
170 | [
171 | Document(id_=doc.doc_id, text=doc.text, metadata={"filename": doc.doc_id})
172 | for doc in self.dataset.docs_iter()
173 | ],
174 | storage_context=StorageContext.from_defaults(vector_store=vector_store),
175 | embed_model=OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim),
176 | show_progress=True,
177 | )
178 | index.storage_context.persist(persist_dir=self.persist_path)
179 |
180 | @cached_property
181 | def index(self) -> Any:
182 | from llama_index.core import StorageContext, load_index_from_storage
183 | from llama_index.embeddings.openai import OpenAIEmbedding
184 | from llama_index.vector_stores.faiss import FaissVectorStore
185 |
186 | vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.persist_path.as_posix())
187 | storage_context = StorageContext.from_defaults(
188 | vector_store=vector_store, persist_dir=self.persist_path.as_posix()
189 | )
190 | embed_model = OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim)
191 | index = load_index_from_storage(storage_context, embed_model=embed_model)
192 | return index
193 |
194 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
195 | if not self.persist_path.exists():
196 | self.insert_documents()
197 | retriever = self.index.as_retriever(similarity_top_k=2 * num_results)
198 | nodes = retriever.retrieve(query)
199 | scored_docs = [
200 | ScoredDoc(
201 | query_id=query_id,
202 | doc_id=node.metadata.get("filename", node.id_),
203 | score=node.score if node.score is not None else 1.0,
204 | )
205 | for node in nodes
206 | ]
207 | return scored_docs
208 |
209 |
210 | class OpenAIVectorStoreEvaluator(IREvaluator):
211 | def __init__(
212 | self,
213 | dataset: Dataset,
214 | *,
215 | num_results: int = 10,
216 | insert_variant: str | None = None,
217 | search_variant: str | None = None,
218 | ):
219 | super().__init__(
220 | dataset,
221 | num_results=num_results,
222 | insert_variant=insert_variant,
223 | search_variant=search_variant,
224 | )
225 | self.vector_store_name = dataset.docs_namespace() + (
226 | f"_{slugify(insert_variant)}" if insert_variant else ""
227 | )
228 |
229 | @cached_property
230 | def client(self) -> Any:
231 | import openai
232 |
233 | return openai.OpenAI()
234 |
235 | @property
236 | def vector_store_id(self) -> str | None:
237 | vector_stores = self.client.vector_stores.list()
238 | vector_store = next((vs for vs in vector_stores if vs.name == self.vector_store_name), None)
239 | if vector_store is None:
240 | return None
241 | if vector_store.file_counts.failed > 0:
242 | warnings.warn(
243 | f"Vector store {vector_store.name} has {vector_store.file_counts.failed} failed files.",
244 | stacklevel=2,
245 | )
246 | if vector_store.file_counts.in_progress > 0:
247 | error_message = f"Vector store {vector_store.name} has {vector_store.file_counts.in_progress} files in progress."
248 | raise RuntimeError(error_message)
249 | return vector_store.id # type: ignore[no-any-return]
250 |
251 | def insert_documents(self, max_workers: int | None = None) -> None:
252 | import tempfile
253 | from pathlib import Path
254 |
255 | vector_store = self.client.vector_stores.create(name=self.vector_store_name)
256 | files, max_files_per_batch = [], 32
257 | with tempfile.TemporaryDirectory() as temp_dir:
258 | for i, doc in tqdm(
259 | enumerate(self.dataset.docs_iter()),
260 | total=self.dataset.docs_count(),
261 | desc="Inserting documents",
262 | unit="document",
263 | dynamic_ncols=True,
264 | ):
265 | if not doc.text.strip():
266 | continue
267 | temp_file = Path(temp_dir) / f"{slugify(doc.doc_id)}.txt"
268 | temp_file.write_text(doc.text)
269 | files.append(temp_file.open("rb"))
270 | if len(files) == max_files_per_batch or (i == self.dataset.docs_count() - 1):
271 | self.client.vector_stores.file_batches.upload_and_poll(
272 | vector_store_id=vector_store.id, files=files, max_concurrency=max_workers
273 | )
274 | for f in files:
275 | f.close()
276 | files = []
277 |
278 | @cached_property
279 | def filename_to_doc_id(self) -> dict[str, str]:
280 | return {f"{slugify(doc.doc_id)}.txt": doc.doc_id for doc in self.dataset.docs_iter()}
281 |
282 | def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]:
283 | if not self.vector_store_id:
284 | return []
285 | response = self.client.vector_stores.search(
286 | vector_store_id=self.vector_store_id, query=query, max_num_results=2 * num_results
287 | )
288 | scored_docs = [
289 | ScoredDoc(
290 | query_id=query_id,
291 | doc_id=self.filename_to_doc_id[result.filename],
292 | score=result.score,
293 | )
294 | for result in response
295 | ]
296 | return scored_docs
297 |
--------------------------------------------------------------------------------
/src/raglite/_chainlit.py:
--------------------------------------------------------------------------------
1 | """Chainlit frontend for RAGLite."""
2 |
3 | import os
4 | from pathlib import Path
5 |
6 | import chainlit as cl
7 | from chainlit.input_widget import Switch, TextInput
8 |
9 | from raglite import (
10 | Document,
11 | RAGLiteConfig,
12 | async_rag,
13 | hybrid_search,
14 | insert_documents,
15 | rerank_chunks,
16 | )
17 | from raglite._markdown import document_to_markdown
18 |
19 | async_insert_documents = cl.make_async(insert_documents)
20 | async_hybrid_search = cl.make_async(hybrid_search)
21 | async_rerank_chunks = cl.make_async(rerank_chunks)
22 |
23 |
24 | @cl.on_chat_start
25 | async def start_chat() -> None:
26 | """Initialize the chat."""
27 | # Set tokenizers parallelism to avoid a deadlock warning.
28 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
29 | # Add Chainlit settings with which the user can configure the RAGLite config.
30 | default_config = RAGLiteConfig()
31 | config = RAGLiteConfig(
32 | db_url=os.environ.get("RAGLITE_DB_URL", default_config.db_url),
33 | llm=os.environ.get("RAGLITE_LLM", default_config.llm),
34 | embedder=os.environ.get("RAGLITE_EMBEDDER", default_config.embedder),
35 | )
36 | settings = await cl.ChatSettings( # type: ignore[no-untyped-call]
37 | [
38 | TextInput(id="db_url", label="Database URL", initial=str(config.db_url)),
39 | TextInput(id="llm", label="LLM", initial=config.llm),
40 | TextInput(id="embedder", label="Embedder", initial=config.embedder),
41 | Switch(id="vector_search_query_adapter", label="Query adapter", initial=True),
42 | ]
43 | ).send()
44 | await update_config(settings)
45 |
46 |
47 | @cl.on_settings_update # type: ignore[arg-type]
48 | async def update_config(settings: cl.ChatSettings) -> None:
49 | """Update the RAGLite config."""
50 | # Update the RAGLite config given the Chainlit settings.
51 | config = RAGLiteConfig(
52 | db_url=settings["db_url"], # type: ignore[index]
53 | llm=settings["llm"], # type: ignore[index]
54 | embedder=settings["embedder"], # type: ignore[index]
55 | vector_search_query_adapter=settings["vector_search_query_adapter"], # type: ignore[index]
56 | )
57 | cl.user_session.set("config", config) # type: ignore[no-untyped-call]
58 | # Run a search to prime the pipeline if it's a local pipeline.
59 | if config.embedder.startswith("llama-cpp-python"):
60 | query = "Hello world"
61 | chunk_ids, _ = await async_hybrid_search(query=query, config=config)
62 | _ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config)
63 |
64 |
65 | @cl.on_message
66 | async def handle_message(user_message: cl.Message) -> None:
67 | """Respond to a user message."""
68 | # Get the config and message history from the user session.
69 | config: RAGLiteConfig = cl.user_session.get("config") # type: ignore[no-untyped-call]
70 | # Determine what to do with the attachments.
71 | inline_attachments = []
72 | for file in user_message.elements:
73 | if file.path:
74 | doc_md = document_to_markdown(Path(file.path))
75 | if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3):
76 | # Document is small enough to attach to the context.
77 | inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}")
78 | else:
79 | # Document is too large and must be inserted into the database.
80 | async with cl.Step(name="insert", type="run") as step:
81 | step.input = Path(file.path).name
82 | document = Document.from_path(Path(file.path))
83 | await async_insert_documents([document], config=config)
84 | # Append any inline attachments to the user prompt.
85 | user_prompt = (
86 | "\n\n".join(
87 | f'\n{attachment.strip()}\n'
88 | for i, attachment in enumerate(inline_attachments)
89 | )
90 | + f"\n\n{user_message.content}"
91 | ).strip()
92 | # Stream the LLM response.
93 | assistant_message = cl.Message(content="")
94 | chunk_spans = []
95 | messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
96 | messages.append({"role": "user", "content": user_prompt})
97 | async for token in async_rag(
98 | messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config
99 | ):
100 | await assistant_message.stream_token(token)
101 | # Append RAG sources, if any.
102 | if chunk_spans:
103 | rag_sources: dict[str, list[str]] = {}
104 | for chunk_span in chunk_spans:
105 | rag_sources.setdefault(chunk_span.document.id, [])
106 | rag_sources[chunk_span.document.id].append(str(chunk_span))
107 | assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks.
108 | f"[{i + 1}]" for i in range(len(rag_sources))
109 | )
110 | assistant_message.elements = [ # Markdown content is rendered in sidebar.
111 | cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc]
112 | for i, (_, content) in enumerate(rag_sources.items())
113 | ]
114 | await assistant_message.update() # type: ignore[no-untyped-call]
115 |
--------------------------------------------------------------------------------
/src/raglite/_cli.py:
--------------------------------------------------------------------------------
1 | """RAGLite CLI."""
2 |
3 | import json
4 | import os
5 | from typing import ClassVar
6 |
7 | import typer
8 | from pydantic_settings import BaseSettings, SettingsConfigDict
9 |
10 | from raglite._config import RAGLiteConfig
11 |
12 |
13 | class RAGLiteCLIConfig(BaseSettings):
14 | """RAGLite CLI config."""
15 |
16 | model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
17 | env_prefix="RAGLITE_", env_file=".env", extra="allow"
18 | )
19 |
20 | mcp_server_name: str = "RAGLite"
21 | db_url: str = str(RAGLiteConfig().db_url)
22 | llm: str = RAGLiteConfig().llm
23 | embedder: str = RAGLiteConfig().embedder
24 |
25 |
26 | cli = typer.Typer()
27 | cli.add_typer(mcp_cli := typer.Typer(), name="mcp")
28 |
29 |
30 | @cli.callback()
31 | def main(
32 | ctx: typer.Context,
33 | db_url: str = typer.Option(RAGLiteCLIConfig().db_url, help="Database URL"),
34 | llm: str = typer.Option(RAGLiteCLIConfig().llm, help="LiteLLM LLM"),
35 | embedder: str = typer.Option(RAGLiteCLIConfig().embedder, help="LiteLLM embedder"),
36 | ) -> None:
37 | """RAGLite CLI."""
38 | ctx.obj = {"db_url": db_url, "llm": llm, "embedder": embedder}
39 |
40 |
41 | @cli.command()
42 | def chainlit(ctx: typer.Context) -> None:
43 | """Serve a Chainlit frontend."""
44 | # Set the environment variables for the Chainlit frontend.
45 | os.environ["RAGLITE_DB_URL"] = ctx.obj["db_url"]
46 | os.environ["RAGLITE_LLM"] = ctx.obj["llm"]
47 | os.environ["RAGLITE_EMBEDDER"] = ctx.obj["embedder"]
48 | # Import Chainlit here as it's an optional dependency.
49 | try:
50 | from chainlit.cli import run_chainlit
51 | except ModuleNotFoundError as error:
52 | error_message = "To serve a Chainlit frontend, please install the `chainlit` extra."
53 | raise ModuleNotFoundError(error_message) from error
54 | # Serve the frontend.
55 | run_chainlit(__file__.replace("_cli.py", "_chainlit.py"))
56 |
57 |
58 | @mcp_cli.command("install")
59 | def install_mcp_server(
60 | ctx: typer.Context,
61 | server_name: str = typer.Option(RAGLiteCLIConfig().mcp_server_name, help="MCP server name"),
62 | ) -> None:
63 | """Install MCP server in the Claude desktop app."""
64 | from fastmcp.cli.claude import get_claude_config_path
65 |
66 | # Get the Claude config path.
67 | claude_config_path = get_claude_config_path()
68 | if not claude_config_path:
69 | typer.echo(
70 | "Please download the Claude desktop app from https://claude.ai/download before installing an MCP server."
71 | )
72 | return
73 | claude_config_filepath = claude_config_path / "claude_desktop_config.json"
74 | # Parse the Claude config.
75 | claude_config = (
76 | json.loads(claude_config_filepath.read_text()) if claude_config_filepath.exists() else {}
77 | )
78 | # Update the Claude config with the MCP server.
79 | mcp_config = RAGLiteCLIConfig(
80 | mcp_server_name=server_name,
81 | db_url=ctx.obj["db_url"],
82 | llm=ctx.obj["llm"],
83 | embedder=ctx.obj["embedder"],
84 | )
85 | claude_config["mcpServers"][server_name] = {
86 | "command": "uvx",
87 | "args": [
88 | "--python",
89 | "3.11",
90 | "--with",
91 | "numpy<2.0.0", # TODO: Remove this constraint when uv no longer needs it to solve the environment.
92 | "raglite",
93 | "mcp",
94 | "run",
95 | ],
96 | "env": {
97 | f"RAGLITE_{key.upper()}" if key in RAGLiteCLIConfig.model_fields else key.upper(): value
98 | for key, value in mcp_config.model_dump().items()
99 | if value
100 | },
101 | }
102 | # Write the updated Claude config to disk.
103 | claude_config_filepath.write_text(json.dumps(claude_config, indent=2))
104 |
105 |
106 | @mcp_cli.command("run")
107 | def run_mcp_server(
108 | ctx: typer.Context,
109 | server_name: str = typer.Option(RAGLiteCLIConfig().mcp_server_name, help="MCP server name"),
110 | ) -> None:
111 | """Run MCP server."""
112 | from raglite._mcp import create_mcp_server
113 |
114 | config = RAGLiteConfig(
115 | db_url=ctx.obj["db_url"], llm=ctx.obj["llm"], embedder=ctx.obj["embedder"]
116 | )
117 | mcp = create_mcp_server(server_name, config=config)
118 | mcp.run()
119 |
120 |
121 | @cli.command()
122 | def bench(
123 | ctx: typer.Context,
124 | dataset_name: str = typer.Option(
125 | "nano-beir/hotpotqa", "--dataset", "-d", help="Dataset to use from https://ir-datasets.com/"
126 | ),
127 | measure: str = typer.Option(
128 | "AP@10",
129 | "--measure",
130 | "-m",
131 | help="Evaluation measure from https://ir-measur.es/en/latest/measures.html",
132 | ),
133 | ) -> None:
134 | """Run benchmark."""
135 | import ir_datasets
136 | import ir_measures
137 | import pandas as pd
138 |
139 | from raglite._bench import (
140 | IREvaluator,
141 | LlamaIndexEvaluator,
142 | OpenAIVectorStoreEvaluator,
143 | RAGLiteEvaluator,
144 | )
145 |
146 | # Initialise the benchmark.
147 | evaluator: IREvaluator
148 | measures = [ir_measures.parse_measure(measure)]
149 | index, results = [], []
150 | # Evaluate RAGLite (single-vector) + DuckDB HNSW + text-embedding-3-large.
151 | chunk_max_size = 2048
152 | config = RAGLiteConfig(
153 | embedder="text-embedding-3-large",
154 | chunk_max_size=chunk_max_size,
155 | vector_search_multivector=False,
156 | vector_search_query_adapter=False,
157 | )
158 | dataset = ir_datasets.load(dataset_name)
159 | evaluator = RAGLiteEvaluator(
160 | dataset, insert_variant=f"single-vector-{chunk_max_size // 4}t", config=config
161 | )
162 | index.append("RAGLite (single-vector)")
163 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
164 | # Evaluate RAGLite (multi-vector) + DuckDB HNSW + text-embedding-3-large.
165 | config = RAGLiteConfig(
166 | embedder="text-embedding-3-large",
167 | chunk_max_size=chunk_max_size,
168 | vector_search_multivector=True,
169 | vector_search_query_adapter=False,
170 | )
171 | dataset = ir_datasets.load(dataset_name)
172 | evaluator = RAGLiteEvaluator(
173 | dataset, insert_variant=f"multi-vector-{chunk_max_size // 4}t", config=config
174 | )
175 | index.append("RAGLite (multi-vector)")
176 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
177 | # Evaluate RAGLite (query adapter) + DuckDB HNSW + text-embedding-3-large.
178 | config = RAGLiteConfig(
179 | llm=(llm := "gpt-4.1"),
180 | embedder="text-embedding-3-large",
181 | chunk_max_size=chunk_max_size,
182 | vector_search_multivector=True,
183 | vector_search_query_adapter=True,
184 | )
185 | dataset = ir_datasets.load(dataset_name)
186 | evaluator = RAGLiteEvaluator(
187 | dataset,
188 | insert_variant=f"multi-vector-{chunk_max_size // 4}t",
189 | search_variant=f"query-adapter-{llm}",
190 | config=config,
191 | )
192 | index.append("RAGLite (query adapter)")
193 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
194 | # Evaluate LLamaIndex + FAISS HNSW + text-embedding-3-large.
195 | dataset = ir_datasets.load(dataset_name)
196 | evaluator = LlamaIndexEvaluator(dataset)
197 | index.append("LlamaIndex")
198 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
199 | # Evaluate OpenAI Vector Store.
200 | dataset = ir_datasets.load(dataset_name)
201 | evaluator = OpenAIVectorStoreEvaluator(dataset)
202 | index.append("OpenAI Vector Store")
203 | results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
204 | # Print the results.
205 | results_df = pd.DataFrame.from_records(results, index=index)
206 | typer.echo(results_df)
207 |
208 |
209 | if __name__ == "__main__":
210 | cli()
211 |
--------------------------------------------------------------------------------
/src/raglite/_config.py:
--------------------------------------------------------------------------------
1 | """RAGLite config."""
2 |
3 | import contextlib
4 | import os
5 | from dataclasses import dataclass, field
6 | from io import StringIO
7 | from pathlib import Path
8 | from typing import Literal
9 |
10 | from platformdirs import user_data_dir
11 | from sqlalchemy.engine import URL
12 |
13 | from raglite._lazy_llama import llama_supports_gpu_offload
14 | from raglite._typing import ChunkId, SearchMethod
15 |
16 | # Suppress rerankers output on import until [1] is fixed.
17 | # [1] https://github.com/AnswerDotAI/rerankers/issues/36
18 | with contextlib.redirect_stdout(StringIO()):
19 | from rerankers.models.flashrank_ranker import FlashRankRanker
20 | from rerankers.models.ranker import BaseRanker
21 |
22 |
23 | cache_path = Path(user_data_dir("raglite", ensure_exists=True))
24 |
25 |
26 | # Lazily load the default search method to avoid circular imports.
27 | # TODO: Replace with search_and_rerank_chunk_spans after benchmarking.
28 | def _vector_search(
29 | query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None
30 | ) -> tuple[list[ChunkId], list[float]]:
31 | from raglite._search import vector_search
32 |
33 | return vector_search(query, num_results=num_results, config=config)
34 |
35 |
36 | @dataclass(frozen=True)
37 | class RAGLiteConfig:
38 | """RAGLite config."""
39 |
40 | # Database config.
41 | db_url: str | URL = f"duckdb:///{(cache_path / 'raglite.db').as_posix()}"
42 | # LLM config used for generation.
43 | llm: str = field(
44 | default_factory=lambda: (
45 | "llama-cpp-python/unsloth/Qwen3-8B-GGUF/*Q4_K_M.gguf@8192"
46 | if llama_supports_gpu_offload()
47 | else "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192"
48 | )
49 | )
50 | llm_max_tries: int = 4
51 | # Embedder config used for indexing.
52 | embedder: str = field(
53 | default_factory=lambda: ( # Nomic-embed may be better if only English is used.
54 | "llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@512"
55 | if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004
56 | else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512"
57 | )
58 | )
59 | embedder_normalize: bool = True
60 | # Chunk config used to partition documents into chunks.
61 | chunk_max_size: int = 2048 # Max number of characters per chunk.
62 | # Vector search config.
63 | vector_search_distance_metric: Literal["cosine", "dot", "l2"] = "cosine"
64 | vector_search_multivector: bool = True
65 | vector_search_query_adapter: bool = True # Only supported for "cosine" and "dot" metrics.
66 | # Reranking config.
67 | reranker: BaseRanker | dict[str, BaseRanker] | None = field(
68 | default_factory=lambda: {
69 | "en": FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0, cache_dir=cache_path),
70 | "other": FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0, cache_dir=cache_path),
71 | },
72 | compare=False, # Exclude the reranker from comparison to avoid lru_cache misses.
73 | )
74 | # Search config: you can pick any search method that returns (list[ChunkId], list[float]),
75 | # list[Chunk], or list[ChunkSpan].
76 | search_method: SearchMethod = field(default=_vector_search, compare=False)
77 |
--------------------------------------------------------------------------------
/src/raglite/_embed.py:
--------------------------------------------------------------------------------
1 | """String embedder."""
2 |
3 | from functools import partial
4 | from typing import Literal
5 |
6 | import numpy as np
7 | from litellm import embedding
8 | from tqdm.auto import tqdm, trange
9 |
10 | from raglite._config import RAGLiteConfig
11 | from raglite._lazy_llama import LLAMA_POOLING_TYPE_NONE, Llama
12 | from raglite._litellm import LlamaCppPythonLLM
13 | from raglite._typing import FloatMatrix, IntVector
14 |
15 |
16 | def embed_strings_with_late_chunking( # noqa: C901,PLR0915
17 | sentences: list[str], *, config: RAGLiteConfig | None = None
18 | ) -> FloatMatrix:
19 | """Embed a document's sentences with late chunking."""
20 |
21 | def _count_tokens(
22 | sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int]
23 | ) -> list[int]:
24 | # Join the sentences with the sentinel token and tokenise the result.
25 | sentences_tokens = np.asarray(
26 | embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp
27 | )
28 | # Map all sentinel token variants to the first one.
29 | for sentinel_token in sentinel_tokens[1:]:
30 | sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0]
31 | # Count how many tokens there are in between sentinel tokens to recover the token counts.
32 | sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0]
33 | num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens))
34 | assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document"
35 | num_tokens_list: list[int] = num_tokens.tolist()
36 | return num_tokens_list
37 |
38 | def _create_segment(
39 | content_start_index: int,
40 | max_tokens_preamble: int,
41 | max_tokens_content: int,
42 | num_tokens: IntVector,
43 | ) -> tuple[int, int]:
44 | # Compute the segment sentence start index so that the segment preamble has no more than
45 | # max_tokens_preamble tokens between [segment_start_index, content_start_index).
46 | cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1])
47 | offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right")
48 | segment_start_index = content_start_index - int(offset_preamble)
49 | # Allow a larger segment content if we didn't use all of the allowed preamble tokens.
50 | max_tokens_content = max_tokens_content + (
51 | max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index])
52 | )
53 | # Compute the segment sentence end index so that the segment content has no more than
54 | # max_tokens_content tokens between [content_start_index, segment_end_index).
55 | cumsum_forwards = np.cumsum(num_tokens[content_start_index:])
56 | offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right")
57 | segment_end_index = content_start_index + int(offset_segment)
58 | return segment_start_index, segment_end_index
59 |
60 | # Assert that we're using a llama-cpp-python model, since API-based embedding models don't
61 | # support outputting token-level embeddings.
62 | config = config or RAGLiteConfig()
63 | assert config.embedder.startswith("llama-cpp-python")
64 | embedder = LlamaCppPythonLLM.llm(
65 | config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
66 | )
67 | n_ctx = embedder.n_ctx()
68 | n_batch = embedder.n_batch
69 | # Identify the tokens corresponding to a sentinel character.
70 | sentinel_char = "⊕"
71 | sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D"
72 | sentinel_tokens = [
73 | token
74 | for token in embedder.tokenize(sentinel_test.encode(), add_bos=False)
75 | if sentinel_char in embedder.detokenize([token]).decode()
76 | ]
77 | assert sentinel_tokens, f"Sentinel `{sentinel_char}` not supported by embedder"
78 | # Compute the number of tokens per sentence. We use a method based on a sentinel token to
79 | # minimise the number of calls to embedder.tokenize, which incurs a significant overhead
80 | # (presumably to load the tokenizer) [1].
81 | # TODO: Make token counting faster and more robust once [1] is fixed.
82 | # [1] https://github.com/abetlen/llama-cpp-python/issues/1763
83 | num_tokens_list: list[int] = []
84 | sentence_batch, sentence_batch_len = [], 0
85 | for i, sentence in enumerate(sentences):
86 | sentence_batch.append(sentence)
87 | sentence_batch_len += len(sentence)
88 | if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2):
89 | num_tokens_list.extend(
90 | _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens)
91 | )
92 | sentence_batch, sentence_batch_len = [], 0
93 | num_tokens = np.asarray(num_tokens_list, dtype=np.intp)
94 | # Compute the maximum number of tokens for each segment's preamble and content.
95 | # Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try
96 | # to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch.
97 | # TODO: Improve the context window size once [1] is fixed.
98 | # [1] https://github.com/abetlen/llama-cpp-python/issues/1762
99 | max_tokens = min(n_ctx, n_batch) - 16
100 | max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio.
101 | max_tokens_content = max_tokens - max_tokens_preamble
102 | # Compute a list of segments, each consisting of a preamble and content.
103 | segments = []
104 | content_start_index = 0
105 | while content_start_index < len(sentences):
106 | segment_start_index, segment_end_index = _create_segment(
107 | content_start_index, max_tokens_preamble, max_tokens_content, num_tokens
108 | )
109 | segments.append((segment_start_index, content_start_index, segment_end_index))
110 | content_start_index = segment_end_index
111 | # Embed the segments and apply late chunking.
112 | sentence_embeddings_list: list[FloatMatrix] = []
113 | if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004
114 | segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True, leave=False)
115 | for segment in segments:
116 | # Get the token embeddings of the entire segment, including preamble and content.
117 | segment_start_index, content_start_index, segment_end_index = segment
118 | segment_sentences = sentences[segment_start_index:segment_end_index]
119 | segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
120 | # Split the segment embeddings into embedding matrices per sentence using the largest
121 | # remainder method.
122 | segment_tokens = num_tokens[segment_start_index:segment_end_index]
123 | sentence_size_frac = len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
124 | sentence_size = np.floor(sentence_size_frac).astype(np.intp)
125 | remainder = len(segment_embedding) - np.sum(sentence_size)
126 | if remainder > 0: # Assign the remaining tokens to sentences with largest fractional parts.
127 | top_remainders = np.argsort(sentence_size_frac - sentence_size)[-remainder:]
128 | sentence_size[top_remainders] += 1
129 | sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1])
130 | # Compute the segment sentence embeddings by averaging the token embeddings.
131 | content_sentence_embeddings = [
132 | np.mean(sentence_matrix, axis=0, keepdims=True)
133 | for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :]
134 | ]
135 | sentence_embeddings_list.append(np.vstack(content_sentence_embeddings))
136 | sentence_embeddings = np.vstack(sentence_embeddings_list)
137 | # Normalise the sentence embeddings to unit norm and cast to half precision.
138 | if config.embedder_normalize:
139 | sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
140 | sentence_embeddings = sentence_embeddings.astype(np.float16)
141 | return sentence_embeddings
142 |
143 |
144 | def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix:
145 | """Embed a batch of text strings."""
146 | if config.embedder.startswith("llama-cpp-python"):
147 | # LiteLLM doesn't yet support registering a custom embedder, so we handle it here.
148 | # Additionally, we explicitly manually pool the token embeddings to obtain sentence
149 | # embeddings because token embeddings are universally supported, while sequence
150 | # embeddings are only supported by some models.
151 | embedder = LlamaCppPythonLLM.llm(
152 | config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
153 | )
154 | embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
155 | else:
156 | # Use LiteLLM's API to embed the batch of strings.
157 | response = embedding(config.embedder, string_batch)
158 | embeddings = np.asarray([item["embedding"] for item in response["data"]])
159 | # Normalise the embeddings to unit norm and cast to half precision.
160 | if config.embedder_normalize:
161 | eps = np.finfo(embeddings.dtype).eps
162 | norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
163 | embeddings /= np.maximum(norm, eps)
164 | embeddings = embeddings.astype(np.float16)
165 | return embeddings
166 |
167 |
168 | def embed_strings_without_late_chunking(
169 | strings: list[str], *, config: RAGLiteConfig | None = None
170 | ) -> FloatMatrix:
171 | """Embed a list of text strings in batches."""
172 | config = config or RAGLiteConfig()
173 | batch_size = 96
174 | batch_range = (
175 | partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True)
176 | if len(strings) > batch_size
177 | else range
178 | )
179 | batch_embeddings = [
180 | _embed_string_batch(strings[i : i + batch_size], config=config)
181 | for i in batch_range(0, len(strings), batch_size)
182 | ]
183 | string_embeddings = np.vstack(batch_embeddings)
184 | return string_embeddings
185 |
186 |
187 | def embedding_type(*, config: RAGLiteConfig | None = None) -> Literal["late_chunking", "standard"]:
188 | """Return the type of sentence embeddings."""
189 | config = config or RAGLiteConfig()
190 | return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "standard"
191 |
192 |
193 | def embed_strings(strings: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
194 | """Embed the chunklets of a document as a NumPy matrix with one row per chunklet."""
195 | config = config or RAGLiteConfig()
196 | if embedding_type(config=config) == "late_chunking":
197 | string_embeddings = embed_strings_with_late_chunking(strings, config=config)
198 | else:
199 | string_embeddings = embed_strings_without_late_chunking(strings, config=config)
200 | return string_embeddings
201 |
--------------------------------------------------------------------------------
/src/raglite/_eval.py:
--------------------------------------------------------------------------------
1 | """Generation and evaluation of evals."""
2 |
3 | import contextlib
4 | from concurrent.futures import ThreadPoolExecutor, as_completed
5 | from functools import partial
6 | from random import randint
7 | from typing import TYPE_CHECKING, ClassVar
8 |
9 | import numpy as np
10 | from pydantic import BaseModel, ConfigDict, Field, field_validator
11 | from sqlalchemy import text
12 | from sqlmodel import Session, func, select
13 | from tqdm.auto import tqdm
14 |
15 | if TYPE_CHECKING:
16 | import pandas as pd
17 |
18 | from raglite._config import RAGLiteConfig
19 | from raglite._database import Chunk, Document, Eval, create_database_engine
20 | from raglite._extract import extract_with_llm
21 | from raglite._rag import add_context, rag, retrieve_context
22 | from raglite._search import retrieve_chunk_spans, vector_search
23 |
24 |
25 | def generate_eval(*, max_chunks: int = 20, config: RAGLiteConfig | None = None) -> Eval:
26 | """Generate an eval."""
27 |
28 | class QuestionResponse(BaseModel):
29 | """A specific question about the content of a set of document contexts."""
30 |
31 | model_config = ConfigDict(
32 | extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
33 | )
34 | question: str = Field(
35 | ..., description="A specific question about the content of a set of document contexts."
36 | )
37 | system_prompt: ClassVar[str] = """
38 | You are given a set of contexts extracted from a document.
39 | You are a subject matter expert on the document's topic.
40 | Your task is to generate a question to quiz other subject matter experts on the information in the provided context.
41 | The question MUST satisfy ALL of the following criteria:
42 | - The question SHOULD integrate as much of the provided context as possible.
43 | - The question MUST NOT be a general or open question, but MUST instead be as specific to the provided context as possible.
44 | - The question MUST be completely answerable using ONLY the information in the provided context, without depending on any background information.
45 | - The question MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
46 | - The question MUST NOT reference the existence of the context, directly or indirectly.
47 | - The question MUST treat the context as if its contents are entirely part of your working memory.
48 | """.strip()
49 |
50 | @field_validator("question")
51 | @classmethod
52 | def validate_question(cls, value: str) -> str:
53 | """Validate the question."""
54 | question = value.strip().lower()
55 | if "context" in question or "document" in question or "question" in question:
56 | raise ValueError
57 | if not question.endswith("?"):
58 | raise ValueError
59 | return value
60 |
61 | with Session(create_database_engine(config := config or RAGLiteConfig())) as session:
62 | # Sample a random document from the database.
63 | seed_document = session.exec(select(Document).order_by(func.random()).limit(1)).first()
64 | if seed_document is None:
65 | error_message = "First run `insert_documents()` before generating evals."
66 | raise ValueError(error_message)
67 | # Sample a random chunk from that document.
68 | seed_chunk = session.exec(
69 | select(Chunk)
70 | .where(Chunk.document_id == seed_document.id)
71 | .order_by(func.random())
72 | .limit(1)
73 | ).first()
74 | assert isinstance(seed_chunk, Chunk)
75 | # Expand the seed chunk into a set of related chunks.
76 | related_chunk_ids, _ = vector_search(
77 | query=np.mean(seed_chunk.embedding_matrix, axis=0),
78 | num_results=randint(1, max_chunks), # noqa: S311
79 | config=config,
80 | )
81 | related_chunks = [
82 | str(chunk_spans)
83 | for chunk_spans in retrieve_chunk_spans(related_chunk_ids, config=config)
84 | ]
85 | # Extract a question from the seed chunk's related chunks.
86 | question = extract_with_llm(
87 | QuestionResponse, related_chunks, strict=True, config=config
88 | ).question
89 | # Search for candidate chunks to answer the generated question.
90 | candidate_chunk_ids, _ = vector_search(
91 | query=question, num_results=2 * max_chunks, config=config
92 | )
93 | candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]
94 |
95 | # Determine which candidate chunks are relevant to answer the generated question.
96 | class ContextEvalResponse(BaseModel):
97 | """Indicate whether the provided context can be used to answer a given question."""
98 |
99 | model_config = ConfigDict(
100 | extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
101 | )
102 | hit: bool = Field(
103 | ...,
104 | description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
105 | )
106 | system_prompt: ClassVar[str] = f"""
107 | You are given a context extracted from a document.
108 | You are a subject matter expert on the document's topic.
109 | Your task is to determine whether the provided context contains (a part of) the answer to this question: "{question}"
110 | An example of a context that does NOT contain (a part of) the answer is a table of contents.
111 | """.strip()
112 |
113 | relevant_chunks = []
114 | for candidate_chunk in tqdm(
115 | candidate_chunks,
116 | desc="Evaluating chunks",
117 | unit="chunk",
118 | dynamic_ncols=True,
119 | leave=False,
120 | ):
121 | try:
122 | context_eval_response = extract_with_llm(
123 | ContextEvalResponse, str(candidate_chunk), strict=True, config=config
124 | )
125 | except ValueError: # noqa: PERF203
126 | pass
127 | else:
128 | if context_eval_response.hit:
129 | relevant_chunks.append(candidate_chunk)
130 | if not relevant_chunks:
131 | error_message = "No relevant chunks found to answer the question."
132 | raise ValueError(error_message)
133 |
134 | # Answer the question using the relevant chunks.
135 | class AnswerResponse(BaseModel):
136 | """Answer a question using the provided context."""
137 |
138 | model_config = ConfigDict(
139 | extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
140 | )
141 | answer: str = Field(
142 | ...,
143 | description="A complete answer to the given question using the provided context.",
144 | )
145 | system_prompt: ClassVar[str] = f"""
146 | You are given a set of contexts extracted from a document.
147 | You are a subject matter expert on the document's topic.
148 | Your task is to generate a complete answer to the following question using the provided context: "{question}"
149 | The answer MUST satisfy ALL of the following criteria:
150 | - The answer MUST integrate as much of the provided context as possible.
151 | - The answer MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
152 | - The answer MUST NOT reference the existence of the context, directly or indirectly.
153 | - The answer MUST treat the context as if its contents are entirely part of your working memory.
154 | """.strip()
155 |
156 | answer = extract_with_llm(
157 | AnswerResponse,
158 | [str(relevant_chunk) for relevant_chunk in relevant_chunks],
159 | strict=True,
160 | config=config,
161 | ).answer
162 | # Construct the eval.
163 | eval_ = Eval.from_chunks(question=question, contexts=relevant_chunks, ground_truth=answer)
164 | return eval_
165 |
166 |
167 | def insert_evals(
168 | *,
169 | num_evals: int = 100,
170 | max_chunks_per_eval: int = 20,
171 | max_workers: int | None = None,
172 | config: RAGLiteConfig | None = None,
173 | ) -> None:
174 | """Generate and insert evals into the database."""
175 | with (
176 | Session(engine := create_database_engine(config := config or RAGLiteConfig())) as session,
177 | ThreadPoolExecutor(max_workers=max_workers) as executor,
178 | tqdm(total=num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True) as pbar,
179 | ):
180 | futures = [
181 | executor.submit(partial(generate_eval, max_chunks=max_chunks_per_eval, config=config))
182 | for _ in range(num_evals)
183 | ]
184 | for future in as_completed(futures):
185 | with contextlib.suppress(Exception): # Eval generation may fail for various reasons.
186 | eval_ = future.result()
187 | session.add(eval_)
188 | pbar.update()
189 | session.commit()
190 | if engine.dialect.name == "duckdb":
191 | session.execute(text("CHECKPOINT;"))
192 |
193 |
194 | def answer_evals(
195 | num_evals: int = 100,
196 | *,
197 | config: RAGLiteConfig | None = None,
198 | ) -> "pd.DataFrame":
199 | """Read evals from the database and answer them with RAG."""
200 | try:
201 | import pandas as pd
202 | except ModuleNotFoundError as import_error:
203 | error_message = "To use the `answer_evals` function, please install the `ragas` extra."
204 | raise ModuleNotFoundError(error_message) from import_error
205 |
206 | # Read evals from the database.
207 | with Session(create_database_engine(config := config or RAGLiteConfig())) as session:
208 | evals = session.exec(select(Eval).limit(num_evals)).all()
209 | # Answer evals with RAG.
210 | answers: list[str] = []
211 | contexts: list[list[str]] = []
212 | for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
213 | chunk_spans = retrieve_context(query=eval_.question, config=config)
214 | messages = [add_context(user_prompt=eval_.question, context=chunk_spans)]
215 | response = rag(messages, config=config)
216 | answer = "".join(response)
217 | answers.append(answer)
218 | contexts.append([str(chunk_span) for chunk_span in chunk_spans])
219 | # Collect the answered evals.
220 | answered_evals: dict[str, list[str] | list[list[str]]] = {
221 | "question": [eval_.question for eval_ in evals],
222 | "answer": answers,
223 | "contexts": contexts,
224 | "ground_truth": [eval_.ground_truth for eval_ in evals],
225 | "ground_truth_contexts": [eval_.contexts for eval_ in evals],
226 | }
227 | answered_evals_df = pd.DataFrame.from_dict(answered_evals)
228 | return answered_evals_df
229 |
230 |
231 | def evaluate(
232 | answered_evals: "pd.DataFrame | int" = 100, config: RAGLiteConfig | None = None
233 | ) -> "pd.DataFrame":
234 | """Evaluate the performance of a set of answered evals with Ragas."""
235 | try:
236 | import pandas as pd
237 | from datasets import Dataset
238 | from langchain_community.chat_models import ChatLiteLLM
239 | from langchain_community.llms import LlamaCpp
240 | from ragas import RunConfig
241 | from ragas import evaluate as ragas_evaluate
242 | from ragas.embeddings import BaseRagasEmbeddings
243 |
244 | from raglite._config import RAGLiteConfig
245 | from raglite._embed import embed_strings
246 | from raglite._litellm import LlamaCppPythonLLM
247 | except ModuleNotFoundError as import_error:
248 | error_message = "To use the `evaluate` function, please install the `ragas` extra."
249 | raise ModuleNotFoundError(error_message) from import_error
250 |
251 | class RAGLiteRagasEmbeddings(BaseRagasEmbeddings):
252 | """A RAGLite embedder for Ragas."""
253 |
254 | def __init__(self, config: RAGLiteConfig | None = None):
255 | self.config = config or RAGLiteConfig()
256 |
257 | def embed_query(self, text: str) -> list[float]:
258 | # Embed the input text with RAGLite's embedding function.
259 | embeddings = embed_strings([text], config=self.config)
260 | return embeddings[0].tolist() # type: ignore[no-any-return]
261 |
262 | def embed_documents(self, texts: list[str]) -> list[list[float]]:
263 | # Embed a list of documents with RAGLite's embedding function.
264 | embeddings = embed_strings(texts, config=self.config)
265 | return embeddings.tolist() # type: ignore[no-any-return]
266 |
267 | # Create a set of answered evals if not provided.
268 | config = config or RAGLiteConfig()
269 | answered_evals_df = (
270 | answered_evals
271 | if isinstance(answered_evals, pd.DataFrame)
272 | else answer_evals(num_evals=answered_evals, config=config)
273 | )
274 | # Load the LLM.
275 | if config.llm.startswith("llama-cpp-python"):
276 | llm = LlamaCppPythonLLM().llm(model=config.llm)
277 | lc_llm = LlamaCpp(
278 | model_path=llm.model_path,
279 | n_batch=llm.n_batch,
280 | n_ctx=llm.n_ctx(),
281 | n_gpu_layers=-1,
282 | verbose=llm.verbose,
283 | )
284 | else:
285 | lc_llm = ChatLiteLLM(model=config.llm)
286 | embedder = RAGLiteRagasEmbeddings(config=config)
287 | # Evaluate the answered evals with Ragas.
288 | evaluation_df = ragas_evaluate(
289 | dataset=Dataset.from_pandas(answered_evals_df),
290 | llm=lc_llm,
291 | embeddings=embedder,
292 | run_config=RunConfig(max_workers=1),
293 | ).to_pandas()
294 | return evaluation_df
295 |
--------------------------------------------------------------------------------
/src/raglite/_extract.py:
--------------------------------------------------------------------------------
1 | """Extract structured data from unstructured text with an LLM."""
2 |
3 | from typing import Any, TypeVar
4 |
5 | from litellm import completion, get_supported_openai_params # type: ignore[attr-defined]
6 | from pydantic import BaseModel, ValidationError
7 |
8 | from raglite._config import RAGLiteConfig
9 |
10 | T = TypeVar("T", bound=BaseModel)
11 |
12 |
13 | def extract_with_llm(
14 | return_type: type[T],
15 | user_prompt: str | list[str],
16 | strict: bool = False, # noqa: FBT001, FBT002
17 | config: RAGLiteConfig | None = None,
18 | **kwargs: Any,
19 | ) -> T:
20 | """Extract structured data from unstructured text with an LLM.
21 |
22 | This function expects a `return_type.system_prompt: ClassVar[str]` that contains the system
23 | prompt to use. Example:
24 |
25 | from typing import ClassVar
26 | from pydantic import BaseModel, Field
27 |
28 | class MyNameResponse(BaseModel):
29 | my_name: str = Field(..., description="The user's name.")
30 | system_prompt: ClassVar[str] = "The system prompt to use (excluded from JSON schema)."
31 |
32 | my_name_response = extract_with_llm(MyNameResponse, "My name is Thomas A. Anderson.")
33 | """
34 | # Load the default config if not provided.
35 | config = config or RAGLiteConfig()
36 | # Check if the LLM supports the response format.
37 | llm_supports_response_format = "response_format" in (
38 | get_supported_openai_params(model=config.llm) or []
39 | )
40 | # Update the system prompt with the JSON schema of the return type to help the LLM.
41 | system_prompt = getattr(return_type, "system_prompt", "").strip()
42 | if not llm_supports_response_format or config.llm.startswith("llama-cpp-python"):
43 | system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}"
44 | # Constrain the response format to the JSON schema if it's supported by the LLM [1]. Strict mode
45 | # is disabled by default because it only supports a subset of JSON schema features [2].
46 | # [1] https://docs.litellm.ai/docs/completion/json_mode
47 | # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
48 | # TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
49 | response_format: dict[str, Any] | None = (
50 | {
51 | "type": "json_schema",
52 | "json_schema": {
53 | "name": return_type.__name__,
54 | "description": return_type.__doc__ or "",
55 | "schema": return_type.model_json_schema(),
56 | "strict": strict,
57 | },
58 | }
59 | if llm_supports_response_format
60 | else None
61 | )
62 | # Concatenate the user prompt if it is a list of strings.
63 | if isinstance(user_prompt, list):
64 | user_prompt = "\n\n".join(
65 | f'\n{chunk.strip()}\n'
66 | for i, chunk in enumerate(user_prompt)
67 | )
68 | # Extract structured data from the unstructured input.
69 | for _ in range(config.llm_max_tries):
70 | response = completion(
71 | model=config.llm,
72 | messages=[
73 | {"role": "system", "content": system_prompt},
74 | {"role": "user", "content": user_prompt},
75 | ],
76 | response_format=response_format,
77 | **kwargs,
78 | )
79 | try:
80 | instance = return_type.model_validate_json(response["choices"][0]["message"]["content"])
81 | except (KeyError, ValueError, ValidationError) as e:
82 | # Malformed response, not a JSON string, or not a valid instance of the return type.
83 | last_exception = e
84 | continue
85 | else:
86 | break
87 | else:
88 | error_message = f"Failed to extract {return_type} from input {user_prompt}."
89 | raise ValueError(error_message) from last_exception
90 | return instance
91 |
--------------------------------------------------------------------------------
/src/raglite/_insert.py:
--------------------------------------------------------------------------------
1 | """Index documents."""
2 |
3 | from concurrent.futures import ThreadPoolExecutor, as_completed
4 | from contextlib import nullcontext
5 | from functools import partial
6 | from pathlib import Path
7 |
8 | from filelock import FileLock
9 | from sqlalchemy import text
10 | from sqlalchemy.engine import make_url
11 | from sqlmodel import Session, col, select
12 | from tqdm.auto import tqdm
13 |
14 | from raglite._config import RAGLiteConfig
15 | from raglite._database import Chunk, ChunkEmbedding, Document, create_database_engine
16 | from raglite._embed import embed_strings, embed_strings_without_late_chunking, embedding_type
17 | from raglite._split_chunklets import split_chunklets
18 | from raglite._split_chunks import split_chunks
19 | from raglite._split_sentences import split_sentences
20 |
21 |
22 | def _create_chunk_records(
23 | document: Document, config: RAGLiteConfig
24 | ) -> tuple[Document, list[Chunk], list[list[ChunkEmbedding]]]:
25 | """Process chunks into chunk and chunk embedding records."""
26 | # Partition the document into chunks.
27 | assert document.content is not None
28 | sentences = split_sentences(document.content, max_len=config.chunk_max_size)
29 | chunklets = split_chunklets(sentences, max_size=config.chunk_max_size)
30 | chunklet_embeddings = embed_strings(chunklets, config=config)
31 | chunks, chunk_embeddings = split_chunks(
32 | chunklets=chunklets,
33 | chunklet_embeddings=chunklet_embeddings,
34 | max_size=config.chunk_max_size,
35 | )
36 | # Create the chunk records.
37 | chunk_records, headings = [], ""
38 | for i, chunk in enumerate(chunks):
39 | # Create and append the chunk record.
40 | record = Chunk.from_body(
41 | document=document, index=i, body=chunk, headings=headings, **document.metadata_
42 | )
43 | chunk_records.append(record)
44 | # Update the Markdown headings with those of this chunk.
45 | headings = record.extract_headings()
46 | # Create the chunk embedding records.
47 | chunk_embedding_records_list = []
48 | if embedding_type(config=config) == "late_chunking":
49 | # Every chunk record is associated with a list of chunk embedding records, one for each of
50 | # the chunklets in the chunk.
51 | for chunk_record, chunk_embedding in zip(chunk_records, chunk_embeddings, strict=True):
52 | chunk_embedding_records_list.append(
53 | [
54 | ChunkEmbedding(chunk_id=chunk_record.id, embedding=chunklet_embedding)
55 | for chunklet_embedding in chunk_embedding
56 | ]
57 | )
58 | else:
59 | # Embed the full chunks, including the current Markdown headings.
60 | full_chunk_embeddings = embed_strings_without_late_chunking(
61 | [chunk_record.content for chunk_record in chunk_records], config=config
62 | )
63 | # Every chunk record is associated with a list of chunk embedding records. The chunk
64 | # embedding records each correspond to a linear combination of a chunklet embedding and an
65 | # embedding of the full chunk with Markdown headings.
66 | α = 0.15 # Benchmark-optimised value. # noqa: PLC2401
67 | for chunk_record, chunk_embedding, full_chunk_embedding in zip(
68 | chunk_records, chunk_embeddings, full_chunk_embeddings, strict=True
69 | ):
70 | if config.vector_search_multivector:
71 | chunk_embedding_records_list.append(
72 | [
73 | ChunkEmbedding(
74 | chunk_id=chunk_record.id,
75 | embedding=α * chunklet_embedding + (1 - α) * full_chunk_embedding,
76 | )
77 | for chunklet_embedding in chunk_embedding
78 | ]
79 | )
80 | else:
81 | chunk_embedding_records_list.append(
82 | [
83 | ChunkEmbedding(
84 | chunk_id=chunk_record.id,
85 | embedding=full_chunk_embedding,
86 | )
87 | ]
88 | )
89 | return document, chunk_records, chunk_embedding_records_list
90 |
91 |
92 | def insert_documents( # noqa: C901
93 | documents: list[Document],
94 | *,
95 | max_workers: int | None = None,
96 | config: RAGLiteConfig | None = None,
97 | ) -> None:
98 | """Insert documents into the database and update the index.
99 |
100 | Parameters
101 | ----------
102 | documents
103 | A list of documents to insert into the database.
104 | max_workers
105 | The maximum number of worker threads to use.
106 | config
107 | The RAGLite config to use to insert the documents into the database.
108 |
109 | Returns
110 | -------
111 | None
112 | """
113 | # Verify that all documents have content.
114 | if not all(isinstance(doc.content, str) for doc in documents):
115 | error_message = "Some or all documents have missing `document.content`."
116 | raise ValueError(error_message)
117 | # Early exit if no documents are provided.
118 | documents = [doc for doc in documents if doc.content.strip()] # type: ignore[union-attr]
119 | if not documents:
120 | return
121 | # Skip documents that are already in the database.
122 | batch_size = 128
123 | with Session(engine := create_database_engine(config := config or RAGLiteConfig())) as session:
124 | existing_doc_ids: set[str] = set()
125 | for i in range(0, len(documents), batch_size):
126 | doc_id_batch = [doc.id for doc in documents[i : i + batch_size]]
127 | existing_doc_ids.update(
128 | session.exec(select(Document.id).where(col(Document.id).in_(doc_id_batch))).all()
129 | )
130 | documents = [doc for doc in documents if doc.id not in existing_doc_ids]
131 | if not documents:
132 | return
133 | # For DuckDB databases, acquire a lock on the database.
134 | if engine.dialect.name == "duckdb":
135 | db_url = make_url(config.db_url) if isinstance(config.db_url, str) else config.db_url
136 | db_lock = (
137 | FileLock(Path(db_url.database).with_suffix(".lock"))
138 | if db_url.database
139 | else nullcontext()
140 | )
141 | else:
142 | db_lock = nullcontext()
143 | # Create and insert the document, chunk, and chunk embedding records.
144 | with (
145 | db_lock,
146 | Session(engine) as session,
147 | ThreadPoolExecutor(max_workers=max_workers) as executor,
148 | tqdm(
149 | total=len(documents), desc="Inserting documents", unit="document", dynamic_ncols=True
150 | ) as pbar,
151 | ):
152 | futures = [
153 | executor.submit(partial(_create_chunk_records, config=config), doc) for doc in documents
154 | ]
155 | num_unflushed_embeddings = 0
156 | for future in as_completed(futures):
157 | try:
158 | document_record, chunk_records, chunk_embedding_records_list = future.result()
159 | except Exception as e:
160 | executor.shutdown(cancel_futures=True) # Cancel remaining work.
161 | session.rollback() # Cancel uncommitted changes.
162 | error_message = f"Error processing document: {e}"
163 | raise ValueError(error_message) from e
164 | session.add(document_record)
165 | session.add_all(chunk_records)
166 | for chunk_embedding_records in chunk_embedding_records_list:
167 | session.add_all(chunk_embedding_records)
168 | num_unflushed_embeddings += len(chunk_embedding_records)
169 | if num_unflushed_embeddings >= batch_size:
170 | session.flush() # Flush changes to the database.
171 | session.expunge_all() # Release memory of flushed changes.
172 | num_unflushed_embeddings = 0
173 | pbar.update()
174 | session.commit()
175 | if engine.dialect.name == "duckdb":
176 | # DuckDB does not automatically update its keyword search index [1], so we do it
177 | # manually after insertion. Additionally, we re-compact the HNSW index [2]. Finally, we
178 | # synchronize data in the write-ahead log (WAL) to the database data file with the
179 | # CHECKPOINT statement [3].
180 | # [1] https://duckdb.org/docs/stable/extensions/full_text_search
181 | # [2] https://duckdb.org/docs/stable/core_extensions/vss.html#inserts-updates-deletes-and-re-compaction
182 | # [3] https://duckdb.org/docs/stable/sql/statements/checkpoint.html
183 | session.execute(text("PRAGMA create_fts_index('chunk', 'id', 'body', overwrite = 1);"))
184 | if len(documents) >= 8: # noqa: PLR2004
185 | session.execute(text("PRAGMA hnsw_compact_index('vector_search_chunk_index');"))
186 | session.commit()
187 | session.execute(text("CHECKPOINT;"))
188 |
--------------------------------------------------------------------------------
/src/raglite/_lazy_llama.py:
--------------------------------------------------------------------------------
1 | """Import from llama-cpp-python with a lazy ModuleNotFoundError if it's not installed."""
2 |
3 | from importlib import import_module
4 | from typing import TYPE_CHECKING, Any, NoReturn
5 |
6 | # When type checking, import everything normally.
7 | if TYPE_CHECKING:
8 | from llama_cpp import ( # type: ignore[attr-defined]
9 | LLAMA_POOLING_TYPE_NONE,
10 | Llama,
11 | LlamaRAMCache,
12 | llama,
13 | llama_chat_format,
14 | llama_grammar,
15 | llama_supports_gpu_offload,
16 | llama_types,
17 | )
18 |
19 | # Explicitly export these names for static analysis.
20 | __all__ = [
21 | "LLAMA_POOLING_TYPE_NONE",
22 | "Llama",
23 | "LlamaRAMCache",
24 | "llama",
25 | "llama_chat_format",
26 | "llama_grammar",
27 | "llama_supports_gpu_offload",
28 | "llama_types",
29 | ]
30 |
31 |
32 | def __getattr__(name: str) -> object:
33 | """Import from llama-cpp-python with a lazy ModuleNotFoundError if it's not installed."""
34 |
35 | # Create a mock attribute and submodule that lazily raises an ModuleNotFoundError when accessed.
36 | class LazyAttributeError:
37 | error_message = "To use llama.cpp models, please install `llama-cpp-python`."
38 |
39 | def __init__(self, error: ModuleNotFoundError | None = None):
40 | self.error = error
41 |
42 | def __getattr__(self, name: str) -> NoReturn:
43 | raise ModuleNotFoundError(self.error_message) from self.error
44 |
45 | def __call__(self, *args: Any, **kwargs: Any) -> NoReturn:
46 | raise ModuleNotFoundError(self.error_message) from self.error
47 |
48 | class LazySubmoduleError:
49 | def __init__(self, error: ModuleNotFoundError):
50 | self.error = error
51 |
52 | def __getattr__(self, name: str) -> LazyAttributeError | type[LazyAttributeError]:
53 | return LazyAttributeError(self.error) if name == name.lower() else LazyAttributeError
54 |
55 | # Check if the attribute is a submodule.
56 | llama_cpp_submodules = ["llama", "llama_chat_format", "llama_grammar", "llama_types"]
57 | attr_is_submodule = name in llama_cpp_submodules
58 | try:
59 | # Import and return the requested submodule or attribute.
60 | module = import_module(f"llama_cpp.{name}" if attr_is_submodule else "llama_cpp")
61 | return module if attr_is_submodule else getattr(module, name)
62 | except ModuleNotFoundError as import_error:
63 | # Return a mock submodule or attribute that lazily raises an ModuleNotFoundError.
64 | return (
65 | LazySubmoduleError(import_error)
66 | if attr_is_submodule
67 | else LazyAttributeError(import_error)
68 | )
69 |
--------------------------------------------------------------------------------
/src/raglite/_markdown.py:
--------------------------------------------------------------------------------
1 | """Convert any document to Markdown."""
2 |
3 | import re
4 | from copy import deepcopy
5 | from pathlib import Path
6 | from typing import Any
7 |
8 | import numpy as np
9 | from pdftext.extraction import dictionary_output
10 | from sklearn.cluster import KMeans
11 |
12 |
13 | def parsed_pdf_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0915
14 | """Convert a PDF parsed with pdftext to Markdown."""
15 |
16 | def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: # noqa: C901
17 | """Add heading level metadata to a PDF parsed with pdftext."""
18 |
19 | def extract_font_size(span: dict[str, Any]) -> float:
20 | """Extract the font size from a text span."""
21 | font_size: float = 1.0
22 | if span["font"]["size"] > 1: # A value of 1 appears to mean "unknown" in pdftext.
23 | font_size = span["font"]["size"]
24 | elif digit_sequences := re.findall(r"\d+", span["font"]["name"] or ""):
25 | font_size = float(digit_sequences[-1])
26 | elif "\n" not in span["text"]: # Occasionally a span can contain a newline character.
27 | if round(span["rotation"]) in (0.0, 180.0, -180.0):
28 | font_size = span["bbox"][3] - span["bbox"][1]
29 | elif round(span["rotation"]) in (90.0, -90.0, 270.0, -270.0):
30 | font_size = span["bbox"][2] - span["bbox"][0]
31 | return font_size
32 |
33 | # Copy the pages.
34 | pages = deepcopy(pages)
35 | # Extract an array of all font sizes used by the text spans.
36 | font_sizes = np.asarray(
37 | [
38 | extract_font_size(span)
39 | for page in pages
40 | for block in page["blocks"]
41 | for line in block["lines"]
42 | for span in line["spans"]
43 | ]
44 | )
45 | font_sizes = np.round(font_sizes * 2) / 2
46 | unique_font_sizes, counts = np.unique(font_sizes, return_counts=True)
47 | # Determine the paragraph font size as the mode font size.
48 | tiny = unique_font_sizes < min(5, np.max(unique_font_sizes))
49 | counts[tiny] = -counts[tiny]
50 | mode = np.argmax(counts)
51 | counts[tiny] = -counts[tiny]
52 | mode_font_size = unique_font_sizes[mode]
53 | # Determine (at most) 6 heading font sizes by clustering font sizes larger than the mode.
54 | heading_font_sizes = unique_font_sizes[mode + 1 :]
55 | if len(heading_font_sizes) > 0:
56 | heading_counts = counts[mode + 1 :]
57 | kmeans = KMeans(n_clusters=min(6, len(heading_font_sizes)), random_state=42)
58 | kmeans.fit(heading_font_sizes[:, np.newaxis], sample_weight=heading_counts)
59 | heading_font_sizes = np.sort(np.ravel(kmeans.cluster_centers_))[::-1]
60 | # Add heading level information to the text spans and lines.
61 | for page in pages:
62 | for block in page["blocks"]:
63 | for line in block["lines"]:
64 | if "md" not in line:
65 | line["md"] = {}
66 | heading_level = np.zeros(8) # 0-5:
-, 6:
, 7:
67 | for span in line["spans"]:
68 | if "md" not in span:
69 | span["md"] = {}
70 | span_font_size = extract_font_size(span)
71 | if span_font_size < mode_font_size:
72 | idx = 7
73 | elif span_font_size == mode_font_size:
74 | idx = 6
75 | else:
76 | idx = np.argmin(np.abs(heading_font_sizes - span_font_size)) # type: ignore[assignment]
77 | span["md"]["heading_level"] = idx + 1
78 | heading_level[idx] += len(span["text"])
79 | line["md"]["heading_level"] = np.argmax(heading_level) + 1
80 | return pages
81 |
82 | def add_emphasis_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]:
83 | """Add emphasis metadata such as bold and italic to a PDF parsed with pdftext."""
84 | # Copy the pages.
85 | pages = deepcopy(pages)
86 | # Add emphasis metadata to the text spans.
87 | for page in pages:
88 | for block in page["blocks"]:
89 | for line in block["lines"]:
90 | if "md" not in line:
91 | line["md"] = {}
92 | for span in line["spans"]:
93 | if "md" not in span:
94 | span["md"] = {}
95 | span["md"]["bold"] = span["font"]["weight"] > 500 # noqa: PLR2004
96 | span["md"]["italic"] = "ital" in (span["font"]["name"] or "").lower()
97 | line["md"]["bold"] = all(
98 | span["md"]["bold"] for span in line["spans"] if span["text"].strip()
99 | )
100 | line["md"]["italic"] = all(
101 | span["md"]["italic"] for span in line["spans"] if span["text"].strip()
102 | )
103 | return pages
104 |
105 | def strip_page_numbers(pages: list[dict[str, Any]]) -> list[dict[str, Any]]:
106 | """Strip page numbers from a PDF parsed with pdftext."""
107 | # Copy the pages.
108 | pages = deepcopy(pages)
109 | # Remove lines that only contain a page number.
110 | for page in pages:
111 | for block in page["blocks"]:
112 | block["lines"] = [
113 | line
114 | for line in block["lines"]
115 | if not re.match(
116 | r"^\s*[#0]*\d+\s*$", "".join(span["text"] for span in line["spans"])
117 | )
118 | ]
119 | return pages
120 |
121 | def convert_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0912
122 | """Convert a list of pages to Markdown."""
123 | pages_md = []
124 | for page in pages:
125 | page_md = ""
126 | for block in page["blocks"]:
127 | block_text = ""
128 | for line in block["lines"]:
129 | # Build the line text and style the spans.
130 | line_text = ""
131 | for span in line["spans"]:
132 | if (
133 | not line["md"]["bold"]
134 | and not line["md"]["italic"]
135 | and span["md"]["bold"]
136 | and span["md"]["italic"]
137 | ):
138 | line_text += f"***{span['text']}***"
139 | elif not line["md"]["bold"] and span["md"]["bold"]:
140 | line_text += f"**{span['text']}**"
141 | elif not line["md"]["italic"] and span["md"]["italic"]:
142 | line_text += f"*{span['text']}*"
143 | else:
144 | line_text += span["text"]
145 | # Add emphasis to the line (if it's not a heading or whitespace).
146 | line_text = line_text.rstrip()
147 | line_is_whitespace = not line_text.strip()
148 | line_is_heading = line["md"]["heading_level"] <= 6 # noqa: PLR2004
149 | if not line_is_heading and not line_is_whitespace:
150 | if line["md"]["bold"] and line["md"]["italic"]:
151 | line_text = f"***{line_text}***"
152 | elif line["md"]["bold"]:
153 | line_text = f"**{line_text}**"
154 | elif line["md"]["italic"]:
155 | line_text = f"*{line_text}*"
156 | # Set the heading level.
157 | if line_is_heading and not line_is_whitespace:
158 | line_text = f"{'#' * line['md']['heading_level']} {line_text}"
159 | line_text += "\n"
160 | block_text += line_text
161 | block_text = block_text.rstrip() + "\n\n"
162 | page_md += block_text
163 | pages_md.append(page_md.strip())
164 | return pages_md
165 |
166 | def merge_split_headings(pages: list[str]) -> list[str]:
167 | """Merge headings that are split across lines."""
168 |
169 | def _merge_split_headings(match: re.Match[str]) -> str:
170 | atx_headings = [line.strip("# ").strip() for line in match.group().splitlines()]
171 | return f"{match.group(1)} {' '.join(atx_headings)}\n\n"
172 |
173 | pages_md = [
174 | re.sub(
175 | r"^(#+)[ \t]+[^\n]+\n+(?:^\1[ \t]+[^\n]+\n+)+",
176 | _merge_split_headings,
177 | page,
178 | flags=re.MULTILINE,
179 | )
180 | for page in pages
181 | ]
182 | return pages_md
183 |
184 | # Add heading level metadata.
185 | pages = add_heading_level_metadata(pages)
186 | # Add emphasis metadata.
187 | pages = add_emphasis_metadata(pages)
188 | # Strip page numbers.
189 | pages = strip_page_numbers(pages)
190 | # Convert the pages to Markdown.
191 | pages_md = convert_to_markdown(pages)
192 | # Merge headings that are split across lines.
193 | pages_md = merge_split_headings(pages_md)
194 | return pages_md
195 |
196 |
197 | def document_to_markdown(doc_path: Path) -> str:
198 | """Convert any document to GitHub Flavored Markdown."""
199 | # Convert the file's content to GitHub Flavored Markdown.
200 | if doc_path.suffix == ".pdf":
201 | # Parse the PDF with pdftext and convert it to Markdown.
202 | pages = dictionary_output(doc_path, sort=True, keep_chars=False)
203 | doc = "\n\n".join(parsed_pdf_to_markdown(pages))
204 | elif doc_path.suffix in (".md", ".txt"):
205 | # Read the Markdown file.
206 | doc = doc_path.read_text()
207 | else:
208 | try:
209 | # Use pandoc for everything else.
210 | import pypandoc
211 |
212 | doc = pypandoc.convert_file(doc_path, to="gfm")
213 | except ModuleNotFoundError as error:
214 | error_message = (
215 | "To convert files to Markdown with pandoc, please install the `pandoc` extra."
216 | )
217 | raise ModuleNotFoundError(error_message) from error
218 | except RuntimeError:
219 | # File format not supported, fall back to reading the text.
220 | doc = doc_path.read_text()
221 | return doc
222 |
--------------------------------------------------------------------------------
/src/raglite/_mcp.py:
--------------------------------------------------------------------------------
1 | """MCP server for RAGLite."""
2 |
3 | from typing import Annotated, Any
4 |
5 | from fastmcp import FastMCP
6 | from pydantic import Field
7 |
8 | from raglite._config import RAGLiteConfig
9 | from raglite._rag import add_context, retrieve_context
10 |
11 | Query = Annotated[
12 | str,
13 | Field(
14 | description=(
15 | "The `query` string MUST be a precise single-faceted question in the user's language.\n"
16 | "The `query` string MUST resolve all pronouns to explicit nouns."
17 | )
18 | ),
19 | ]
20 |
21 |
22 | def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP[Any]:
23 | """Create a RAGLite MCP server."""
24 | mcp: FastMCP[Any] = FastMCP(server_name)
25 |
26 | @mcp.prompt()
27 | def kb(query: Query) -> str:
28 | """Answer a question with information from the knowledge base."""
29 | chunk_spans = retrieve_context(query, config=config)
30 | rag_instruction = add_context(query, chunk_spans)
31 | return rag_instruction["content"]
32 |
33 | @mcp.tool()
34 | def search_knowledge_base(query: Query) -> str:
35 | """Search the knowledge base.
36 |
37 | IMPORTANT: You MAY NOT use this function if the question can be answered with common
38 | knowledge or straightforward reasoning. For multi-faceted questions, call this function once
39 | for each facet.
40 | """
41 | chunk_spans = retrieve_context(query, config=config)
42 | rag_context = '{{"documents": [{elements}]}}'.format(
43 | elements=", ".join(
44 | chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans)
45 | )
46 | )
47 | return rag_context
48 |
49 | # Warm up the querying pipeline.
50 | if config.embedder.startswith("llama-cpp-python"):
51 | _ = retrieve_context("Hello world", config=config)
52 |
53 | return mcp
54 |
--------------------------------------------------------------------------------
/src/raglite/_query_adapter.py:
--------------------------------------------------------------------------------
1 | """Compute and update an optimal query adapter."""
2 |
3 | # ruff: noqa: N806
4 |
5 | from dataclasses import replace
6 |
7 | import numpy as np
8 | from scipy.optimize import lsq_linear
9 | from sqlalchemy import text
10 | from sqlalchemy.orm.attributes import flag_modified
11 | from sqlmodel import Session, col, select
12 | from tqdm.auto import tqdm
13 |
14 | from raglite._config import RAGLiteConfig
15 | from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine
16 | from raglite._embed import embed_strings
17 | from raglite._search import vector_search
18 | from raglite._typing import FloatMatrix, FloatVector
19 |
20 |
21 | def _optimize_query_target(
22 | q: FloatVector,
23 | P: FloatMatrix, # noqa: N803,
24 | N: FloatMatrix, # noqa: N803,
25 | *,
26 | α: float = 0.05, # noqa: PLC2401
27 | ) -> FloatVector:
28 | # Convert to double precision for the optimizer.
29 | q_dtype = q.dtype
30 | q, P, N = q.astype(np.float64), P.astype(np.float64), N.astype(np.float64)
31 | # Construct the constraint matrix D := P - (1 + α) * N. # noqa: RUF003
32 | D = np.reshape(P[:, np.newaxis, :] - (1.0 + α) * N[np.newaxis, :, :], (-1, P.shape[1]))
33 | # Solve the dual problem min_μ ½ ‖q + Dᵀ μ‖² s.t. μ ≥ 0.
34 | A, b = D.T, -q
35 | μ_star = lsq_linear(A, b, bounds=(0.0, np.inf), tol=np.finfo(A.dtype).eps).x # noqa: PLC2401
36 | # Recover the primal solution q* = q + Dᵀ μ*.
37 | q_star: FloatVector = (q + D.T @ μ_star).astype(q_dtype)
38 | return q_star
39 |
40 |
41 | def update_query_adapter(
42 | *,
43 | max_evals: int = 4096,
44 | optimize_top_k: int = 40,
45 | optimize_gap: float = 0.05,
46 | config: RAGLiteConfig | None = None,
47 | ) -> FloatMatrix:
48 | """Compute an optimal query adapter and update the database with it.
49 |
50 | This function computes an optimal linear transform A, called a 'query adapter', that is used to
51 | transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in
52 | order to improve the quality of the search results.
53 |
54 | Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
55 | score pᵢᵀqᵢ of the positive chunk pᵢ and decreases the score nᵢᵀqᵢ of the negative chunk nᵢ.
56 |
57 | If the nearest neighbour search uses the dot product as its relevance score, we can find the
58 | optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
59 | bound on the Frobenius norm of A:
60 |
61 | A* := argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ)
62 | Σᵢ (pᵢ - nᵢ)ᵀ A qᵢ
63 | trace[ (P - N) A Qᵀ ] where Q := [q₁ᵀ; ...; qₖᵀ]
64 | P := [p₁ᵀ; ...; pₖᵀ]
65 | N := [n₁ᵀ; ...; nₖᵀ]
66 | trace[ Qᵀ (P - N) A ]
67 | trace[ Mᵀ A ] where M := (P - N)ᵀ Q
68 | s.t. ||A||_F == 1
69 | = M / ||M||_F
70 |
71 | If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
72 | the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
73 | [1] with an orthogonality constraint on A:
74 |
75 | A* := argmax Σᵢ pᵢᵀ (A qᵢ) - nᵢᵀ (A qᵢ)
76 | Σᵢ (pᵢ - nᵢ)ᵀ A qᵢ
77 | trace[ (P - N) A Qᵀ ]
78 | trace[ Qᵀ (P - N) A ]
79 | trace[ Mᵀ A ]
80 | trace[ (U Σ V)ᵀ A ] where U Σ Vᵀ := M is the SVD of M
81 | trace[ Σ V A Uᵀ ]
82 | s.t. AᵀA == 𝕀
83 | = U Vᵀ
84 |
85 | The action of A* is to map a query embedding qᵢ to a target vector t := (pᵢ - nᵢ) that maximally
86 | separates the positive and negative chunks. For a given query embedding qᵢ, a retrieval method
87 | will yield a result set containing both positive and negative chunks. Instead of extracting
88 | multiple triplets (qᵢ, pᵢ, nᵢ) from each such result set, we can compute a single optimal target
89 | vector t* for the query embedding qᵢ as follows:
90 |
91 | t* := argmax ½ ||t - qᵢ||²
92 | s.t. Dᵢ t >= 0
93 |
94 | where the constraint matrix Dᵢ := [pₘᵀ - (1 + α) nₙᵀ]ₘₙ comprises all pairs of positive and
95 | negative chunk embeddings in the result set corresponding to the query embedding qᵢ. This
96 | optimisation problem expresses the idea that the target vector t* should be as close as
97 | possible to the query embedding qᵢ, while separating all positive and negative chunk embeddings
98 | in the result set by a margin of at least α. To solve this problem, we'll first introduce
99 | a Lagrangian with Lagrange multipliers μ:
100 |
101 | L(t, μ) := ½ ||t - qᵢ||² + μᵀ (-Dᵢ t)
102 |
103 | Now we can set the gradient of the Lagrangian to zero to find the optimal target vector t*:
104 |
105 | ∇ₜL = t - qᵢ - Dᵢᵀ μ = 0
106 | t* = qᵢ + Dᵢᵀ μ*
107 |
108 | where μ* is the solution to the dual nonnegative least squares problem
109 |
110 | μ* := argmin ½ ||qᵢ + Dᵢᵀ μ||²
111 | s.t. μ >= 0
112 |
113 | Parameters
114 | ----------
115 | max_evals
116 | The maximum number of evals to use to compute the query adapter. Each eval corresponds to a
117 | rank-one update of the query adapter A.
118 | optimize_top_k
119 | The number of search results per eval to optimize.
120 | optimize_gap
121 | The strength of the query adapter, expressed as a nonnegative number. Should be large enough
122 | to correct incorrectly ranked results, but small enough to not affect correctly ranked
123 | results.
124 | config
125 | The RAGLite config to use to construct and store the query adapter.
126 |
127 | Raises
128 | ------
129 | ValueError
130 | If no documents have been inserted into the database yet.
131 | ValueError
132 | If no evals have been inserted into the database yet.
133 | ValueError
134 | If the `config.vector_search_distance_metric` is not supported.
135 |
136 | Returns
137 | -------
138 | FloatMatrix
139 | The query adapter.
140 | """
141 | config = config or RAGLiteConfig()
142 | config_no_query_adapter = replace(config, vector_search_query_adapter=False)
143 | with Session(engine := create_database_engine(config)) as session:
144 | # Get random evals from the database.
145 | chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first()
146 | if chunk_embedding is None:
147 | error_message = "First run `insert_documents()` to insert documents."
148 | raise ValueError(error_message)
149 | evals = session.exec(select(Eval).order_by(Eval.id).limit(max_evals)).all()
150 | if len(evals) == 0:
151 | error_message = "First run `insert_evals()` to generate evals."
152 | raise ValueError(error_message)
153 | # Construct the query and target matrices.
154 | Q = np.zeros((0, len(chunk_embedding.embedding)))
155 | T = np.zeros_like(Q)
156 | for eval_ in tqdm(
157 | evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False
158 | ):
159 | # Embed the question.
160 | q = embed_strings([eval_.question], config=config)[0]
161 | # Retrieve chunks that would be used to answer the question.
162 | chunk_ids, _ = vector_search(
163 | q, num_results=optimize_top_k, config=config_no_query_adapter
164 | )
165 | retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
166 | retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id))
167 | # Skip this eval if it doesn't contain both relevant and irrelevant chunks.
168 | is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks])
169 | if not np.any(is_relevant) or not np.any(~is_relevant):
170 | continue
171 | # Extract the positive and negative chunk embeddings.
172 | P = np.vstack(
173 | [
174 | chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
175 | for chunk in np.array(retrieved_chunks)[is_relevant]
176 | ]
177 | )
178 | N = np.vstack(
179 | [
180 | chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
181 | for chunk in np.array(retrieved_chunks)[~is_relevant]
182 | ]
183 | )
184 | # Compute the optimal target vector t for this query embedding q.
185 | t = _optimize_query_target(q, P, N, α=optimize_gap)
186 | Q = np.vstack([Q, q[np.newaxis, :]])
187 | T = np.vstack([T, t[np.newaxis, :]])
188 | # Normalise the rows of Q and T.
189 | Q /= np.linalg.norm(Q, axis=1, keepdims=True)
190 | if config.vector_search_distance_metric == "cosine":
191 | T /= np.linalg.norm(T, axis=1, keepdims=True)
192 | # Compute the optimal unconstrained query adapter M.
193 | n, d = Q.shape
194 | M = (1 / n) * T.T @ Q
195 | if n < d or np.linalg.matrix_rank(Q) < d:
196 | M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q
197 | # Compute the optimal constrained query adapter A* from M, given the distance metric.
198 | A_star: FloatMatrix
199 | if config.vector_search_distance_metric == "dot":
200 | # Use the relaxed Procrustes solution.
201 | A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d)
202 | elif config.vector_search_distance_metric == "cosine":
203 | # Use the orthogonal Procrustes solution.
204 | U, _, VT = np.linalg.svd(M, full_matrices=False)
205 | A_star = U @ VT
206 | else:
207 | error_message = f"Unsupported metric: {config.vector_search_distance_metric}"
208 | raise ValueError(error_message)
209 | # Store the optimal query adapter in the database.
210 | index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
211 | index_metadata.metadata_["query_adapter"] = A_star
212 | flag_modified(index_metadata, "metadata_")
213 | session.add(index_metadata)
214 | session.commit()
215 | if engine.dialect.name == "duckdb":
216 | session.execute(text("CHECKPOINT;"))
217 | # Clear the index metadata cache to allow the new query adapter to be used.
218 | IndexMetadata._get.cache_clear() # noqa: SLF001
219 | return A_star
220 |
--------------------------------------------------------------------------------
/src/raglite/_rag.py:
--------------------------------------------------------------------------------
1 | """Retrieval-augmented generation."""
2 |
3 | import json
4 | from collections.abc import AsyncIterator, Callable, Iterator
5 | from typing import Any
6 |
7 | import numpy as np
8 | from litellm import ( # type: ignore[attr-defined]
9 | ChatCompletionMessageToolCall,
10 | acompletion,
11 | completion,
12 | stream_chunk_builder,
13 | supports_function_calling,
14 | )
15 |
16 | from raglite._config import RAGLiteConfig
17 | from raglite._database import Chunk, ChunkSpan
18 | from raglite._litellm import get_context_size
19 | from raglite._search import retrieve_chunk_spans
20 |
21 | # The default RAG instruction template follows Anthropic's best practices [1].
22 | # [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
23 | RAG_INSTRUCTION_TEMPLATE = """
24 | ---
25 | You are a friendly and knowledgeable assistant that provides complete and insightful answers.
26 | Whenever possible, use only the provided context to respond to the question at the end.
27 | When responding, you MUST NOT reference the existence of the context, directly or indirectly.
28 | Instead, you MUST treat the context as if its contents are entirely part of your working memory.
29 | ---
30 |
31 |
32 | {context}
33 |
34 |
35 | {user_prompt}
36 | """.strip()
37 |
38 |
39 | def retrieve_context(
40 | query: str, *, num_chunks: int = 10, config: RAGLiteConfig | None = None
41 | ) -> list[ChunkSpan]:
42 | """Retrieve context for RAG."""
43 | # Call the search method.
44 | config = config or RAGLiteConfig()
45 | results = config.search_method(query, num_results=num_chunks, config=config)
46 | # Convert results to chunk spans.
47 | chunk_spans = []
48 | if isinstance(results, tuple):
49 | chunk_spans = retrieve_chunk_spans(results[0], config=config)
50 | elif all(isinstance(result, Chunk) for result in results):
51 | chunk_spans = retrieve_chunk_spans(results, config=config) # type: ignore[arg-type]
52 | elif all(isinstance(result, ChunkSpan) for result in results):
53 | chunk_spans = results # type: ignore[assignment]
54 | return chunk_spans
55 |
56 |
57 | def add_context(
58 | user_prompt: str,
59 | context: list[ChunkSpan],
60 | *,
61 | rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE,
62 | ) -> dict[str, str]:
63 | """Convert a user prompt to a RAG instruction.
64 |
65 | The RAG instruction's format follows Anthropic's best practices [1].
66 |
67 | [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
68 | """
69 | message = {
70 | "role": "user",
71 | "content": rag_instruction_template.format(
72 | context="\n".join(
73 | chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context)
74 | ),
75 | user_prompt=user_prompt.strip(),
76 | ),
77 | }
78 | return message
79 |
80 |
81 | def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]:
82 | """Left clip a messages array to avoid hitting the context limit."""
83 | cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1])
84 | first_message = -np.searchsorted(cum_tokens, max_tokens)
85 | return messages[first_message:]
86 |
87 |
88 | def _get_tools(
89 | messages: list[dict[str, str]], config: RAGLiteConfig
90 | ) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]:
91 | """Get tools to search the knowledge base if no RAG context is provided in the messages."""
92 | # Check if messages already contain RAG context or if the LLM supports tool use.
93 | final_message = messages[-1].get("content", "")
94 | messages_contain_rag_context = any(
95 | s in final_message for s in ("", "", "from_chunk_id")
96 | )
97 | llm_supports_function_calling = supports_function_calling(config.llm)
98 | if not messages_contain_rag_context and not llm_supports_function_calling:
99 | error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling."
100 | raise ValueError(error_message)
101 | # Return a single tool to search the knowledge base if no RAG context is provided.
102 | tools: list[dict[str, Any]] | None = (
103 | [
104 | {
105 | "type": "function",
106 | "function": {
107 | "name": "search_knowledge_base",
108 | "description": (
109 | "Search the knowledge base.\n"
110 | "IMPORTANT: You MAY NOT use this function if the question can be answered with common knowledge or straightforward reasoning.\n"
111 | "For multi-faceted questions, call this function once for each facet."
112 | ),
113 | "parameters": {
114 | "type": "object",
115 | "properties": {
116 | "query": {
117 | "type": "string",
118 | "description": (
119 | "The `query` string MUST be a precise single-faceted question in the user's language.\n"
120 | "The `query` string MUST resolve all pronouns to explicit nouns."
121 | ),
122 | },
123 | },
124 | "required": ["query"],
125 | "additionalProperties": False,
126 | },
127 | },
128 | }
129 | ]
130 | if not messages_contain_rag_context
131 | else None
132 | )
133 | tool_choice: dict[str, Any] | str | None = "auto" if tools else None
134 | return tools, tool_choice
135 |
136 |
137 | def _run_tools(
138 | tool_calls: list[ChatCompletionMessageToolCall],
139 | on_retrieval: Callable[[list[ChunkSpan]], None] | None,
140 | config: RAGLiteConfig,
141 | ) -> list[dict[str, Any]]:
142 | """Run tools to search the knowledge base for RAG context."""
143 | tool_messages: list[dict[str, Any]] = []
144 | for tool_call in tool_calls:
145 | if tool_call.function.name == "search_knowledge_base":
146 | kwargs = json.loads(tool_call.function.arguments)
147 | kwargs["config"] = config
148 | chunk_spans = retrieve_context(**kwargs)
149 | tool_messages.append(
150 | {
151 | "role": "tool",
152 | "content": '{{"documents": [{elements}]}}'.format(
153 | elements=", ".join(
154 | chunk_span.to_json(index=i + 1)
155 | for i, chunk_span in enumerate(chunk_spans)
156 | )
157 | ),
158 | "tool_call_id": tool_call.id,
159 | }
160 | )
161 | if chunk_spans and callable(on_retrieval):
162 | on_retrieval(chunk_spans)
163 | else:
164 | error_message = f"Unknown function `{tool_call.function.name}`."
165 | raise ValueError(error_message)
166 | return tool_messages
167 |
168 |
169 | def rag(
170 | messages: list[dict[str, str]],
171 | *,
172 | on_retrieval: Callable[[list[ChunkSpan]], None] | None = None,
173 | config: RAGLiteConfig,
174 | ) -> Iterator[str]:
175 | # If the final message does not contain RAG context, get a tool to search the knowledge base.
176 | max_tokens = get_context_size(config)
177 | tools, tool_choice = _get_tools(messages, config)
178 | # Stream the LLM response, which is either a tool call request or an assistant response.
179 | stream = completion(
180 | model=config.llm,
181 | messages=_clip(messages, max_tokens),
182 | tools=tools,
183 | tool_choice=tool_choice,
184 | stream=True,
185 | )
186 | chunks = []
187 | for chunk in stream:
188 | chunks.append(chunk)
189 | if isinstance(token := chunk.choices[0].delta.content, str):
190 | yield token
191 | # Check if there are tools to be called.
192 | response = stream_chunk_builder(chunks, messages)
193 | tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr]
194 | if tool_calls:
195 | # Add the tool call request to the message array.
196 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
197 | # Run the tool calls to retrieve the RAG context and append the output to the message array.
198 | messages.extend(_run_tools(tool_calls, on_retrieval, config))
199 | # Stream the assistant response.
200 | chunks = []
201 | stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True)
202 | for chunk in stream:
203 | chunks.append(chunk)
204 | if isinstance(token := chunk.choices[0].delta.content, str):
205 | yield token
206 | # Append the assistant response to the message array.
207 | response = stream_chunk_builder(chunks, messages)
208 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
209 |
210 |
211 | async def async_rag(
212 | messages: list[dict[str, str]],
213 | *,
214 | on_retrieval: Callable[[list[ChunkSpan]], None] | None = None,
215 | config: RAGLiteConfig,
216 | ) -> AsyncIterator[str]:
217 | # If the final message does not contain RAG context, get a tool to search the knowledge base.
218 | max_tokens = get_context_size(config)
219 | tools, tool_choice = _get_tools(messages, config)
220 | # Asynchronously stream the LLM response, which is either a tool call or an assistant response.
221 | async_stream = await acompletion(
222 | model=config.llm,
223 | messages=_clip(messages, max_tokens),
224 | tools=tools,
225 | tool_choice=tool_choice,
226 | stream=True,
227 | )
228 | chunks = []
229 | async for chunk in async_stream:
230 | chunks.append(chunk)
231 | if isinstance(token := chunk.choices[0].delta.content, str):
232 | yield token
233 | # Check if there are tools to be called.
234 | response = stream_chunk_builder(chunks, messages)
235 | tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr]
236 | if tool_calls:
237 | # Add the tool call requests to the message array.
238 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
239 | # Run the tool calls to retrieve the RAG context and append the output to the message array.
240 | # TODO: Make this async.
241 | messages.extend(_run_tools(tool_calls, on_retrieval, config))
242 | # Asynchronously stream the assistant response.
243 | chunks = []
244 | async_stream = await acompletion(
245 | model=config.llm, messages=_clip(messages, max_tokens), stream=True
246 | )
247 | async for chunk in async_stream:
248 | chunks.append(chunk)
249 | if isinstance(token := chunk.choices[0].delta.content, str):
250 | yield token
251 | # Append the assistant response to the message array.
252 | response = stream_chunk_builder(chunks, messages)
253 | messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
254 |
--------------------------------------------------------------------------------
/src/raglite/_search.py:
--------------------------------------------------------------------------------
1 | """Search and retrieve chunks."""
2 |
3 | import contextlib
4 | import re
5 | import string
6 | from collections import defaultdict
7 | from itertools import groupby
8 |
9 | import numpy as np
10 | from langdetect import LangDetectException, detect
11 | from sqlalchemy.orm import joinedload
12 | from sqlmodel import Session, and_, col, func, or_, select, text
13 |
14 | from raglite._config import RAGLiteConfig
15 | from raglite._database import (
16 | Chunk,
17 | ChunkEmbedding,
18 | ChunkSpan,
19 | IndexMetadata,
20 | create_database_engine,
21 | )
22 | from raglite._embed import embed_strings
23 | from raglite._typing import BasicSearchMethod, ChunkId, FloatVector
24 |
25 |
26 | def vector_search(
27 | query: str | FloatVector,
28 | *,
29 | num_results: int = 3,
30 | oversample: int = 4,
31 | config: RAGLiteConfig | None = None,
32 | ) -> tuple[list[ChunkId], list[float]]:
33 | """Search chunks using ANN vector search."""
34 | # Read the config.
35 | config = config or RAGLiteConfig()
36 | # Embed the query.
37 | query_embedding = (
38 | embed_strings([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query)
39 | )
40 | # Apply the query adapter to the query embedding.
41 | if (
42 | config.vector_search_query_adapter
43 | and (Q := IndexMetadata.get("default", config=config).get("query_adapter")) is not None # noqa: N806
44 | ):
45 | query_embedding = (Q @ query_embedding).astype(query_embedding.dtype)
46 | # Rank the chunks by relevance according to the L∞ norm of the similarities of the multi-vector
47 | # chunk embeddings to the query embedding with a single query.
48 | with Session(create_database_engine(config)) as session:
49 | corrected_oversample = oversample * config.chunk_max_size / RAGLiteConfig.chunk_max_size
50 | num_hits = round(corrected_oversample) * max(num_results, 10)
51 | dist = ChunkEmbedding.embedding.distance( # type: ignore[attr-defined]
52 | query_embedding, metric=config.vector_search_distance_metric
53 | ).label("dist")
54 | sim = (1.0 - dist).label("sim")
55 | top_vectors = select(ChunkEmbedding.chunk_id, sim).order_by(dist).limit(num_hits).subquery()
56 | sim_norm = func.max(top_vectors.c.sim).label("sim_norm")
57 | statement = (
58 | select(top_vectors.c.chunk_id, sim_norm)
59 | .group_by(top_vectors.c.chunk_id)
60 | .order_by(sim_norm.desc())
61 | .limit(num_results)
62 | )
63 | rows = session.exec(statement).all()
64 | chunk_ids = [row[0] for row in rows]
65 | similarity = [float(row[1]) for row in rows]
66 | return chunk_ids, similarity
67 |
68 |
69 | def keyword_search(
70 | query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
71 | ) -> tuple[list[ChunkId], list[float]]:
72 | """Search chunks using BM25 keyword search."""
73 | # Read the config.
74 | config = config or RAGLiteConfig()
75 | # Connect to the database.
76 | with Session(create_database_engine(config)) as session:
77 | dialect = session.get_bind().dialect.name
78 | if dialect == "postgresql":
79 | # Convert the query to a tsquery [1].
80 | # [1] https://www.postgresql.org/docs/current/textsearch-controls.html
81 | query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", " ", query)
82 | tsv_query = " | ".join(query_escaped.split())
83 | # Perform keyword search with tsvector.
84 | statement = text("""
85 | SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score
86 | FROM chunk
87 | WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query)
88 | ORDER BY score DESC
89 | LIMIT :limit;
90 | """)
91 | results = session.execute(statement, params={"query": tsv_query, "limit": num_results})
92 | elif dialect == "duckdb":
93 | statement = text(
94 | """
95 | SELECT chunk_id, score
96 | FROM (
97 | SELECT id AS chunk_id, fts_main_chunk.match_bm25(id, :query) AS score
98 | FROM chunk
99 | ) sq
100 | WHERE score IS NOT NULL
101 | ORDER BY score DESC
102 | LIMIT :limit;
103 | """
104 | )
105 | results = session.execute(statement, params={"query": query, "limit": num_results})
106 | # Unpack the results.
107 | results = list(results) # type: ignore[assignment]
108 | chunk_ids = [result.chunk_id for result in results]
109 | keyword_score = [result.score for result in results]
110 | return chunk_ids, keyword_score
111 |
112 |
113 | def reciprocal_rank_fusion(
114 | rankings: list[list[ChunkId]], *, k: int = 60, weights: list[float] | None = None
115 | ) -> tuple[list[ChunkId], list[float]]:
116 | """Reciprocal Rank Fusion."""
117 | if weights is None:
118 | weights = [1.0] * len(rankings)
119 | if len(weights) != len(rankings):
120 | error = "The number of weights must match the number of rankings."
121 | raise ValueError(error)
122 | # Compute the RRF score.
123 | chunk_id_score: defaultdict[str, float] = defaultdict(float)
124 | for ranking, weight in zip(rankings, weights, strict=True):
125 | for i, chunk_id in enumerate(ranking):
126 | chunk_id_score[chunk_id] += weight / (k + i)
127 | # Exit early if there are no results to fuse.
128 | if not chunk_id_score:
129 | return [], []
130 | # Rank RRF results according to descending RRF score.
131 | rrf_chunk_ids, rrf_score = zip(
132 | *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True
133 | )
134 | return list(rrf_chunk_ids), list(rrf_score)
135 |
136 |
137 | def hybrid_search( # noqa: PLR0913
138 | query: str,
139 | *,
140 | num_results: int = 3,
141 | oversample: int = 2,
142 | vector_search_weight: float = 0.75,
143 | keyword_search_weight: float = 0.25,
144 | config: RAGLiteConfig | None = None,
145 | ) -> tuple[list[ChunkId], list[float]]:
146 | """Search chunks by combining ANN vector search with BM25 keyword search."""
147 | # Run both searches.
148 | vs_chunk_ids, _ = vector_search(query, num_results=oversample * num_results, config=config)
149 | ks_chunk_ids, _ = keyword_search(query, num_results=oversample * num_results, config=config)
150 | # Combine the results with Reciprocal Rank Fusion (RRF).
151 | chunk_ids, hybrid_score = reciprocal_rank_fusion(
152 | [vs_chunk_ids, ks_chunk_ids], weights=[vector_search_weight, keyword_search_weight]
153 | )
154 | chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results]
155 | return chunk_ids, hybrid_score
156 |
157 |
158 | def retrieve_chunks(
159 | chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None
160 | ) -> list[Chunk]:
161 | """Retrieve chunks by their ids."""
162 | if not chunk_ids:
163 | return []
164 | with Session(create_database_engine(config := config or RAGLiteConfig())) as session:
165 | chunks = list(
166 | session.exec(
167 | select(Chunk)
168 | .where(col(Chunk.id).in_(chunk_ids))
169 | # Eagerly load chunk.document.
170 | .options(joinedload(Chunk.document)) # type: ignore[arg-type]
171 | ).all()
172 | )
173 | chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id))
174 | return chunks
175 |
176 |
177 | def retrieve_chunk_spans(
178 | chunk_ids: list[ChunkId] | list[Chunk],
179 | *,
180 | neighbors: tuple[int, ...] | None = (-1, 1),
181 | config: RAGLiteConfig | None = None,
182 | ) -> list[ChunkSpan]:
183 | """Group chunks into contiguous chunk spans and retrieve them.
184 |
185 | Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as
186 | determined by the order in which they are provided to this function.
187 | """
188 | # Exit early if the input is empty.
189 | if not chunk_ids:
190 | return []
191 | # Retrieve the chunks.
192 | config = config or RAGLiteConfig()
193 | chunks: list[Chunk] = (
194 | retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
195 | if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
196 | else chunk_ids
197 | )
198 | # Assign a reciprocal ranking score to each chunk based on its position in the original list.
199 | chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
200 | # Extend the chunks with their neighbouring chunks.
201 | with Session(create_database_engine(config)) as session:
202 | if neighbors:
203 | neighbor_conditions = [
204 | and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset)
205 | for chunk in chunks
206 | for offset in neighbors
207 | ]
208 | chunks += list(
209 | session.exec(
210 | select(Chunk)
211 | .where(or_(*neighbor_conditions))
212 | # Eagerly load chunk.document.
213 | .options(joinedload(Chunk.document)) # type: ignore[arg-type]
214 | ).all()
215 | )
216 | # Deduplicate and sort the chunks by document_id and index (needed for groupby).
217 | unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index))
218 | # Group the chunks into contiguous segments.
219 | chunk_spans: list[ChunkSpan] = []
220 | for _, group in groupby(unique_chunks, key=lambda chunk: chunk.document_id):
221 | chunk_sequence: list[Chunk] = []
222 | for chunk in group:
223 | if not chunk_sequence or chunk.index == chunk_sequence[-1].index + 1:
224 | chunk_sequence.append(chunk)
225 | else:
226 | chunk_spans.append(ChunkSpan(chunks=chunk_sequence))
227 | chunk_sequence = [chunk]
228 | chunk_spans.append(ChunkSpan(chunks=chunk_sequence))
229 | # Rank segments according to the aggregate relevance of their chunks.
230 | chunk_spans.sort(
231 | key=lambda chunk_span: sum(
232 | chunk_id_to_score.get(chunk.id, 0.0) for chunk in chunk_span.chunks
233 | ),
234 | reverse=True,
235 | )
236 | return chunk_spans
237 |
238 |
239 | def rerank_chunks(
240 | query: str, chunk_ids: list[ChunkId] | list[Chunk], *, config: RAGLiteConfig | None = None
241 | ) -> list[Chunk]:
242 | """Rerank chunks according to their relevance to a given query."""
243 | # Retrieve the chunks.
244 | config = config or RAGLiteConfig()
245 | chunks: list[Chunk] = (
246 | retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
247 | if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
248 | else chunk_ids
249 | )
250 | # Exit early if no reranker is configured or if the input is empty.
251 | if not config.reranker or not chunks:
252 | return chunks
253 | # Select the reranker.
254 | if isinstance(config.reranker, dict):
255 | # Detect the languages of the chunks and queries.
256 | with contextlib.suppress(LangDetectException):
257 | langs = {detect(str(chunk)) for chunk in chunks}
258 | langs.add(detect(query))
259 | # If all chunks and the query are in the same language, use a language-specific reranker.
260 | rerankers = config.reranker
261 | if len(langs) == 1 and (lang := next(iter(langs))) in rerankers:
262 | reranker = rerankers[lang]
263 | else:
264 | reranker = rerankers.get("other")
265 | else:
266 | # A specific reranker was configured.
267 | reranker = config.reranker
268 | # Rerank the chunks.
269 | if reranker:
270 | results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks])
271 | chunks = [chunks[result.doc_id] for result in results.results]
272 | return chunks
273 |
274 |
275 | def search_and_rerank_chunks(
276 | query: str,
277 | *,
278 | num_results: int = 8,
279 | oversample: int = 4,
280 | search: BasicSearchMethod = hybrid_search,
281 | config: RAGLiteConfig | None = None,
282 | ) -> list[Chunk]:
283 | """Search and rerank chunks."""
284 | chunk_ids, _ = search(query, num_results=oversample * num_results, config=config)
285 | chunks = rerank_chunks(query, chunk_ids, config=config)[:num_results]
286 | return chunks
287 |
288 |
289 | def search_and_rerank_chunk_spans( # noqa: PLR0913
290 | query: str,
291 | *,
292 | num_results: int = 8,
293 | oversample: int = 4,
294 | neighbors: tuple[int, ...] | None = (-1, 1),
295 | search: BasicSearchMethod = hybrid_search,
296 | config: RAGLiteConfig | None = None,
297 | ) -> list[ChunkSpan]:
298 | """Search and rerank chunks, and then collate into chunk spans."""
299 | chunk_ids, _ = search(query, num_results=oversample * num_results, config=config)
300 | chunks = rerank_chunks(query, chunk_ids, config=config)[:num_results]
301 | chunk_spans = retrieve_chunk_spans(chunks, neighbors=neighbors, config=config)
302 | return chunk_spans
303 |
--------------------------------------------------------------------------------
/src/raglite/_split_chunklets.py:
--------------------------------------------------------------------------------
1 | """Split a document into chunklets."""
2 |
3 | from collections.abc import Callable
4 |
5 | import numpy as np
6 | from markdown_it import MarkdownIt
7 |
8 | from raglite._typing import FloatVector
9 |
10 |
11 | def markdown_chunklet_boundaries(sentences: list[str]) -> FloatVector:
12 | """Estimate chunklet boundary probabilities given a Markdown document."""
13 | # Parse the document.
14 | doc = "".join(sentences)
15 | md = MarkdownIt()
16 | tokens = md.parse(doc)
17 | # Identify the character index of each line in the document.
18 | lines = doc.splitlines(keepends=True)
19 | line_start_char = [0]
20 | for line in lines[:-1]:
21 | line_start_char.append(line_start_char[-1] + len(line))
22 | # Identify the character index of each sentence in the document.
23 | sentence_start_char = [0]
24 | for sentence in sentences:
25 | sentence_start_char.append(sentence_start_char[-1] + len(sentence))
26 | # Map each line index to a corresponding sentence index.
27 | line_to_sentence = np.searchsorted(sentence_start_char, line_start_char, side="right") - 1
28 | # Configure probabilities for token types to be chunklet boundaries.
29 | token_type_to_proba = {
30 | "blockquote_open": 0.75,
31 | "bullet_list_open": 0.25,
32 | "heading_open": 1.0,
33 | "paragraph_open": 0.5,
34 | "ordered_list_open": 0.25,
35 | }
36 | # Compute the boundary probabilities for each sentence.
37 | last_sentence = -1
38 | boundary_probas = np.zeros(len(sentences))
39 | for token in tokens:
40 | if token.type in token_type_to_proba:
41 | start_line, _ = token.map # type: ignore[misc]
42 | if (i := line_to_sentence[start_line]) != last_sentence:
43 | boundary_probas[i] = token_type_to_proba[token.type]
44 | last_sentence = i # type: ignore[assignment]
45 | # For segments of consecutive boundaries, encourage splitting on the largest boundary in the
46 | # segment by setting the other boundary probabilities in the segment to zero.
47 | mask = boundary_probas != 0.0
48 | split_indices = np.flatnonzero(mask[1:] != mask[:-1]) + 1
49 | segments = np.split(boundary_probas, split_indices)
50 | for segment in segments:
51 | max_idx, max_proba = np.argmax(segment), np.max(segment)
52 | segment[:] = 0.0
53 | segment[max_idx] = max_proba
54 | boundary_probas = np.concatenate(segments)
55 | return boundary_probas
56 |
57 |
58 | def compute_num_statements(sentences: list[str]) -> FloatVector:
59 | """Compute the approximate number of statements of each sentence in a list of sentences."""
60 | sentence_word_length = np.asarray(
61 | [len(sentence.split()) for sentence in sentences], dtype=np.float64
62 | )
63 | q25, q75 = np.quantile(sentence_word_length, [0.25, 0.75])
64 | q25 = max(q25, np.sqrt(np.finfo(np.float64).eps))
65 | q75 = max(q75, q25 + np.sqrt(np.finfo(np.float64).eps))
66 | num_statements = np.piecewise(
67 | sentence_word_length,
68 | [sentence_word_length <= q25, sentence_word_length > q25],
69 | [lambda n: 0.75 * n / q25, lambda n: 0.75 + 0.5 * (n - q25) / (q75 - q25)],
70 | )
71 | return num_statements
72 |
73 |
74 | def split_chunklets(
75 | sentences: list[str],
76 | boundary_cost: Callable[[FloatVector], float] = lambda p: (1.0 - p[0]) + np.sum(p[1:]),
77 | statement_cost: Callable[[float], float] = lambda s: ((s - 3) ** 2 / np.sqrt(max(s, 1e-6)) / 2),
78 | max_size: int = 2048,
79 | ) -> list[str]:
80 | """Split sentences into optimal chunklets.
81 |
82 | A chunklet is a concatenated contiguous list of sentences from a document. This function
83 | optimally partitions a document into chunklets using dynamic programming.
84 |
85 | A chunklet is considered optimal when it contains as close to 3 statements as possible, when the
86 | first sentence in the chunklet is a Markdown boundary such as the start of a heading or
87 | paragraph, and when the remaining sentences in the chunklet are not Markdown boundaries.
88 |
89 | Here, we define the number of statements in a sentence as a measure of the sentence's
90 | information content. A sentence is said to contain 1 statement if it contains the median number
91 | of words per sentence, across sentences in the document.
92 |
93 | The given document of sentences is optimally partitioned into chunklets by solving a dynamic
94 | programming problem that assigns a cost to each chunklet given the
95 | `boundary_cost(boundary_probas)` function and the `statement_cost(num_statements)` function. The
96 | former outputs the cost associated with the boundaries of the chunklet's sentences given the
97 | sentences' Markdown boundary probabilities, while the latter outputs the cost of the total
98 | number of statements in the chunklet.
99 |
100 | Parameters
101 | ----------
102 | sentences
103 | The input document as a list of sentences.
104 | boundary_cost
105 | A function that computes the boundary cost of a chunklet given the boundary probabilities of
106 | its sentences. The total cost of a chunklet is the sum of its boundary and statement cost.
107 | statement_cost
108 | A function that computes the statement cost of a chunklet given its number of statements.
109 | The total cost of a chunklet is the sum of its boundary and statement cost.
110 | max_size
111 | The maximum size of a chunklet in characters.
112 |
113 | Returns
114 | -------
115 | list[str]
116 | The document optimally partitioned into chunklets.
117 | """
118 | # Precompute chunklet boundary probabilities and each sentence's number of statements.
119 | boundary_probas = markdown_chunklet_boundaries(sentences)
120 | num_statements = compute_num_statements(sentences)
121 | # Initialize a dynamic programming table and backpointers. The dynamic programming table dp[i]
122 | # is defined as the minimum cost to segment the first i sentences (i.e., sentences[:i]).
123 | num_sentences = len(sentences)
124 | dp = np.full(num_sentences + 1, np.inf)
125 | dp[0] = 0.0
126 | back = -np.ones(num_sentences + 1, dtype=np.intp)
127 | # Compute the cost of partitioning sentences into chunklets.
128 | for i in range(1, num_sentences + 1):
129 | for j in range(i):
130 | # Limit the chunklets to a maximum size.
131 | if sum(len(s) for s in sentences[j:i]) > max_size:
132 | continue
133 | # Compute the cost of partitioning sentences[j:i] into a single chunklet.
134 | cost_ji = boundary_cost(boundary_probas[j:i])
135 | cost_ji += statement_cost(np.sum(num_statements[j:i]))
136 | # Compute the cost of partitioning sentences[:i] if we were to split at j.
137 | cost_0i = dp[j] + cost_ji
138 | # If the cost is less than the current minimum, update the DP table and backpointer.
139 | if cost_0i < dp[i]:
140 | dp[i] = cost_0i
141 | back[i] = j
142 | # Recover the optimal partitioning.
143 | partition_indices: list[int] = []
144 | i = back[num_sentences]
145 | while i > 0:
146 | partition_indices.append(i)
147 | i = back[i]
148 | partition_indices.reverse()
149 | # Split the sentences into optimal chunklets.
150 | chunklets = [
151 | "".join(sentences[i:j])
152 | for i, j in zip([0, *partition_indices], [*partition_indices, len(sentences)], strict=True)
153 | ]
154 | return chunklets
155 |
--------------------------------------------------------------------------------
/src/raglite/_split_chunks.py:
--------------------------------------------------------------------------------
1 | """Split a document into semantic chunks."""
2 |
3 | import re
4 |
5 | import numpy as np
6 | from scipy.optimize import linprog
7 | from scipy.sparse import coo_matrix
8 |
9 | from raglite._typing import FloatMatrix
10 |
11 |
12 | def split_chunks( # noqa: C901, PLR0915
13 | chunklets: list[str],
14 | chunklet_embeddings: FloatMatrix,
15 | max_size: int = 2048,
16 | ) -> tuple[list[str], list[FloatMatrix]]:
17 | """Split chunklets into optimal semantic chunks with corresponding chunklet embeddings.
18 |
19 | A chunk is a concatenated contiguous list of chunklets from a document. This function
20 | optimally partitions a document into chunks using binary integer programming.
21 |
22 | A partioning of a document into chunks is considered optimal if the total cost of partitioning
23 | the document into chunks is minimized. The cost of adding a partition point is given by the
24 | cosine similarity of the chunklet embedding before and after the partition point, corrected by
25 | the discourse vector of the chunklet embeddings across the document.
26 |
27 | Parameters
28 | ----------
29 | chunklets
30 | The input document as a list of chunklets.
31 | chunklet_embeddings
32 | A NumPy array wherein the i'th row is an embedding vector corresponding to the i'th
33 | chunklet. Embedding vectors are expected to have nonzero length.
34 | max_size
35 | The maximum size of a chunk in characters.
36 |
37 | Returns
38 | -------
39 | tuple[list[str], list[FloatMatrix]]
40 | The document and chunklet embeddings optimally partitioned into chunks.
41 | """
42 | # Validate the input.
43 | chunklet_size = np.asarray([len(chunklet) for chunklet in chunklets])
44 | if not np.all(chunklet_size <= max_size):
45 | error_message = "Chunklet larger than chunk max_size detected."
46 | raise ValueError(error_message)
47 | if not np.all(np.linalg.norm(chunklet_embeddings, axis=1) > 0.0):
48 | error_message = "Chunklet embeddings with zero norm detected."
49 | raise ValueError(error_message)
50 | # Exit early if there is only one chunk to return.
51 | if len(chunklets) <= 1 or sum(chunklet_size) <= max_size:
52 | return ["".join(chunklets)] if chunklets else chunklets, [chunklet_embeddings]
53 | # Normalise the chunklet embeddings to unit norm.
54 | X = chunklet_embeddings.astype(np.float32) # noqa: N806
55 | X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
56 | # Select nonoutlying chunklets and remove the discourse vector.
57 | q15, q85 = np.quantile(chunklet_size, [0.15, 0.85])
58 | nonoutlying_chunklets = (q15 <= chunklet_size) & (chunklet_size <= q85)
59 | if np.any(nonoutlying_chunklets):
60 | discourse = np.mean(X[nonoutlying_chunklets, :], axis=0)
61 | discourse = discourse / np.linalg.norm(discourse)
62 | X_modulo = X - np.outer(X @ discourse, discourse) # noqa: N806
63 | if not np.any(np.linalg.norm(X_modulo, axis=1) <= np.finfo(X.dtype).eps):
64 | X = X_modulo # noqa: N806
65 | X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
66 | # For each partition point in the list of chunklets, compute the similarity of chunklet before
67 | # and after the partition point.
68 | partition_similarity = np.sum(X[:-1] * X[1:], axis=1)
69 | # Make partition similarity nonnegative before modification and optimisation.
70 | partition_similarity = np.maximum(
71 | (partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps)
72 | )
73 | # Modify the partition similarity to encourage splitting on Markdown headings.
74 | prev_chunklet_is_heading = True
75 | for i, chunklet in enumerate(chunklets[:-1]):
76 | is_heading = bool(re.match(r"^#+\s", chunklet.replace("\n", "").strip()))
77 | if is_heading:
78 | # Encourage splitting before a heading.
79 | if not prev_chunklet_is_heading:
80 | partition_similarity[i - 1] = partition_similarity[i - 1] / 4
81 | # Don't split immediately after a heading.
82 | partition_similarity[i] = 1.0
83 | prev_chunklet_is_heading = is_heading
84 | # Solve an optimisation problem to find the best partition points.
85 | chunklet_size_cumsum = np.cumsum(chunklet_size)
86 | row_indices = []
87 | col_indices = []
88 | data = []
89 | for i in range(len(chunklets) - 1):
90 | r = chunklet_size_cumsum[i - 1] if i > 0 else 0
91 | idx = np.searchsorted(chunklet_size_cumsum - r, max_size, side="right")
92 | assert idx > i
93 | if idx == len(chunklet_size_cumsum):
94 | break
95 | cols = list(range(i, idx))
96 | col_indices.extend(cols)
97 | row_indices.extend([i] * len(cols))
98 | data.extend([1] * len(cols))
99 | A = coo_matrix( # noqa: N806
100 | (data, (row_indices, col_indices)),
101 | shape=(max(row_indices) + 1, len(chunklets) - 1),
102 | dtype=np.float32,
103 | )
104 | b_ub = np.ones(A.shape[0], dtype=np.float32)
105 | res = linprog(
106 | partition_similarity,
107 | A_ub=-A,
108 | b_ub=-b_ub,
109 | bounds=(0, 1),
110 | integrality=[1] * A.shape[1],
111 | )
112 | if not res.success:
113 | error_message = "Optimization of chunk partitions failed."
114 | raise ValueError(error_message)
115 | # Split the chunklets and their window embeddings into optimal chunks.
116 | partition_indices = (np.where(res.x)[0] + 1).tolist()
117 | chunks = [
118 | "".join(chunklets[i:j])
119 | for i, j in zip([0, *partition_indices], [*partition_indices, len(chunklets)], strict=True)
120 | ]
121 | chunk_embeddings = np.split(chunklet_embeddings, partition_indices)
122 | return chunks, chunk_embeddings
123 |
--------------------------------------------------------------------------------
/src/raglite/_split_sentences.py:
--------------------------------------------------------------------------------
1 | """Split a document into sentences."""
2 |
3 | import warnings
4 | from collections.abc import Callable
5 | from functools import cache
6 | from typing import Any
7 |
8 | import numpy as np
9 | from markdown_it import MarkdownIt
10 | from scipy import sparse
11 | from scipy.optimize import OptimizeWarning, linprog
12 | from wtpsplit_lite import SaT
13 |
14 | from raglite._typing import FloatVector
15 |
16 |
17 | @cache
18 | def _load_sat() -> tuple[SaT, dict[str, Any]]:
19 | """Load a Segment any Text (SaT) model."""
20 | sat = SaT("sat-3l-sm") # This model makes the best trade-off between speed and accuracy.
21 | sat_kwargs = {"stride": 128, "block_size": 256, "weighting": "hat"}
22 | return sat, sat_kwargs
23 |
24 |
25 | def markdown_sentence_boundaries(doc: str) -> FloatVector:
26 | """Determine known sentence boundaries from a Markdown document."""
27 |
28 | def get_markdown_heading_indexes(doc: str) -> list[tuple[int, int]]:
29 | """Get the indexes of the headings in a Markdown document."""
30 | md = MarkdownIt()
31 | tokens = md.parse(doc)
32 | headings = []
33 | lines = doc.splitlines(keepends=True)
34 | line_start_char = [0]
35 | for line in lines:
36 | line_start_char.append(line_start_char[-1] + len(line))
37 | for token in tokens:
38 | if token.type == "heading_open":
39 | start_line, end_line = token.map # type: ignore[misc]
40 | heading_start = line_start_char[start_line]
41 | heading_end = line_start_char[end_line]
42 | headings.append((heading_start, heading_end + 1))
43 | return headings
44 |
45 | # Get the start and end character indexes of the headings in the document.
46 | headings = get_markdown_heading_indexes(doc)
47 | # Indicate that each heading is a contiguous sentence by setting the boundary probabilities.
48 | boundary_probas = np.full(len(doc), np.nan)
49 | for heading_start, heading_end in headings:
50 | if 0 <= heading_start - 1 < len(boundary_probas):
51 | boundary_probas[heading_start - 1] = 1 # First heading character starts a sentence.
52 | boundary_probas[heading_start : heading_end - 1] = 0 # Body does not contain boundaries.
53 | if 0 <= heading_end - 1 < len(boundary_probas):
54 | boundary_probas[heading_end - 1] = 1 # Last heading character is the end of a sentence.
55 | return boundary_probas
56 |
57 |
58 | def _split_sentences(
59 | doc: str, probas: FloatVector, *, min_len: int, max_len: int | None = None
60 | ) -> list[str]:
61 | # Solve an optimisation problem to find the best sentence boundaries given the predicted
62 | # boundary probabilities. The objective is to select boundaries that maximise the sum of the
63 | # boundary probabilities above a given threshold, subject to the resulting sentences not being
64 | # shorter or longer than the given minimum or maximum length, respectively.
65 | sentence_threshold = 0.25 # Default threshold for -sm models.
66 | c = probas - sentence_threshold
67 | N = len(probas) # noqa: N806
68 | M = N - min_len + 1 # noqa: N806
69 | diagonals = [np.ones(M, dtype=np.float32) for _ in range(min_len)]
70 | offsets = list(range(min_len))
71 | A_min = sparse.diags(diagonals, offsets, shape=(M, N), format="csr") # noqa: N806
72 | b_min = np.ones(M, dtype=np.float32)
73 | bounds = [(0, 1)] * N
74 | bounds[: min_len - 1] = [(0, 0)] * (min_len - 1) # Prevent short leading sentences.
75 | bounds[-min_len:] = [(0, 0)] * min_len # Prevent short trailing sentences.
76 | if max_len is not None and (M := N - max_len + 1) > 0: # noqa: N806
77 | diagonals = [np.ones(M, dtype=np.float32) for _ in range(max_len)]
78 | offsets = list(range(max_len))
79 | A_max = sparse.diags(diagonals, offsets, shape=(M, N), format="csr") # noqa: N806
80 | b_max = np.ones(M, dtype=np.float32)
81 | A_ub = sparse.vstack([A_min, -A_max], format="csr") # noqa: N806
82 | b_ub = np.hstack([b_min, -b_max])
83 | else:
84 | A_ub = A_min # noqa: N806
85 | b_ub = b_min
86 | x0 = (probas >= sentence_threshold).astype(np.float32)
87 | with warnings.catch_warnings():
88 | warnings.filterwarnings("ignore", category=OptimizeWarning) # Ignore x0 not being used.
89 | res = linprog(-c, A_ub=A_ub, b_ub=b_ub, bounds=bounds, x0=x0, integrality=[1] * N)
90 | if not res.success:
91 | error_message = "Optimization of sentence partitions failed."
92 | raise ValueError(error_message)
93 | # Split the document into sentences where the boundary probability exceeds a threshold.
94 | partition_indices = np.where(res.x > 0.5)[0] + 1 # noqa: PLR2004
95 | sentences = [
96 | doc[i:j] for i, j in zip([0, *partition_indices], [*partition_indices, None], strict=True)
97 | ]
98 | return sentences
99 |
100 |
101 | def split_sentences(
102 | doc: str,
103 | min_len: int = 4,
104 | max_len: int | None = None,
105 | boundary_probas: FloatVector | Callable[[str], FloatVector] = markdown_sentence_boundaries,
106 | ) -> list[str]:
107 | """Split a document into sentences.
108 |
109 | Parameters
110 | ----------
111 | doc
112 | The document to split into sentences.
113 | min_len
114 | The minimum number of characters in a sentence.
115 | max_len
116 | The maximum number of characters in a sentence, with no maximum by default.
117 | boundary_probas
118 | Any known sentence boundary probabilities to override the model's predicted sentence
119 | boundary probabilities. If an element of the probability vector with index k is 1 (0), then
120 | the character at index k + 1 is (not) the start of a sentence. Elements set to NaN will not
121 | override the predicted probabilities. By default, the known sentence boundary probabilities
122 | are determined from the document's Markdown headings.
123 |
124 | Returns
125 | -------
126 | list[str]
127 | The input document partitioned into sentences. All sentences are constructed to contain at
128 | least one non-whitespace character, not have any leading whitespace (except for the first
129 | sentence if the document itself has leading whitespace), and respect the minimum and maximum
130 | sentence lengths.
131 | """
132 | # Exit early if there is only one sentence to return.
133 | if len(doc) <= min_len:
134 | return [doc]
135 | # Compute the sentence boundary probabilities with a wtpsplit Segment any Text (SaT) model.
136 | sat, sat_kwargs = _load_sat()
137 | predicted_probas = sat.predict_proba(doc, **sat_kwargs)
138 | # Override the predicted boundary probabilities with the known boundary probabilities.
139 | known_probas = boundary_probas(doc) if callable(boundary_probas) else boundary_probas
140 | probas = predicted_probas.copy()
141 | probas[np.isfinite(known_probas)] = known_probas[np.isfinite(known_probas)]
142 | # Propagate the boundary probabilities so that whitespace is always trailing and never leading.
143 | is_space = np.array([char.isspace() for char in doc], dtype=np.bool_)
144 | start = np.where(np.insert(~is_space[:-1] & is_space[1:], len(is_space) - 1, False))[0]
145 | end = np.where(np.insert(~is_space[1:] & is_space[:-1], 0, False))[0]
146 | start = start[start < np.max(end, initial=-1)]
147 | end = end[end > np.min(start, initial=len(is_space))]
148 | for i, j in zip(start, end, strict=True):
149 | min_proba, max_proba = np.min(probas[i:j]), np.max(probas[i:j])
150 | probas[i : j - 1] = min_proba # From the non-whitespace to the penultimate whitespace char.
151 | probas[j - 1] = max_proba # The last whitespace char.
152 | # Solve an optimization problem to find optimal sentences with no maximum length. We delay the
153 | # maximum length constraint to a subsequent step to avoid blowing up memory usage.
154 | sentences = _split_sentences(doc, probas, min_len=min_len, max_len=None)
155 | # For each sentence that exceeds the maximum length, solve the optimization problem again with
156 | # a maximum length constraint.
157 | if max_len is not None:
158 | sentences = [
159 | subsentence
160 | for sentence in sentences
161 | for subsentence in (
162 | [sentence]
163 | if len(sentence) <= max_len
164 | else _split_sentences(
165 | sentence,
166 | probas[doc.index(sentence) : doc.index(sentence) + len(sentence)],
167 | min_len=min_len,
168 | max_len=max_len,
169 | )
170 | )
171 | ]
172 | return sentences
173 |
--------------------------------------------------------------------------------
/src/raglite/_typing.py:
--------------------------------------------------------------------------------
1 | """RAGLite typing."""
2 |
3 | import io
4 | import pickle
5 | from collections.abc import Callable
6 | from typing import TYPE_CHECKING, Any, Literal, Protocol
7 |
8 | import numpy as np
9 | from sqlalchemy import literal
10 | from sqlalchemy.engine import Dialect
11 | from sqlalchemy.ext.compiler import compiles
12 | from sqlalchemy.sql.functions import FunctionElement
13 | from sqlalchemy.sql.operators import Operators
14 | from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType
15 |
16 | if TYPE_CHECKING:
17 | from raglite._config import RAGLiteConfig
18 | from raglite._database import Chunk, ChunkSpan
19 |
20 | ChunkId = str
21 | DocumentId = str
22 | EvalId = str
23 | IndexId = str
24 |
25 | DistanceMetric = Literal["cosine", "dot", "l1", "l2"]
26 |
27 | FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
28 | FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
29 | IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]
30 |
31 |
32 | class BasicSearchMethod(Protocol):
33 | def __call__(
34 | self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None
35 | ) -> tuple[list[ChunkId], list[float]]: ...
36 |
37 |
38 | class SearchMethod(Protocol):
39 | def __call__(
40 | self, query: str, *, num_results: int, config: "RAGLiteConfig | None" = None
41 | ) -> tuple[list[ChunkId], list[float]] | list["Chunk"] | list["ChunkSpan"]: ...
42 |
43 |
44 | class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
45 | """A NumPy array column type for SQLAlchemy."""
46 |
47 | impl = LargeBinary
48 |
49 | def process_bind_param(
50 | self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
51 | ) -> bytes | None:
52 | """Convert a NumPy array to bytes."""
53 | if value is None:
54 | return None
55 | buffer = io.BytesIO()
56 | np.save(buffer, value, allow_pickle=False, fix_imports=False)
57 | return buffer.getvalue()
58 |
59 | def process_result_value(
60 | self, value: bytes | None, dialect: Dialect
61 | ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
62 | """Convert bytes to a NumPy array."""
63 | if value is None:
64 | return None
65 | return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return]
66 |
67 |
68 | class PickledObject(TypeDecorator[object]):
69 | """A pickled object column type for SQLAlchemy."""
70 |
71 | impl = LargeBinary
72 |
73 | def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None:
74 | """Convert a Python object to bytes."""
75 | if value is None:
76 | return None
77 | return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False)
78 |
79 | def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None:
80 | """Convert bytes to a Python object."""
81 | if value is None:
82 | return None
83 | return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301
84 |
85 |
86 | class EmbeddingDistance(FunctionElement[float]):
87 | """SQL expression that renders a distance operator per dialect."""
88 |
89 | inherit_cache = True
90 | type = Float() # The result is always a scalar float.
91 |
92 | def __init__(self, left: Any, right: Any, metric: DistanceMetric) -> None:
93 | self.metric = metric
94 | super().__init__(left, right)
95 |
96 |
97 | @compiles(EmbeddingDistance, "postgresql")
98 | def _embedding_distance_postgresql(element: EmbeddingDistance, compiler: Any, **kwargs: Any) -> str:
99 | op_map: dict[DistanceMetric, str] = {
100 | "cosine": "<=>",
101 | "dot": "<#>",
102 | "l1": "<+>",
103 | "l2": "<->",
104 | }
105 | left, right = list(element.clauses)
106 | operator = op_map[element.metric]
107 | return f"({compiler.process(left)} {operator} {compiler.process(right)})"
108 |
109 |
110 | @compiles(EmbeddingDistance, "duckdb")
111 | def _embedding_distance_duckdb(element: EmbeddingDistance, compiler: Any, **kwargs: Any) -> str:
112 | func_map: dict[DistanceMetric, str] = {
113 | "cosine": "array_cosine_distance",
114 | "dot": "array_negative_inner_product",
115 | "l2": "array_distance",
116 | }
117 | left, right = list(element.clauses)
118 | dim = left.type.dim # type: ignore[attr-defined]
119 | func_name = func_map[element.metric]
120 | right_cast = f"{compiler.process(right)}::FLOAT[{dim}]"
121 | return f"{func_name}({compiler.process(left)}, {right_cast})"
122 |
123 |
124 | class EmbeddingComparator(UserDefinedType.Comparator[FloatVector]):
125 | """An embedding distance comparator."""
126 |
127 | def distance(self, other: FloatVector, *, metric: DistanceMetric) -> Operators:
128 | rhs = literal(other, type_=self.expr.type)
129 | return EmbeddingDistance(self.expr, rhs, metric)
130 |
131 |
132 | class PostgresHalfVec(UserDefinedType[FloatVector]):
133 | """A PostgreSQL half-precision vector column type for SQLAlchemy."""
134 |
135 | cache_ok = True
136 |
137 | def __init__(self, dim: int | None = None) -> None:
138 | super().__init__()
139 | self.dim = dim
140 |
141 | def get_col_spec(self, **kwargs: Any) -> str:
142 | return f"halfvec({self.dim})"
143 |
144 | def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]:
145 | """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters."""
146 |
147 | def process(value: FloatVector | None) -> str | None:
148 | return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None
149 |
150 | return process
151 |
152 | def result_processor(
153 | self, dialect: Dialect, coltype: Any
154 | ) -> Callable[[str | None], FloatVector | None]:
155 | """Process PostgreSQL halfvec format to NumPy ndarray."""
156 |
157 | def process(value: str | None) -> FloatVector | None:
158 | if value is None:
159 | return None
160 | return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16)
161 |
162 | return process
163 |
164 |
165 | class DuckDBSingleVec(UserDefinedType[FloatVector]):
166 | """A DuckDB single precision vector column type for SQLAlchemy."""
167 |
168 | cache_ok = True
169 |
170 | def __init__(self, dim: int | None = None) -> None:
171 | super().__init__()
172 | self.dim = dim
173 |
174 | def get_col_spec(self, **kwargs: Any) -> str:
175 | return f"FLOAT[{self.dim}]" if self.dim is not None else "FLOAT[]"
176 |
177 | def bind_processor(
178 | self, dialect: Dialect
179 | ) -> Callable[[FloatVector | None], list[float] | None]:
180 | """Process NumPy ndarray to DuckDB single precision vector format for bound parameters."""
181 |
182 | def process(value: FloatVector | None) -> list[float] | None:
183 | return np.ravel(value).tolist() if value is not None else None
184 |
185 | return process
186 |
187 | def result_processor(
188 | self, dialect: Dialect, coltype: Any
189 | ) -> Callable[[list[float] | None], FloatVector | None]:
190 | """Process DuckDB single precision vector format to NumPy ndarray."""
191 |
192 | def process(value: list[float] | None) -> FloatVector | None:
193 | return np.asarray(value, dtype=np.float32) if value is not None else None
194 |
195 | return process
196 |
197 |
198 | class Embedding(TypeDecorator[FloatVector]):
199 | """An embedding column type for SQLAlchemy."""
200 |
201 | cache_ok = True
202 | impl = NumpyArray
203 | comparator_factory: type[EmbeddingComparator] = EmbeddingComparator
204 |
205 | def __init__(self, dim: int = -1):
206 | super().__init__()
207 | self.dim = dim
208 |
209 | def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]:
210 | if dialect.name == "postgresql":
211 | return dialect.type_descriptor(PostgresHalfVec(self.dim))
212 | if dialect.name == "duckdb":
213 | return dialect.type_descriptor(DuckDBSingleVec(self.dim))
214 | return dialect.type_descriptor(NumpyArray())
215 |
--------------------------------------------------------------------------------
/src/raglite/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlinear-ai/raglite/17cf038a597c23911412d6e3f34a686c0eca933a/src/raglite/py.typed
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """RAGLite test suite."""
2 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Fixtures for the tests."""
2 |
3 | import os
4 | import socket
5 | import tempfile
6 | from collections.abc import Generator
7 | from pathlib import Path
8 |
9 | import pytest
10 | from sqlalchemy import create_engine, text
11 |
12 | from raglite import Document, RAGLiteConfig, insert_documents
13 |
14 | POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres"
15 |
16 |
17 | def pytest_configure(config: pytest.Config) -> None:
18 | """Configure pytest to skip slow tests in CI."""
19 | if os.environ.get("CI"):
20 | markexpr = "not slow"
21 | if config.option.markexpr:
22 | markexpr = f"({config.option.markexpr}) and ({markexpr})"
23 | config.option.markexpr = markexpr
24 |
25 |
26 | def is_postgres_running() -> bool:
27 | """Check if PostgreSQL is running."""
28 | try:
29 | with socket.create_connection(("postgres", 5432), timeout=1):
30 | return True
31 | except OSError:
32 | return False
33 |
34 |
35 | def is_openai_available() -> bool:
36 | """Check if an OpenAI API key is set."""
37 | return bool(os.environ.get("OPENAI_API_KEY"))
38 |
39 |
40 | def pytest_sessionstart(session: pytest.Session) -> None:
41 | """Reset the PostgreSQL database."""
42 | if is_postgres_running():
43 | engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT")
44 | with engine.connect() as conn:
45 | for variant in ["local", "remote"]:
46 | conn.execute(text(f"DROP DATABASE IF EXISTS raglite_test_{variant}"))
47 | conn.execute(text(f"CREATE DATABASE raglite_test_{variant}"))
48 |
49 |
50 | @pytest.fixture(scope="session")
51 | def duckdb_url() -> Generator[str, None, None]:
52 | """Create a temporary DuckDB database file and return the database URL."""
53 | with tempfile.TemporaryDirectory() as temp_dir:
54 | db_file = Path(temp_dir) / "raglite_test.db"
55 | yield f"duckdb:///{db_file}"
56 |
57 |
58 | @pytest.fixture(
59 | scope="session",
60 | params=[
61 | pytest.param("duckdb", id="duckdb"),
62 | pytest.param(
63 | POSTGRES_URL,
64 | id="postgres",
65 | marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"),
66 | ),
67 | ],
68 | )
69 | def database(request: pytest.FixtureRequest) -> str:
70 | """Get a database URL to test RAGLite with."""
71 | db_url: str = (
72 | request.getfixturevalue("duckdb_url") if request.param == "duckdb" else request.param
73 | )
74 | return db_url
75 |
76 |
77 | @pytest.fixture(
78 | scope="session",
79 | params=[
80 | pytest.param(
81 | (
82 | "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192",
83 | "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512", # More context degrades performance.
84 | ),
85 | id="qwen3_4B-bge_m3",
86 | ),
87 | pytest.param(
88 | ("gpt-4o-mini", "text-embedding-3-small"),
89 | id="gpt_4o_mini-text_embedding_3_small",
90 | marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"),
91 | ),
92 | ],
93 | )
94 | def llm_embedder(request: pytest.FixtureRequest) -> str:
95 | """Get an LLM and embedder pair to test RAGLite with."""
96 | llm_embedder: str = request.param
97 | return llm_embedder
98 |
99 |
100 | @pytest.fixture(scope="session")
101 | def llm(llm_embedder: tuple[str, str]) -> str:
102 | """Get an LLM to test RAGLite with."""
103 | llm, _ = llm_embedder
104 | return llm
105 |
106 |
107 | @pytest.fixture(scope="session")
108 | def embedder(llm_embedder: tuple[str, str]) -> str:
109 | """Get an embedder to test RAGLite with."""
110 | _, embedder = llm_embedder
111 | return embedder
112 |
113 |
114 | @pytest.fixture(scope="session")
115 | def raglite_test_config(database: str, llm: str, embedder: str) -> RAGLiteConfig:
116 | """Create a lightweight in-memory config for testing DuckDB and PostgreSQL."""
117 | # Select the database based on the embedder.
118 | variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
119 | if "postgres" in database:
120 | database = database.replace("/postgres", f"/raglite_test_{variant}")
121 | elif "duckdb" in database:
122 | database = database.replace(".db", f"_{variant}.db")
123 | # Create a RAGLite config for the given database and embedder.
124 | db_config = RAGLiteConfig(db_url=database, llm=llm, embedder=embedder)
125 | # Insert a document and update the index.
126 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
127 | insert_documents([Document.from_path(doc_path)], config=db_config)
128 | return db_config
129 |
--------------------------------------------------------------------------------
/tests/specrel.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlinear-ai/raglite/17cf038a597c23911412d6e3f34a686c0eca933a/tests/specrel.pdf
--------------------------------------------------------------------------------
/tests/test_chatml_function_calling.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's upgraded chatml-function-calling llama-cpp-python chat handler."""
2 |
3 | import os
4 | from typing import TYPE_CHECKING, cast
5 |
6 | if TYPE_CHECKING:
7 | from collections.abc import Iterator
8 |
9 | import pytest
10 | from typeguard import ForwardRefPolicy, check_type
11 |
12 | from raglite._chatml_function_calling import chatml_function_calling_with_streaming
13 | from raglite._lazy_llama import (
14 | Llama,
15 | llama_supports_gpu_offload,
16 | llama_types,
17 | )
18 |
19 |
20 | def is_accelerator_available() -> bool:
21 | """Check if an accelerator is available."""
22 | if llama_supports_gpu_offload():
23 | return True
24 | if os.environ.get("CI"):
25 | return False
26 | return (os.cpu_count() or 1) >= 8 # noqa: PLR2004
27 |
28 |
29 | @pytest.mark.parametrize(
30 | "stream",
31 | [
32 | pytest.param(True, id="stream=True"),
33 | pytest.param(False, id="stream=False"),
34 | ],
35 | )
36 | @pytest.mark.parametrize(
37 | "tool_choice",
38 | [
39 | pytest.param("none", id="tool_choice=none"),
40 | pytest.param("auto", id="tool_choice=auto"),
41 | pytest.param(
42 | {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed"
43 | ),
44 | ],
45 | )
46 | @pytest.mark.parametrize(
47 | "user_prompt_expected_tool_calls",
48 | [
49 | pytest.param(
50 | ("Is 7 a prime number?", 0),
51 | id="expected_tool_calls=0",
52 | ),
53 | pytest.param(
54 | ("What's the weather like in Paris today?", 1),
55 | id="expected_tool_calls=1",
56 | ),
57 | pytest.param(
58 | ("What's the weather like in Paris today? What about New York?", 2),
59 | id="expected_tool_calls=2",
60 | ),
61 | ],
62 | )
63 | @pytest.mark.parametrize(
64 | "llm_repo_id",
65 | [
66 | pytest.param("unsloth/Qwen3-4B-GGUF", id="qwen3_4B"),
67 | pytest.param(
68 | "unsloth/Qwen3-8B-GGUF",
69 | id="qwen3_8B",
70 | marks=pytest.mark.skipif(
71 | not is_accelerator_available(), reason="Accelerator not available"
72 | ),
73 | ),
74 | ],
75 | )
76 | @pytest.mark.slow
77 | def test_llama_cpp_python_tool_use(
78 | llm_repo_id: str,
79 | user_prompt_expected_tool_calls: tuple[str, int],
80 | tool_choice: llama_types.ChatCompletionToolChoiceOption,
81 | stream: bool, # noqa: FBT001
82 | ) -> None:
83 | """Test the upgraded chatml-function-calling llama-cpp-python chat handler."""
84 | user_prompt, expected_tool_calls = user_prompt_expected_tool_calls
85 | if isinstance(tool_choice, dict) and expected_tool_calls == 0:
86 | pytest.skip("Nonsensical")
87 | llm = Llama.from_pretrained(
88 | repo_id=llm_repo_id,
89 | filename="*Q4_K_M.gguf",
90 | n_ctx=8192,
91 | n_gpu_layers=-1,
92 | verbose=False,
93 | chat_handler=chatml_function_calling_with_streaming,
94 | )
95 | messages: list[llama_types.ChatCompletionRequestMessage] = [
96 | {"role": "user", "content": user_prompt}
97 | ]
98 | tools: list[llama_types.ChatCompletionTool] = [
99 | {
100 | "type": "function",
101 | "function": {
102 | "name": "get_weather",
103 | "description": "Get the weather for a location.",
104 | "parameters": {
105 | "type": "object",
106 | "properties": {"location": {"type": "string", "description": "A city name."}},
107 | },
108 | },
109 | }
110 | ]
111 | response = llm.create_chat_completion(
112 | messages=messages, tools=tools, tool_choice=tool_choice, stream=stream
113 | )
114 | if stream:
115 | response = cast("Iterator[llama_types.CreateChatCompletionStreamResponse]", response)
116 | num_tool_calls = 0
117 | for chunk in response:
118 | check_type(chunk, llama_types.CreateChatCompletionStreamResponse)
119 | tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
120 | if isinstance(tool_calls, list):
121 | num_tool_calls = max(tool_call["index"] for tool_call in tool_calls) + 1
122 | assert num_tool_calls == (expected_tool_calls if tool_choice != "none" else 0)
123 | else:
124 | response = cast("llama_types.CreateChatCompletionResponse", response)
125 | check_type(
126 | response,
127 | llama_types.CreateChatCompletionResponse,
128 | forward_ref_policy=ForwardRefPolicy.IGNORE,
129 | )
130 | if expected_tool_calls == 0 or tool_choice == "none":
131 | assert response["choices"][0]["message"].get("tool_calls") is None
132 | else:
133 | assert len(response["choices"][0]["message"]["tool_calls"]) == expected_tool_calls
134 | assert all(
135 | tool_call["function"]["name"] == tools[0]["function"]["name"]
136 | for tool_call in response["choices"][0]["message"]["tool_calls"]
137 | )
138 |
--------------------------------------------------------------------------------
/tests/test_database.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's database engine creation."""
2 |
3 | from pathlib import Path
4 |
5 | from raglite import RAGLiteConfig
6 | from raglite._database import create_database_engine
7 |
8 |
9 | def test_in_memory_duckdb_creation(tmp_path: Path) -> None:
10 | """Test creating an in-memory DuckDB database."""
11 | config = RAGLiteConfig(db_url="duckdb:///:memory:")
12 | create_database_engine(config)
13 |
14 |
15 | def test_repeated_duckdb_creation(tmp_path: Path) -> None:
16 | """Test creating the same DuckDB database engine twice."""
17 | duckdb_filepath = tmp_path / "test.db"
18 | config = RAGLiteConfig(db_url=f"duckdb:///{duckdb_filepath.as_posix()}")
19 | create_database_engine(config)
20 | create_database_engine.cache_clear()
21 | create_database_engine(config)
22 |
--------------------------------------------------------------------------------
/tests/test_embed.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's embedding functionality."""
2 |
3 | from pathlib import Path
4 |
5 | import numpy as np
6 |
7 | from raglite import RAGLiteConfig
8 | from raglite._embed import embed_strings
9 | from raglite._markdown import document_to_markdown
10 | from raglite._split_sentences import split_sentences
11 |
12 |
13 | def test_embed(embedder: str) -> None:
14 | """Test embedding a document."""
15 | raglite_test_config = RAGLiteConfig(embedder=embedder, embedder_normalize=True)
16 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
17 | doc = document_to_markdown(doc_path)
18 | sentences = split_sentences(doc, max_len=raglite_test_config.chunk_max_size)
19 | sentence_embeddings = embed_strings(sentences, config=raglite_test_config)
20 | assert isinstance(sentences, list)
21 | assert isinstance(sentence_embeddings, np.ndarray)
22 | assert len(sentences) == len(sentence_embeddings)
23 | assert sentence_embeddings.shape[1] >= 128 # noqa: PLR2004
24 | assert sentence_embeddings.dtype == np.float16
25 | assert np.all(np.isfinite(sentence_embeddings))
26 | assert np.allclose(np.linalg.norm(sentence_embeddings, axis=1), 1.0, rtol=1e-3)
27 |
--------------------------------------------------------------------------------
/tests/test_extract.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's structured output extraction."""
2 |
3 | from typing import ClassVar
4 |
5 | import pytest
6 | from pydantic import BaseModel, ConfigDict, Field
7 |
8 | from raglite import RAGLiteConfig
9 | from raglite._extract import extract_with_llm
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")]
14 | )
15 | def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001
16 | """Test extracting structured data."""
17 | # Set the LLM.
18 | config = RAGLiteConfig(llm=llm)
19 |
20 | # Define the JSON schema of the response.
21 | class LoginResponse(BaseModel):
22 | """The response to a login request."""
23 |
24 | model_config = ConfigDict(extra="forbid" if strict else "allow")
25 | username: str = Field(..., description="The username.")
26 | password: str = Field(..., description="The password.")
27 | system_prompt: ClassVar[str] = "Extract the username and password from the input."
28 |
29 | # Extract structured data.
30 | username, password = "cypher", "steak"
31 | login_response = extract_with_llm(
32 | LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config
33 | )
34 | # Validate the response.
35 | assert isinstance(login_response, LoginResponse)
36 | assert login_response.username == username
37 | assert login_response.password == password
38 |
--------------------------------------------------------------------------------
/tests/test_import.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite."""
2 |
3 | import raglite
4 |
5 |
6 | def test_import() -> None:
7 | """Test that the package can be imported."""
8 | assert isinstance(raglite.__name__, str)
9 |
--------------------------------------------------------------------------------
/tests/test_insert.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's document insertion."""
2 |
3 | from pathlib import Path
4 |
5 | from sqlmodel import Session, select
6 | from tqdm import tqdm
7 |
8 | from raglite._config import RAGLiteConfig
9 | from raglite._database import Chunk, Document, create_database_engine
10 | from raglite._markdown import document_to_markdown
11 |
12 |
13 | def test_insert(raglite_test_config: RAGLiteConfig) -> None:
14 | """Test document insertion."""
15 | # Open a session to extract document and chunks from the existing database.
16 | with Session(create_database_engine(raglite_test_config)) as session:
17 | # Get the first document from the database (already inserted by the fixture).
18 | document = session.exec(select(Document)).first()
19 | assert document is not None, "No document found in the database"
20 | # Get the existing chunks for this document.
21 | chunks = session.exec(
22 | select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.index) # type: ignore[arg-type]
23 | ).all()
24 | assert len(chunks) > 0, "No chunks found for the document"
25 | restored_document = ""
26 | for chunk in tqdm(chunks, desc="Processing chunks", leave=False):
27 | # Body should not contain the heading string (except if heading is empty).
28 | if chunk.headings.strip() != "":
29 | assert chunk.headings.strip() not in chunk.body.strip(), (
30 | f"Chunk body contains heading: '{chunk.headings.strip()}'\n"
31 | f"Chunk body: '{chunk.body.strip()}'"
32 | )
33 | # Body that starts with a # should not have a heading.
34 | if chunk.body.strip().startswith("# "):
35 | assert chunk.headings.strip() == "", (
36 | f"Chunk body starts with a heading: '{chunk.body.strip()}'\n"
37 | f"Chunk headings: '{chunk.headings.strip()}'"
38 | )
39 | restored_document += chunk.body
40 | # Combining the chunks should yield the original document.
41 | restored_document = restored_document.replace("\n", "").strip()
42 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
43 | doc = document_to_markdown(doc_path)
44 | doc = doc.replace("\n", "").strip()
45 | assert restored_document == doc, "Restored document does not match the original input."
46 |
--------------------------------------------------------------------------------
/tests/test_lazy_llama.py:
--------------------------------------------------------------------------------
1 | """Test that llama-cpp-python package is an optional dependency for RAGLite."""
2 |
3 | import builtins
4 | import sys
5 | from typing import Any
6 |
7 | import pytest
8 |
9 |
10 | def test_raglite_import_without_llama_cpp(monkeypatch: pytest.MonkeyPatch) -> None:
11 | """Test that RAGLite can be imported without llama-cpp-python being available."""
12 | # Unimport raglite and llama_cpp.
13 | module_names = list(sys.modules)
14 | for module_name in module_names:
15 | if module_name.startswith(("llama_cpp", "raglite", "sqlmodel")):
16 | monkeypatch.delitem(sys.modules, module_name, raising=False)
17 |
18 | # Save the original __import__ function.
19 | original_import = builtins.__import__
20 |
21 | # Define a fake import function that raises ModuleNotFoundError when trying to import llama_cpp.
22 | def fake_import(name: str, *args: Any) -> Any:
23 | if name.startswith("llama_cpp"):
24 | import_error = f"No module named '{name}'"
25 | raise ModuleNotFoundError(import_error)
26 | return original_import(name, *args)
27 |
28 | # Monkey patch __import__ with the fake import function.
29 | monkeypatch.setattr(builtins, "__import__", fake_import)
30 |
31 | # Verify that importing raglite does not raise an error.
32 | import raglite # noqa: F401
33 | from raglite._config import llama_supports_gpu_offload # type: ignore[attr-defined]
34 |
35 | # Verify that lazily using llama-cpp-python raises a ModuleNotFoundError.
36 | with pytest.raises(ModuleNotFoundError, match="llama.cpp models"):
37 | llama_supports_gpu_offload()
38 |
--------------------------------------------------------------------------------
/tests/test_markdown.py:
--------------------------------------------------------------------------------
1 | """Test Markdown conversion."""
2 |
3 | from pathlib import Path
4 |
5 | from raglite._markdown import document_to_markdown
6 |
7 |
8 | def test_pdf_with_missing_font_sizes() -> None:
9 | """Test conversion of a PDF with missing font sizes."""
10 | # Convert a PDF whose parsed font sizes are all equal to 1.
11 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
12 | doc = document_to_markdown(doc_path)
13 | # Verify that we can reconstruct the font sizes and heading levels regardless of the missing
14 | # font size data.
15 | expected_heading = "# ON THE ELECTRODYNAMICS OF MOVING BODIES \n\n## By A. EINSTEIN June 30, 1905 \n\nIt is known that Maxwell"
16 | assert doc.startswith(expected_heading)
17 |
--------------------------------------------------------------------------------
/tests/test_query_adapter.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's query adapter."""
2 |
3 | from dataclasses import replace
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from raglite import RAGLiteConfig, insert_evals, update_query_adapter, vector_search
9 | from raglite._database import IndexMetadata
10 |
11 |
12 | @pytest.mark.slow
13 | def test_query_adapter(raglite_test_config: RAGLiteConfig) -> None:
14 | """Test the query adapter update functionality."""
15 | # Create a config with and without the query adapter enabled.
16 | config_with_query_adapter = replace(raglite_test_config, vector_search_query_adapter=True)
17 | config_without_query_adapter = replace(raglite_test_config, vector_search_query_adapter=False)
18 | # Verify that there is no query adapter in the database.
19 | Q = IndexMetadata.get("default", config=config_without_query_adapter).get("query_adapter") # noqa: N806
20 | assert Q is None
21 | # Insert evals.
22 | insert_evals(num_evals=2, max_chunks_per_eval=10, config=config_with_query_adapter)
23 | # Update the query adapter.
24 | A = update_query_adapter(config=config_with_query_adapter) # noqa: N806
25 | assert isinstance(A, np.ndarray)
26 | assert A.ndim == 2 # noqa: PLR2004
27 | assert A.shape[0] == A.shape[1]
28 | assert np.isfinite(A).all()
29 | # Verify that there is a query adapter in the database.
30 | Q = IndexMetadata.get("default", config=config_without_query_adapter).get("query_adapter") # noqa: N806
31 | assert isinstance(Q, np.ndarray)
32 | assert Q.ndim == 2 # noqa: PLR2004
33 | assert Q.shape[0] == Q.shape[1]
34 | assert np.isfinite(Q).all()
35 | assert np.all(A == Q)
36 | # Verify that the query adapter affects the results of vector search.
37 | query = "How does Einstein define 'simultaneous events' in his special relativity paper?"
38 | _, scores_qa = vector_search(query, config=config_with_query_adapter)
39 | _, scores_no_qa = vector_search(query, config=config_without_query_adapter)
40 | assert scores_qa != scores_no_qa
41 |
--------------------------------------------------------------------------------
/tests/test_rag.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's RAG functionality."""
2 |
3 | import json
4 |
5 | from raglite import (
6 | RAGLiteConfig,
7 | add_context,
8 | retrieve_context,
9 | )
10 | from raglite._database import ChunkSpan
11 | from raglite._rag import rag
12 |
13 |
14 | def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
15 | """Test Retrieval-Augmented Generation with manual retrieval."""
16 | # Answer a question with manual RAG.
17 | user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
18 | chunk_spans = retrieve_context(query=user_prompt, config=raglite_test_config)
19 | messages = [add_context(user_prompt, context=chunk_spans)]
20 | stream = rag(messages, config=raglite_test_config)
21 | answer = ""
22 | for update in stream:
23 | assert isinstance(update, str)
24 | answer += update
25 | assert "event" in answer.lower()
26 | # Verify that no RAG context was retrieved through tool use.
27 | assert [message["role"] for message in messages] == ["user", "assistant"]
28 |
29 |
30 | def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
31 | """Test Retrieval-Augmented Generation with automatic retrieval."""
32 | # Answer a question that requires RAG.
33 | user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
34 | messages = [{"role": "user", "content": user_prompt}]
35 | chunk_spans = []
36 | stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
37 | answer = ""
38 | for update in stream:
39 | assert isinstance(update, str)
40 | answer += update
41 | assert "event" in answer.lower()
42 | # Verify that RAG context was retrieved automatically.
43 | assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"]
44 | assert json.loads(messages[-2]["content"])
45 | assert chunk_spans
46 | assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
47 |
48 |
49 | def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
50 | """Test Retrieval-Augmented Generation with automatic retrieval."""
51 | # Answer a question that does not require RAG.
52 | user_prompt = "Is 7 a prime number?"
53 | messages = [{"role": "user", "content": user_prompt}]
54 | chunk_spans = []
55 | stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
56 | answer = ""
57 | for update in stream:
58 | assert isinstance(update, str)
59 | answer += update
60 | # Verify that no RAG context was retrieved.
61 | assert [message["role"] for message in messages] == ["user", "assistant"]
62 | assert not chunk_spans
63 |
--------------------------------------------------------------------------------
/tests/test_rerank.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's reranking functionality."""
2 |
3 | import random
4 | from typing import TypeVar
5 |
6 | import pytest
7 | from rerankers.models.flashrank_ranker import FlashRankRanker
8 | from rerankers.models.ranker import BaseRanker
9 | from scipy.stats import kendalltau
10 |
11 | from raglite import RAGLiteConfig, rerank_chunks, retrieve_chunks, vector_search
12 | from raglite._database import Chunk
13 |
14 | T = TypeVar("T")
15 |
16 |
17 | def kendall_tau(a: list[T], b: list[T]) -> float:
18 | """Measure the Kendall rank correlation coefficient between two lists."""
19 | τ: float = kendalltau(range(len(a)), [a.index(el) for el in b])[0] # noqa: PLC2401
20 | return τ
21 |
22 |
23 | @pytest.fixture(
24 | params=[
25 | pytest.param(None, id="no_reranker"),
26 | pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"),
27 | pytest.param(
28 | {
29 | "en": FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0),
30 | "other": FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0),
31 | },
32 | id="flashrank_multilingual",
33 | ),
34 | ],
35 | )
36 | def reranker(
37 | request: pytest.FixtureRequest,
38 | ) -> BaseRanker | dict[str, BaseRanker] | None:
39 | """Get a reranker to test RAGLite with."""
40 | reranker: BaseRanker | dict[str, BaseRanker] | None = request.param
41 | return reranker
42 |
43 |
44 | def test_reranker(
45 | raglite_test_config: RAGLiteConfig,
46 | reranker: BaseRanker | dict[str, BaseRanker] | None,
47 | ) -> None:
48 | """Test inserting a document, updating the indexes, and searching for a query."""
49 | # Update the config with the reranker.
50 | raglite_test_config = RAGLiteConfig(
51 | db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker
52 | )
53 | # Search for a query.
54 | query = "What does it mean for two events to be simultaneous?"
55 | chunk_ids, _ = vector_search(query, num_results=40, config=raglite_test_config)
56 | # Retrieve the chunks.
57 | chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
58 | assert all(isinstance(chunk, Chunk) for chunk in chunks)
59 | assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
60 | # Randomly shuffle the chunks.
61 | random.seed(42)
62 | chunks_random = random.sample(chunks, len(chunks))
63 | # Rerank the chunks starting from a pathological order and verify that it improves the ranking.
64 | for arg in (chunks[::-1], chunk_ids[::-1]):
65 | reranked_chunks = rerank_chunks(query, arg, config=raglite_test_config)
66 | if reranker:
67 | τ_search = kendall_tau(chunks, reranked_chunks) # noqa: PLC2401
68 | τ_inverse = kendall_tau(chunks[::-1], reranked_chunks) # noqa: PLC2401
69 | τ_random = kendall_tau(chunks_random, reranked_chunks) # noqa: PLC2401
70 | assert τ_search >= τ_random >= τ_inverse
71 |
--------------------------------------------------------------------------------
/tests/test_search.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's search functionality."""
2 |
3 | import pytest
4 |
5 | from raglite import (
6 | RAGLiteConfig,
7 | hybrid_search,
8 | keyword_search,
9 | retrieve_chunk_spans,
10 | retrieve_chunks,
11 | vector_search,
12 | )
13 | from raglite._database import Chunk, ChunkSpan, Document
14 | from raglite._typing import BasicSearchMethod
15 |
16 |
17 | @pytest.fixture(
18 | params=[
19 | pytest.param(keyword_search, id="keyword_search"),
20 | pytest.param(vector_search, id="vector_search"),
21 | pytest.param(hybrid_search, id="hybrid_search"),
22 | ],
23 | )
24 | def search_method(
25 | request: pytest.FixtureRequest,
26 | ) -> BasicSearchMethod:
27 | """Get a search method to test RAGLite with."""
28 | search_method: BasicSearchMethod = request.param
29 | return search_method
30 |
31 |
32 | def test_search(raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod) -> None:
33 | """Test searching for a query."""
34 | # Search for a query.
35 | query = "What does it mean for two events to be simultaneous?"
36 | num_results = 5
37 | chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
38 | assert len(chunk_ids) == len(scores) == num_results
39 | assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
40 | assert all(isinstance(score, float) for score in scores)
41 | # Retrieve the chunks.
42 | chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
43 | assert all(isinstance(chunk, Chunk) for chunk in chunks)
44 | assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
45 | assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks), (
46 | "Expected 'Definition of Simultaneity' in chunks but got:\n"
47 | + "\n".join(f"- Chunk {i + 1}:\n{chunk!s}\n{'-' * 80}" for i, chunk in enumerate(chunks))
48 | )
49 | assert all(isinstance(chunk.document, Document) for chunk in chunks)
50 | # Extend the chunks with their neighbours and group them into contiguous segments.
51 | chunk_spans = retrieve_chunk_spans(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
52 | assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
53 | assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans)
54 | chunk_spans = retrieve_chunk_spans(chunks, neighbors=(-1, 1), config=raglite_test_config)
55 | assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
56 | assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans)
57 |
58 |
59 | def test_search_no_results(
60 | raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod
61 | ) -> None:
62 | """Test searching for a query with no keyword search results."""
63 | query = "supercalifragilisticexpialidocious"
64 | num_results = 5
65 | chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
66 | num_results_expected = 0 if search_method == keyword_search else num_results
67 | assert len(chunk_ids) == len(scores) == num_results_expected
68 | assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
69 | assert all(isinstance(score, float) for score in scores)
70 |
71 |
72 | def test_search_empty_database(llm: str, embedder: str, search_method: BasicSearchMethod) -> None:
73 | """Test searching for a query with an empty database."""
74 | raglite_test_config = RAGLiteConfig(db_url="duckdb:///:memory:", llm=llm, embedder=embedder)
75 | query = "supercalifragilisticexpialidocious"
76 | num_results = 5
77 | chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
78 | num_results_expected = 0
79 | assert len(chunk_ids) == len(scores) == num_results_expected
80 | assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
81 | assert all(isinstance(score, float) for score in scores)
82 |
--------------------------------------------------------------------------------
/tests/test_split_chunklets.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's chunk splitting functionality."""
2 |
3 | import pytest
4 |
5 | from raglite._split_chunklets import split_chunklets
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "sentences_splits",
10 | [
11 | pytest.param(
12 | (
13 | [
14 | # Sentence 1:
15 | "It is known that Maxwell’s electrodynamics—as usually understood at the\n" # noqa: RUF001
16 | "present time—when applied to moving bodies, leads to asymmetries which do\n\n"
17 | "not appear to be inherent in the phenomena. ",
18 | # Sentence 2:
19 | "Take, for example, the recipro-\ncal electrodynamic action of a magnet and a conductor. \n\n",
20 | # Sentence 3 (heading):
21 | "# ON THE ELECTRODYNAMICS OF MOVING BODIES\n\n",
22 | # Sentence 4 (heading):
23 | "## By A. EINSTEIN June 30, 1905\n\n",
24 | # Sentence 5 (paragraph boundary):
25 | "The observable phe-\n"
26 | "nomenon here depends only on the relative motion of the conductor and the\n"
27 | "magnet, whereas the customary view draws a sharp distinction between the two\n"
28 | "cases in which either the one or the other of these bodies is in motion. ",
29 | # Sentence 6:
30 | "For if the\n"
31 | "magnet is in motion and the conductor at rest, there arises in the neighbour-\n"
32 | "hood of the magnet an electric field with a certain definite energy, producing\n"
33 | "a current at the places where parts of the conductor are situated. ",
34 | ],
35 | [2],
36 | ),
37 | id="consecutive_boundaries",
38 | ),
39 | pytest.param(
40 | (
41 | [
42 | # Sentence 1:
43 | "The theory to be developed is based—like all electrodynamics—on the kine-\n"
44 | "matics of the rigid body, since the assertions of any such theory have to do\n"
45 | "with the relationships between rigid bodies (systems of co-ordinates), clocks,\n"
46 | "and electromagnetic processes. ",
47 | # Sentence 2:
48 | "Insufficient consideration of this circumstance\n"
49 | "lies at the root of the difficulties which the electrodynamics of moving bodies\n"
50 | "at present encounters.\n\n",
51 | # Sentence 3 (paragraph boundary):
52 | "The observable phe-\n"
53 | "nomenon here depends only on the relative motion of the conductor and the\n"
54 | "magnet, whereas the customary view draws a sharp distinction between the two\n"
55 | "cases in which either the one or the other of these bodies is in motion. ",
56 | # Sentence 4:
57 | "For if the\n"
58 | "magnet is in motion and the conductor at rest, there arises in the neighbour-\n"
59 | "hood of the magnet an electric field with a certain definite energy, producing\n"
60 | "a current at the places where parts of the conductor are situated. ",
61 | # Sentence 5:
62 | "But if the\n"
63 | "magnet is stationary and the conductor in motion, no electric field arises in the\n"
64 | "neighbourhood of the magnet. ",
65 | ],
66 | [2],
67 | ),
68 | id="paragraph_boundary",
69 | ),
70 | ],
71 | )
72 | def test_split_chunklets(sentences_splits: tuple[list[str], list[int]]) -> None:
73 | """Test chunklet splitting."""
74 | sentences, splits = sentences_splits
75 | chunklets = split_chunklets(sentences)
76 | expected_chunklets = [
77 | "".join(sentences[i:j])
78 | for i, j in zip([0, *splits], [*splits, len(sentences)], strict=True)
79 | ]
80 | assert isinstance(chunklets, list)
81 | assert all(isinstance(chunklet, str) for chunklet in chunklets)
82 | assert sum(len(chunklet) for chunklet in chunklets) == sum(
83 | len(sentence) for sentence in sentences
84 | )
85 | assert all(
86 | chunklet == expected_chunklet
87 | for chunklet, expected_chunklet in zip(chunklets, expected_chunklets, strict=True)
88 | )
89 |
--------------------------------------------------------------------------------
/tests/test_split_chunks.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's chunk splitting functionality."""
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from raglite._split_chunks import split_chunks
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "chunklets",
11 | [
12 | pytest.param([], id="one_chunk:no_chunklets"),
13 | pytest.param(["Hello world"], id="one_chunk:one_chunklet"),
14 | pytest.param(["Hello world"] * 2, id="one_chunk:two_chunklets"),
15 | pytest.param(["Hello world"] * 3, id="one_chunk:three_chunklets"),
16 | pytest.param(["Hello world"] * 100, id="one_chunk:many_chunklets"),
17 | pytest.param(["Hello world", "X" * 1000], id="n_chunks:two_chunklets_a"),
18 | pytest.param(["X" * 1000, "Hello world"], id="n_chunks:two_chunklets_b"),
19 | pytest.param(["Hello world", "X" * 1000, "X" * 1000], id="n_chunks:three_chunklets_a"),
20 | pytest.param(["X" * 1000, "Hello world", "X" * 1000], id="n_chunks:three_chunklets_b"),
21 | pytest.param(["X" * 1000, "X" * 1000, "Hello world"], id="n_chunks:three_chunklets_c"),
22 | pytest.param(["X" * 1000] * 100, id="n_chunks:many_chunklets_a"),
23 | pytest.param(["X" * 100] * 1000, id="n_chunks:many_chunklets_b"),
24 | ],
25 | )
26 | def test_edge_cases(chunklets: list[str]) -> None:
27 | """Test chunk splitting edge cases."""
28 | chunklet_embeddings = np.ones((len(chunklets), 768)).astype(np.float16)
29 | chunks, chunk_embeddings = split_chunks(chunklets, chunklet_embeddings, max_size=1440)
30 | assert isinstance(chunks, list)
31 | assert isinstance(chunk_embeddings, list)
32 | assert len(chunk_embeddings) == (len(chunks) if chunklets else 1)
33 | assert all(isinstance(chunk, str) for chunk in chunks)
34 | assert all(isinstance(chunk_embedding, np.ndarray) for chunk_embedding in chunk_embeddings)
35 | assert all(ce.dtype == chunklet_embeddings.dtype for ce in chunk_embeddings)
36 | assert sum(ce.shape[0] for ce in chunk_embeddings) == chunklet_embeddings.shape[0]
37 | assert all(ce.shape[1] == chunklet_embeddings.shape[1] for ce in chunk_embeddings)
38 |
39 |
40 | @pytest.mark.parametrize(
41 | "chunklets",
42 | [
43 | pytest.param(["Hello world" * 1000] + ["X"] * 100, id="first"),
44 | pytest.param(["X"] * 50 + ["Hello world" * 1000] + ["X"] * 50, id="middle"),
45 | pytest.param(["X"] * 100 + ["Hello world" * 1000], id="last"),
46 | ],
47 | )
48 | def test_long_chunklet(chunklets: list[str]) -> None:
49 | """Test chunking on chunklets that are too long."""
50 | chunklet_embeddings = np.ones((len(chunklets), 768)).astype(np.float16)
51 | with pytest.raises(ValueError, match="Chunklet larger than chunk max_size detected."):
52 | _ = split_chunks(chunklets, chunklet_embeddings, max_size=1440)
53 |
--------------------------------------------------------------------------------
/tests/test_split_sentences.py:
--------------------------------------------------------------------------------
1 | """Test RAGLite's sentence splitting functionality."""
2 |
3 | from pathlib import Path
4 |
5 | import pytest
6 |
7 | from raglite._markdown import document_to_markdown
8 | from raglite._split_sentences import split_sentences
9 |
10 |
11 | def test_split_sentences() -> None:
12 | """Test splitting a document into sentences."""
13 | doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
14 | doc = document_to_markdown(doc_path)
15 | sentences = split_sentences(doc)
16 | expected_sentences = [
17 | "# ON THE ELECTRODYNAMICS OF MOVING BODIES \n\n",
18 | "## By A. EINSTEIN June 30, 1905 \n\n",
19 | "It is known that Maxwell’s electrodynamics—as usually understood at the\npresent time—when applied to moving bodies, leads to asymmetries which do\n\nnot appear to be inherent in the phenomena. ", # noqa: RUF001
20 | "Take, for example, the recipro-\ncal electrodynamic action of a magnet and a conductor. ",
21 | "The observable phe-\nnomenon here depends only on the relative motion of the conductor and the\nmagnet, whereas the customary view draws a sharp distinction between the two\ncases in which either the one or the other of these bodies is in motion. ",
22 | "For if the\nmagnet is in motion and the conductor at rest, there arises in the neighbour-\nhood of the magnet an electric field with a certain definite energy, producing\na current at the places where parts of the conductor are situated. ",
23 | "But if the\n\nmagnet is stationary and the conductor in motion, no electric field arises in the\nneighbourhood of the magnet. ",
24 | "In the conductor, however, we find an electro-\nmotive force, to which in itself there is no corresponding energy, but which gives\nrise—assuming equality of relative motion in the two cases discussed—to elec-\n\ntric currents of the same path and intensity as those produced by the electric\nforces in the former case.\n\n",
25 | "Examples of this sort, together with the unsuccessful attempts to discover\nany motion of the earth relatively to the “light medium,” suggest that the\n\nphenomena of electrodynamics as well as of mechanics possess no properties\ncorresponding to the idea of absolute rest. ",
26 | "They suggest rather that, as has\nalready been shown to the first order of small quantities, the same laws of\nelectrodynamics and optics will be valid for all frames of reference for which the\nequations of mechanics hold good.1 ",
27 | "We will raise this conjecture (the purport\nof which will hereafter be called the “Principle of Relativity”) to the status\n\nof a postulate, and also introduce another postulate, which is only apparently\nirreconcilable with the former, namely, that light is always propagated in empty\nspace with a definite velocity c which is independent of the state of motion of the\nemitting body. ",
28 | "These two postulates suffice for the attainment of a simple and\nconsistent theory of the electrodynamics of moving bodies based on Maxwell’s\ntheory for stationary bodies. ", # noqa: RUF001
29 | "The introduction of a “luminiferous ether” will\nprove to be superfluous inasmuch as the view here to be developed will not\nrequire an “absolutely stationary space” provided with special properties, nor\n1",
30 | "The preceding memoir by Lorentz was not at this time known to the author.\n\n",
31 | "assign a velocity-vector to a point of the empty space in which electromagnetic\nprocesses take place.\n\n",
32 | "The theory to be developed is based—like all electrodynamics—on the kine-\nmatics of the rigid body, since the assertions of any such theory have to do\nwith the relationships between rigid bodies (systems of co-ordinates), clocks,\nand electromagnetic processes. ",
33 | "Insufficient consideration of this circumstance\nlies at the root of the difficulties which the electrodynamics of moving bodies\nat present encounters.\n\n",
34 | "## I. KINEMATICAL PART § **1. Definition of Simultaneity** \n\n",
35 | "Let us take a system of co-ordinates in which the equations of Newtonian\nmechanics hold good.2 ",
36 | ]
37 | assert isinstance(sentences, list)
38 | assert all(not sentence.isspace() for sentence in sentences)
39 | assert all(
40 | sentence == expected_sentence
41 | for sentence, expected_sentence in zip(
42 | sentences[: len(expected_sentences)], expected_sentences, strict=True
43 | )
44 | )
45 |
46 |
47 | @pytest.mark.parametrize(
48 | "case",
49 | [
50 | pytest.param(("", [""], (4, None)), id="tiny-0"),
51 | pytest.param(("Hi!", ["Hi!"], (4, None)), id="tiny-1a"),
52 | pytest.param(("Yes? No!", ["Yes? No!"], (4, None)), id="tiny-1b"),
53 | pytest.param(("Yes? No!", ["Yes? ", "No!"], (3, None)), id="tiny-2a"),
54 | pytest.param(("\n\nYes?\n\nNo!\n\n", ["\n\nYes?\n\n", "No!\n\n"], (3, None)), id="tiny-2b"),
55 | pytest.param(
56 | ("X" * 768 + "\n\n" + "X" * 768, ["X" * 768 + "\n\n" + "X" * 768], (4, None)),
57 | id="huge-1",
58 | ),
59 | pytest.param(
60 | ("X" * 768 + "\n\n" + "X" * 768, ["X" * 768 + "\n\n", "X" * 768], (4, 1024)),
61 | id="huge-2a",
62 | ),
63 | pytest.param(
64 | ("X" * 768 + " " + "X" * 768, ["X" * 768 + " ", "X" * 768], (4, 1024)),
65 | id="huge-2b",
66 | ),
67 | ],
68 | )
69 | def test_split_sentences_edge_cases(case: tuple[str, list[str], tuple[int, int | None]]) -> None:
70 | """Test edge cases of splitting a document into sentences."""
71 | doc, expected_sentences, (min_len, max_len) = case
72 | sentences = split_sentences(doc, min_len=min_len, max_len=max_len)
73 | assert isinstance(sentences, list)
74 | assert all(
75 | sentence == expected_sentence
76 | for sentence, expected_sentence in zip(
77 | sentences[: len(expected_sentences)], expected_sentences, strict=True
78 | )
79 | )
80 |
--------------------------------------------------------------------------------