9 |
10 | - :material-hexagon-outline: [__LSR Fine-tuning__](./lsr.md) — A tutorial on the
11 | LM-Supervised Retriever (LSR) fine-tuning method.
12 | - :material-hexagon-outline: [__RALT Fine-tuning__](./ralt.md) — A tutorial on the
13 | Retriever-Augmented LM Training (RALT) method.
14 |
15 |
16 |
--------------------------------------------------------------------------------
/src/fed_rag/inspectors/common.py:
--------------------------------------------------------------------------------
1 | """Common abstractions for inspectors"""
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class TrainerSignatureSpec(BaseModel):
7 | net_parameter: str
8 | train_data_param: str
9 | val_data_param: str
10 | extra_train_kwargs: list[str] = []
11 | net_parameter_class_name: str
12 |
13 |
14 | class TesterSignatureSpec(BaseModel):
15 | __test__ = (
16 | False # needed for Pytest collision. Avoids PytestCollectionWarning
17 | )
18 | net_parameter: str
19 | test_data_param: str
20 | extra_test_kwargs: list[str] = []
21 | net_parameter_class_name: str
22 |
--------------------------------------------------------------------------------
/tests/utils/data/finetuning_datasets/test_pt_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from fed_rag.utils.data.finetuning_datasets import PyTorchRAGFinetuningDataset
4 |
5 |
6 | def test_pt_rag_ft_dataset_init(
7 | input_and_target_ids: tuple[torch.Tensor, torch.Tensor],
8 | ) -> None:
9 | input_ids, target_ids = input_and_target_ids
10 | rag_ft_dataset = PyTorchRAGFinetuningDataset(
11 | input_ids=input_ids, target_ids=target_ids
12 | )
13 |
14 | assert len(rag_ft_dataset) == len(input_ids)
15 | assert isinstance(rag_ft_dataset, torch.utils.data.Dataset)
16 | assert rag_ft_dataset[:] == input_and_target_ids[:]
17 |
--------------------------------------------------------------------------------
/src/fed_rag/core/rag_system/synchronous.py:
--------------------------------------------------------------------------------
1 | """RAG System Module"""
2 |
3 | from fed_rag._bridges.langchain.bridge import LangChainBridgeMixin
4 | from fed_rag._bridges.llamaindex.bridge import LlamaIndexBridgeMixin
5 | from fed_rag.core.rag_system._synchronous import _RAGSystem
6 |
7 |
8 | # Define the public RAGSystem with all available bridges
9 | class RAGSystem(LlamaIndexBridgeMixin, LangChainBridgeMixin, _RAGSystem):
10 | """RAG System with all available bridge functionality.
11 |
12 | The RAGSystem is the main entry point for creating and managing
13 | retrieval-augmented generation systems.
14 | """
15 |
16 | pass
17 |
--------------------------------------------------------------------------------
/tests/evals/metrics/test_exact_match.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from fed_rag.evals import ExactMatchEvaluationMetric
4 |
5 |
6 | @pytest.mark.parametrize(
7 | ("pred", "actual", "expected"),
8 | [
9 | ("1+1=2", "1+1=2", 1.0),
10 | ("Yes, Correct!", "yes, correct!", 1.0),
11 | ("not the same", "as me", 0.0),
12 | ],
13 | ids=["match", "match case insensitive", "not match"],
14 | )
15 | def test_exact_match(pred: str, actual: str, expected: float) -> None:
16 | metric = ExactMatchEvaluationMetric()
17 |
18 | # act
19 | res = metric(prediction=pred, actual=actual)
20 |
21 | assert res == expected
22 |
--------------------------------------------------------------------------------
/src/fed_rag/types/bridge.py:
--------------------------------------------------------------------------------
1 | """Bridge type definitions for fed-rag.
2 |
3 | Note: The BridgeMetadata implementation has moved to fed_rag.data_structures.bridge.
4 | This module is maintained for backward compatibility.
5 | """
6 |
7 | import warnings
8 |
9 | from ..data_structures.bridge import BridgeMetadata
10 |
11 | warnings.warn(
12 | "Importing BridgeMetadata from fed_rag.types.bridge is deprecated and will be "
13 | "removed in a future release. Use fed_rag.data_structures.bridge or "
14 | "fed_rag.data_structures instead.",
15 | DeprecationWarning,
16 | stacklevel=2, # point to users import statement
17 | )
18 |
19 | __all__ = ["BridgeMetadata"]
20 |
--------------------------------------------------------------------------------
/examples/knowledge_stores/ra-dit-ks/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "ra-dit-ks"
7 | version = "0.1.0"
8 | description = "Knowledge store builds for ra-dit examples"
9 | readme = "README.md"
10 | authors = [
11 | {name = "nerdai", email = "andrei@vectorinstitute.ai"}
12 | ]
13 | requires-python = ">=3.12"
14 | dependencies = [
15 | "colorama>=0.4.6",
16 | "fed-rag[huggingface,qdrant]>=0.0.13",
17 | "fire>=0.7.0",
18 | "python-dotenv>=1.1.0"
19 | ]
20 |
21 | [project.scripts]
22 | ra-dit = "ra_dit:main"
23 |
24 | [tool.uv.sources]
25 | fed-rag = {path = "../../../", editable = true}
26 |
--------------------------------------------------------------------------------
/src/fed_rag/types/results.py:
--------------------------------------------------------------------------------
1 | """Data structures for results
2 |
3 | Note: The correct module has moved to fed_rag.data_structures.results. This module is
4 | maintained for backward compatibility.
5 | """
6 |
7 | import warnings
8 |
9 | from ..data_structures.results import TestResult, TrainResult
10 |
11 | warnings.warn(
12 | "Importing TrainResult, TestResult from fed_rag.types.results"
13 | "is deprecated and will be removed in a future release. Use "
14 | "fed_rag.data_structures.results or fed_rag.data_structures instead.",
15 | DeprecationWarning,
16 | stacklevel=2, # point to users import statement
17 | )
18 |
19 | __all__ = ["TrainResult", "TestResult"]
20 |
--------------------------------------------------------------------------------
/src/fed_rag/types/rag.py:
--------------------------------------------------------------------------------
1 | """Data structures for RAG.
2 |
3 | Note: The correct module has moved to fed_rag.data_structures.rag. This module is
4 | maintained for backward compatibility.
5 | """
6 |
7 | import warnings
8 |
9 | from ..data_structures.rag import RAGConfig, RAGResponse, SourceNode
10 |
11 | warnings.warn(
12 | "Importing RAGConfig, RAGResponse, SourceNode from fed_rag.types.rag"
13 | "is deprecated and will be removed in a future release. Use "
14 | "fed_rag.data_structures.rag or fed_rag.data_structures instead.",
15 | DeprecationWarning,
16 | stacklevel=2, # point to users import statement
17 | )
18 |
19 | __all__ = ["RAGConfig", "RAGResponse", "SourceNode"]
20 |
--------------------------------------------------------------------------------
/src/fed_rag/types/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | fed_rag.types
3 |
4 | Only components defined in `__all__` are considered stable and public.
5 | """
6 |
7 | from .bridge import BridgeMetadata
8 | from .knowledge_node import KnowledgeNode, NodeContent, NodeType
9 | from .rag import RAGConfig, RAGResponse, SourceNode
10 | from .results import TestResult, TrainResult
11 |
12 | __all__ = [
13 | # bridge
14 | "BridgeMetadata",
15 | # results
16 | "TrainResult",
17 | "TestResult",
18 | # knowledge node
19 | "KnowledgeNode",
20 | "NodeType",
21 | "NodeContent",
22 | # rag
23 | "RAGConfig",
24 | "RAGResponse",
25 | "SourceNode",
26 | ]
27 |
28 | __deprecated__ = True
29 |
--------------------------------------------------------------------------------
/src/fed_rag/evals/benchmarks/huggingface/__init__.py:
--------------------------------------------------------------------------------
1 | from .boolq import HuggingFaceBoolQ
2 | from .hellaswag import HuggingFaceHellaSwag
3 | from .hotpotqa import HuggingFaceHotpotQA
4 | from .mixin import HuggingFaceBenchmarkMixin
5 | from .mmlu import HuggingFaceMMLU
6 | from .natural_questions import HuggingFaceNaturalQuestions
7 | from .pubmedqa import HuggingFacePubMedQA
8 | from .squad_v2 import HuggingFaceSQuADv2
9 |
10 | __all__ = [
11 | "HuggingFaceBenchmarkMixin",
12 | "HuggingFaceMMLU",
13 | "HuggingFacePubMedQA",
14 | "HuggingFaceHotpotQA",
15 | "HuggingFaceSQuADv2",
16 | "HuggingFaceNaturalQuestions",
17 | "HuggingFaceBoolQ",
18 | "HuggingFaceHellaSwag",
19 | ]
20 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "github-actions"
9 | directory: "/"
10 | schedule:
11 | interval: "weekly"
12 | - package-ecosystem: "uv"
13 | directory: "/"
14 | schedule:
15 | interval: "weekly"
16 | groups:
17 | all-python-packages:
18 | patterns:
19 | - "**"
20 |
--------------------------------------------------------------------------------
/src/fed_rag/data_collators/huggingface/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 |
3 | from fed_rag.exceptions.common import MissingExtraError
4 |
5 | from .lsr import DataCollatorForLSR
6 | from .ralt import DataCollatorForRALT
7 |
8 | # check if huggingface extra is installed
9 | _has_huggingface = (importlib.util.find_spec("transformers") is not None) and (
10 | importlib.util.find_spec("peft") is not None
11 | )
12 | if not _has_huggingface:
13 | msg = (
14 | f"`{__name__}` requires `huggingface` extra to be installed."
15 | " To fix please run `pip install fed-rag[huggingface]`."
16 | )
17 | raise MissingExtraError(msg)
18 |
19 | __all__ = ["DataCollatorForLSR", "DataCollatorForRALT"]
20 |
--------------------------------------------------------------------------------
/docs/overrides/partials/copyright.html:
--------------------------------------------------------------------------------
1 |
2 |
11 | {% if config.copyright %}
12 |
{{ config.copyright }}
13 | {% endif %} {% if not config.extra.generator == false %} Made with
14 |
19 | Material for MkDocs
20 |
21 | {% endif %}
22 |
23 |
--------------------------------------------------------------------------------
/docs/community/contributing/ask_question.md:
--------------------------------------------------------------------------------
1 | # Ask a Question
2 |
3 | We welcome questions from users and contributors at all levels of experience with
4 | FedRAG. Having questions is a natural part of engaging with a complex project,
5 | and we're here to help.
6 |
7 | ## Where to Ask Questions
8 |
9 | FedRAG offers several channels for asking questions:
10 |
11 | - **Discord Community**: Join our [Discord community](https://discord.gg/5GMpSCFbTe)
12 | for real-time discussions and quick questions.
13 |
14 | - **GitHub Discussions**: For longer, more detailed questions, use
15 | [GitHub Discussions](https://github.com/VectorInstitute/fed-rag/discussions).
16 | This is ideal for questions that might benefit the wider community.
17 |
--------------------------------------------------------------------------------
/tests/generators/test_base.py:
--------------------------------------------------------------------------------
1 | from fed_rag.base.generator import BaseGenerator
2 |
3 |
4 | def test_generate(mock_generator: BaseGenerator) -> None:
5 | output = mock_generator.generate(query="hello", context="again")
6 | assert output == "mock output from 'hello' and 'again'."
7 |
8 |
9 | def test_complete(mock_generator: BaseGenerator) -> None:
10 | output = mock_generator.complete(prompt="hello again")
11 | assert output == "mock completion output from 'hello again'."
12 |
13 |
14 | def test_compute_target_sequence_proba(mock_generator: BaseGenerator) -> None:
15 | proba = mock_generator.compute_target_sequence_proba(
16 | prompt="mock prompt", target="mock target"
17 | )
18 | assert proba == 0.42
19 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.yml:
--------------------------------------------------------------------------------
1 | name: General Question
2 | description: Ask a question about using or developing with FedRAG.
3 | title: "[Question]: "
4 | labels: ["question", "triage"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Have a question about FedRAG?
10 | Please ask it here — we’ll do our best to help!
11 |
12 | - type: textarea
13 | id: question
14 | attributes:
15 | label: Your Question
16 | description: Clearly state your question.
17 | validations:
18 | required: true
19 |
20 | - type: textarea
21 | id: context
22 | attributes:
23 | label: Additional Context
24 | description: If helpful, add code examples, screenshots, or links.
25 |
--------------------------------------------------------------------------------
/examples/ra-dit/README.md:
--------------------------------------------------------------------------------
1 | # RA-DIT
2 |
3 | ## Usage
4 |
5 | Run the following commands from the `examples/ra-dit` directory.
6 |
7 | ```sh
8 | # source venv
9 | source .venv/bin/activate
10 |
11 | # run federated learning
12 |
13 | ## start server (note this will load the model into cpu)
14 | uv run -m ra_dit.main --task generator --generator_id llama2_7b \
15 | --generator_variant qlora --component server
16 |
17 | ## start clients using a two-gpu setup
18 | CUDA_VISIBLE_DEVICES=0 uv run -m ra_dit.main --task generator --generator_id \
19 | llama2_7b --generator_variant qlora --component client_1
20 |
21 | CUDA_VISIBLE_DEVICES=1 uv run -m ra_dit.main --task generator --generator_id \
22 | llama2_7b --generator_variant qlora --component client_2
23 | ```
24 |
--------------------------------------------------------------------------------
/src/fed_rag/utils/data/finetuning_datasets/pytorch.py:
--------------------------------------------------------------------------------
1 | """PyTorch RAG Finetuning Dataset"""
2 |
3 | from typing import Any
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class PyTorchRAGFinetuningDataset(Dataset):
10 | """PyTorch RAG Fine-Tuning Dataset Class.
11 |
12 | Args:
13 | Dataset (_type_): _description_
14 | """
15 |
16 | def __init__(
17 | self, input_ids: list[torch.Tensor], target_ids: list[torch.Tensor]
18 | ):
19 | self.input_ids = input_ids
20 | self.target_ids = target_ids
21 |
22 | def __len__(self) -> int:
23 | return len(self.input_ids)
24 |
25 | def __getitem__(self, idx: int) -> Any:
26 | return self.input_ids[idx], self.target_ids[idx]
27 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | help: ## Show all Makefile targets.
2 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
3 |
4 | format: ## Run code autoformatters (black).
5 | pre-commit install
6 | git ls-files | xargs pre-commit run black --files
7 |
8 | lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
9 | pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
10 |
11 | test:
12 | pytest tests -v --capture=no
13 |
14 | coverage: # for ci purposes
15 | pytest --cov fed_rag --cov-report=xml tests
16 |
17 | coverage-report: ## Show coverage summary in terminal
18 | coverage report -m
19 |
20 | coverage-html: ## Generate HTML coverage report
21 | coverage html
22 |
--------------------------------------------------------------------------------
/src/fed_rag/evals/utils.py:
--------------------------------------------------------------------------------
1 | """Utils module for evals"""
2 |
3 | import json
4 | from pathlib import Path
5 |
6 | from fed_rag.data_structures.evals import BenchmarkEvaluatedExample
7 | from fed_rag.exceptions import EvaluationsFileNotFoundError
8 |
9 |
10 | def load_evaluations(filename: str | Path) -> list[BenchmarkEvaluatedExample]:
11 | """Utility for loading serialized BenchmarkEvaluatedExamples in a JSONL file."""
12 |
13 | if isinstance(filename, str):
14 | filename = Path(filename)
15 |
16 | if not filename.exists():
17 | raise EvaluationsFileNotFoundError(str(filename))
18 |
19 | with open(filename, "r") as f:
20 | data = [json.loads(line) for line in f]
21 |
22 | return [BenchmarkEvaluatedExample(**item) for item in data]
23 |
--------------------------------------------------------------------------------
/src/fed_rag/types/rag_system.py:
--------------------------------------------------------------------------------
1 | """
2 | RAG System type definitions and implementation.
3 |
4 | Note: The RAGSystem implementation has moved to fed_rag.core.rag_system.
5 | This module is maintained for backward compatibility.
6 | """
7 |
8 | import warnings
9 |
10 | from ..core.rag_system import RAGSystem
11 | from .rag import RAGConfig, RAGResponse, SourceNode
12 |
13 | warnings.warn(
14 | "Importing RAGSystem from fed_rag.types.rag_system is deprecated and will be"
15 | "removed in a future release. Use fed_rag.core.rag_system or fed_rag instead.",
16 | DeprecationWarning,
17 | stacklevel=2, # point to users import statement
18 | )
19 |
20 |
21 | # Export all symbols for backward compatibility
22 | __all__ = ["RAGSystem", "RAGConfig", "RAGResponse", "SourceNode"]
23 |
--------------------------------------------------------------------------------
/src/fed_rag/types/knowledge_node.py:
--------------------------------------------------------------------------------
1 | """Knowledge Node
2 |
3 | Note: The KnowledgeNOde implementation has moved to fed_rag.data_structures.knowledge_node.
4 | This module is maintained for backward compatibility.
5 | """
6 |
7 | import warnings
8 |
9 | from ..data_structures.knowledge_node import (
10 | KnowledgeNode,
11 | NodeContent,
12 | NodeType,
13 | )
14 |
15 | warnings.warn(
16 | "Importing KnowledgeNode, NodeContent, and NodeType from fed_rag.types.knowledge_node"
17 | "is deprecated and will be removed in a future release. Use "
18 | "fed_rag.data_structures.knowledge_node or fed_rag.data_structures instead.",
19 | DeprecationWarning,
20 | stacklevel=2, # point to users import statement
21 | )
22 |
23 | __all__ = ["KnowledgeNode", "NodeContent", "NodeType"]
24 |
--------------------------------------------------------------------------------
/tests/rag_system/test_source_node.py:
--------------------------------------------------------------------------------
1 | from fed_rag.data_structures import KnowledgeNode, SourceNode
2 |
3 |
4 | def test_getattr_sourcenode_wraps_knowledge_node() -> None:
5 | # arrange
6 | knowledge_node = KnowledgeNode(
7 | embedding=[0.1, 0.2],
8 | node_type="text",
9 | text_content="fake text context",
10 | metadata={"some_field": 12},
11 | )
12 |
13 | # act
14 | source_node = SourceNode(score=0.42, node=knowledge_node)
15 |
16 | # assert
17 | assert source_node.score == 0.42
18 | assert source_node.text_content == knowledge_node.text_content
19 | assert source_node.node_type == knowledge_node.node_type
20 | assert source_node.node_id == knowledge_node.node_id
21 | assert source_node.metadata == knowledge_node.metadata
22 |
--------------------------------------------------------------------------------
/src/fed_rag/exceptions/trainer.py:
--------------------------------------------------------------------------------
1 | from .core import FedRAGError
2 |
3 |
4 | class TrainerError(FedRAGError):
5 | """Base errors for all rag trainer relevant exceptions."""
6 |
7 | pass
8 |
9 |
10 | class InconsistentDatasetError(TrainerError):
11 | """Raised if underlying datasets between dataloaders are inconsistent."""
12 |
13 | pass
14 |
15 |
16 | class InvalidLossError(TrainerError):
17 | """Raised if an unexpected loss is attached to a trainer object."""
18 |
19 | pass
20 |
21 |
22 | class InvalidDataCollatorError(TrainerError):
23 | """Raised if an invalid data collator is attached to a trainer object."""
24 |
25 | pass
26 |
27 |
28 | class MissingInputTensor(TrainerError):
29 | """Raised if a required tensor has not been supplied in the inputs."""
30 |
31 | pass
32 |
--------------------------------------------------------------------------------
/tests/generators/test_hf_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | from fed_rag.exceptions import MissingExtraError
7 | from fed_rag.generators.huggingface.utils import check_huggingface_installed
8 |
9 |
10 | def test_check_raises_error() -> None:
11 | """Check raises error from utils."""
12 |
13 | modules = {"transformers": None}
14 |
15 | with patch.dict("sys.modules", modules):
16 | msg = (
17 | "Missing installation of the huggingface extra, yet is required "
18 | "by an imported class. To fix please run `pip install fed-rag[huggingface]`."
19 | )
20 | with pytest.raises(
21 | MissingExtraError,
22 | match=re.escape(msg),
23 | ):
24 | check_huggingface_installed()
25 |
--------------------------------------------------------------------------------
/src/fed_rag/exceptions/evals.py:
--------------------------------------------------------------------------------
1 | """Exceptions for Evals."""
2 |
3 | from .core import FedRAGError, FedRAGWarning
4 |
5 |
6 | class EvalsError(FedRAGError):
7 | """Base evals error for all evals-related exceptions."""
8 |
9 | pass
10 |
11 |
12 | class EvalsWarning(FedRAGWarning):
13 | """Base inspector warning for all evals-related warnings."""
14 |
15 | pass
16 |
17 |
18 | class BenchmarkGetExamplesError(EvalsError):
19 | """Raised if an error occurs when getting examples for a benchmark."""
20 |
21 | pass
22 |
23 |
24 | class BenchmarkParseError(EvalsError):
25 | """Raised when errors occur during parsing examples."""
26 |
27 | pass
28 |
29 |
30 | class EvaluationsFileNotFoundError(EvalsError, FileNotFoundError):
31 | """Benchmark evaluations file not found error."""
32 |
33 | pass
34 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Linting
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | types:
8 | - opened
9 | - synchronize
10 | jobs:
11 | lint:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - name: get code
15 | uses: actions/checkout@v6
16 |
17 | - name: Install uv
18 | uses: astral-sh/setup-uv@v7
19 | with:
20 | # Install a specific version of uv.
21 | version: "0.5.21"
22 | enable-cache: true
23 |
24 | - name: "Set up Python"
25 | uses: actions/setup-python@v6
26 | with:
27 | python-version: "3.12"
28 |
29 | - name: Install the project
30 | run: uv sync --all-extras --dev
31 |
32 | - name: Run linter and formatter
33 | run: |
34 | uv run make lint
35 |
--------------------------------------------------------------------------------
/src/fed_rag/generators/unsloth/utils.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 |
3 | from fed_rag.exceptions import MissingExtraError
4 |
5 |
6 | def check_unsloth_installed(cls_name: str | None = None) -> None:
7 | unsloth_spec = find_spec("unsloth")
8 |
9 | has_unsloth = unsloth_spec is not None
10 | if not has_unsloth:
11 | if cls_name:
12 | msg = (
13 | f"`{cls_name}` requires the `unsloth` extra to be installed. "
14 | "To fix please run `pip install fed-rag[unsloth]`."
15 | )
16 | else:
17 | msg = (
18 | "Missing installation of the `unsloth` extra, yet is required "
19 | "by an imported class. To fix please run `pip install fed-rag[unsloth]`."
20 | )
21 |
22 | raise MissingExtraError(msg)
23 |
--------------------------------------------------------------------------------
/src/fed_rag/exceptions/fl_tasks.py:
--------------------------------------------------------------------------------
1 | """Exceptions for FL Tasks."""
2 |
3 | from .core import FedRAGError
4 |
5 |
6 | class FLTaskError(FedRAGError):
7 | """Base fl task error for all fl-task-related exceptions."""
8 |
9 | pass
10 |
11 |
12 | class MissingFLTaskConfig(FLTaskError):
13 | """Raised if fl task `trainer` and `tester` do not have `__fl_task_tester_config` attr set."""
14 |
15 | pass
16 |
17 |
18 | class MissingRequiredNetParam(FLTaskError):
19 | """Raised when invoking fl_task.server without passing the specified model/net param."""
20 |
21 | pass
22 |
23 |
24 | class NetTypeMismatch(FLTaskError):
25 | """Raised when a `trainer` and `tester` spec have differing `net_parameter_class_name`.
26 |
27 | This indicates that the these methods have different types for the `net_parameter`.
28 | """
29 |
30 | pass
31 |
--------------------------------------------------------------------------------
/docs/community/contributing/submit_issue.md:
--------------------------------------------------------------------------------
1 | # Submitting an Issue
2 |
3 | Issues are an important way to track bugs, feature requests, and improvements to
4 | FedRAG.
5 |
6 | ## Before Creating an Issue
7 |
8 | Before submitting a new issue:
9 |
10 | 1. **Search existing issues**: Check [GitHub Issues](https://github.com/VectorInstitute/fed-rag/issues)
11 | to see if your problem has already been reported or if a related feature request exists.
12 |
13 | 2. **Check the documentation**: Verify that your question isn't already addressed
14 | in our documentation.
15 |
16 | 3. **Confirm it's an issue**: For general questions, please use [GitHub Discussions](https://github.com/VectorInstitute/fed-rag/discussions)
17 | or our [Discord community](https://discord.gg/5GMpSCFbTe) instead.
18 |
19 | We appreciate your contributions to making FedRAG better through thoughtful issue submissions!
20 |
--------------------------------------------------------------------------------
/src/fed_rag/core/no_encode_rag_system/synchronous.py:
--------------------------------------------------------------------------------
1 | """No Encode RAG System Module"""
2 |
3 | from fed_rag.core.no_encode_rag_system._synchronous import _NoEncodeRAGSystem
4 |
5 |
6 | # Define the public NoEncodeRAGSystem with all available bridges
7 | class NoEncodeRAGSystem(_NoEncodeRAGSystem):
8 | """NoEncode RAG System with all available bridge functionality.
9 |
10 | The NoEncodeRAGSystem is the main entry point for creating and managing
11 | retrieval-augmented generation systems that skip encoding altogether,
12 | enabling direct natural language queries to knowledge sources like MCP
13 | servers, APIs, and databases.
14 |
15 | Unlike traditional RAG systems that require separate retriever components
16 | and pre-computed embeddings, NoEncode RAG systems perform direct queries
17 | against NoEncode knowledge sources.
18 | """
19 |
20 | pass
21 |
--------------------------------------------------------------------------------
/tests/trainers/test_base.py:
--------------------------------------------------------------------------------
1 | from fed_rag import RAGSystem
2 |
3 | from .conftest import MockRetrieverTrainer, MockTrainer
4 |
5 |
6 | def test_init(mock_rag_system: RAGSystem) -> None:
7 | trainer = MockTrainer(
8 | rag_system=mock_rag_system,
9 | train_dataset=[{"query": "mock example", "response": "mock response"}],
10 | )
11 |
12 | assert trainer.rag_system == mock_rag_system
13 |
14 |
15 | def test_retriever_trainer_with_dual_encoder_retriever(
16 | mock_rag_system_dual_encoder: RAGSystem,
17 | ) -> None:
18 | trainer = MockRetrieverTrainer(
19 | rag_system=mock_rag_system_dual_encoder,
20 | train_dataset=[{"query": "mock example", "response": "mock response"}],
21 | )
22 |
23 | assert trainer.rag_system == mock_rag_system_dual_encoder
24 | assert (
25 | trainer.model == mock_rag_system_dual_encoder.retriever.query_encoder
26 | )
27 |
--------------------------------------------------------------------------------
/src/fed_rag/exceptions/trainer_manager.py:
--------------------------------------------------------------------------------
1 | from .core import FedRAGError
2 |
3 |
4 | class RAGTrainerManagerError(FedRAGError):
5 | """Base errors for all rag trainer manager relevant exceptions."""
6 |
7 | pass
8 |
9 |
10 | class UnspecifiedRetrieverTrainer(RAGTrainerManagerError):
11 | """Raised if a retriever trainer has not been specified when one was expected to be."""
12 |
13 | pass
14 |
15 |
16 | class UnspecifiedGeneratorTrainer(RAGTrainerManagerError):
17 | """Raised if a generator trainer has not been specified when one was expected to be."""
18 |
19 | pass
20 |
21 |
22 | class UnsupportedTrainerMode(RAGTrainerManagerError):
23 | """Raised if an unsupported trainer mode has been supplied."""
24 |
25 | pass
26 |
27 |
28 | class InconsistentRAGSystems(RAGTrainerManagerError):
29 | """Raised if trainers have inconsistent underlying RAG systems."""
30 |
31 | pass
32 |
--------------------------------------------------------------------------------
/tests/evals/benchmarks/huggingface/test_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | from fed_rag.evals.benchmarks.huggingface.utils import (
7 | check_huggingface_evals_installed,
8 | )
9 | from fed_rag.exceptions import MissingExtraError
10 |
11 |
12 | def test_check_raises_error() -> None:
13 | """Check raises error from utils."""
14 |
15 | modules = {"datasets": None}
16 |
17 | with patch.dict("sys.modules", modules):
18 | msg = (
19 | "Missing installation of the huggingface-evals extra, yet is required "
20 | "by an import `HuggingFaceBenchmark` class. To fix please run "
21 | "`pip install fed-rag[huggingface-evals]`."
22 | )
23 |
24 | with pytest.raises(
25 | MissingExtraError,
26 | match=re.escape(msg),
27 | ):
28 | check_huggingface_evals_installed()
29 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Fajardo"
5 | given-names: "Andrei"
6 | email: "andrei.fajardo@vectorinstitute.ai"
7 | - family-names: "Emerson"
8 | given-names: "David"
9 | email: "david.emerson@vectorinstitute.ai"
10 | title: "fed-rag"
11 | version: "0.0.27"
12 | abstract: "Simplified fine-tuning of retrieval-augmented generation (RAG) systems."
13 | keywords:
14 | - machine learning
15 | - federated learning
16 | - deep learning
17 | - llms
18 | - rag
19 | - retrieval
20 | - semantic search
21 | license: Apache-2.0
22 | doi: 10.5281/zenodo.15092361
23 | repository-code: "https://github.com/VectorInstitute/fed-rag"
24 | type: software
25 | date-released: "2025-03-26"
26 | contact:
27 | - family-names: "Fajardo"
28 | given-names: "Andrei"
29 | email: "andrei.fajardo@vectorinstitute.ai"
30 |
--------------------------------------------------------------------------------
/tests/api/test_evals_imports.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import pytest
4 |
5 | from fed_rag.evals import __all__ as _evals_all
6 | from fed_rag.evals.benchmarks import __all__ as _benchmarks_all
7 |
8 |
9 | @pytest.mark.parametrize("name", _evals_all)
10 | def test_evals_all_importable(name: str) -> None:
11 | """Tests that all names listed in evals __all__ are importable."""
12 | mod = importlib.import_module("fed_rag.evals")
13 | attr = getattr(mod, name)
14 |
15 | assert hasattr(mod, name)
16 | assert attr is not None
17 |
18 |
19 | @pytest.mark.parametrize("name", _benchmarks_all)
20 | def test_evals_benchmarks_all_importable(name: str) -> None:
21 | """Tests that all names listed in evals.benchmarks __all__ are importable."""
22 | mod = importlib.import_module("fed_rag.evals.benchmarks")
23 | attr = getattr(mod, name)
24 |
25 | assert hasattr(mod, name)
26 | assert attr is not None
27 |
--------------------------------------------------------------------------------
/tests/evals/benchmarks/test_base.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pytest
4 |
5 | from fed_rag.exceptions import BenchmarkGetExamplesError
6 |
7 | from . import _benchmarks as benchmarks
8 |
9 |
10 | def test_sequence_interface() -> None:
11 | # typical pattern
12 | test_benchmark = benchmarks.TestBenchmark()
13 |
14 | assert len(test_benchmark) == 3
15 | assert test_benchmark.num_examples == 3
16 | for ix in range(len(test_benchmark)):
17 | assert test_benchmark[ix] == test_benchmark._examples[ix]
18 | example_iter = iter(test_benchmark.as_iterator())
19 | assert next(example_iter) == test_benchmark[0]
20 |
21 |
22 | def test_get_example_raises_exception() -> None:
23 | # typical pattern
24 |
25 | with pytest.raises(
26 | BenchmarkGetExamplesError,
27 | match=re.escape("Failed to get examples: Too bad, so sad."),
28 | ):
29 | _ = benchmarks.TestBenchmarkBadExamples()
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Local development
2 | venv/
3 | .venv/
4 | .ipynb_checkpoints
5 | .__pycache__
6 | __pycache__
7 | dev_notebooks/
8 | .vscode
9 | .mypy_cache
10 | .pytest_cache
11 | .ruff_cache
12 | .env
13 | fed-rag.code-workspace
14 | .fed_rag/
15 |
16 | # docs
17 | site/
18 | docs/stylesheets/extra.css.map
19 |
20 | # datasets for running examples
21 | data/
22 | !src/fed_rag/utils/data
23 | !tests/utils/data
24 |
25 | # HF training artifacts
26 | tmp_trainer/
27 |
28 | # example benchmark results
29 | .benchmark_results
30 |
31 | # example checkpoints
32 | .checkpoints
33 |
34 | # qdrant
35 | qdrant_storage
36 |
37 | # notebooks
38 | dev_notebooks
39 | **/rag_federated_learning.py
40 |
41 | # trainer outputs
42 | trainer_output
43 | unsloth_compiled_cache
44 |
45 | # import profile
46 | # python -X importtime -c "import fed_rag" 2> import_profile.txt
47 | import_profile.txt
48 |
49 | # coverage
50 | .coverage
51 | coverage.xml
52 | htmlcov/
53 |
--------------------------------------------------------------------------------
/src/fed_rag/data_structures/bridge.py:
--------------------------------------------------------------------------------
1 | from typing import TypedDict
2 |
3 |
4 | class CompatibleVersions(TypedDict, total=False):
5 | """Type definition for compatible versions.
6 |
7 | Defines optional, inclusive version bounds for compatibility checks.
8 |
9 | Attributes:
10 | min: Minimum compatible version (inclusive).
11 | max: Maximum compatible version (inclusive).
12 | """
13 |
14 | min: str
15 | max: str
16 |
17 |
18 | class BridgeMetadata(TypedDict):
19 | """Type definition for bridge metadata.
20 |
21 | Attributes:
22 | bridge_version: The version of the bridge.
23 | framework: The framework name.
24 | compatible_versions: Version bounds for compatibility.
25 | method_name: The method name associated with the bridge.
26 | """
27 |
28 | bridge_version: str
29 | framework: str
30 | compatible_versions: CompatibleVersions
31 | method_name: str
32 |
--------------------------------------------------------------------------------
/examples/ra-dit/ra_dit/retrievers/dragon.py:
--------------------------------------------------------------------------------
1 | """Dragon Retriever."""
2 |
3 | from fed_rag.retrievers.huggingface.hf_sentence_transformer import (
4 | HFSentenceTransformerRetriever,
5 | )
6 |
7 | retriever = HFSentenceTransformerRetriever(
8 | query_model_name="nthakur/dragon-plus-query-encoder",
9 | context_model_name="nthakur/dragon-plus-context-encoder",
10 | load_model_at_init=False,
11 | )
12 |
13 | if __name__ == "__main__":
14 | query = "Where was Marie Curie born?"
15 | contexts = [
16 | "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
17 | "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace.",
18 | ]
19 |
20 | query_embeddings = retriever.encode_query(query)
21 | context_embeddings = retriever.encode_context(contexts)
22 |
23 | scores = query_embeddings @ context_embeddings.T
24 | print(scores)
25 |
--------------------------------------------------------------------------------
/src/fed_rag/evals/benchmarks/huggingface/utils.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 |
3 | from fed_rag.exceptions import MissingExtraError
4 |
5 |
6 | def check_huggingface_evals_installed(cls_name: str | None = None) -> None:
7 | datasets_spec = find_spec("datasets")
8 |
9 | has_huggingface = datasets_spec is not None
10 |
11 | if not has_huggingface:
12 | if cls_name:
13 | msg = (
14 | f"`{cls_name}` requires the `huggingface-evals` extra to be installed. "
15 | "To fix please run `pip install fed-rag[huggingface-evals]`."
16 | )
17 | else:
18 | msg = (
19 | "Missing installation of the huggingface-evals extra, yet is required "
20 | "by an import `HuggingFaceBenchmark` class. To fix please run "
21 | "`pip install fed-rag[huggingface-evals]`."
22 | )
23 |
24 | raise MissingExtraError(msg)
25 |
--------------------------------------------------------------------------------
/src/fed_rag/base/generator_mixins/audio.py:
--------------------------------------------------------------------------------
1 | """Generator Mixins."""
2 |
3 | from typing import Protocol, runtime_checkable
4 |
5 | from fed_rag.exceptions.generator import GeneratorError
6 |
7 |
8 | @runtime_checkable
9 | class GeneratorHasAudioModality(Protocol):
10 | """Associated protocol for `AudioModalityMixin`."""
11 |
12 | __supports_audio__: bool = True
13 |
14 |
15 | class AudioModalityMixin:
16 | """Audio Modality Mixin.
17 |
18 | Meant to be mixed with a `BaseGenerator` to indicate the ability to accept
19 | audio inputs.
20 | """
21 |
22 | __supports_audio__ = True
23 |
24 | def __init_subclass__(cls) -> None:
25 | """Validate this is mixed with `BaseGenerator`."""
26 | super().__init_subclass__()
27 |
28 | if "BaseGenerator" not in [t.__name__ for t in cls.__mro__]:
29 | raise GeneratorError(
30 | "`AudioModalityMixin` must be mixed with `BaseGenerator`."
31 | )
32 |
--------------------------------------------------------------------------------
/src/fed_rag/base/generator_mixins/video.py:
--------------------------------------------------------------------------------
1 | """Generator Mixins."""
2 |
3 | from typing import Protocol, runtime_checkable
4 |
5 | from fed_rag.exceptions.generator import GeneratorError
6 |
7 |
8 | @runtime_checkable
9 | class GeneratorHasVideoModality(Protocol):
10 | """Associated protocol for `VideoModalityMixin`."""
11 |
12 | __supports_video__: bool = True
13 |
14 |
15 | class VideoModalityMixin:
16 | """Video Modality Mixin.
17 |
18 | Meant to be mixed with a `BaseGenerator` to indicate the ability to accept
19 | video inputs.
20 | """
21 |
22 | __supports_video__ = True
23 |
24 | def __init_subclass__(cls) -> None:
25 | """Validate this is mixed with `BaseGenerator`."""
26 | super().__init_subclass__()
27 |
28 | if "BaseGenerator" not in [t.__name__ for t in cls.__mro__]:
29 | raise GeneratorError(
30 | "`VideoModalityMixin` must be mixed with `BaseGenerator`."
31 | )
32 |
--------------------------------------------------------------------------------
/src/fed_rag/base/generator_mixins/image.py:
--------------------------------------------------------------------------------
1 | """Generator Mixins."""
2 |
3 | from typing import Protocol, runtime_checkable
4 |
5 | from fed_rag.exceptions.generator import GeneratorError
6 |
7 |
8 | @runtime_checkable
9 | class GeneratorHasImageModality(Protocol):
10 | """Associated protocol for `ImageModalityMixin`."""
11 |
12 | __supports_images__: bool = True
13 |
14 |
15 | class ImageModalityMixin:
16 | """Image Modality Mixin.
17 |
18 | Meant to be mixed with a `BaseGenerator` to indicate the ability to accept
19 | image inputs.
20 | """
21 |
22 | __supports_images__ = True
23 |
24 | def __init_subclass__(cls) -> None:
25 | """Validate this is mixed with `BaseGenerator`."""
26 | super().__init_subclass__()
27 |
28 | if "BaseGenerator" not in [t.__name__ for t in cls.__mro__]:
29 | raise GeneratorError(
30 | "`ImageModalityMixin` must be mixed with `BaseGenerator`."
31 | )
32 |
--------------------------------------------------------------------------------
/tests/generators/mixins/test_image_mixin.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic import BaseModel
3 |
4 | from fed_rag.base.generator import BaseGenerator
5 | from fed_rag.base.generator_mixins import (
6 | GeneratorHasImageModality,
7 | ImageModalityMixin,
8 | )
9 | from fed_rag.exceptions.generator import GeneratorError
10 |
11 | from ..conftest import MockGenerator
12 |
13 |
14 | class MockMMGenerator(ImageModalityMixin, MockGenerator):
15 | pass
16 |
17 |
18 | def test_mixin() -> None:
19 | mixed_generator = MockMMGenerator()
20 |
21 | assert isinstance(mixed_generator, GeneratorHasImageModality)
22 | assert isinstance(mixed_generator, BaseGenerator)
23 |
24 |
25 | def test_mixin_fails_validation() -> None:
26 | with pytest.raises(
27 | GeneratorError,
28 | match="`ImageModalityMixin` must be mixed with `BaseGenerator`.",
29 | ):
30 |
31 | class InvalidMockMMGenerator(ImageModalityMixin, BaseModel):
32 | pass
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation_improvement.yml:
--------------------------------------------------------------------------------
1 | name: Documentation Improvement
2 | description: Suggest a fix or improvement to the FedRAG documentation.
3 | title: "[Docs]: "
4 | labels: ["documentation", "triage"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Help us improve our documentation!
10 | Please provide as much detail as you can below.
11 |
12 | - type: textarea
13 | id: location
14 | attributes:
15 | label: Location of Issue
16 | description: What page, section, or example needs updating?
17 |
18 | - type: textarea
19 | id: problem
20 | attributes:
21 | label: Problem or Gap
22 | description: Describe the issue — missing info, outdated content, unclear instructions, etc.
23 | validations:
24 | required: true
25 |
26 | - type: textarea
27 | id: suggestion
28 | attributes:
29 | label: Suggested Change
30 | description: What would you like to see instead?
31 |
--------------------------------------------------------------------------------
/src/fed_rag/base/retriever_mixins/audio.py:
--------------------------------------------------------------------------------
1 | """Retriever Mixins."""
2 |
3 | from abc import ABC
4 | from typing import Protocol, runtime_checkable
5 |
6 | from fed_rag.exceptions.retriever import RetrieverError
7 |
8 |
9 | @runtime_checkable
10 | class RetrieverHasAudioModality(Protocol):
11 | """Associated protocol for `AudioRetrieverMixin`."""
12 |
13 | __supports_audio__: bool = True
14 |
15 |
16 | class AudioRetrieverMixin(ABC):
17 | """Audio Retriever Mixin.
18 |
19 | Meant to be mixed with a `BaseRetriever` to add audio modality for
20 | retrieval.
21 | """
22 |
23 | __supports_audio__ = True
24 |
25 | def __init_subclass__(cls) -> None:
26 | """Validate this is mixed with `BaseRetriever`."""
27 | super().__init_subclass__()
28 |
29 | if "BaseRetriever" not in [t.__name__ for t in cls.__mro__]:
30 | raise RetrieverError(
31 | "`AudioRetrieverMixin` must be mixed with `BaseRetriever`."
32 | )
33 |
--------------------------------------------------------------------------------
/src/fed_rag/base/retriever_mixins/image.py:
--------------------------------------------------------------------------------
1 | """Retriever Mixins."""
2 |
3 | from abc import ABC
4 | from typing import Protocol, runtime_checkable
5 |
6 | from fed_rag.exceptions.retriever import RetrieverError
7 |
8 |
9 | @runtime_checkable
10 | class RetrieverHasImageModality(Protocol):
11 | """Associated protocol for `ImageRetrieverMixin`."""
12 |
13 | __supports_images__: bool = True
14 |
15 |
16 | class ImageRetrieverMixin(ABC):
17 | """Image Retriever Mixin.
18 |
19 | Meant to be mixed with a `BaseRetriever` to add image modality for
20 | retrieval.
21 | """
22 |
23 | __supports_images__ = True
24 |
25 | def __init_subclass__(cls) -> None:
26 | """Validate this is mixed with `BaseRetriever`."""
27 | super().__init_subclass__()
28 |
29 | if "BaseRetriever" not in [t.__name__ for t in cls.__mro__]:
30 | raise RetrieverError(
31 | "`ImageRetrieverMixin` must be mixed with `BaseRetriever`."
32 | )
33 |
--------------------------------------------------------------------------------
/src/fed_rag/base/retriever_mixins/video.py:
--------------------------------------------------------------------------------
1 | """Retriever Mixins."""
2 |
3 | from abc import ABC
4 | from typing import Protocol, runtime_checkable
5 |
6 | from fed_rag.exceptions.retriever import RetrieverError
7 |
8 |
9 | @runtime_checkable
10 | class RetrieverHasVideoModality(Protocol):
11 | """Associated protocol for `VideoRetrieverMixin`."""
12 |
13 | __supports_video__: bool = True
14 |
15 |
16 | class VideoRetrieverMixin(ABC):
17 | """Video Retriever Mixin.
18 |
19 | Meant to be mixed with a `BaseRetriever` to add video modality for
20 | retrieval.
21 | """
22 |
23 | __supports_video__ = True
24 |
25 | def __init_subclass__(cls) -> None:
26 | """Validate this is mixed with `BaseRetriever`."""
27 | super().__init_subclass__()
28 |
29 | if "BaseRetriever" not in [t.__name__ for t in cls.__mro__]:
30 | raise RetrieverError(
31 | "`VideoRetrieverMixin` must be mixed with `BaseRetriever`."
32 | )
33 |
--------------------------------------------------------------------------------
/docs/getting_started/quick_starts/index.md:
--------------------------------------------------------------------------------
1 | # Quick Starts
2 |
3 |
4 |
5 | In this next part in getting to know FedRAG, we provide a mini series of
6 | quick start examples in order to get a better feeling of the library.
7 |
8 |
9 |
10 | - :material-hexagon-outline: [__Centralized to Federated__](./federated.md) — Transform
11 | a centralized training task into a federated learning task.
12 | - :material-hexagon-outline: [__Build a RAG System__](./rag_inference.md) — Assemble
13 | a RAG system using FedRAG's lightweight abstractions.
14 | - :material-hexagon-outline: [__Fine-tune a RAG System__](./rag_finetuning.md) — Fine-tune
15 | a RAG system on custom QA data, demonstrating both centralized training and
16 | optional federation capabilities.
17 | - :material-hexagon-outline: [__Benchmark a RAG System__](./benchmark_mmlu.md) —
18 | Evaluate a RAG system on popular benchmarks like MMLU.
19 |
20 |
21 |
--------------------------------------------------------------------------------
/tests/generators/mixins/test_audio_mixin.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic import BaseModel
3 |
4 | from fed_rag.base.generator import BaseGenerator
5 | from fed_rag.base.generator_mixins.audio import (
6 | AudioModalityMixin,
7 | GeneratorHasAudioModality,
8 | )
9 | from fed_rag.exceptions.generator import GeneratorError
10 |
11 | from ..conftest import MockGenerator
12 |
13 |
14 | class MockAudioGenerator(AudioModalityMixin, MockGenerator):
15 | pass
16 |
17 |
18 | def test_audio_mixin() -> None:
19 | mixed_generator = MockAudioGenerator()
20 | assert isinstance(mixed_generator, GeneratorHasAudioModality)
21 | assert isinstance(mixed_generator, BaseGenerator)
22 |
23 |
24 | def test_audio_mixin_fails_validation() -> None:
25 | with pytest.raises(
26 | GeneratorError,
27 | match="`AudioModalityMixin` must be mixed with `BaseGenerator`.",
28 | ):
29 |
30 | class InvalidMockAudioGenerator(AudioModalityMixin, BaseModel):
31 | pass
32 |
--------------------------------------------------------------------------------
/tests/generators/mixins/test_video_mixin.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic import BaseModel
3 |
4 | from fed_rag.base.generator import BaseGenerator
5 | from fed_rag.base.generator_mixins.video import (
6 | GeneratorHasVideoModality,
7 | VideoModalityMixin,
8 | )
9 | from fed_rag.exceptions.generator import GeneratorError
10 |
11 | from ..conftest import MockGenerator
12 |
13 |
14 | class MockVideoGenerator(VideoModalityMixin, MockGenerator):
15 | pass
16 |
17 |
18 | def test_video_mixin() -> None:
19 | mixed_generator = MockVideoGenerator()
20 | assert isinstance(mixed_generator, GeneratorHasVideoModality)
21 | assert isinstance(mixed_generator, BaseGenerator)
22 |
23 |
24 | def test_video_mixin_fails_validation() -> None:
25 | with pytest.raises(
26 | GeneratorError,
27 | match="`VideoModalityMixin` must be mixed with `BaseGenerator`.",
28 | ):
29 |
30 | class InvalidMockVideoGenerator(VideoModalityMixin, BaseModel):
31 | pass
32 |
--------------------------------------------------------------------------------
/src/fed_rag/knowledge_stores/no_encode/mcp/sources/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Protocol
2 |
3 | from mcp.types import CallToolResult
4 |
5 | from fed_rag.data_structures import KnowledgeNode
6 | from fed_rag.exceptions import CallToolResultConversionError
7 |
8 |
9 | class CallToolResultConverter(Protocol):
10 | def __call__(
11 | self, result: CallToolResult, metadata: dict[str, Any] | None = None
12 | ) -> list[KnowledgeNode]:
13 | pass # pragma: no cover
14 |
15 |
16 | def default_converter(
17 | result: CallToolResult, metadata: dict[str, Any] | None = None
18 | ) -> list[KnowledgeNode]:
19 | if result.isError:
20 | raise CallToolResultConversionError(
21 | "Cannot convert a `CallToolResult` with `isError` set to `True`."
22 | )
23 |
24 | return [
25 | KnowledgeNode(
26 | node_type="text",
27 | text_content=c.text,
28 | metadata=metadata,
29 | )
30 | for c in result.content
31 | if c.type == "text"
32 | ]
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: Suggest a new feature or improvement for FedRAG.
3 | title: "[Feature]: "
4 | labels: ["enhancement", "triage"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thank you for suggesting a feature to improve FedRAG!
10 | Please describe your idea in detail below.
11 |
12 | - type: textarea
13 | id: problem
14 | attributes:
15 | label: Problem Statement
16 | description: What problem or need would this feature solve?
17 | validations:
18 | required: true
19 |
20 | - type: textarea
21 | id: proposal
22 | attributes:
23 | label: Proposed Solution
24 | description: How would you like to see this implemented? Feel free to share ideas or API sketches.
25 | validations:
26 | required: true
27 |
28 | - type: textarea
29 | id: alternatives
30 | attributes:
31 | label: Alternatives Considered
32 | description: Have you considered other solutions or workarounds?
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/integration_request.yml:
--------------------------------------------------------------------------------
1 | name: Integration Request
2 | description: Request support for a new framework or tool with FedRAG.
3 | title: "[Integration]: "
4 | labels: ["integration", "triage"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Suggest a new integration for FedRAG!
10 | Tell us about the tool or framework and why it would be useful.
11 |
12 | - type: input
13 | id: framework
14 | attributes:
15 | label: Target Framework/Tool
16 | description: Name and (optionally) link to the tool you want to integrate.
17 | validations:
18 | required: true
19 |
20 | - type: textarea
21 | id: motivation
22 | attributes:
23 | label: Motivation
24 | description: Why would this integration be valuable for FedRAG users?
25 | validations:
26 | required: true
27 |
28 | - type: textarea
29 | id: ideas
30 | attributes:
31 | label: Proposed Approach
32 | description: If you have ideas about how integration might work, share them here.
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: Report an unexpected error or broken behavior in FedRAG.
3 | title: "[Bug]: "
4 | labels: ["bug", "triage"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thanks for taking the time to file a bug report!
10 | Please complete the following sections to help us reproduce and fix the issue.
11 |
12 | - type: textarea
13 | id: what-happened
14 | attributes:
15 | label: Bug Description
16 | description: What happened? What behavior did you expect?
17 | validations:
18 | required: true
19 |
20 | - type: textarea
21 | id: steps-to-reproduce
22 | attributes:
23 | label: Steps to Reproduce
24 | description: Provide a clear minimal example to reproduce the bug.
25 | validations:
26 | required: true
27 |
28 | - type: textarea
29 | id: logs
30 | attributes:
31 | label: Relevant Logs/Tracebacks
32 | description: Please copy and paste any error messages or logs.
33 | render: shell
34 |
--------------------------------------------------------------------------------
/src/fed_rag/core/rag_system/asynchronous.py:
--------------------------------------------------------------------------------
1 | """Async RAG System Module"""
2 |
3 | from fed_rag._bridges.langchain.bridge import LangChainBridgeMixin
4 | from fed_rag._bridges.llamaindex.bridge import LlamaIndexBridgeMixin
5 | from fed_rag.core.rag_system._asynchronous import _AsyncRAGSystem
6 |
7 | from .synchronous import RAGSystem
8 |
9 |
10 | # Define the public RAGSystem with all available bridges
11 | class AsyncRAGSystem(
12 | LlamaIndexBridgeMixin, LangChainBridgeMixin, _AsyncRAGSystem
13 | ):
14 | """Async RAG System with all available bridge functionality.
15 |
16 | The RAGSystem is the main entry point for creating and managing
17 | retrieval-augmented generation systems.
18 | """
19 |
20 | def to_sync(
21 | self,
22 | ) -> RAGSystem:
23 | return RAGSystem(
24 | knowledge_store=self.knowledge_store.to_sync(),
25 | generator=self.generator, # NOTE: this should actually be sync!
26 | retriever=self.retriever, # NOTE: this should actually be sync!
27 | rag_config=self.rag_config,
28 | )
29 |
--------------------------------------------------------------------------------
/src/fed_rag/data_structures/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | fed_rag.data_structures
3 |
4 | Only components defined in `__all__` are considered stable and public.
5 | """
6 |
7 | from .bridge import BridgeMetadata, CompatibleVersions
8 | from .evals import (
9 | AggregationMode,
10 | BenchmarkEvaluatedExample,
11 | BenchmarkExample,
12 | BenchmarkResult,
13 | )
14 | from .knowledge_node import KnowledgeNode, NodeContent, NodeType
15 | from .rag import Context, Prompt, Query, RAGConfig, RAGResponse, SourceNode
16 | from .results import TestResult, TrainResult
17 |
18 | __all__ = [
19 | # bridge
20 | "BridgeMetadata",
21 | "CompatibleVersions",
22 | # evals
23 | "AggregationMode",
24 | "BenchmarkExample",
25 | "BenchmarkResult",
26 | "BenchmarkEvaluatedExample",
27 | # results
28 | "TrainResult",
29 | "TestResult",
30 | # knowledge node
31 | "KnowledgeNode",
32 | "NodeType",
33 | "NodeContent",
34 | # rag
35 | "RAGConfig",
36 | "RAGResponse",
37 | "SourceNode",
38 | "Query",
39 | "Context",
40 | "Prompt",
41 | ]
42 |
--------------------------------------------------------------------------------
/src/fed_rag/generators/huggingface/utils.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 |
3 | from fed_rag.exceptions import MissingExtraError
4 |
5 |
6 | def check_huggingface_installed(cls_name: str | None = None) -> None:
7 | transformers_spec = find_spec("transformers")
8 | peft_spec = find_spec("peft")
9 | sentence_transformers_spec = find_spec("sentence_transformers")
10 |
11 | has_huggingface = (
12 | (transformers_spec is not None)
13 | and (peft_spec is not None)
14 | and (sentence_transformers_spec is not None)
15 | )
16 | if not has_huggingface:
17 | if cls_name:
18 | msg = (
19 | f"`{cls_name}` requires the `huggingface` extra to be installed. "
20 | "To fix please run `pip install fed-rag[huggingface]`."
21 | )
22 | else:
23 | msg = (
24 | "Missing installation of the huggingface extra, yet is required "
25 | "by an imported class. To fix please run `pip install fed-rag[huggingface]`."
26 | )
27 |
28 | raise MissingExtraError(msg)
29 |
--------------------------------------------------------------------------------
/src/fed_rag/utils/data/finetuning_datasets/huggingface.py:
--------------------------------------------------------------------------------
1 | """HuggingFace RAG Finetuning Dataset"""
2 |
3 | from typing_extensions import Self
4 |
5 | from fed_rag.exceptions.common import MissingExtraError
6 |
7 | # check if huggingface extra was installed
8 | try:
9 | from datasets import Dataset
10 | except ModuleNotFoundError:
11 | msg = (
12 | "`HuggingFaceRAGFinetuningDataset` requires the `huggingface` extra to be installed. "
13 | "To fix please run `pip install fed-rag[huggingface]`."
14 | )
15 | raise MissingExtraError(msg)
16 |
17 |
18 | class HuggingFaceRAGFinetuningDataset(Dataset):
19 | """Thin wrapper over ~datasets.Dataset."""
20 |
21 | @classmethod
22 | def from_inputs(
23 | cls,
24 | input_ids: list[list[int]],
25 | target_ids: list[list[int]],
26 | attention_mask: list[list[int]],
27 | ) -> Self:
28 | return cls.from_dict( # type: ignore[no-any-return]
29 | {
30 | "input_ids": input_ids,
31 | "target_ids": target_ids,
32 | "attention_mask": attention_mask,
33 | }
34 | )
35 |
--------------------------------------------------------------------------------
/src/fed_rag/base/data_collator.py:
--------------------------------------------------------------------------------
1 | """Base Data Collator"""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any
5 |
6 | from pydantic import BaseModel, ConfigDict
7 |
8 | from fed_rag import RAGSystem
9 |
10 |
11 | class BaseDataCollator(BaseModel, ABC):
12 | """
13 | Base Data Collator.
14 |
15 | Abstract base class for collating input examples into batches that can
16 | be used by a retrieval-augmented generation (RAG) system.
17 | """
18 |
19 | model_config = ConfigDict(arbitrary_types_allowed=True)
20 | rag_system: RAGSystem
21 |
22 | @abstractmethod
23 | def __call__(self, features: list[dict[str, Any]], **kwargs: Any) -> Any:
24 | """Collate examples into a batch.
25 |
26 | Args:
27 | features (list[dict[str, Any]]): A list of feature dictionaries,
28 | where each dictionary represents one example.
29 | **kwargs (Any): Additional keyword arguments that may be used
30 | by specific implementations.
31 |
32 | Returns:
33 | Any: A collated batch, with format depending on the implementation.
34 | """
35 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
2 | contact_links:
3 | - name: General Questions or Support
4 | url: https://github.com/VectorInstitute/fed-rag/discussions
5 | about: Please ask questions, request help, or start a conversation here.
6 |
7 | issue_templates:
8 | - name: Bug Report
9 | description: Report an unexpected error or broken behavior in FedRAG.
10 | labels: ["bug"]
11 | file: bug_report.md
12 |
13 | - name: Feature Request
14 | description: Suggest a new feature or improvement for FedRAG.
15 | labels: ["enhancement"]
16 | file: feature_request.md
17 |
18 | - name: Documentation Improvement
19 | description: Propose improvements or updates to the FedRAG documentation.
20 | labels: ["documentation"]
21 | file: documentation_improvement.md
22 |
23 | - name: Integration Request
24 | description: Request support for new frameworks or tools with FedRAG.
25 | labels: ["integration"]
26 | file: integration_request.md
27 |
28 | - name: General Question
29 | description: Ask a question about using or developing with FedRAG.
30 | labels: ["question"]
31 | file: question.md
32 |
--------------------------------------------------------------------------------
/src/fed_rag/__init__.pyi:
--------------------------------------------------------------------------------
1 | """Type stubs for fed_rag module"""
2 |
3 | # Lazy-loaded classes (type-only declarations)
4 | # Lazy-loaded modules
5 | from fed_rag import generators as generators
6 | from fed_rag import retrievers as retrievers
7 | from fed_rag import trainer_managers as trainer_managers
8 | from fed_rag import trainers as trainers
9 | from fed_rag.generators import HFPeftModelGenerator as HFPeftModelGenerator
10 | from fed_rag.generators import (
11 | HFPretrainedModelGenerator as HFPretrainedModelGenerator,
12 | )
13 | from fed_rag.generators import (
14 | UnslothFastModelGenerator as UnslothFastModelGenerator,
15 | )
16 | from fed_rag.retrievers import (
17 | HFSentenceTransformerRetriever as HFSentenceTransformerRetriever,
18 | )
19 | from fed_rag.trainer_managers import (
20 | HuggingFaceRAGTrainerManager as HuggingFaceRAGTrainerManager,
21 | )
22 | from fed_rag.trainer_managers import (
23 | PyTorchRAGTrainerManager as PyTorchRAGTrainerManager,
24 | )
25 | from fed_rag.trainers import (
26 | HuggingFaceTrainerForLSR as HuggingFaceTrainerForLSR,
27 | )
28 | from fed_rag.trainers import (
29 | HuggingFaceTrainerForRALT as HuggingFaceTrainerForRALT,
30 | )
31 |
--------------------------------------------------------------------------------
/tests/data_structures/test_evals.py:
--------------------------------------------------------------------------------
1 | from fed_rag.data_structures import (
2 | BenchmarkEvaluatedExample,
3 | BenchmarkExample,
4 | KnowledgeNode,
5 | RAGResponse,
6 | SourceNode,
7 | )
8 |
9 |
10 | def test_model_dump_without_embs() -> None:
11 | evaluated = BenchmarkEvaluatedExample(
12 | score=0.42,
13 | example=BenchmarkExample(query="mock query", response="mock response"),
14 | rag_response=RAGResponse(
15 | response="mock rag reponse",
16 | source_nodes=[
17 | SourceNode(
18 | score=0.1,
19 | node=KnowledgeNode(
20 | embedding=[1, 2, 3], # embeddings not persisted
21 | node_type="text",
22 | text_content="fake content",
23 | ),
24 | ),
25 | ],
26 | ),
27 | )
28 |
29 | # act
30 | json_str = evaluated.model_dump_json_without_embeddings()
31 |
32 | # assert
33 | loaded_evaluated = BenchmarkEvaluatedExample.model_validate_json(json_str)
34 | assert loaded_evaluated.rag_response.source_nodes[0].node.embedding is None
35 |
--------------------------------------------------------------------------------
/tests/api/test_deprecated_types_imports.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pytest
4 |
5 | DEPRECATED_IMPORTS = [
6 | ("fed_rag.types.bridge", "BridgeMetadata"),
7 | ("fed_rag.types.results", "TrainResult"),
8 | ("fed_rag.types.results", "TestResult"),
9 | ("fed_rag.types.knowledge_node", "KnowledgeNode"),
10 | ("fed_rag.types.knowledge_node", "NodeType"),
11 | ("fed_rag.types.knowledge_node", "NodeContent"),
12 | ("fed_rag.types.rag", "RAGConfig"),
13 | ("fed_rag.types.rag", "RAGResponse"),
14 | ("fed_rag.types.rag", "SourceNode"),
15 | ]
16 |
17 |
18 | @pytest.mark.parametrize("module_path,class_name", DEPRECATED_IMPORTS)
19 | def test_import_from_types_raises_deprecation_warning(
20 | module_path: str, class_name: str
21 | ) -> None:
22 | """Test that importing from deprecated types modules raises warnings."""
23 |
24 | # clear the module from sys.modules if it exists
25 | if module_path in sys.modules:
26 | del sys.modules[module_path]
27 |
28 | with pytest.warns(DeprecationWarning):
29 | import importlib
30 |
31 | module = importlib.import_module(module_path)
32 | getattr(module, class_name) # ensure its loaded
33 |
--------------------------------------------------------------------------------
/tests/loss/pytorch/conftest.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import pytest
4 | import torch
5 |
6 | EMB_DIM = 10
7 | BATCH_SIZE = 2
8 | NUM_CHUNKS = 3
9 |
10 |
11 | @pytest.fixture()
12 | def retrieved_chunks() -> torch.Tensor:
13 | """Embeddings of 'retrieved' chunks."""
14 | batch = []
15 | for bx in range(1, BATCH_SIZE + 1):
16 | embs = []
17 | for ix in range(1, NUM_CHUNKS + 1):
18 | embs.append([bx / ix for _ in range(EMB_DIM)])
19 | batch.append(embs)
20 |
21 | return torch.tensor(batch, dtype=torch.float32)
22 |
23 |
24 | @pytest.fixture()
25 | def contexts() -> torch.Tensor:
26 | batch = []
27 | for ix in range(1, BATCH_SIZE):
28 | batch.append(torch.ones(EMB_DIM) * ix)
29 | return torch.stack(batch, dim=0)
30 |
31 |
32 | @pytest.fixture()
33 | def lm_scores() -> torch.Tensor:
34 | """Mock probas of generated outputs 'given' context and chunk."""
35 | batch = []
36 | for bx in range(1, BATCH_SIZE + 1):
37 | scores = [math.exp(ix) for ix in range(NUM_CHUNKS)]
38 | scores = [el / sum(scores) for el in scores]
39 | batch.append(scores)
40 |
41 | return torch.tensor(batch, dtype=torch.float32)
42 |
--------------------------------------------------------------------------------
/examples/ra-dit/ra_dit/_dataset_prep/qa/pubmed.py:
--------------------------------------------------------------------------------
1 | """PubmedQA
2 |
3 | Example
4 | ===
5 | {
6 | "question": ...,
7 | "context": {
8 | "contexts": [],
9 | ...
10 | },
11 | "long_answer": ...,
12 | "final_decision": ...
13 | }
14 | """
15 |
16 | import pandas as pd
17 |
18 | from ..base_data_prepper import DEFAULT_SAVE_DIR, BaseDataPrepper
19 | from .mixin import QAMixin
20 |
21 | QA_SAVE_DIR = DEFAULT_SAVE_DIR / "qa"
22 |
23 |
24 | class PubmedQADataPrepper(QAMixin, BaseDataPrepper):
25 | @property
26 | def dataset_name(self) -> str:
27 | return "pubmed_qa"
28 |
29 | def _get_answer(self, row: pd.Series) -> str:
30 | return str(row["long_answer"] + "\n\n" + row["final_decision"])
31 |
32 | def _get_evidence(self, row: pd.Series) -> str:
33 | return "\n\n".join(row["context"]["contexts"])
34 |
35 | def _get_question(self, row: pd.Series) -> str:
36 | return str(row["question"])
37 |
38 |
39 | df = pd.read_parquet(
40 | "hf://datasets/qiaojin/PubMedQA/pqa_artificial/train-00000-of-00001.parquet"
41 | )
42 | data_prepper = PubmedQADataPrepper(df=df, save_dir=QA_SAVE_DIR)
43 | data_prepper.execute_and_save()
44 |
--------------------------------------------------------------------------------
/tests/generators/test_unsloth_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | from fed_rag.exceptions import MissingExtraError
7 | from fed_rag.generators.unsloth.utils import check_unsloth_installed
8 |
9 |
10 | def test_check_raises_error() -> None:
11 | """Check raises error from utils."""
12 |
13 | modules = {"unsloth": None}
14 |
15 | with patch.dict("sys.modules", modules):
16 | # without class name
17 | msg = (
18 | "Missing installation of the `unsloth` extra, yet is required "
19 | "by an imported class. To fix please run `pip install fed-rag[unsloth]`."
20 | )
21 | with pytest.raises(
22 | MissingExtraError,
23 | match=re.escape(msg),
24 | ):
25 | check_unsloth_installed()
26 |
27 | # with class name
28 | msg = (
29 | "`FakeClass` requires the `unsloth` extra to be installed. "
30 | "To fix please run `pip install fed-rag[unsloth]`."
31 | )
32 | with pytest.raises(
33 | MissingExtraError,
34 | match=re.escape(msg),
35 | ):
36 | check_unsloth_installed("FakeClass")
37 |
--------------------------------------------------------------------------------
/tests/tokenizers/conftest.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 | import tokenizers
5 | from tokenizers import Tokenizer, models
6 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
7 |
8 | from fed_rag.base.tokenizer import BaseTokenizer
9 |
10 |
11 | class MockTokenizer(BaseTokenizer):
12 | def encode(self, input: str, **kwargs: Any) -> list[int]:
13 | return [0, 1, 2]
14 |
15 | def decode(self, input_ids: list[int], **kwargs: Any) -> str:
16 | return "mock decoded sentence"
17 |
18 | @property
19 | def unwrapped(self) -> None:
20 | return None
21 |
22 |
23 | @pytest.fixture()
24 | def mock_tokenizer() -> BaseTokenizer:
25 | return MockTokenizer()
26 |
27 |
28 | @pytest.fixture
29 | def hf_tokenizer() -> PreTrainedTokenizer:
30 | tokenizer = Tokenizer(
31 | models.WordPiece({"hello": 0, "[UNK]": 1}, unk_token="[UNK]")
32 | )
33 | tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.WhitespaceSplit()
34 | return PreTrainedTokenizerFast(
35 | tokenizer_object=tokenizer,
36 | pad_token="[PAD]",
37 | cls_token="[CLS]",
38 | sep_token="[SEP]",
39 | mask_token="[MASK]",
40 | )
41 |
--------------------------------------------------------------------------------
/tests/retrievers/conftest.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 | import torch
5 | from pydantic import PrivateAttr
6 | from sentence_transformers import SentenceTransformer
7 |
8 | from fed_rag.base.retriever import BaseRetriever
9 |
10 |
11 | class MockRetriever(BaseRetriever):
12 | _encoder: torch.nn.Module = PrivateAttr(default=torch.nn.Linear(2, 1))
13 |
14 | def encode_context(self, context: str, **kwargs: Any) -> torch.Tensor:
15 | return self._encoder.forward(torch.ones(2))
16 |
17 | def encode_query(self, query: str, **kwargs: Any) -> torch.Tensor:
18 | return self._encoder.forward(torch.zeros(2))
19 |
20 | @property
21 | def encoder(self) -> torch.nn.Module:
22 | return self._encoder
23 |
24 | @property
25 | def query_encoder(self) -> torch.nn.Module | None:
26 | return None
27 |
28 | @property
29 | def context_encoder(self) -> torch.nn.Module | None:
30 | return None
31 |
32 |
33 | @pytest.fixture
34 | def mock_retriever() -> MockRetriever:
35 | return MockRetriever()
36 |
37 |
38 | @pytest.fixture
39 | def dummy_sentence_transformer() -> SentenceTransformer:
40 | return SentenceTransformer(modules=[torch.nn.Linear(5, 5)])
41 |
--------------------------------------------------------------------------------
/examples/ra-dit/ra_dit/_dataset_prep/qa/web_questions.py:
--------------------------------------------------------------------------------
1 | """WebQA
2 |
3 | Example
4 | ===
5 | {
6 | 'url': 'http://www.freebase.com/view/en/justin_bieber',
7 | 'question': 'http://www.freebase.com/view/en/justin_bieber',
8 | 'answer': 'answers'
9 | },
10 | """
11 |
12 | import pandas as pd
13 |
14 | from ..base_data_prepper import DEFAULT_SAVE_DIR, BaseDataPrepper
15 | from .mixin import QAMixin
16 |
17 | QA_SAVE_DIR = DEFAULT_SAVE_DIR / "qa"
18 |
19 |
20 | class WebQuestionsDataPrepper(QAMixin, BaseDataPrepper):
21 | @property
22 | def dataset_name(self) -> str:
23 | return "web_questions_qa"
24 |
25 | def _get_answer(self, row: pd.Series) -> str:
26 | return str(", ".join(row["answers"]))
27 |
28 | def _get_evidence(self, row: pd.Series) -> str | None:
29 | return None
30 |
31 | def _get_question(self, row: pd.Series) -> str:
32 | return str(row["question"])
33 |
34 |
35 | splits = {
36 | "train": "data/train-00000-of-00001.parquet",
37 | "test": "data/test-00000-of-00001.parquet",
38 | }
39 |
40 | df = pd.read_parquet("hf://datasets/Stanford/web_questions/" + splits["train"])
41 | data_prepper = WebQuestionsDataPrepper(df=df, save_dir=QA_SAVE_DIR)
42 | data_prepper.execute_and_save()
43 |
--------------------------------------------------------------------------------
/src/fed_rag/exceptions/knowledge_stores.py:
--------------------------------------------------------------------------------
1 | """Exceptions for Knowledge Stores."""
2 |
3 | from .core import FedRAGError, FedRAGWarning
4 |
5 |
6 | class KnowledgeStoreError(FedRAGError):
7 | """Base knowledge store error for all knowledge-store-related exceptions."""
8 |
9 | pass
10 |
11 |
12 | class KnowledgeStoreWarning(FedRAGWarning):
13 | """Base knowledge store error for all knowledge-store-related warnings."""
14 |
15 | pass
16 |
17 |
18 | class KnowledgeStoreNotFoundError(KnowledgeStoreError, FileNotFoundError):
19 | """Raised if the knowledge store can not be found or loaded from file."""
20 |
21 | pass
22 |
23 |
24 | class InvalidDistanceError(KnowledgeStoreError):
25 | """Raised if provided an invalid similarity distance."""
26 |
27 | pass
28 |
29 |
30 | class LoadNodeError(KnowledgeStoreError):
31 | """Raised if an error occurs when loading a node."""
32 |
33 | pass
34 |
35 |
36 | class MCPKnowledgeStoreError(KnowledgeStoreError):
37 | """Base knowledge store error for all knowledge-store-related exceptions."""
38 |
39 | pass
40 |
41 |
42 | class CallToolResultConversionError(MCPKnowledgeStoreError):
43 | """Raised when trying to convert a ~mcp.CallToolResult that has error status."""
44 |
45 | pass
46 |
--------------------------------------------------------------------------------
/tests/retrievers/test_base.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext as does_not_raise
2 |
3 | import torch
4 |
5 | from fed_rag.base.retriever import BaseRetriever
6 |
7 |
8 | def test_base_abstract_attr() -> None:
9 | abstract_methods = BaseRetriever.__abstractmethods__
10 |
11 | assert "encode_context" in abstract_methods
12 | assert "encode_query" in abstract_methods
13 | assert "encoder" in abstract_methods
14 | assert "query_encoder" in abstract_methods
15 | assert "context_encoder" in abstract_methods
16 |
17 |
18 | def test_base_encode(mock_retriever: BaseRetriever) -> None:
19 | encoded_ctx = mock_retriever.encode_context("mock context")
20 | encoded_query = mock_retriever.encode_query("mock query")
21 | cosine_sim = encoded_ctx @ encoded_query.T
22 | *_, final_layer = mock_retriever.encoder.parameters()
23 |
24 | with does_not_raise():
25 | # cosine sim should be a Tensor with a single item
26 | cosine_sim.item()
27 |
28 | assert encoded_ctx.numel() == final_layer.size()[-1]
29 | assert encoded_query.numel() == final_layer.size()[-1]
30 | assert isinstance(mock_retriever.encoder, torch.nn.Module)
31 | assert mock_retriever.query_encoder is None
32 | assert mock_retriever.context_encoder is None
33 |
--------------------------------------------------------------------------------
/.github/workflows/unit_test.yml:
--------------------------------------------------------------------------------
1 | name: Unit Testing and Upload Coverage
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | types:
8 | - opened
9 | - synchronize
10 | jobs:
11 | test:
12 | runs-on: ubuntu-latest
13 | permissions:
14 | id-token: write
15 | contents: read
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.10", "3.11", "3.12"]
20 | steps:
21 | - name: get code
22 | uses: actions/checkout@v6
23 |
24 | - name: Install uv
25 | uses: astral-sh/setup-uv@v7
26 | with:
27 | # Install a specific version of uv.
28 | version: "0.5.21"
29 | enable-cache: true
30 | python-version: ${{ matrix.python-version }}
31 |
32 | - name: Install the project
33 | run: uv sync --all-extras --dev
34 |
35 | - name: Run tests
36 | run: |
37 | uv run make coverage
38 |
39 | - if: matrix.python-version == '3.12'
40 | name: Upload results to Codecov
41 | uses: codecov/codecov-action@v5
42 | with:
43 | token: ${{ secrets.CODECOV_TOKEN }}
44 | slug: VectorInstitute/fed-rag
45 | fail_ci_if_error: true
46 | verbose: true
47 |
--------------------------------------------------------------------------------
/src/fed_rag/core/no_encode_rag_system/asynchronous.py:
--------------------------------------------------------------------------------
1 | """Async No Encode RAG System Module"""
2 |
3 | from fed_rag.core.no_encode_rag_system._asynchronous import (
4 | _AsyncNoEncodeRAGSystem,
5 | )
6 |
7 | from .synchronous import NoEncodeRAGSystem
8 |
9 |
10 | # Define the public NoEncodeRAGSystem with all available bridges
11 | class AsyncNoEncodeRAGSystem(_AsyncNoEncodeRAGSystem):
12 | """Async NoEncode RAG System with all available bridge functionality.
13 |
14 | The AsyncNoEncodeRAGSystem is the main entry point for creating and managing
15 | retrieval-augmented generation systems that skip encoding altogether,
16 | enabling direct natural language queries to knowledge sources like MCP
17 | servers, APIs, and databases.
18 |
19 | Unlike traditional RAG systems that require separate retriever components
20 | and pre-computed embeddings, NoEncode RAG systems perform direct queries
21 | against NoEncode knowledge sources.
22 | """
23 |
24 | def to_sync(
25 | self,
26 | ) -> NoEncodeRAGSystem:
27 | return NoEncodeRAGSystem(
28 | knowledge_store=self.knowledge_store.to_sync(),
29 | generator=self.generator, # NOTE: this should actually be sync!
30 | rag_config=self.rag_config,
31 | )
32 |
--------------------------------------------------------------------------------
/tests/api/test_namespaced_imports.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import pytest
4 |
5 | from fed_rag import generators, knowledge_stores, retrievers
6 |
7 |
8 | @pytest.mark.parametrize("name", generators.__all__)
9 | def test_public_generators_all_importable(name: str) -> None:
10 | """Tests that all names listed in generators __all__ are importable."""
11 | mod = importlib.import_module("fed_rag.generators")
12 | attr = getattr(mod, name)
13 |
14 | assert hasattr(mod, name)
15 | assert attr is not None
16 |
17 |
18 | @pytest.mark.parametrize("name", retrievers.__all__)
19 | def test_public_retrievers_all_importable(name: str) -> None:
20 | """Tests that all names listed in retrievers __all__ are importable."""
21 | mod = importlib.import_module("fed_rag.retrievers")
22 | attr = getattr(mod, name)
23 |
24 | assert hasattr(mod, name)
25 | assert attr is not None
26 |
27 |
28 | @pytest.mark.parametrize("name", knowledge_stores.__all__)
29 | def test_public_knowledge_stores_all_importable(name: str) -> None:
30 | """Tests that all names listed in knowledge_stores __all__ are importable."""
31 | mod = importlib.import_module("fed_rag.knowledge_stores")
32 | attr = getattr(mod, name)
33 |
34 | assert hasattr(mod, name)
35 | assert attr is not None
36 |
--------------------------------------------------------------------------------
/src/fed_rag/utils/huggingface.py:
--------------------------------------------------------------------------------
1 | from fed_rag import NoEncodeRAGSystem, RAGSystem
2 | from fed_rag.exceptions import FedRAGError
3 |
4 |
5 | def _validate_rag_system(rag_system: RAGSystem | NoEncodeRAGSystem) -> None:
6 | # Skip validation if environment variable is set
7 | import os
8 |
9 | if os.environ.get("FEDRAG_SKIP_VALIDATION") == "1":
10 | return
11 |
12 | from fed_rag.generators.huggingface import (
13 | HFPeftModelGenerator,
14 | HFPretrainedModelGenerator,
15 | )
16 | from fed_rag.generators.unsloth import UnslothFastModelGenerator
17 | from fed_rag.retrievers.huggingface.hf_sentence_transformer import (
18 | HFSentenceTransformerRetriever,
19 | )
20 |
21 | if not isinstance(
22 | rag_system.generator,
23 | (
24 | HFPretrainedModelGenerator,
25 | HFPeftModelGenerator,
26 | UnslothFastModelGenerator,
27 | ),
28 | ):
29 | raise FedRAGError(
30 | "Generator must be HFPretrainedModelGenerator or HFPeftModelGenerator"
31 | )
32 |
33 | if isinstance(rag_system, RAGSystem) and not isinstance(
34 | rag_system.retriever, HFSentenceTransformerRetriever
35 | ):
36 | raise FedRAGError("Retriever must be a HFSentenceTransformerRetriever")
37 |
--------------------------------------------------------------------------------
/src/fed_rag/decorators/tester.py:
--------------------------------------------------------------------------------
1 | """Tester Decorators"""
2 |
3 | from typing import Callable
4 |
5 |
6 | class TesterDecorators:
7 | def pytorch(self, func: Callable) -> Callable:
8 | from fed_rag.inspectors.pytorch import inspect_tester_signature
9 |
10 | def decorator(func: Callable) -> Callable:
11 | # inspect func sig
12 | spec = inspect_tester_signature(
13 | func
14 | ) # may need to create a cfg for this if decorater accepts params
15 |
16 | # store fl_task config
17 | func.__setattr__("__fl_task_tester_config", spec) # type: ignore[attr-defined]
18 |
19 | return func
20 |
21 | return decorator(func)
22 |
23 | def huggingface(self, func: Callable) -> Callable:
24 | from fed_rag.inspectors.huggingface import inspect_tester_signature
25 |
26 | def decorator(func: Callable) -> Callable:
27 | # inspect func sig
28 | spec = inspect_tester_signature(
29 | func
30 | ) # may need to create a cfg for this if decorater accepts params
31 |
32 | # store fl_task config
33 | func.__setattr__("__fl_task_tester_config", spec) # type: ignore[attr-defined]
34 |
35 | return func
36 |
37 | return decorator(func)
38 |
--------------------------------------------------------------------------------
/src/fed_rag/decorators/trainer.py:
--------------------------------------------------------------------------------
1 | """Trainer Decorators"""
2 |
3 | from typing import Callable
4 |
5 |
6 | class TrainerDecorators:
7 | def pytorch(self, func: Callable) -> Callable:
8 | from fed_rag.inspectors.pytorch import inspect_trainer_signature
9 |
10 | def decorator(func: Callable) -> Callable:
11 | # inspect func sig
12 | spec = inspect_trainer_signature(
13 | func
14 | ) # may need to create a cfg for this if decorater accepts params
15 |
16 | # store fl_task config
17 | func.__setattr__("__fl_task_trainer_config", spec) # type: ignore[attr-defined]
18 |
19 | return func
20 |
21 | return decorator(func)
22 |
23 | def huggingface(self, func: Callable) -> Callable:
24 | from fed_rag.inspectors.huggingface import inspect_trainer_signature
25 |
26 | def decorator(func: Callable) -> Callable:
27 | # inspect func sig
28 | spec = inspect_trainer_signature(
29 | func
30 | ) # may need to create a cfg for this if decorater accepts params
31 |
32 | # store fl_task config
33 | func.__setattr__("__fl_task_trainer_config", spec) # type: ignore[attr-defined]
34 |
35 | return func
36 |
37 | return decorator(func)
38 |
--------------------------------------------------------------------------------
/src/fed_rag/data_structures/retriever.py:
--------------------------------------------------------------------------------
1 | """Data structures for retrievers."""
2 |
3 | from typing import TypedDict
4 |
5 | import torch
6 |
7 |
8 | class EncodeResult(TypedDict):
9 | """
10 | Represents the result of encoding multiple types of data.
11 |
12 | This TypedDict is used as a structured output for encoding operations
13 | involving various data modalities such as text, image, audio, or video.
14 | Each key corresponds to a specific modality and may contain a tensor
15 | result or None if that modality is not used or applicable.
16 |
17 | Attributes:
18 | text: Union[torch.Tensor, None]
19 | The tensor representation of encoded text data, or None if text
20 | is not processed.
21 | image: Union[torch.Tensor, None]
22 | The tensor representation of encoded image data, or None if image
23 | processing is not performed.
24 | audio: Union[torch.Tensor, None]
25 | The tensor representation of encoded audio data, or None if audio
26 | is not processed.
27 | video: Union[torch.Tensor, None]
28 | The tensor representation of encoded video data, or None if video
29 | processing is not performed.
30 | """
31 |
32 | text: torch.Tensor | None
33 | image: torch.Tensor | None
34 | audio: torch.Tensor | None
35 | video: torch.Tensor | None
36 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Docs Publish
2 | on:
3 | push:
4 | branches:
5 | - main
6 | workflow_dispatch:
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | permissions:
12 | contents: write # To push a branch
13 | pull-requests: write # To create a PR from that branch
14 | steps:
15 | - name: get code
16 | uses: actions/checkout@v6
17 |
18 | - name: Configure Git Credentials
19 | run: |
20 | git config user.name github-actions[bot]
21 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com
22 |
23 | - uses: actions/setup-python@v6
24 | with:
25 | python-version: "3.12"
26 |
27 | - name: Install uv
28 | uses: astral-sh/setup-uv@v7
29 | with:
30 | # Install a specific version of uv.
31 | version: "0.5.21"
32 | enable-cache: true
33 |
34 | - name: Install the project
35 | run: uv sync --all-extras --group dev --group docs
36 |
37 | - name: Build docs
38 | run: |
39 | uv run mkdocs build
40 |
41 | - name: Deploy to github pages
42 | uses: JamesIves/github-pages-deploy-action@v4.7.6
43 | with:
44 | branch: gh-pages # The branch the action should deploy to.
45 | folder: site # The folder the action should deploy.
46 |
--------------------------------------------------------------------------------
/docs/community/resources/pocket_references.md:
--------------------------------------------------------------------------------
1 | # AI Pocket References
2 |
3 |
4 |
5 |
6 | 
7 |
8 |
9 |
10 | 
11 |
12 |
13 | The [AI Pocket Reference](https://github.com/VectorInstitute/ai-pocket-reference)
14 | project is maintained by Vector AI Engineering as an accessible resource for the
15 | AI community. It provides a collection of _pocket references_ offering concise
16 | information on a wide range of AI topics, including Natural Language Processing
17 | (NLP) and Federated Learning (FL).
18 |
19 | ## Recommended Collections
20 |
21 | - [NLP Collection](https://vectorinstitute.github.io/ai-pocket-reference/nlp/) —
22 | Covers various topics within NLP, including RAG, LoRA, Quantization, Chain of Thought,
23 | Agents, and more.
24 |
25 | - [FL Collection](https://vectorinstitute.github.io/ai-pocket-reference/fl/) —
26 | Encompasses the fundamentals of federated learning along with advanced topics such
27 | as personalized federated learning and vertical federated learning.
28 |
--------------------------------------------------------------------------------
/src/fed_rag/_bridges/llamaindex/bridge.py:
--------------------------------------------------------------------------------
1 | """LlamaIndex Bridge"""
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from fed_rag._bridges.llamaindex._version import __version__
6 | from fed_rag.base.bridge import BaseBridgeMixin
7 |
8 | if TYPE_CHECKING: # pragma: no cover
9 | from llama_index.core.indices.managed.base import BaseManagedIndex
10 |
11 | from fed_rag.core.rag_system._synchronous import ( # avoids circular import
12 | _RAGSystem,
13 | )
14 |
15 |
16 | class LlamaIndexBridgeMixin(BaseBridgeMixin):
17 | """LlamaIndex Bridge.
18 |
19 | This mixin adds LlamaIndex conversion capabilities to _RAGSystem.
20 | When mixed with an unbridged _RAGSystem, it allows direct conversion to
21 | LlamaIndex's BaseManagedIndex through the to_llamaindex() method.
22 | """
23 |
24 | _bridge_version = __version__
25 | _bridge_extra = "llama-index"
26 | _framework = "llama-index-core"
27 | _compatible_versions = {"min": "0.12.35"}
28 | _method_name = "to_llamaindex"
29 |
30 | def to_llamaindex(self: "_RAGSystem") -> "BaseManagedIndex":
31 | """Converts the _RAGSystem to a ~llamaindex.core.BaseManagedIndex."""
32 | self._validate_framework_installed()
33 |
34 | from fed_rag._bridges.llamaindex._managed_index import (
35 | FedRAGManagedIndex,
36 | )
37 |
38 | return FedRAGManagedIndex(rag_system=self)
39 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | # FedRAG Pull Request Template
2 |
3 | Thanks for contributing to FedRAG!
4 | Please fill out the sections below to help us review your PR efficiently.
5 |
6 | ## Summary
7 |
8 | What does this PR do? Please provide a brief summary of the changes introduced.
9 |
10 | - [ ] Bug fix
11 | - [ ] New feature
12 | - [ ] Documentation update
13 | - [ ] Code quality / linting
14 | - [ ] Other (please describe):
15 |
16 | ## Description
17 | Any information reviewers should be aware of:
18 |
19 | ## Testing
20 |
21 | Describe how you tested your changes. Include the steps to reproduce, commands run, and any relevant outputs.
22 |
23 | - [ ] Unit tests added or updated
24 | - [ ] All tests pass locally (`make test`)
25 | - [ ] Code coverage maintained or improved
26 |
27 | ## Checklist
28 |
29 | Before submitting your PR, please check off the following:
30 |
31 | - [ ] My code follows the existing style and conventions
32 | - [ ] I’ve run linting (`make lint`)
33 | - [ ] I’ve added/updated relevant documentation
34 | - [ ] I’ve added/updated tests as needed
35 | - [ ] I’ve verified integration with existing tools (HuggingFace, LlamaIndex, LangChain, etc. if applicable)
36 | - [ ] I’ve added an entry to the CHANGELOG.md (if applicable)
37 |
38 | ## Related Issues or PRs
39 |
40 | If this PR addresses or relates to existing issues or pull requests, link them here:
41 |
42 | - Closes #
43 | - Related to #
44 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Install apt dependencies
13 | run: |
14 | sudo apt-get update
15 | sudo apt-get install libcurl4-openssl-dev libssl-dev
16 | - uses: actions/checkout@v6
17 |
18 | - name: Install uv
19 | uses: astral-sh/setup-uv@v7
20 | with:
21 | # Install a specific version of uv.
22 | version: "0.5.21"
23 | enable-cache: true
24 |
25 | - name: "Set up Python"
26 | uses: actions/setup-python@v6
27 | with:
28 | python-version: "3.10"
29 |
30 | - name: Install the project
31 | run: uv sync --all-extras --dev
32 |
33 | - name: Build package
34 | run: uv build
35 |
36 | - name: Publish package
37 | uses: pypa/gh-action-pypi-publish@v1.13.0
38 | with:
39 | user: __token__
40 | password: ${{ secrets.PYPI_API_TOKEN }}
41 |
42 | release_github:
43 | needs: deploy
44 | runs-on: ubuntu-latest
45 | steps:
46 | - name: Create GitHub Release
47 | id: create_release
48 | uses: ncipollo/release-action@v1.20.0
49 | with:
50 | artifacts: "dist/*"
51 | generateReleaseNotes: true
52 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - navigation
4 | - toc
5 | ---
6 |
7 |
8 |
9 |
21 |
22 | Comprehensive support for state-of-the-art RAG fine-tuning methods that can
23 | be federated with ease.
24 |
25 | [:octicons-arrow-right-24: Getting started](getting_started/essentials.md)
26 |
27 | -
:fontawesome-solid-cubes-stacked:{ .lg .middle } Work with your tools
28 |
29 | Seamlessly integrates with popular frameworks including HuggingFace,
30 | and LlamaIndex — use the tools you already know.
31 |
32 | [:octicons-arrow-right-24: In-Depth Examples](examples/index.md)
33 |
34 | -