├── tests ├── __init__.py ├── api │ ├── __init__.py │ ├── test_data_structures_imports.py │ ├── test_evals_imports.py │ ├── test_deprecated_types_imports.py │ └── test_namespaced_imports.py ├── evals │ ├── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ ├── test_base.py │ │ └── test_exact_match.py │ └── benchmarks │ │ ├── __init__.py │ │ ├── huggingface │ │ ├── __init__.py │ │ ├── test_utils.py │ │ └── test_boolq.py │ │ ├── test_base.py │ │ └── _benchmarks.py ├── loss │ ├── __init__.py │ └── pytorch │ │ ├── __init__.py │ │ └── conftest.py ├── utils │ ├── __init__.py │ ├── data │ │ └── finetuning_datasets │ │ │ ├── __init__.py │ │ │ ├── conftest.py │ │ │ └── test_pt_dataset.py │ └── test_asyncio.py ├── bridges │ └── __init__.py ├── decorators │ └── __init__.py ├── generators │ ├── __init__.py │ ├── mixins │ │ ├── __init__.py │ │ ├── test_image_mixin.py │ │ ├── test_audio_mixin.py │ │ └── test_video_mixin.py │ ├── test_base.py │ ├── test_hf_utils.py │ └── test_unsloth_utils.py ├── rag_system │ ├── __init__.py │ └── test_source_node.py ├── retrievers │ ├── __init__.py │ ├── mixins │ │ ├── __init__.py │ │ ├── test_audio_retriever_mixin.py │ │ ├── test_video_retriever_mixin.py │ │ └── test_image_retriever_mixin.py │ ├── conftest.py │ └── test_base.py ├── tokenizers │ ├── __init__.py │ ├── test_base.py │ ├── conftest.py │ └── test_unsloth_pretrained.py ├── trainers │ ├── __init__.py │ ├── huggingface │ │ ├── __init__.py │ │ └── conftest.py │ ├── pytorch │ │ ├── __init__.py │ │ └── conftest.py │ └── test_base.py ├── data_collators │ ├── __init__.py │ ├── huggingface │ │ └── __init__.py │ └── test_base.py ├── trainer_configs │ ├── __init__.py │ └── test_pytorch_trainer_config.py ├── fl_tasks │ ├── huggingface │ │ └── __init__.py │ ├── pytorch │ │ ├── __init__.py │ │ └── conftest.py │ └── __init__.py ├── knowledge_stores │ ├── __init__.py │ └── no_encode │ │ ├── __init__.py │ │ └── mcp │ │ └── __init__.py ├── trainer_managers │ ├── __init__.py │ └── huggingface │ │ └── __init__.py └── data_structures │ └── test_evals.py ├── docs ├── community │ ├── index.md │ ├── resources │ │ ├── index.md │ │ └── pocket_references.md │ ├── changelog.md │ └── contributing │ │ ├── index.md │ │ ├── ask_question.md │ │ └── submit_issue.md ├── api_reference │ ├── index.md │ ├── retrievers │ │ ├── index.md │ │ └── huggingface.md │ ├── tokenizers │ │ ├── index.md │ │ └── huggingface.md │ ├── exceptions │ │ ├── loss.md │ │ ├── bridge.md │ │ ├── evals.md │ │ ├── fl_tasks.md │ │ ├── generator.md │ │ ├── retriever.md │ │ ├── tokenizer.md │ │ ├── trainer.md │ │ ├── inspectors.md │ │ ├── data_collator.md │ │ ├── rag_trainer.md │ │ ├── knowledge_stores.md │ │ └── index.md │ ├── loss │ │ └── pytorch.md │ ├── data_structures │ │ ├── evals.md │ │ ├── knowledge_node.md │ │ ├── results.md │ │ ├── bridge.md │ │ └── rag.md │ ├── knowledge_stores │ │ ├── mixins.md │ │ ├── index.md │ │ ├── qdrant.md │ │ └── in_memory.md │ ├── knowledge_nodes │ │ └── index.md │ ├── generators │ │ ├── index.md │ │ ├── unsloth.md │ │ └── huggingface.md │ ├── data_collators │ │ ├── index.md │ │ └── huggingface.md │ ├── evals │ │ ├── benchmarker.md │ │ ├── benchmarks │ │ │ └── huggingface │ │ │ │ ├── mmlu.md │ │ │ │ ├── boolq.md │ │ │ │ ├── hotpotqa.md │ │ │ │ ├── pubmedqa.md │ │ │ │ ├── squad_v2.md │ │ │ │ ├── hellaswag.md │ │ │ │ └── natural_questions.md │ │ ├── metrics │ │ │ └── exact_match.md │ │ └── index.md │ ├── fl_tasks │ │ ├── index.md │ │ ├── pytorch.md │ │ └── huggingface.md │ ├── bridges │ │ ├── langchain.md │ │ ├── llamaindex.md │ │ └── index.md │ ├── finetuning_datasets │ │ ├── index.md │ │ ├── pytorch.md │ │ └── huggingface.md │ ├── trainer_managers │ │ ├── pytorch.md │ │ ├── index.md │ │ └── huggingface.md │ ├── trainers │ │ ├── index.md │ │ ├── pytorch.md │ │ └── huggingface.md │ ├── inspectors │ │ ├── pytorch.md │ │ ├── huggingface.md │ │ └── index.md │ ├── decorators │ │ └── index.md │ └── rag_system │ │ └── index.md ├── assets │ └── favicon.ico ├── examples │ ├── ra_dit │ │ ├── benchmarking.md │ │ ├── finetune.md │ │ ├── federated_finetune.md │ │ └── index.md │ └── index.md ├── overrides │ └── partials │ │ ├── logo.html │ │ └── copyright.html ├── javascripts │ └── mathjax.js ├── getting_started │ ├── tutorials │ │ └── index.md │ ├── quick_starts │ │ └── index.md │ ├── import_patterns.md │ ├── integrations.md │ └── installation.md └── index.md ├── src └── fed_rag │ ├── py.typed │ ├── base │ ├── __init__.py │ ├── evals │ │ ├── __init__.py │ │ ├── metric.py │ │ └── benchmark.py │ ├── generator_mixins │ │ ├── __init__.py │ │ ├── audio.py │ │ ├── video.py │ │ └── image.py │ ├── retriever_mixins │ │ ├── __init__.py │ │ ├── audio.py │ │ ├── image.py │ │ └── video.py │ ├── data_collator.py │ ├── tokenizer.py │ └── generator.py │ ├── loss │ ├── __init__.py │ └── pytorch │ │ └── __init__.py │ ├── _bridges │ ├── __init__.py │ ├── langchain │ │ ├── _version.py │ │ ├── __init__.py │ │ └── bridge.py │ └── llamaindex │ │ ├── _version.py │ │ ├── __init__.py │ │ └── bridge.py │ ├── fl_tasks │ └── __init__.py │ ├── utils │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ └── finetuning_datasets │ │ │ ├── __init__.py │ │ │ ├── pytorch.py │ │ │ └── huggingface.py │ └── huggingface.py │ ├── data_collators │ ├── __init__.py │ └── huggingface │ │ └── __init__.py │ ├── inspectors │ ├── __init__.py │ ├── pytorch │ │ ├── __init__.py │ │ └── tester.py │ ├── huggingface │ │ ├── __init__.py │ │ └── utils.py │ └── common.py │ ├── tokenizers │ ├── __init__.py │ └── unsloth_pretrained_tokenizer.py │ ├── _version.py │ ├── trainer_configs │ ├── __init__.py │ └── pytorch.py │ ├── retrievers │ ├── huggingface │ │ └── __init__.py │ └── __init__.py │ ├── core │ ├── rag_system │ │ ├── __init__.py │ │ ├── synchronous.py │ │ └── asynchronous.py │ ├── no_encode_rag_system │ │ ├── __init__.py │ │ ├── synchronous.py │ │ └── asynchronous.py │ └── __init__.py │ ├── evals │ ├── metrics │ │ ├── __init__.py │ │ └── exact_match.py │ ├── benchmarks │ │ ├── __init__.py │ │ └── huggingface │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ └── boolq.py │ ├── __init__.py │ └── utils.py │ ├── exceptions │ ├── common.py │ ├── data_collator.py │ ├── core.py │ ├── loss.py │ ├── generator.py │ ├── rag_system.py │ ├── retriever.py │ ├── tokenizer.py │ ├── bridge.py │ ├── trainer.py │ ├── evals.py │ ├── fl_tasks.py │ ├── trainer_manager.py │ ├── knowledge_stores.py │ └── inspectors.py │ ├── knowledge_stores │ ├── qdrant │ │ ├── __init__.py │ │ └── utils.py │ ├── no_encode │ │ ├── mcp │ │ │ ├── sources │ │ │ │ ├── __init__.py │ │ │ │ └── utils.py │ │ │ └── __init__.py │ │ └── __init__.py │ ├── __init__.py │ └── mixins.py │ ├── generators │ ├── unsloth │ │ ├── mixin.py │ │ ├── __init__.py │ │ └── utils.py │ ├── huggingface │ │ ├── __init__.py │ │ └── utils.py │ └── __init__.py │ ├── trainers │ ├── pytorch │ │ ├── __init__.py │ │ ├── training_args.py │ │ └── mixin.py │ ├── huggingface │ │ └── __init__.py │ └── __init__.py │ ├── trainer_managers │ └── __init__.py │ ├── decorators │ ├── __init__.py │ ├── tester.py │ └── trainer.py │ ├── types │ ├── bridge.py │ ├── results.py │ ├── rag.py │ ├── __init__.py │ ├── rag_system.py │ └── knowledge_node.py │ ├── data_structures │ ├── bridge.py │ ├── __init__.py │ ├── retriever.py │ └── results.py │ └── __init__.pyi ├── .python-versions ├── examples ├── quick-start │ ├── quick_start │ │ ├── __init__.py │ │ └── _cifar_dataloaders.py │ ├── pyproject.toml │ └── README.md ├── ra-dit │ ├── ra_dit │ │ ├── _dataset_prep │ │ │ ├── __init__.py │ │ │ └── qa │ │ │ │ ├── __init__.py │ │ │ │ ├── pubmed.py │ │ │ │ ├── web_questions.py │ │ │ │ ├── commonsense.py │ │ │ │ ├── wiki.py │ │ │ │ ├── math.py │ │ │ │ └── mixin.py │ │ ├── knowledge_stores │ │ │ ├── __init__.py │ │ │ └── from_dragon.py │ │ ├── trainers_and_testers │ │ │ └── __init__.py │ │ ├── retrievers │ │ │ ├── __init__.py │ │ │ └── dragon.py │ │ ├── utils.py │ │ ├── generators │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── evaluation_benchmarks │ │ │ └── __init__.py │ │ └── __init__.py │ ├── pyproject.toml │ └── README.md └── knowledge_stores │ └── ra-dit-ks │ ├── src │ └── ra_dit_ks │ │ └── __init__.py │ ├── docker │ └── healthcheck.sh │ ├── pyproject.toml │ └── README.md ├── pytest.ini ├── .gitmodules ├── codecov.yml ├── mypy.ini ├── example_scripts └── README.md ├── .github ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── question.yml │ ├── documentation_improvement.yml │ ├── feature_request.yml │ ├── integration_request.yml │ ├── bug_report.yml │ └── config.yml ├── workflows │ ├── lint.yml │ ├── unit_test.yml │ ├── docs.yml │ └── release.yml └── pull_request_template.md ├── Makefile ├── CITATION.cff └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/community/index.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/evals/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/bridges/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/generators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/rag_system/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/community/resources/index.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/_bridges/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/fl_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/data_collators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/evals/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/loss/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/trainer_configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/base/evals/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/data_collators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/inspectors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/loss/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/evals/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/fl_tasks/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/fl_tasks/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/generators/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/knowledge_stores/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/retrievers/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/trainer_managers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/trainers/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/trainers/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-versions: -------------------------------------------------------------------------------- 1 | 3.12 2 | 3.11 3 | 3.10 4 | -------------------------------------------------------------------------------- /docs/api_reference/index.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | -------------------------------------------------------------------------------- /examples/quick-start/quick_start/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode=auto 3 | -------------------------------------------------------------------------------- /src/fed_rag/_version.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.0.27" 2 | -------------------------------------------------------------------------------- /tests/data_collators/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/evals/benchmarks/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/knowledge_stores/no_encode/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/knowledge_stores/no_encode/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/trainer_managers/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/data/finetuning_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/community/changelog.md: -------------------------------------------------------------------------------- 1 | --8<-- "CHANGELOG.md" 2 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/qa/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/knowledge_stores/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/trainers_and_testers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fed_rag/_bridges/langchain/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /tests/fl_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorchTrainerConfig Unit Tests""" 2 | -------------------------------------------------------------------------------- /src/fed_rag/_bridges/llamaindex/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /docs/api_reference/retrievers/index.md: -------------------------------------------------------------------------------- 1 | # Retrievers 2 | 3 | ::: src.fed_rag.base.retriever 4 | -------------------------------------------------------------------------------- /docs/api_reference/tokenizers/index.md: -------------------------------------------------------------------------------- 1 | # Tokenizers 2 | 3 | ::: src.fed_rag.base.tokenizer 4 | -------------------------------------------------------------------------------- /docs/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/fed-rag/HEAD/docs/assets/favicon.ico -------------------------------------------------------------------------------- /docs/api_reference/exceptions/loss.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.loss 4 | -------------------------------------------------------------------------------- /docs/api_reference/loss/pytorch.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.loss.pytorch.lsr 4 | -------------------------------------------------------------------------------- /examples/knowledge_stores/ra-dit-ks/src/ra_dit_ks/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import logger 2 | 3 | __all__ = ["logger"] 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/bridge.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.bridge 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/evals.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.evals 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/fl_tasks.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.fl_tasks 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/generator.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.generator 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/retriever.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.retriever 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/tokenizer.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.tokenizer 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/trainer.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.trainer 4 | -------------------------------------------------------------------------------- /src/fed_rag/_bridges/langchain/__init__.py: -------------------------------------------------------------------------------- 1 | from .bridge import LangChainBridgeMixin 2 | 3 | __all__ = ["LangChainBridgeMixin"] 4 | -------------------------------------------------------------------------------- /src/fed_rag/trainer_configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import PyTorchTrainerConfig 2 | 3 | __all__ = ["PyTorchTrainerConfig"] 4 | -------------------------------------------------------------------------------- /src/fed_rag/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from ._functions import build_finetune_dataset 2 | 3 | __all__ = ["build_finetune_dataset"] 4 | -------------------------------------------------------------------------------- /docs/api_reference/data_structures/evals.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.data_structures.evals 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/inspectors.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.inspectors 4 | -------------------------------------------------------------------------------- /src/fed_rag/_bridges/llamaindex/__init__.py: -------------------------------------------------------------------------------- 1 | from .bridge import LlamaIndexBridgeMixin 2 | 3 | __all__ = ["LlamaIndexBridgeMixin"] 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/data_collator.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.data_collator 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/rag_trainer.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.trainer_manager 4 | -------------------------------------------------------------------------------- /docs/api_reference/knowledge_stores/mixins.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.knowledge_stores.mixins 4 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/knowledge_stores.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.exceptions.knowledge_stores 4 | -------------------------------------------------------------------------------- /docs/api_reference/knowledge_nodes/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.data_structures.knowledge_node 4 | -------------------------------------------------------------------------------- /docs/api_reference/data_structures/knowledge_node.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.data_structures.knowledge_node 4 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dragon import retriever as dragon_retriever 2 | 3 | RETRIEVERS = {"dragon": dragon_retriever} 4 | -------------------------------------------------------------------------------- /docs/examples/ra_dit/benchmarking.md: -------------------------------------------------------------------------------- 1 | # Evaluate with Benchmarks 2 | 3 | __Coming Soon!__ 4 | 5 | This documentation page is currently under development. 6 | -------------------------------------------------------------------------------- /docs/examples/ra_dit/finetune.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning with QA Datasets 2 | 3 | __Coming Soon!__ 4 | 5 | This documentation page is currently under development. 6 | -------------------------------------------------------------------------------- /docs/api_reference/generators/index.md: -------------------------------------------------------------------------------- 1 | # Generators 2 | 3 | ::: src.fed_rag.base.generator 4 | options: 5 | members: 6 | - BaseGenerator 7 | -------------------------------------------------------------------------------- /docs/examples/ra_dit/federated_finetune.md: -------------------------------------------------------------------------------- 1 | # Federated Fine-tuning 2 | 3 | __Coming Soon!__ 4 | 5 | This documentation page is currently under development. 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "examples-vector-compute"] 2 | path = examples-vector-compute 3 | url = git@github.com:VectorInstitute/fed-rag-examples-vector-compute.git 4 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | def generate_timestamp() -> str: 5 | return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") 6 | -------------------------------------------------------------------------------- /src/fed_rag/retrievers/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_sentence_transformer import HFSentenceTransformerRetriever 2 | 3 | __all__ = ["HFSentenceTransformerRetriever"] 4 | -------------------------------------------------------------------------------- /src/fed_rag/utils/data/finetuning_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import PyTorchRAGFinetuningDataset 2 | 3 | __all__ = [ 4 | "PyTorchRAGFinetuningDataset", 5 | ] 6 | -------------------------------------------------------------------------------- /docs/api_reference/data_collators/index.md: -------------------------------------------------------------------------------- 1 | # Base Data Collator 2 | 3 | ::: src.fed_rag.base.data_collator 4 | options: 5 | members: 6 | - BaseDataCollator 7 | -------------------------------------------------------------------------------- /src/fed_rag/core/rag_system/__init__.py: -------------------------------------------------------------------------------- 1 | from .asynchronous import AsyncRAGSystem 2 | from .synchronous import RAGSystem 3 | 4 | __all__ = ["RAGSystem", "AsyncRAGSystem"] 5 | -------------------------------------------------------------------------------- /src/fed_rag/evals/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """Metrics public API""" 2 | 3 | from .exact_match import ExactMatchEvaluationMetric 4 | 5 | __all__ = ["ExactMatchEvaluationMetric"] 6 | -------------------------------------------------------------------------------- /docs/api_reference/knowledge_stores/index.md: -------------------------------------------------------------------------------- 1 | # Base KnowledgeStore 2 | 3 | ::: src.fed_rag.base.knowledge_store 4 | options: 5 | members: 6 | - BaseKnowledgeStore 7 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama2_7b import generator_registry as llama2_7b_generators 2 | 3 | GENERATORS = { 4 | "llama2_7b": llama2_7b_generators, 5 | } 6 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarker.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.evals.benchmarker 4 | options: 5 | members: 6 | - Benchmarker 7 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarks/huggingface/mmlu.md: -------------------------------------------------------------------------------- 1 | # MMLU 2 | 3 | ::: src.fed_rag.evals.benchmarks.huggingface.mmlu 4 | options: 5 | members: 6 | - HuggingFaceMMLU 7 | -------------------------------------------------------------------------------- /docs/api_reference/fl_tasks/index.md: -------------------------------------------------------------------------------- 1 | # Base FL Task Classes 2 | 3 | ::: src.fed_rag.base.fl_task 4 | options: 5 | members: 6 | - BaseFLTask 7 | - BaseFLTaskConfig 8 | -------------------------------------------------------------------------------- /docs/api_reference/fl_tasks/pytorch.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.fl_tasks.pytorch 4 | options: 5 | members: 6 | - PyTorchFLTask 7 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarks/huggingface/boolq.md: -------------------------------------------------------------------------------- 1 | # BoolQ 2 | 3 | ::: src.fed_rag.evals.benchmarks.huggingface.boolq 4 | options: 5 | members: 6 | - HuggingFaceBoolQ 7 | -------------------------------------------------------------------------------- /docs/api_reference/bridges/langchain.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag._bridges.langchain 4 | options: 5 | members: 6 | - LangChainBridgeMixin 7 | -------------------------------------------------------------------------------- /docs/api_reference/bridges/llamaindex.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag._bridges.llamaindex 4 | options: 5 | members: 6 | - LlamaIndexBridgeMixin 7 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarks/huggingface/hotpotqa.md: -------------------------------------------------------------------------------- 1 | # HotPotQA 2 | 3 | ::: src.fed_rag.evals.benchmarks.huggingface.hotpotqa 4 | options: 5 | members: 6 | - HuggingFaceHotPotQA 7 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarks/huggingface/pubmedqa.md: -------------------------------------------------------------------------------- 1 | # PubMedQA 2 | 3 | ::: src.fed_rag.evals.benchmarks.huggingface.pubmedqa 4 | options: 5 | members: 6 | - HuggingFacePubMedQA 7 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarks/huggingface/squad_v2.md: -------------------------------------------------------------------------------- 1 | # SQuAD v2 2 | 3 | ::: src.fed_rag.evals.benchmarks.huggingface.squad_v2 4 | options: 5 | members: 6 | - HuggingFaceSQuADv2 7 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarks/huggingface/hellaswag.md: -------------------------------------------------------------------------------- 1 | # HellaSwag 2 | 3 | ::: src.fed_rag.evals.benchmarks.huggingface.hellaswag 4 | options: 5 | members: 6 | - HuggingFaceHellaSwag 7 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/common.py: -------------------------------------------------------------------------------- 1 | """Common exceptions.""" 2 | 3 | from .core import FedRAGError 4 | 5 | 6 | class MissingExtraError(FedRAGError): 7 | """Raised when a fed-rag extra is not installed.""" 8 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/data_collator.py: -------------------------------------------------------------------------------- 1 | from .core import FedRAGError 2 | 3 | 4 | class DataCollatorError(FedRAGError): 5 | """Base errors for all data collator relevant exceptions.""" 6 | 7 | pass 8 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: # https://docs.codecov.com/docs/commit-status 4 | default: 5 | target: 80% 6 | threshold: 5% 7 | github_checks: 8 | annotations: false 9 | -------------------------------------------------------------------------------- /docs/api_reference/finetuning_datasets/index.md: -------------------------------------------------------------------------------- 1 | # Finetuning Datasets 2 | 3 | ::: src.fed_rag.utils.data._functions 4 | options: 5 | members: 6 | - build_finetune_dataset 7 | - ReturnType 8 | -------------------------------------------------------------------------------- /docs/api_reference/trainer_managers/pytorch.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.trainer_managers.pytorch 4 | options: 5 | members: 6 | - PyTorchRAGTrainerManager 7 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/evaluation_benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BenchmarkResult 2 | from .mmlu import mmlu_benchmark 3 | 4 | benchmarks = {"mmlu": mmlu_benchmark} 5 | 6 | __all__ = ["BenchmarkResult"] 7 | -------------------------------------------------------------------------------- /src/fed_rag/core/no_encode_rag_system/__init__.py: -------------------------------------------------------------------------------- 1 | from .asynchronous import AsyncNoEncodeRAGSystem 2 | from .synchronous import NoEncodeRAGSystem 3 | 4 | __all__ = ["NoEncodeRAGSystem", "AsyncNoEncodeRAGSystem"] 5 | -------------------------------------------------------------------------------- /docs/api_reference/evals/metrics/exact_match.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.evals.metrics.exact_match 4 | options: 5 | members: 6 | - ExactMatchEvaluationMetric 7 | -------------------------------------------------------------------------------- /docs/api_reference/tokenizers/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.tokenizers.hf_pretrained_tokenizer 4 | options: 5 | members: 6 | - HFPretrainedTokenizer 7 | -------------------------------------------------------------------------------- /docs/api_reference/trainer_managers/index.md: -------------------------------------------------------------------------------- 1 | # Base RAG Trainer Manager 2 | 3 | ::: src.fed_rag.base.trainer_manager 4 | options: 5 | members: 6 | - BaseRAGTrainerManager 7 | - RAGTrainMode 8 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/qdrant/__init__.py: -------------------------------------------------------------------------------- 1 | from .asynchronous import AsyncQdrantKnowledgeStore 2 | from .sync import QdrantKnowledgeStore 3 | 4 | __all__ = ["QdrantKnowledgeStore", "AsyncQdrantKnowledgeStore"] 5 | -------------------------------------------------------------------------------- /docs/api_reference/bridges/index.md: -------------------------------------------------------------------------------- 1 | # Base Bridges Module 2 | 3 | ::: src.fed_rag.base.bridge 4 | options: 5 | members: 6 | - BaseBridgeMixin 7 | - BridgeRegistryMixin 8 | - BridgeMetadata 9 | -------------------------------------------------------------------------------- /docs/api_reference/data_structures/results.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.data_structures.results 4 | options: 5 | members: 6 | - TrainResult 7 | - TestResult 8 | -------------------------------------------------------------------------------- /docs/api_reference/trainer_managers/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.trainer_managers.huggingface 4 | options: 5 | members: 6 | - HuggingFaceRAGTrainerManager 7 | -------------------------------------------------------------------------------- /docs/api_reference/trainers/index.md: -------------------------------------------------------------------------------- 1 | # Base Trainer 2 | 3 | ::: src.fed_rag.base.trainer 4 | options: 5 | members: 6 | - BaseTrainer 7 | - BaseRetrieverTrainer 8 | - BaseGeneratorTrainer 9 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_return_any = True 3 | warn_unused_configs = True 4 | disallow_untyped_defs = True 5 | ignore_missing_imports = True 6 | explicit_package_bases = True 7 | mypy_path = "src" 8 | plugins = pydantic.mypy 9 | -------------------------------------------------------------------------------- /src/fed_rag/generators/unsloth/mixin.py: -------------------------------------------------------------------------------- 1 | from ..huggingface.mixin import HFGeneratorProtocol, HuggingFaceGeneratorMixin 2 | 3 | UnslothGeneratorMixin = HuggingFaceGeneratorMixin 4 | UnslothGeneratorProtocol = HFGeneratorProtocol 5 | -------------------------------------------------------------------------------- /docs/api_reference/data_structures/bridge.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.data_structures.bridge 4 | options: 5 | members: 6 | - BridgeMetadata 7 | - CompatibleVersions 8 | -------------------------------------------------------------------------------- /docs/api_reference/finetuning_datasets/pytorch.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.utils.data.finetuning_datasets.pytorch 4 | options: 5 | members: 6 | - PyTorchRAGFinetuningDataset 7 | -------------------------------------------------------------------------------- /docs/api_reference/trainers/pytorch.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.trainers.pytorch.mixin 4 | options: 5 | members: 6 | - PyTorchTrainerProtocol 7 | - PyTorchTrainerMixin 8 | -------------------------------------------------------------------------------- /src/fed_rag/trainers/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixin import PyTorchTrainerMixin, PyTorchTrainerProtocol 2 | from .training_args import TrainingArgs 3 | 4 | __all__ = ["TrainingArgs", "PyTorchTrainerMixin", "PyTorchTrainerProtocol"] 5 | -------------------------------------------------------------------------------- /docs/api_reference/evals/benchmarks/huggingface/natural_questions.md: -------------------------------------------------------------------------------- 1 | # Natural Questions 2 | 3 | ::: src.fed_rag.evals.benchmarks.huggingface.natural_questions 4 | options: 5 | members: 6 | - HuggingFaceNaturalQuestions 7 | -------------------------------------------------------------------------------- /docs/api_reference/fl_tasks/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.fl_tasks.huggingface 4 | options: 5 | members: 6 | - HuggingFaceFlowerClient 7 | - HuggingFaceFLTask 8 | -------------------------------------------------------------------------------- /docs/api_reference/retrievers/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.retrievers.huggingface.hf_sentence_transformer 4 | options: 5 | members: 6 | - HFSentenceTransformerRetriever 7 | -------------------------------------------------------------------------------- /docs/api_reference/data_collators/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.data_collators.huggingface 4 | options: 5 | members: 6 | - DataCollatorForLSR 7 | - DataCollatorForRALT 8 | -------------------------------------------------------------------------------- /docs/api_reference/exceptions/index.md: -------------------------------------------------------------------------------- 1 | # FedRAG Exceptions 2 | 3 | ::: src.fed_rag.exceptions.core 4 | options: 5 | members: 6 | - FedRAGError 7 | - FedRAGWarning 8 | 9 | ::: src.fed_rag.exceptions.common 10 | -------------------------------------------------------------------------------- /docs/api_reference/finetuning_datasets/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.utils.data.finetuning_datasets.huggingface 4 | options: 5 | members: 6 | - HuggingFaceRAGFinetuningDataset 7 | -------------------------------------------------------------------------------- /docs/api_reference/inspectors/pytorch.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.inspectors.pytorch 4 | options: 5 | members: 6 | - inspect_trainer_signature 7 | - inspect_tester_signature 8 | -------------------------------------------------------------------------------- /docs/api_reference/inspectors/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.inspectors.huggingface 4 | options: 5 | members: 6 | - inspect_trainer_signature 7 | - inspect_tester_signature 8 | -------------------------------------------------------------------------------- /docs/api_reference/knowledge_stores/qdrant.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.knowledge_stores.qdrant 4 | options: 5 | members: 6 | - QdrantKnowledgeStore 7 | - AsyncQdrantKnowledgeStore 8 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/no_encode/mcp/sources/__init__.py: -------------------------------------------------------------------------------- 1 | from .stdio import MCPStdioKnowledgeSource 2 | from .streamable_http import MCPStreamableHttpKnowledgeSource 3 | 4 | __all__ = ["MCPStreamableHttpKnowledgeSource", "MCPStdioKnowledgeSource"] 5 | -------------------------------------------------------------------------------- /docs/api_reference/knowledge_stores/in_memory.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.knowledge_stores.in_memory 4 | options: 5 | members: 6 | - InMemoryKnowledgeStore 7 | - ManagedInMemoryKnowledgeStore 8 | -------------------------------------------------------------------------------- /docs/overrides/partials/logo.html: -------------------------------------------------------------------------------- 1 | logo 6 | logo 11 | -------------------------------------------------------------------------------- /src/fed_rag/generators/unsloth/__init__.py: -------------------------------------------------------------------------------- 1 | from .unsloth_fast_model import UnslothFastModelGenerator 2 | from .unsloth_fast_multimodal_model import UnslothFastMultimodalModelGenerator 3 | 4 | __all__ = ["UnslothFastModelGenerator", "UnslothFastMultimodalModelGenerator"] 5 | -------------------------------------------------------------------------------- /src/fed_rag/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | """Public Retrievers API""" 2 | 3 | # Disable the F403 warning for wildcard imports 4 | # ruff: noqa: F403, F401 5 | from .huggingface import * 6 | from .huggingface import __all__ as _huggingface_all 7 | 8 | __all__ = _huggingface_all 9 | -------------------------------------------------------------------------------- /src/fed_rag/trainer_managers/__init__.py: -------------------------------------------------------------------------------- 1 | """Public RAG Trainer Managers API""" 2 | 3 | from .huggingface import HuggingFaceRAGTrainerManager 4 | from .pytorch import PyTorchRAGTrainerManager 5 | 6 | __all__ = ["HuggingFaceRAGTrainerManager", "PyTorchRAGTrainerManager"] 7 | -------------------------------------------------------------------------------- /docs/api_reference/evals/index.md: -------------------------------------------------------------------------------- 1 | # Evals 2 | 3 | ::: src.fed_rag.base.evals.benchmark 4 | options: 5 | members: 6 | - BaseBenchmark 7 | 8 | ::: src.fed_rag.base.evals.metric 9 | options: 10 | members: 11 | - BaseEvaluationMetric 12 | -------------------------------------------------------------------------------- /src/fed_rag/evals/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | """Evals Benchmarks Public API""" 2 | 3 | # Disable the F403 warning for wildcard imports 4 | # ruff: noqa: F403, F401 5 | from .huggingface import * 6 | from .huggingface import __all__ as _huggingface_all 7 | 8 | __all__ = _huggingface_all 9 | -------------------------------------------------------------------------------- /docs/api_reference/data_structures/rag.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.data_structures.rag 4 | options: 5 | members: 6 | - Query 7 | - Context 8 | - Prompt 9 | - SourceNode 10 | - RAGResponse 11 | -------------------------------------------------------------------------------- /docs/api_reference/decorators/index.md: -------------------------------------------------------------------------------- 1 | # Trainer and Tester Decorators 2 | 3 | ::: src.fed_rag.decorators.trainer 4 | options: 5 | members: 6 | - TrainerDecorators 7 | 8 | ::: src.fed_rag.decorators.tester 9 | options: 10 | members: 11 | - TesterDecorators 12 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/core.py: -------------------------------------------------------------------------------- 1 | """Base Error Class for FedRAG.""" 2 | 3 | 4 | class FedRAGError(Exception): 5 | """Base error for all fed-rag exceptions.""" 6 | 7 | pass 8 | 9 | 10 | class FedRAGWarning(Warning): 11 | """Base warning for all fed-rag warnings.""" 12 | 13 | pass 14 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/no_encode/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | from .sources import MCPStdioKnowledgeSource, MCPStreamableHttpKnowledgeSource 2 | from .store import MCPKnowledgeStore 3 | 4 | __all__ = [ 5 | "MCPKnowledgeStore", 6 | "MCPStdioKnowledgeSource", 7 | "MCPStreamableHttpKnowledgeSource", 8 | ] 9 | -------------------------------------------------------------------------------- /src/fed_rag/evals/__init__.py: -------------------------------------------------------------------------------- 1 | """Evals public API""" 2 | 3 | # Disable the F403 warning for wildcard imports 4 | # ruff: noqa: F403, F401 5 | from .benchmarker import Benchmarker 6 | from .metrics import * 7 | from .metrics import __all__ as _metrics_all 8 | 9 | __all__ = ["Benchmarker"] + _metrics_all 10 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/no_encode/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcp import ( 2 | MCPKnowledgeStore, 3 | MCPStdioKnowledgeSource, 4 | MCPStreamableHttpKnowledgeSource, 5 | ) 6 | 7 | __all__ = [ 8 | "MCPKnowledgeStore", 9 | "MCPStdioKnowledgeSource", 10 | "MCPStreamableHttpKnowledgeSource", 11 | ] 12 | -------------------------------------------------------------------------------- /tests/utils/data/finetuning_datasets/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | 5 | @pytest.fixture() 6 | def input_and_target_ids() -> tuple[list[torch.Tensor], list[torch.Tensor]]: 7 | input_ids = [torch.zeros(3)] * 3 8 | target_ids = [torch.ones(3)] * 3 9 | return input_ids, target_ids 10 | -------------------------------------------------------------------------------- /src/fed_rag/trainers/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | from .lsr import HuggingFaceTrainerForLSR 2 | from .mixin import HuggingFaceTrainerMixin 3 | from .ralt import HuggingFaceTrainerForRALT 4 | 5 | __all__ = [ 6 | "HuggingFaceTrainerForLSR", 7 | "HuggingFaceTrainerForRALT", 8 | "HuggingFaceTrainerMixin", 9 | ] 10 | -------------------------------------------------------------------------------- /src/fed_rag/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Public Core API""" 2 | 3 | from .no_encode_rag_system import AsyncNoEncodeRAGSystem, NoEncodeRAGSystem 4 | from .rag_system import AsyncRAGSystem, RAGSystem 5 | 6 | __all__ = [ 7 | "AsyncNoEncodeRAGSystem", 8 | "AsyncRAGSystem", 9 | "NoEncodeRAGSystem", 10 | "RAGSystem", 11 | ] 12 | -------------------------------------------------------------------------------- /docs/api_reference/inspectors/index.md: -------------------------------------------------------------------------------- 1 | # Inspectors 2 | 3 | ::: src.fed_rag.inspectors.common 4 | options: 5 | members: 6 | - TesterSignatureSpec 7 | - TesterSignatureSpec 8 | 9 | ::: src.fed_rag.data_structures.results 10 | options: 11 | members: 12 | - TrainResult 13 | - TestResult 14 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/knowledge_stores/from_dragon.py: -------------------------------------------------------------------------------- 1 | """Knowledge Store.""" 2 | 3 | # ra_dit 4 | from ra_dit.retrievers.dragon import retriever 5 | 6 | from .utils import knowledge_store_from_retriever 7 | 8 | knowledge_store = knowledge_store_from_retriever( 9 | retriever=retriever, persist=True, name="dragon", overwrite=True 10 | ) 11 | -------------------------------------------------------------------------------- /src/fed_rag/inspectors/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .tester import TesterSignatureSpec, inspect_tester_signature 2 | from .trainer import TrainerSignatureSpec, inspect_trainer_signature 3 | 4 | __all__ = [ 5 | "TesterSignatureSpec", 6 | "TrainerSignatureSpec", 7 | "inspect_tester_signature", 8 | "inspect_trainer_signature", 9 | ] 10 | -------------------------------------------------------------------------------- /src/fed_rag/inspectors/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | from .tester import TesterSignatureSpec, inspect_tester_signature 2 | from .trainer import TrainerSignatureSpec, inspect_trainer_signature 3 | 4 | __all__ = [ 5 | "TesterSignatureSpec", 6 | "TrainerSignatureSpec", 7 | "inspect_tester_signature", 8 | "inspect_trainer_signature", 9 | ] 10 | -------------------------------------------------------------------------------- /src/fed_rag/trainer_configs/pytorch.py: -------------------------------------------------------------------------------- 1 | """PyTorch Trainer Config""" 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from fed_rag.base.trainer_config import BaseTrainerConfig 7 | 8 | 9 | class PyTorchTrainerConfig(BaseTrainerConfig): 10 | net: torch.nn.Module 11 | train_data: DataLoader 12 | val_data: DataLoader 13 | -------------------------------------------------------------------------------- /src/fed_rag/generators/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_multimodal_model import HFMultimodalModelGenerator 2 | from .hf_peft_model import HFPeftModelGenerator 3 | from .hf_pretrained_model import HFPretrainedModelGenerator 4 | 5 | __all__ = [ 6 | "HFPeftModelGenerator", 7 | "HFPretrainedModelGenerator", 8 | "HFMultimodalModelGenerator", 9 | ] 10 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/__init__.py: -------------------------------------------------------------------------------- 1 | """Public KnowledgeStores API""" 2 | 3 | from .in_memory import InMemoryKnowledgeStore 4 | 5 | # Disable the F403 warning for wildcard imports 6 | # ruff: noqa: F403, F401 7 | from .qdrant import * 8 | from .qdrant import __all__ as _qdrant_all 9 | 10 | __all__ = sorted(["InMemoryKnowledgeStore"] + _qdrant_all) 11 | -------------------------------------------------------------------------------- /examples/knowledge_stores/ra-dit-ks/docker/healthcheck.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if Qdrant API is responding 4 | if curl -s -f -X GET "http://localhost:6333/collections" -H "Content-Type: application/json" > /dev/null; then 5 | # Qdrant API is responding successfully 6 | exit 0 7 | else 8 | # Qdrant API is not responding 9 | exit 1 10 | fi 11 | -------------------------------------------------------------------------------- /src/fed_rag/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """Public Trainers API""" 2 | 3 | # Disable the F403 warning for wildcard imports 4 | # ruff: noqa: F403, F401 5 | from .huggingface import * 6 | from .huggingface import __all__ as _huggingface_all 7 | from .pytorch import * 8 | from .pytorch import __all__ as _pytorch_all 9 | 10 | __all__ = _huggingface_all + _pytorch_all 11 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/loss.py: -------------------------------------------------------------------------------- 1 | """Exceptions for loss.""" 2 | 3 | from .core import FedRAGError 4 | 5 | 6 | class LossError(FedRAGError): 7 | """Base loss errors for all loss-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class InvalidReductionParam(LossError): 13 | """Raised if an invalid aggregation mode is provided.""" 14 | 15 | pass 16 | -------------------------------------------------------------------------------- /src/fed_rag/generators/__init__.py: -------------------------------------------------------------------------------- 1 | """Public Generators API""" 2 | 3 | # Disable the F403 warning for wildcard imports 4 | # ruff: noqa: F403, F401 5 | from .huggingface import * 6 | from .huggingface import __all__ as _huggingface_all 7 | from .unsloth import * 8 | from .unsloth import __all__ as _unsloth_all 9 | 10 | __all__ = sorted(_huggingface_all + _unsloth_all) 11 | -------------------------------------------------------------------------------- /docs/community/contributing/index.md: -------------------------------------------------------------------------------- 1 | # Contributing to FedRAG 2 | 3 | Thank you for your interest in contributing to FedRAG! This document provides 4 | guidelines and instructions for contributing. 5 | 6 | We welcome contributions from developers of all skill levels. Whether you're 7 | fixing a typo or implementing a complex feature, your help is valuable to the 8 | FedRAG project. 9 | -------------------------------------------------------------------------------- /examples/quick-start/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "quick-start" 3 | version = "0.1.0" 4 | description = "A quick start example using fed-rag." 5 | readme = "README.md" 6 | requires-python = ">=3.10,<4.0" 7 | dependencies = [ 8 | "fed-rag", 9 | "torch", 10 | "fire", 11 | "flwr-datasets>=0.5.0" 12 | ] 13 | 14 | [tool.uv.sources] 15 | fed-rag = {workspace = true} 16 | -------------------------------------------------------------------------------- /docs/api_reference/generators/unsloth.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.generators.unsloth.unsloth_fast_model 4 | options: 5 | members: 6 | - UnslothFastModelGenerator 7 | 8 | ::: src.fed_rag.generators.unsloth.unsloth_fast_multimodal_model 9 | options: 10 | members: 11 | - UnslothFastMultimodalModelGenerator 12 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import logger 2 | 3 | CHECKPOINT_DIR_TEMPLATES = { 4 | "generator": ( 5 | ".checkpoints/{retriever_id}-{generator_id}-{generator_variant}/generator" 6 | ), 7 | "retriever": ( 8 | ".checkpoints/{retriever_id}-{generator_id}-{generator_variant}/retriever" 9 | ), 10 | } 11 | 12 | __all__ = ["logger"] 13 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/generator.py: -------------------------------------------------------------------------------- 1 | """Exceptions for Generators.""" 2 | 3 | from .core import FedRAGError, FedRAGWarning 4 | 5 | 6 | class GeneratorError(FedRAGError): 7 | """Base evals error for all generator-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class GeneratorWarning(FedRAGWarning): 13 | """Base inspector warning for all generator-related warnings.""" 14 | 15 | pass 16 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/rag_system.py: -------------------------------------------------------------------------------- 1 | """Exceptions for RAG System.""" 2 | 3 | from .core import FedRAGError, FedRAGWarning 4 | 5 | 6 | class RAGSystemError(FedRAGError): 7 | """Base evals error for all generator-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class RAGSystemWarning(FedRAGWarning): 13 | """Base inspector warning for all generator-related warnings.""" 14 | 15 | pass 16 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/retriever.py: -------------------------------------------------------------------------------- 1 | """Exceptions for Retrievers.""" 2 | 3 | from .core import FedRAGError, FedRAGWarning 4 | 5 | 6 | class RetrieverError(FedRAGError): 7 | """Base evals error for all retriever-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class RetrieverWarning(FedRAGWarning): 13 | """Base inspector warning for all retriever-related warnings.""" 14 | 15 | pass 16 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/tokenizer.py: -------------------------------------------------------------------------------- 1 | """Exceptions for Tokenizer.""" 2 | 3 | from .core import FedRAGError, FedRAGWarning 4 | 5 | 6 | class TokenizerError(FedRAGError): 7 | """Base evals error for all tokenizer-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class TokenizerWarning(FedRAGWarning): 13 | """Base inspector warning for all tokenizer-related warnings.""" 14 | 15 | pass 16 | -------------------------------------------------------------------------------- /tests/tokenizers/test_base.py: -------------------------------------------------------------------------------- 1 | from fed_rag.base.tokenizer import BaseTokenizer 2 | 3 | 4 | def test_base(mock_tokenizer: BaseTokenizer) -> None: 5 | input_ids = mock_tokenizer.encode("hello world!") 6 | decoded_str = mock_tokenizer.decode([1, 2, 3]) 7 | 8 | assert input_ids == [0, 1, 2] 9 | assert decoded_str == "mock decoded sentence" 10 | assert mock_tokenizer.unwrapped is None 11 | -------------------------------------------------------------------------------- /docs/examples/index.md: -------------------------------------------------------------------------------- 1 | # Case Studies 2 | 3 | 4 | 5 | Here are some in-depth case studies to further demonstrate FedRAG's usage and its 6 | overall utilty. 7 | 8 |
9 | 10 | - :material-hexagon-outline: [__RA-DIT__](./ra_dit/index.md) — A comprehensive 11 | reproduction of the RA-DIT method, adapted for practical demonstration. 12 | 13 |
14 | -------------------------------------------------------------------------------- /src/fed_rag/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | """Decorators""" 2 | 3 | from .tester import TesterDecorators 4 | from .trainer import TrainerDecorators 5 | 6 | 7 | class Federate: 8 | def __init__(self, trainer: TrainerDecorators, tester: TesterDecorators): 9 | self.trainer = trainer 10 | self.tester = tester 11 | 12 | 13 | federate = Federate(trainer=TrainerDecorators(), tester=TesterDecorators()) 14 | 15 | 16 | __all__ = ["federate"] 17 | -------------------------------------------------------------------------------- /docs/api_reference/rag_system/index.md: -------------------------------------------------------------------------------- 1 | # RAG Systems and its Variations 2 | 3 | ::: src.fed_rag.core.rag_system 4 | options: 5 | members: 6 | - RAGSystem 7 | - AsyncRAGSystem 8 | 9 | ::: src.fed_rag.core.no_encode_rag_system 10 | options: 11 | members: 12 | - NoEncodeRAGSystem 13 | - AsyncNoEncodeRAGSystem 14 | 15 | ::: src.fed_rag.data_structures.rag 16 | options: 17 | members: 18 | - RAGConfig 19 | -------------------------------------------------------------------------------- /src/fed_rag/base/generator_mixins/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio import AudioModalityMixin, GeneratorHasAudioModality 2 | from .image import GeneratorHasImageModality, ImageModalityMixin 3 | from .video import GeneratorHasVideoModality, VideoModalityMixin 4 | 5 | __all__ = [ 6 | "ImageModalityMixin", 7 | "GeneratorHasImageModality", 8 | "AudioModalityMixin", 9 | "GeneratorHasAudioModality", 10 | "VideoModalityMixin", 11 | "GeneratorHasVideoModality", 12 | ] 13 | -------------------------------------------------------------------------------- /src/fed_rag/evals/metrics/exact_match.py: -------------------------------------------------------------------------------- 1 | """Exact Match Metric""" 2 | 3 | from typing import Any 4 | 5 | from fed_rag.base.evals.metric import BaseEvaluationMetric 6 | 7 | 8 | class ExactMatchEvaluationMetric(BaseEvaluationMetric): 9 | """Exact match evaluation metric class.""" 10 | 11 | def __call__( 12 | self, prediction: str, actual: str, *args: Any, **kwargs: Any 13 | ) -> float: 14 | return float(prediction.lower() == actual.lower()) 15 | -------------------------------------------------------------------------------- /src/fed_rag/inspectors/huggingface/utils.py: -------------------------------------------------------------------------------- 1 | from inspect import Parameter 2 | from typing import cast 3 | 4 | 5 | def get_type_name(t: Parameter) -> str | None: 6 | if isinstance(t.annotation, str): 7 | type_name = t.annotation 8 | else: 9 | type_name = getattr( 10 | t.annotation, "__name__", None 11 | ) # type:ignore [assignment] 12 | type_name = cast(str | None, type_name) # type:ignore [assignment] 13 | return type_name 14 | -------------------------------------------------------------------------------- /src/fed_rag/base/retriever_mixins/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio import AudioRetrieverMixin, RetrieverHasAudioModality 2 | from .image import ImageRetrieverMixin, RetrieverHasImageModality 3 | from .video import RetrieverHasVideoModality, VideoRetrieverMixin 4 | 5 | __all__ = [ 6 | "ImageRetrieverMixin", 7 | "RetrieverHasImageModality", 8 | "AudioRetrieverMixin", 9 | "RetrieverHasAudioModality", 10 | "VideoRetrieverMixin", 11 | "RetrieverHasVideoModality", 12 | ] 13 | -------------------------------------------------------------------------------- /tests/evals/metrics/test_base.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from fed_rag.base.evals.metric import BaseEvaluationMetric 4 | 5 | 6 | class MyMetric(BaseEvaluationMetric): 7 | def __call__( 8 | self, prediction: str, actual: str, *args: Any, **kwargs: Any 9 | ) -> float: 10 | return 0.42 11 | 12 | 13 | def test_metric_call() -> None: 14 | metric = MyMetric() 15 | 16 | score = metric("fake pred", "fake actual") 17 | 18 | assert score == 0.42 19 | -------------------------------------------------------------------------------- /src/fed_rag/base/evals/metric.py: -------------------------------------------------------------------------------- 1 | """Base EvaluationMetric""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any 5 | 6 | from pydantic import BaseModel 7 | 8 | 9 | class BaseEvaluationMetric(BaseModel, ABC): 10 | """Base Data Collator.""" 11 | 12 | @abstractmethod 13 | def __call__( 14 | self, prediction: str, actual: str, *args: Any, **kwargs: Any 15 | ) -> float: 16 | """Evaluate an example prediction against the actual response.""" 17 | -------------------------------------------------------------------------------- /docs/javascripts/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true, 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex", 11 | }, 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.startup.output.clearCache(); 16 | MathJax.typesetClear(); 17 | MathJax.texReset(); 18 | MathJax.typesetPromise(); 19 | }); 20 | -------------------------------------------------------------------------------- /src/fed_rag/trainers/pytorch/training_args.py: -------------------------------------------------------------------------------- 1 | """Training Args""" 2 | 3 | from typing import Any 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class TrainingArgs(BaseModel): 9 | """Arguments for training.""" 10 | 11 | learning_rate: float | None = None 12 | batch_size: int | None = None 13 | num_epochs: int | None = None 14 | warmup_steps: int | None = None 15 | weight_decay: float | None = None 16 | custom_kwargs: dict[str, Any] = Field(default_factory=dict) 17 | -------------------------------------------------------------------------------- /examples/quick-start/README.md: -------------------------------------------------------------------------------- 1 | # Example: Quick Start 2 | 3 | ## Usage 4 | 5 | Install the dependencies with `uv`: 6 | ```sh 7 | uv sync --all-extras --dev 8 | ``` 9 | 10 | To run, execute the following commands while in `fed-rag/examples/quickstart`: 11 | 12 | ```sh 13 | # start server 14 | uv run -m quick_start.main --component server 15 | 16 | # start client 1 17 | uv run -m quick_start.main --component client_1 18 | 19 | # start client 2 20 | uv run -m quick_start.main --component client_2 21 | ``` 22 | -------------------------------------------------------------------------------- /docs/api_reference/generators/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.generators.huggingface.hf_peft_model 4 | options: 5 | members: 6 | - HFPeftModelGenerator 7 | 8 | ::: src.fed_rag.generators.huggingface.hf_pretrained_model 9 | options: 10 | members: 11 | - HFPretrainedModelGenerator 12 | 13 | ::: src.fed_rag.generators.huggingface.hf_multimodal_model 14 | options: 15 | members: 16 | - HFMultimodalModelGenerator 17 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/mixins.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from abc import ABC, abstractmethod 3 | 4 | from pydantic import BaseModel, Field 5 | from typing_extensions import Self 6 | 7 | 8 | def generate_ks_id() -> str: 9 | return str(uuid.uuid4()) 10 | 11 | 12 | class ManagedMixin(BaseModel, ABC): 13 | ks_id: str = Field(default_factory=generate_ks_id) 14 | 15 | @classmethod 16 | @abstractmethod 17 | def from_name_and_id(cls, ks_id: str) -> Self: 18 | """Load a managed Knowledge Store by id.""" 19 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/bridge.py: -------------------------------------------------------------------------------- 1 | """Exceptions for Bridges.""" 2 | 3 | from .core import FedRAGError 4 | 5 | 6 | class BridgeError(FedRAGError): 7 | """Base bridge error for all bridge-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class MissingSpecifiedConversionMethod(BridgeError): 13 | """Raised when bridge is missing its specified method.""" 14 | 15 | pass 16 | 17 | 18 | class IncompatibleVersionError(FedRAGError): 19 | """Raised when a fed-rag component is not compatible with the current version.""" 20 | -------------------------------------------------------------------------------- /tests/api/test_data_structures_imports.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import pytest 4 | 5 | from fed_rag.data_structures import __all__ as _data_structures_all 6 | 7 | 8 | @pytest.mark.parametrize("name", _data_structures_all) 9 | def test_data_structures_all_importable(name: str) -> None: 10 | """Tests that all names listed in generators __all__ are importable.""" 11 | mod = importlib.import_module("fed_rag.data_structures") 12 | attr = getattr(mod, name) 13 | 14 | assert hasattr(mod, name) 15 | assert attr is not None 16 | -------------------------------------------------------------------------------- /docs/api_reference/trainers/huggingface.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ::: src.fed_rag.trainers.huggingface.mixin 4 | options: 5 | members: 6 | - HuggingFaceTrainerProtocol 7 | - HuggingFaceTrainerMixin 8 | 9 | ::: src.fed_rag.trainers.huggingface.lsr 10 | options: 11 | members: 12 | - HuggingFaceTrainerForLSR 13 | - LSRSentenceTransformerTrainer 14 | 15 | ::: src.fed_rag.trainers.huggingface.ralt 16 | options: 17 | members: 18 | - HuggingFaceTrainerForRALT 19 | -------------------------------------------------------------------------------- /example_scripts/README.md: -------------------------------------------------------------------------------- 1 | # Example Scripts 2 | 3 | This sub-directory contains example scripts that are run either using the command 4 | line or within a Jupyter notebook (i.e., one of our cookbooks). 5 | 6 | To generally be able to run any of these scripts you should install from source. 7 | 8 | ```sh 9 | # clone repo 10 | git clone git@github.com:VectorInstitute/fed-rag.git 11 | 12 | # install from source with uv 13 | cd fed-rag 14 | uv sync --all-extras --dev --group docs 15 | 16 | # run script 17 | uv run example_scripts/.py 18 | ``` 19 | -------------------------------------------------------------------------------- /examples/ra-dit/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ra-dit" 3 | version = "0.1.0" 4 | description = "Example of FedRAG following dual finetuning framework of RA-DIT." 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "accelerate>=1.4.0", 9 | "bitsandbytes>=0.45.3", 10 | "colorama>=0.4.6", 11 | "fed-rag[huggingface,qdrant]", 12 | "fire>=0.7.0", 13 | "pandas>=2.2.3", 14 | "sentence-transformers>=3.4.1", 15 | "tqdm>=4.67.1", 16 | "transformers>=4.49.0", 17 | "trl>=0.15.2" 18 | ] 19 | 20 | [tool.uv.sources] 21 | fed-rag = {path = "../../", editable = true} 22 | -------------------------------------------------------------------------------- /tests/data_collators/test_base.py: -------------------------------------------------------------------------------- 1 | from fed_rag import RAGSystem 2 | 3 | from .conftest import MockDataCollator 4 | 5 | 6 | def test_init(mock_rag_system: RAGSystem) -> None: 7 | collator = MockDataCollator(rag_system=mock_rag_system) 8 | 9 | assert collator.rag_system == mock_rag_system 10 | 11 | 12 | def test_collate(mock_rag_system: RAGSystem) -> None: 13 | collator = MockDataCollator(rag_system=mock_rag_system) 14 | 15 | # act 16 | res = collator(features={"feat": ["mock_input"]}) 17 | 18 | assert collator.rag_system == mock_rag_system 19 | assert res == "collated!" 20 | -------------------------------------------------------------------------------- /docs/getting_started/tutorials/index.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | 4 | 5 | We've prepared the following tutorials on some of the core concepts and methods 6 | that underpin the important features and capabilities of FedRAG. 7 | 8 |
9 | 10 | - :material-hexagon-outline: [__LSR Fine-tuning__](./lsr.md) — A tutorial on the 11 | LM-Supervised Retriever (LSR) fine-tuning method. 12 | - :material-hexagon-outline: [__RALT Fine-tuning__](./ralt.md) — A tutorial on the 13 | Retriever-Augmented LM Training (RALT) method. 14 | 15 |
16 | -------------------------------------------------------------------------------- /src/fed_rag/inspectors/common.py: -------------------------------------------------------------------------------- 1 | """Common abstractions for inspectors""" 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class TrainerSignatureSpec(BaseModel): 7 | net_parameter: str 8 | train_data_param: str 9 | val_data_param: str 10 | extra_train_kwargs: list[str] = [] 11 | net_parameter_class_name: str 12 | 13 | 14 | class TesterSignatureSpec(BaseModel): 15 | __test__ = ( 16 | False # needed for Pytest collision. Avoids PytestCollectionWarning 17 | ) 18 | net_parameter: str 19 | test_data_param: str 20 | extra_test_kwargs: list[str] = [] 21 | net_parameter_class_name: str 22 | -------------------------------------------------------------------------------- /tests/utils/data/finetuning_datasets/test_pt_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fed_rag.utils.data.finetuning_datasets import PyTorchRAGFinetuningDataset 4 | 5 | 6 | def test_pt_rag_ft_dataset_init( 7 | input_and_target_ids: tuple[torch.Tensor, torch.Tensor], 8 | ) -> None: 9 | input_ids, target_ids = input_and_target_ids 10 | rag_ft_dataset = PyTorchRAGFinetuningDataset( 11 | input_ids=input_ids, target_ids=target_ids 12 | ) 13 | 14 | assert len(rag_ft_dataset) == len(input_ids) 15 | assert isinstance(rag_ft_dataset, torch.utils.data.Dataset) 16 | assert rag_ft_dataset[:] == input_and_target_ids[:] 17 | -------------------------------------------------------------------------------- /src/fed_rag/core/rag_system/synchronous.py: -------------------------------------------------------------------------------- 1 | """RAG System Module""" 2 | 3 | from fed_rag._bridges.langchain.bridge import LangChainBridgeMixin 4 | from fed_rag._bridges.llamaindex.bridge import LlamaIndexBridgeMixin 5 | from fed_rag.core.rag_system._synchronous import _RAGSystem 6 | 7 | 8 | # Define the public RAGSystem with all available bridges 9 | class RAGSystem(LlamaIndexBridgeMixin, LangChainBridgeMixin, _RAGSystem): 10 | """RAG System with all available bridge functionality. 11 | 12 | The RAGSystem is the main entry point for creating and managing 13 | retrieval-augmented generation systems. 14 | """ 15 | 16 | pass 17 | -------------------------------------------------------------------------------- /tests/evals/metrics/test_exact_match.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fed_rag.evals import ExactMatchEvaluationMetric 4 | 5 | 6 | @pytest.mark.parametrize( 7 | ("pred", "actual", "expected"), 8 | [ 9 | ("1+1=2", "1+1=2", 1.0), 10 | ("Yes, Correct!", "yes, correct!", 1.0), 11 | ("not the same", "as me", 0.0), 12 | ], 13 | ids=["match", "match case insensitive", "not match"], 14 | ) 15 | def test_exact_match(pred: str, actual: str, expected: float) -> None: 16 | metric = ExactMatchEvaluationMetric() 17 | 18 | # act 19 | res = metric(prediction=pred, actual=actual) 20 | 21 | assert res == expected 22 | -------------------------------------------------------------------------------- /src/fed_rag/types/bridge.py: -------------------------------------------------------------------------------- 1 | """Bridge type definitions for fed-rag. 2 | 3 | Note: The BridgeMetadata implementation has moved to fed_rag.data_structures.bridge. 4 | This module is maintained for backward compatibility. 5 | """ 6 | 7 | import warnings 8 | 9 | from ..data_structures.bridge import BridgeMetadata 10 | 11 | warnings.warn( 12 | "Importing BridgeMetadata from fed_rag.types.bridge is deprecated and will be " 13 | "removed in a future release. Use fed_rag.data_structures.bridge or " 14 | "fed_rag.data_structures instead.", 15 | DeprecationWarning, 16 | stacklevel=2, # point to users import statement 17 | ) 18 | 19 | __all__ = ["BridgeMetadata"] 20 | -------------------------------------------------------------------------------- /examples/knowledge_stores/ra-dit-ks/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "ra-dit-ks" 7 | version = "0.1.0" 8 | description = "Knowledge store builds for ra-dit examples" 9 | readme = "README.md" 10 | authors = [ 11 | {name = "nerdai", email = "andrei@vectorinstitute.ai"} 12 | ] 13 | requires-python = ">=3.12" 14 | dependencies = [ 15 | "colorama>=0.4.6", 16 | "fed-rag[huggingface,qdrant]>=0.0.13", 17 | "fire>=0.7.0", 18 | "python-dotenv>=1.1.0" 19 | ] 20 | 21 | [project.scripts] 22 | ra-dit = "ra_dit:main" 23 | 24 | [tool.uv.sources] 25 | fed-rag = {path = "../../../", editable = true} 26 | -------------------------------------------------------------------------------- /src/fed_rag/types/results.py: -------------------------------------------------------------------------------- 1 | """Data structures for results 2 | 3 | Note: The correct module has moved to fed_rag.data_structures.results. This module is 4 | maintained for backward compatibility. 5 | """ 6 | 7 | import warnings 8 | 9 | from ..data_structures.results import TestResult, TrainResult 10 | 11 | warnings.warn( 12 | "Importing TrainResult, TestResult from fed_rag.types.results" 13 | "is deprecated and will be removed in a future release. Use " 14 | "fed_rag.data_structures.results or fed_rag.data_structures instead.", 15 | DeprecationWarning, 16 | stacklevel=2, # point to users import statement 17 | ) 18 | 19 | __all__ = ["TrainResult", "TestResult"] 20 | -------------------------------------------------------------------------------- /src/fed_rag/types/rag.py: -------------------------------------------------------------------------------- 1 | """Data structures for RAG. 2 | 3 | Note: The correct module has moved to fed_rag.data_structures.rag. This module is 4 | maintained for backward compatibility. 5 | """ 6 | 7 | import warnings 8 | 9 | from ..data_structures.rag import RAGConfig, RAGResponse, SourceNode 10 | 11 | warnings.warn( 12 | "Importing RAGConfig, RAGResponse, SourceNode from fed_rag.types.rag" 13 | "is deprecated and will be removed in a future release. Use " 14 | "fed_rag.data_structures.rag or fed_rag.data_structures instead.", 15 | DeprecationWarning, 16 | stacklevel=2, # point to users import statement 17 | ) 18 | 19 | __all__ = ["RAGConfig", "RAGResponse", "SourceNode"] 20 | -------------------------------------------------------------------------------- /src/fed_rag/types/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | fed_rag.types 3 | 4 | Only components defined in `__all__` are considered stable and public. 5 | """ 6 | 7 | from .bridge import BridgeMetadata 8 | from .knowledge_node import KnowledgeNode, NodeContent, NodeType 9 | from .rag import RAGConfig, RAGResponse, SourceNode 10 | from .results import TestResult, TrainResult 11 | 12 | __all__ = [ 13 | # bridge 14 | "BridgeMetadata", 15 | # results 16 | "TrainResult", 17 | "TestResult", 18 | # knowledge node 19 | "KnowledgeNode", 20 | "NodeType", 21 | "NodeContent", 22 | # rag 23 | "RAGConfig", 24 | "RAGResponse", 25 | "SourceNode", 26 | ] 27 | 28 | __deprecated__ = True 29 | -------------------------------------------------------------------------------- /src/fed_rag/evals/benchmarks/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | from .boolq import HuggingFaceBoolQ 2 | from .hellaswag import HuggingFaceHellaSwag 3 | from .hotpotqa import HuggingFaceHotpotQA 4 | from .mixin import HuggingFaceBenchmarkMixin 5 | from .mmlu import HuggingFaceMMLU 6 | from .natural_questions import HuggingFaceNaturalQuestions 7 | from .pubmedqa import HuggingFacePubMedQA 8 | from .squad_v2 import HuggingFaceSQuADv2 9 | 10 | __all__ = [ 11 | "HuggingFaceBenchmarkMixin", 12 | "HuggingFaceMMLU", 13 | "HuggingFacePubMedQA", 14 | "HuggingFaceHotpotQA", 15 | "HuggingFaceSQuADv2", 16 | "HuggingFaceNaturalQuestions", 17 | "HuggingFaceBoolQ", 18 | "HuggingFaceHellaSwag", 19 | ] 20 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "weekly" 12 | - package-ecosystem: "uv" 13 | directory: "/" 14 | schedule: 15 | interval: "weekly" 16 | groups: 17 | all-python-packages: 18 | patterns: 19 | - "**" 20 | -------------------------------------------------------------------------------- /src/fed_rag/data_collators/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | from fed_rag.exceptions.common import MissingExtraError 4 | 5 | from .lsr import DataCollatorForLSR 6 | from .ralt import DataCollatorForRALT 7 | 8 | # check if huggingface extra is installed 9 | _has_huggingface = (importlib.util.find_spec("transformers") is not None) and ( 10 | importlib.util.find_spec("peft") is not None 11 | ) 12 | if not _has_huggingface: 13 | msg = ( 14 | f"`{__name__}` requires `huggingface` extra to be installed." 15 | " To fix please run `pip install fed-rag[huggingface]`." 16 | ) 17 | raise MissingExtraError(msg) 18 | 19 | __all__ = ["DataCollatorForLSR", "DataCollatorForRALT"] 20 | -------------------------------------------------------------------------------- /docs/overrides/partials/copyright.html: -------------------------------------------------------------------------------- 1 | 23 | -------------------------------------------------------------------------------- /docs/community/contributing/ask_question.md: -------------------------------------------------------------------------------- 1 | # Ask a Question 2 | 3 | We welcome questions from users and contributors at all levels of experience with 4 | FedRAG. Having questions is a natural part of engaging with a complex project, 5 | and we're here to help. 6 | 7 | ## Where to Ask Questions 8 | 9 | FedRAG offers several channels for asking questions: 10 | 11 | - **Discord Community**: Join our [Discord community](https://discord.gg/5GMpSCFbTe) 12 | for real-time discussions and quick questions. 13 | 14 | - **GitHub Discussions**: For longer, more detailed questions, use 15 | [GitHub Discussions](https://github.com/VectorInstitute/fed-rag/discussions). 16 | This is ideal for questions that might benefit the wider community. 17 | -------------------------------------------------------------------------------- /tests/generators/test_base.py: -------------------------------------------------------------------------------- 1 | from fed_rag.base.generator import BaseGenerator 2 | 3 | 4 | def test_generate(mock_generator: BaseGenerator) -> None: 5 | output = mock_generator.generate(query="hello", context="again") 6 | assert output == "mock output from 'hello' and 'again'." 7 | 8 | 9 | def test_complete(mock_generator: BaseGenerator) -> None: 10 | output = mock_generator.complete(prompt="hello again") 11 | assert output == "mock completion output from 'hello again'." 12 | 13 | 14 | def test_compute_target_sequence_proba(mock_generator: BaseGenerator) -> None: 15 | proba = mock_generator.compute_target_sequence_proba( 16 | prompt="mock prompt", target="mock target" 17 | ) 18 | assert proba == 0.42 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.yml: -------------------------------------------------------------------------------- 1 | name: General Question 2 | description: Ask a question about using or developing with FedRAG. 3 | title: "[Question]: " 4 | labels: ["question", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Have a question about FedRAG? 10 | Please ask it here — we’ll do our best to help! 11 | 12 | - type: textarea 13 | id: question 14 | attributes: 15 | label: Your Question 16 | description: Clearly state your question. 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: context 22 | attributes: 23 | label: Additional Context 24 | description: If helpful, add code examples, screenshots, or links. 25 | -------------------------------------------------------------------------------- /examples/ra-dit/README.md: -------------------------------------------------------------------------------- 1 | # RA-DIT 2 | 3 | ## Usage 4 | 5 | Run the following commands from the `examples/ra-dit` directory. 6 | 7 | ```sh 8 | # source venv 9 | source .venv/bin/activate 10 | 11 | # run federated learning 12 | 13 | ## start server (note this will load the model into cpu) 14 | uv run -m ra_dit.main --task generator --generator_id llama2_7b \ 15 | --generator_variant qlora --component server 16 | 17 | ## start clients using a two-gpu setup 18 | CUDA_VISIBLE_DEVICES=0 uv run -m ra_dit.main --task generator --generator_id \ 19 | llama2_7b --generator_variant qlora --component client_1 20 | 21 | CUDA_VISIBLE_DEVICES=1 uv run -m ra_dit.main --task generator --generator_id \ 22 | llama2_7b --generator_variant qlora --component client_2 23 | ``` 24 | -------------------------------------------------------------------------------- /src/fed_rag/utils/data/finetuning_datasets/pytorch.py: -------------------------------------------------------------------------------- 1 | """PyTorch RAG Finetuning Dataset""" 2 | 3 | from typing import Any 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class PyTorchRAGFinetuningDataset(Dataset): 10 | """PyTorch RAG Fine-Tuning Dataset Class. 11 | 12 | Args: 13 | Dataset (_type_): _description_ 14 | """ 15 | 16 | def __init__( 17 | self, input_ids: list[torch.Tensor], target_ids: list[torch.Tensor] 18 | ): 19 | self.input_ids = input_ids 20 | self.target_ids = target_ids 21 | 22 | def __len__(self) -> int: 23 | return len(self.input_ids) 24 | 25 | def __getitem__(self, idx: int) -> Any: 26 | return self.input_ids[idx], self.target_ids[idx] 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | help: ## Show all Makefile targets. 2 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' 3 | 4 | format: ## Run code autoformatters (black). 5 | pre-commit install 6 | git ls-files | xargs pre-commit run black --files 7 | 8 | lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy 9 | pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files 10 | 11 | test: 12 | pytest tests -v --capture=no 13 | 14 | coverage: # for ci purposes 15 | pytest --cov fed_rag --cov-report=xml tests 16 | 17 | coverage-report: ## Show coverage summary in terminal 18 | coverage report -m 19 | 20 | coverage-html: ## Generate HTML coverage report 21 | coverage html 22 | -------------------------------------------------------------------------------- /src/fed_rag/evals/utils.py: -------------------------------------------------------------------------------- 1 | """Utils module for evals""" 2 | 3 | import json 4 | from pathlib import Path 5 | 6 | from fed_rag.data_structures.evals import BenchmarkEvaluatedExample 7 | from fed_rag.exceptions import EvaluationsFileNotFoundError 8 | 9 | 10 | def load_evaluations(filename: str | Path) -> list[BenchmarkEvaluatedExample]: 11 | """Utility for loading serialized BenchmarkEvaluatedExamples in a JSONL file.""" 12 | 13 | if isinstance(filename, str): 14 | filename = Path(filename) 15 | 16 | if not filename.exists(): 17 | raise EvaluationsFileNotFoundError(str(filename)) 18 | 19 | with open(filename, "r") as f: 20 | data = [json.loads(line) for line in f] 21 | 22 | return [BenchmarkEvaluatedExample(**item) for item in data] 23 | -------------------------------------------------------------------------------- /src/fed_rag/types/rag_system.py: -------------------------------------------------------------------------------- 1 | """ 2 | RAG System type definitions and implementation. 3 | 4 | Note: The RAGSystem implementation has moved to fed_rag.core.rag_system. 5 | This module is maintained for backward compatibility. 6 | """ 7 | 8 | import warnings 9 | 10 | from ..core.rag_system import RAGSystem 11 | from .rag import RAGConfig, RAGResponse, SourceNode 12 | 13 | warnings.warn( 14 | "Importing RAGSystem from fed_rag.types.rag_system is deprecated and will be" 15 | "removed in a future release. Use fed_rag.core.rag_system or fed_rag instead.", 16 | DeprecationWarning, 17 | stacklevel=2, # point to users import statement 18 | ) 19 | 20 | 21 | # Export all symbols for backward compatibility 22 | __all__ = ["RAGSystem", "RAGConfig", "RAGResponse", "SourceNode"] 23 | -------------------------------------------------------------------------------- /src/fed_rag/types/knowledge_node.py: -------------------------------------------------------------------------------- 1 | """Knowledge Node 2 | 3 | Note: The KnowledgeNOde implementation has moved to fed_rag.data_structures.knowledge_node. 4 | This module is maintained for backward compatibility. 5 | """ 6 | 7 | import warnings 8 | 9 | from ..data_structures.knowledge_node import ( 10 | KnowledgeNode, 11 | NodeContent, 12 | NodeType, 13 | ) 14 | 15 | warnings.warn( 16 | "Importing KnowledgeNode, NodeContent, and NodeType from fed_rag.types.knowledge_node" 17 | "is deprecated and will be removed in a future release. Use " 18 | "fed_rag.data_structures.knowledge_node or fed_rag.data_structures instead.", 19 | DeprecationWarning, 20 | stacklevel=2, # point to users import statement 21 | ) 22 | 23 | __all__ = ["KnowledgeNode", "NodeContent", "NodeType"] 24 | -------------------------------------------------------------------------------- /tests/rag_system/test_source_node.py: -------------------------------------------------------------------------------- 1 | from fed_rag.data_structures import KnowledgeNode, SourceNode 2 | 3 | 4 | def test_getattr_sourcenode_wraps_knowledge_node() -> None: 5 | # arrange 6 | knowledge_node = KnowledgeNode( 7 | embedding=[0.1, 0.2], 8 | node_type="text", 9 | text_content="fake text context", 10 | metadata={"some_field": 12}, 11 | ) 12 | 13 | # act 14 | source_node = SourceNode(score=0.42, node=knowledge_node) 15 | 16 | # assert 17 | assert source_node.score == 0.42 18 | assert source_node.text_content == knowledge_node.text_content 19 | assert source_node.node_type == knowledge_node.node_type 20 | assert source_node.node_id == knowledge_node.node_id 21 | assert source_node.metadata == knowledge_node.metadata 22 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/trainer.py: -------------------------------------------------------------------------------- 1 | from .core import FedRAGError 2 | 3 | 4 | class TrainerError(FedRAGError): 5 | """Base errors for all rag trainer relevant exceptions.""" 6 | 7 | pass 8 | 9 | 10 | class InconsistentDatasetError(TrainerError): 11 | """Raised if underlying datasets between dataloaders are inconsistent.""" 12 | 13 | pass 14 | 15 | 16 | class InvalidLossError(TrainerError): 17 | """Raised if an unexpected loss is attached to a trainer object.""" 18 | 19 | pass 20 | 21 | 22 | class InvalidDataCollatorError(TrainerError): 23 | """Raised if an invalid data collator is attached to a trainer object.""" 24 | 25 | pass 26 | 27 | 28 | class MissingInputTensor(TrainerError): 29 | """Raised if a required tensor has not been supplied in the inputs.""" 30 | 31 | pass 32 | -------------------------------------------------------------------------------- /tests/generators/test_hf_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from fed_rag.exceptions import MissingExtraError 7 | from fed_rag.generators.huggingface.utils import check_huggingface_installed 8 | 9 | 10 | def test_check_raises_error() -> None: 11 | """Check raises error from utils.""" 12 | 13 | modules = {"transformers": None} 14 | 15 | with patch.dict("sys.modules", modules): 16 | msg = ( 17 | "Missing installation of the huggingface extra, yet is required " 18 | "by an imported class. To fix please run `pip install fed-rag[huggingface]`." 19 | ) 20 | with pytest.raises( 21 | MissingExtraError, 22 | match=re.escape(msg), 23 | ): 24 | check_huggingface_installed() 25 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/evals.py: -------------------------------------------------------------------------------- 1 | """Exceptions for Evals.""" 2 | 3 | from .core import FedRAGError, FedRAGWarning 4 | 5 | 6 | class EvalsError(FedRAGError): 7 | """Base evals error for all evals-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class EvalsWarning(FedRAGWarning): 13 | """Base inspector warning for all evals-related warnings.""" 14 | 15 | pass 16 | 17 | 18 | class BenchmarkGetExamplesError(EvalsError): 19 | """Raised if an error occurs when getting examples for a benchmark.""" 20 | 21 | pass 22 | 23 | 24 | class BenchmarkParseError(EvalsError): 25 | """Raised when errors occur during parsing examples.""" 26 | 27 | pass 28 | 29 | 30 | class EvaluationsFileNotFoundError(EvalsError, FileNotFoundError): 31 | """Benchmark evaluations file not found error.""" 32 | 33 | pass 34 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | types: 8 | - opened 9 | - synchronize 10 | jobs: 11 | lint: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: get code 15 | uses: actions/checkout@v6 16 | 17 | - name: Install uv 18 | uses: astral-sh/setup-uv@v7 19 | with: 20 | # Install a specific version of uv. 21 | version: "0.5.21" 22 | enable-cache: true 23 | 24 | - name: "Set up Python" 25 | uses: actions/setup-python@v6 26 | with: 27 | python-version: "3.12" 28 | 29 | - name: Install the project 30 | run: uv sync --all-extras --dev 31 | 32 | - name: Run linter and formatter 33 | run: | 34 | uv run make lint 35 | -------------------------------------------------------------------------------- /src/fed_rag/generators/unsloth/utils.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | 3 | from fed_rag.exceptions import MissingExtraError 4 | 5 | 6 | def check_unsloth_installed(cls_name: str | None = None) -> None: 7 | unsloth_spec = find_spec("unsloth") 8 | 9 | has_unsloth = unsloth_spec is not None 10 | if not has_unsloth: 11 | if cls_name: 12 | msg = ( 13 | f"`{cls_name}` requires the `unsloth` extra to be installed. " 14 | "To fix please run `pip install fed-rag[unsloth]`." 15 | ) 16 | else: 17 | msg = ( 18 | "Missing installation of the `unsloth` extra, yet is required " 19 | "by an imported class. To fix please run `pip install fed-rag[unsloth]`." 20 | ) 21 | 22 | raise MissingExtraError(msg) 23 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/fl_tasks.py: -------------------------------------------------------------------------------- 1 | """Exceptions for FL Tasks.""" 2 | 3 | from .core import FedRAGError 4 | 5 | 6 | class FLTaskError(FedRAGError): 7 | """Base fl task error for all fl-task-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class MissingFLTaskConfig(FLTaskError): 13 | """Raised if fl task `trainer` and `tester` do not have `__fl_task_tester_config` attr set.""" 14 | 15 | pass 16 | 17 | 18 | class MissingRequiredNetParam(FLTaskError): 19 | """Raised when invoking fl_task.server without passing the specified model/net param.""" 20 | 21 | pass 22 | 23 | 24 | class NetTypeMismatch(FLTaskError): 25 | """Raised when a `trainer` and `tester` spec have differing `net_parameter_class_name`. 26 | 27 | This indicates that the these methods have different types for the `net_parameter`. 28 | """ 29 | 30 | pass 31 | -------------------------------------------------------------------------------- /docs/community/contributing/submit_issue.md: -------------------------------------------------------------------------------- 1 | # Submitting an Issue 2 | 3 | Issues are an important way to track bugs, feature requests, and improvements to 4 | FedRAG. 5 | 6 | ## Before Creating an Issue 7 | 8 | Before submitting a new issue: 9 | 10 | 1. **Search existing issues**: Check [GitHub Issues](https://github.com/VectorInstitute/fed-rag/issues) 11 | to see if your problem has already been reported or if a related feature request exists. 12 | 13 | 2. **Check the documentation**: Verify that your question isn't already addressed 14 | in our documentation. 15 | 16 | 3. **Confirm it's an issue**: For general questions, please use [GitHub Discussions](https://github.com/VectorInstitute/fed-rag/discussions) 17 | or our [Discord community](https://discord.gg/5GMpSCFbTe) instead. 18 | 19 | We appreciate your contributions to making FedRAG better through thoughtful issue submissions! 20 | -------------------------------------------------------------------------------- /src/fed_rag/core/no_encode_rag_system/synchronous.py: -------------------------------------------------------------------------------- 1 | """No Encode RAG System Module""" 2 | 3 | from fed_rag.core.no_encode_rag_system._synchronous import _NoEncodeRAGSystem 4 | 5 | 6 | # Define the public NoEncodeRAGSystem with all available bridges 7 | class NoEncodeRAGSystem(_NoEncodeRAGSystem): 8 | """NoEncode RAG System with all available bridge functionality. 9 | 10 | The NoEncodeRAGSystem is the main entry point for creating and managing 11 | retrieval-augmented generation systems that skip encoding altogether, 12 | enabling direct natural language queries to knowledge sources like MCP 13 | servers, APIs, and databases. 14 | 15 | Unlike traditional RAG systems that require separate retriever components 16 | and pre-computed embeddings, NoEncode RAG systems perform direct queries 17 | against NoEncode knowledge sources. 18 | """ 19 | 20 | pass 21 | -------------------------------------------------------------------------------- /tests/trainers/test_base.py: -------------------------------------------------------------------------------- 1 | from fed_rag import RAGSystem 2 | 3 | from .conftest import MockRetrieverTrainer, MockTrainer 4 | 5 | 6 | def test_init(mock_rag_system: RAGSystem) -> None: 7 | trainer = MockTrainer( 8 | rag_system=mock_rag_system, 9 | train_dataset=[{"query": "mock example", "response": "mock response"}], 10 | ) 11 | 12 | assert trainer.rag_system == mock_rag_system 13 | 14 | 15 | def test_retriever_trainer_with_dual_encoder_retriever( 16 | mock_rag_system_dual_encoder: RAGSystem, 17 | ) -> None: 18 | trainer = MockRetrieverTrainer( 19 | rag_system=mock_rag_system_dual_encoder, 20 | train_dataset=[{"query": "mock example", "response": "mock response"}], 21 | ) 22 | 23 | assert trainer.rag_system == mock_rag_system_dual_encoder 24 | assert ( 25 | trainer.model == mock_rag_system_dual_encoder.retriever.query_encoder 26 | ) 27 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/trainer_manager.py: -------------------------------------------------------------------------------- 1 | from .core import FedRAGError 2 | 3 | 4 | class RAGTrainerManagerError(FedRAGError): 5 | """Base errors for all rag trainer manager relevant exceptions.""" 6 | 7 | pass 8 | 9 | 10 | class UnspecifiedRetrieverTrainer(RAGTrainerManagerError): 11 | """Raised if a retriever trainer has not been specified when one was expected to be.""" 12 | 13 | pass 14 | 15 | 16 | class UnspecifiedGeneratorTrainer(RAGTrainerManagerError): 17 | """Raised if a generator trainer has not been specified when one was expected to be.""" 18 | 19 | pass 20 | 21 | 22 | class UnsupportedTrainerMode(RAGTrainerManagerError): 23 | """Raised if an unsupported trainer mode has been supplied.""" 24 | 25 | pass 26 | 27 | 28 | class InconsistentRAGSystems(RAGTrainerManagerError): 29 | """Raised if trainers have inconsistent underlying RAG systems.""" 30 | 31 | pass 32 | -------------------------------------------------------------------------------- /tests/evals/benchmarks/huggingface/test_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from fed_rag.evals.benchmarks.huggingface.utils import ( 7 | check_huggingface_evals_installed, 8 | ) 9 | from fed_rag.exceptions import MissingExtraError 10 | 11 | 12 | def test_check_raises_error() -> None: 13 | """Check raises error from utils.""" 14 | 15 | modules = {"datasets": None} 16 | 17 | with patch.dict("sys.modules", modules): 18 | msg = ( 19 | "Missing installation of the huggingface-evals extra, yet is required " 20 | "by an import `HuggingFaceBenchmark` class. To fix please run " 21 | "`pip install fed-rag[huggingface-evals]`." 22 | ) 23 | 24 | with pytest.raises( 25 | MissingExtraError, 26 | match=re.escape(msg), 27 | ): 28 | check_huggingface_evals_installed() 29 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Fajardo" 5 | given-names: "Andrei" 6 | email: "andrei.fajardo@vectorinstitute.ai" 7 | - family-names: "Emerson" 8 | given-names: "David" 9 | email: "david.emerson@vectorinstitute.ai" 10 | title: "fed-rag" 11 | version: "0.0.27" 12 | abstract: "Simplified fine-tuning of retrieval-augmented generation (RAG) systems." 13 | keywords: 14 | - machine learning 15 | - federated learning 16 | - deep learning 17 | - llms 18 | - rag 19 | - retrieval 20 | - semantic search 21 | license: Apache-2.0 22 | doi: 10.5281/zenodo.15092361 23 | repository-code: "https://github.com/VectorInstitute/fed-rag" 24 | type: software 25 | date-released: "2025-03-26" 26 | contact: 27 | - family-names: "Fajardo" 28 | given-names: "Andrei" 29 | email: "andrei.fajardo@vectorinstitute.ai" 30 | -------------------------------------------------------------------------------- /tests/api/test_evals_imports.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import pytest 4 | 5 | from fed_rag.evals import __all__ as _evals_all 6 | from fed_rag.evals.benchmarks import __all__ as _benchmarks_all 7 | 8 | 9 | @pytest.mark.parametrize("name", _evals_all) 10 | def test_evals_all_importable(name: str) -> None: 11 | """Tests that all names listed in evals __all__ are importable.""" 12 | mod = importlib.import_module("fed_rag.evals") 13 | attr = getattr(mod, name) 14 | 15 | assert hasattr(mod, name) 16 | assert attr is not None 17 | 18 | 19 | @pytest.mark.parametrize("name", _benchmarks_all) 20 | def test_evals_benchmarks_all_importable(name: str) -> None: 21 | """Tests that all names listed in evals.benchmarks __all__ are importable.""" 22 | mod = importlib.import_module("fed_rag.evals.benchmarks") 23 | attr = getattr(mod, name) 24 | 25 | assert hasattr(mod, name) 26 | assert attr is not None 27 | -------------------------------------------------------------------------------- /tests/evals/benchmarks/test_base.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | 5 | from fed_rag.exceptions import BenchmarkGetExamplesError 6 | 7 | from . import _benchmarks as benchmarks 8 | 9 | 10 | def test_sequence_interface() -> None: 11 | # typical pattern 12 | test_benchmark = benchmarks.TestBenchmark() 13 | 14 | assert len(test_benchmark) == 3 15 | assert test_benchmark.num_examples == 3 16 | for ix in range(len(test_benchmark)): 17 | assert test_benchmark[ix] == test_benchmark._examples[ix] 18 | example_iter = iter(test_benchmark.as_iterator()) 19 | assert next(example_iter) == test_benchmark[0] 20 | 21 | 22 | def test_get_example_raises_exception() -> None: 23 | # typical pattern 24 | 25 | with pytest.raises( 26 | BenchmarkGetExamplesError, 27 | match=re.escape("Failed to get examples: Too bad, so sad."), 28 | ): 29 | _ = benchmarks.TestBenchmarkBadExamples() 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Local development 2 | venv/ 3 | .venv/ 4 | .ipynb_checkpoints 5 | .__pycache__ 6 | __pycache__ 7 | dev_notebooks/ 8 | .vscode 9 | .mypy_cache 10 | .pytest_cache 11 | .ruff_cache 12 | .env 13 | fed-rag.code-workspace 14 | .fed_rag/ 15 | 16 | # docs 17 | site/ 18 | docs/stylesheets/extra.css.map 19 | 20 | # datasets for running examples 21 | data/ 22 | !src/fed_rag/utils/data 23 | !tests/utils/data 24 | 25 | # HF training artifacts 26 | tmp_trainer/ 27 | 28 | # example benchmark results 29 | .benchmark_results 30 | 31 | # example checkpoints 32 | .checkpoints 33 | 34 | # qdrant 35 | qdrant_storage 36 | 37 | # notebooks 38 | dev_notebooks 39 | **/rag_federated_learning.py 40 | 41 | # trainer outputs 42 | trainer_output 43 | unsloth_compiled_cache 44 | 45 | # import profile 46 | # python -X importtime -c "import fed_rag" 2> import_profile.txt 47 | import_profile.txt 48 | 49 | # coverage 50 | .coverage 51 | coverage.xml 52 | htmlcov/ 53 | -------------------------------------------------------------------------------- /src/fed_rag/data_structures/bridge.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | 3 | 4 | class CompatibleVersions(TypedDict, total=False): 5 | """Type definition for compatible versions. 6 | 7 | Defines optional, inclusive version bounds for compatibility checks. 8 | 9 | Attributes: 10 | min: Minimum compatible version (inclusive). 11 | max: Maximum compatible version (inclusive). 12 | """ 13 | 14 | min: str 15 | max: str 16 | 17 | 18 | class BridgeMetadata(TypedDict): 19 | """Type definition for bridge metadata. 20 | 21 | Attributes: 22 | bridge_version: The version of the bridge. 23 | framework: The framework name. 24 | compatible_versions: Version bounds for compatibility. 25 | method_name: The method name associated with the bridge. 26 | """ 27 | 28 | bridge_version: str 29 | framework: str 30 | compatible_versions: CompatibleVersions 31 | method_name: str 32 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/retrievers/dragon.py: -------------------------------------------------------------------------------- 1 | """Dragon Retriever.""" 2 | 3 | from fed_rag.retrievers.huggingface.hf_sentence_transformer import ( 4 | HFSentenceTransformerRetriever, 5 | ) 6 | 7 | retriever = HFSentenceTransformerRetriever( 8 | query_model_name="nthakur/dragon-plus-query-encoder", 9 | context_model_name="nthakur/dragon-plus-context-encoder", 10 | load_model_at_init=False, 11 | ) 12 | 13 | if __name__ == "__main__": 14 | query = "Where was Marie Curie born?" 15 | contexts = [ 16 | "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.", 17 | "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace.", 18 | ] 19 | 20 | query_embeddings = retriever.encode_query(query) 21 | context_embeddings = retriever.encode_context(contexts) 22 | 23 | scores = query_embeddings @ context_embeddings.T 24 | print(scores) 25 | -------------------------------------------------------------------------------- /src/fed_rag/evals/benchmarks/huggingface/utils.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | 3 | from fed_rag.exceptions import MissingExtraError 4 | 5 | 6 | def check_huggingface_evals_installed(cls_name: str | None = None) -> None: 7 | datasets_spec = find_spec("datasets") 8 | 9 | has_huggingface = datasets_spec is not None 10 | 11 | if not has_huggingface: 12 | if cls_name: 13 | msg = ( 14 | f"`{cls_name}` requires the `huggingface-evals` extra to be installed. " 15 | "To fix please run `pip install fed-rag[huggingface-evals]`." 16 | ) 17 | else: 18 | msg = ( 19 | "Missing installation of the huggingface-evals extra, yet is required " 20 | "by an import `HuggingFaceBenchmark` class. To fix please run " 21 | "`pip install fed-rag[huggingface-evals]`." 22 | ) 23 | 24 | raise MissingExtraError(msg) 25 | -------------------------------------------------------------------------------- /src/fed_rag/base/generator_mixins/audio.py: -------------------------------------------------------------------------------- 1 | """Generator Mixins.""" 2 | 3 | from typing import Protocol, runtime_checkable 4 | 5 | from fed_rag.exceptions.generator import GeneratorError 6 | 7 | 8 | @runtime_checkable 9 | class GeneratorHasAudioModality(Protocol): 10 | """Associated protocol for `AudioModalityMixin`.""" 11 | 12 | __supports_audio__: bool = True 13 | 14 | 15 | class AudioModalityMixin: 16 | """Audio Modality Mixin. 17 | 18 | Meant to be mixed with a `BaseGenerator` to indicate the ability to accept 19 | audio inputs. 20 | """ 21 | 22 | __supports_audio__ = True 23 | 24 | def __init_subclass__(cls) -> None: 25 | """Validate this is mixed with `BaseGenerator`.""" 26 | super().__init_subclass__() 27 | 28 | if "BaseGenerator" not in [t.__name__ for t in cls.__mro__]: 29 | raise GeneratorError( 30 | "`AudioModalityMixin` must be mixed with `BaseGenerator`." 31 | ) 32 | -------------------------------------------------------------------------------- /src/fed_rag/base/generator_mixins/video.py: -------------------------------------------------------------------------------- 1 | """Generator Mixins.""" 2 | 3 | from typing import Protocol, runtime_checkable 4 | 5 | from fed_rag.exceptions.generator import GeneratorError 6 | 7 | 8 | @runtime_checkable 9 | class GeneratorHasVideoModality(Protocol): 10 | """Associated protocol for `VideoModalityMixin`.""" 11 | 12 | __supports_video__: bool = True 13 | 14 | 15 | class VideoModalityMixin: 16 | """Video Modality Mixin. 17 | 18 | Meant to be mixed with a `BaseGenerator` to indicate the ability to accept 19 | video inputs. 20 | """ 21 | 22 | __supports_video__ = True 23 | 24 | def __init_subclass__(cls) -> None: 25 | """Validate this is mixed with `BaseGenerator`.""" 26 | super().__init_subclass__() 27 | 28 | if "BaseGenerator" not in [t.__name__ for t in cls.__mro__]: 29 | raise GeneratorError( 30 | "`VideoModalityMixin` must be mixed with `BaseGenerator`." 31 | ) 32 | -------------------------------------------------------------------------------- /src/fed_rag/base/generator_mixins/image.py: -------------------------------------------------------------------------------- 1 | """Generator Mixins.""" 2 | 3 | from typing import Protocol, runtime_checkable 4 | 5 | from fed_rag.exceptions.generator import GeneratorError 6 | 7 | 8 | @runtime_checkable 9 | class GeneratorHasImageModality(Protocol): 10 | """Associated protocol for `ImageModalityMixin`.""" 11 | 12 | __supports_images__: bool = True 13 | 14 | 15 | class ImageModalityMixin: 16 | """Image Modality Mixin. 17 | 18 | Meant to be mixed with a `BaseGenerator` to indicate the ability to accept 19 | image inputs. 20 | """ 21 | 22 | __supports_images__ = True 23 | 24 | def __init_subclass__(cls) -> None: 25 | """Validate this is mixed with `BaseGenerator`.""" 26 | super().__init_subclass__() 27 | 28 | if "BaseGenerator" not in [t.__name__ for t in cls.__mro__]: 29 | raise GeneratorError( 30 | "`ImageModalityMixin` must be mixed with `BaseGenerator`." 31 | ) 32 | -------------------------------------------------------------------------------- /tests/generators/mixins/test_image_mixin.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from fed_rag.base.generator import BaseGenerator 5 | from fed_rag.base.generator_mixins import ( 6 | GeneratorHasImageModality, 7 | ImageModalityMixin, 8 | ) 9 | from fed_rag.exceptions.generator import GeneratorError 10 | 11 | from ..conftest import MockGenerator 12 | 13 | 14 | class MockMMGenerator(ImageModalityMixin, MockGenerator): 15 | pass 16 | 17 | 18 | def test_mixin() -> None: 19 | mixed_generator = MockMMGenerator() 20 | 21 | assert isinstance(mixed_generator, GeneratorHasImageModality) 22 | assert isinstance(mixed_generator, BaseGenerator) 23 | 24 | 25 | def test_mixin_fails_validation() -> None: 26 | with pytest.raises( 27 | GeneratorError, 28 | match="`ImageModalityMixin` must be mixed with `BaseGenerator`.", 29 | ): 30 | 31 | class InvalidMockMMGenerator(ImageModalityMixin, BaseModel): 32 | pass 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation_improvement.yml: -------------------------------------------------------------------------------- 1 | name: Documentation Improvement 2 | description: Suggest a fix or improvement to the FedRAG documentation. 3 | title: "[Docs]: " 4 | labels: ["documentation", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Help us improve our documentation! 10 | Please provide as much detail as you can below. 11 | 12 | - type: textarea 13 | id: location 14 | attributes: 15 | label: Location of Issue 16 | description: What page, section, or example needs updating? 17 | 18 | - type: textarea 19 | id: problem 20 | attributes: 21 | label: Problem or Gap 22 | description: Describe the issue — missing info, outdated content, unclear instructions, etc. 23 | validations: 24 | required: true 25 | 26 | - type: textarea 27 | id: suggestion 28 | attributes: 29 | label: Suggested Change 30 | description: What would you like to see instead? 31 | -------------------------------------------------------------------------------- /src/fed_rag/base/retriever_mixins/audio.py: -------------------------------------------------------------------------------- 1 | """Retriever Mixins.""" 2 | 3 | from abc import ABC 4 | from typing import Protocol, runtime_checkable 5 | 6 | from fed_rag.exceptions.retriever import RetrieverError 7 | 8 | 9 | @runtime_checkable 10 | class RetrieverHasAudioModality(Protocol): 11 | """Associated protocol for `AudioRetrieverMixin`.""" 12 | 13 | __supports_audio__: bool = True 14 | 15 | 16 | class AudioRetrieverMixin(ABC): 17 | """Audio Retriever Mixin. 18 | 19 | Meant to be mixed with a `BaseRetriever` to add audio modality for 20 | retrieval. 21 | """ 22 | 23 | __supports_audio__ = True 24 | 25 | def __init_subclass__(cls) -> None: 26 | """Validate this is mixed with `BaseRetriever`.""" 27 | super().__init_subclass__() 28 | 29 | if "BaseRetriever" not in [t.__name__ for t in cls.__mro__]: 30 | raise RetrieverError( 31 | "`AudioRetrieverMixin` must be mixed with `BaseRetriever`." 32 | ) 33 | -------------------------------------------------------------------------------- /src/fed_rag/base/retriever_mixins/image.py: -------------------------------------------------------------------------------- 1 | """Retriever Mixins.""" 2 | 3 | from abc import ABC 4 | from typing import Protocol, runtime_checkable 5 | 6 | from fed_rag.exceptions.retriever import RetrieverError 7 | 8 | 9 | @runtime_checkable 10 | class RetrieverHasImageModality(Protocol): 11 | """Associated protocol for `ImageRetrieverMixin`.""" 12 | 13 | __supports_images__: bool = True 14 | 15 | 16 | class ImageRetrieverMixin(ABC): 17 | """Image Retriever Mixin. 18 | 19 | Meant to be mixed with a `BaseRetriever` to add image modality for 20 | retrieval. 21 | """ 22 | 23 | __supports_images__ = True 24 | 25 | def __init_subclass__(cls) -> None: 26 | """Validate this is mixed with `BaseRetriever`.""" 27 | super().__init_subclass__() 28 | 29 | if "BaseRetriever" not in [t.__name__ for t in cls.__mro__]: 30 | raise RetrieverError( 31 | "`ImageRetrieverMixin` must be mixed with `BaseRetriever`." 32 | ) 33 | -------------------------------------------------------------------------------- /src/fed_rag/base/retriever_mixins/video.py: -------------------------------------------------------------------------------- 1 | """Retriever Mixins.""" 2 | 3 | from abc import ABC 4 | from typing import Protocol, runtime_checkable 5 | 6 | from fed_rag.exceptions.retriever import RetrieverError 7 | 8 | 9 | @runtime_checkable 10 | class RetrieverHasVideoModality(Protocol): 11 | """Associated protocol for `VideoRetrieverMixin`.""" 12 | 13 | __supports_video__: bool = True 14 | 15 | 16 | class VideoRetrieverMixin(ABC): 17 | """Video Retriever Mixin. 18 | 19 | Meant to be mixed with a `BaseRetriever` to add video modality for 20 | retrieval. 21 | """ 22 | 23 | __supports_video__ = True 24 | 25 | def __init_subclass__(cls) -> None: 26 | """Validate this is mixed with `BaseRetriever`.""" 27 | super().__init_subclass__() 28 | 29 | if "BaseRetriever" not in [t.__name__ for t in cls.__mro__]: 30 | raise RetrieverError( 31 | "`VideoRetrieverMixin` must be mixed with `BaseRetriever`." 32 | ) 33 | -------------------------------------------------------------------------------- /docs/getting_started/quick_starts/index.md: -------------------------------------------------------------------------------- 1 | # Quick Starts 2 | 3 | 4 | 5 | In this next part in getting to know FedRAG, we provide a mini series of 6 | quick start examples in order to get a better feeling of the library. 7 | 8 |
9 | 10 | - :material-hexagon-outline: [__Centralized to Federated__](./federated.md) — Transform 11 | a centralized training task into a federated learning task. 12 | - :material-hexagon-outline: [__Build a RAG System__](./rag_inference.md) — Assemble 13 | a RAG system using FedRAG's lightweight abstractions. 14 | - :material-hexagon-outline: [__Fine-tune a RAG System__](./rag_finetuning.md) — Fine-tune 15 | a RAG system on custom QA data, demonstrating both centralized training and 16 | optional federation capabilities. 17 | - :material-hexagon-outline: [__Benchmark a RAG System__](./benchmark_mmlu.md) — 18 | Evaluate a RAG system on popular benchmarks like MMLU. 19 | 20 |
21 | -------------------------------------------------------------------------------- /tests/generators/mixins/test_audio_mixin.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from fed_rag.base.generator import BaseGenerator 5 | from fed_rag.base.generator_mixins.audio import ( 6 | AudioModalityMixin, 7 | GeneratorHasAudioModality, 8 | ) 9 | from fed_rag.exceptions.generator import GeneratorError 10 | 11 | from ..conftest import MockGenerator 12 | 13 | 14 | class MockAudioGenerator(AudioModalityMixin, MockGenerator): 15 | pass 16 | 17 | 18 | def test_audio_mixin() -> None: 19 | mixed_generator = MockAudioGenerator() 20 | assert isinstance(mixed_generator, GeneratorHasAudioModality) 21 | assert isinstance(mixed_generator, BaseGenerator) 22 | 23 | 24 | def test_audio_mixin_fails_validation() -> None: 25 | with pytest.raises( 26 | GeneratorError, 27 | match="`AudioModalityMixin` must be mixed with `BaseGenerator`.", 28 | ): 29 | 30 | class InvalidMockAudioGenerator(AudioModalityMixin, BaseModel): 31 | pass 32 | -------------------------------------------------------------------------------- /tests/generators/mixins/test_video_mixin.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from fed_rag.base.generator import BaseGenerator 5 | from fed_rag.base.generator_mixins.video import ( 6 | GeneratorHasVideoModality, 7 | VideoModalityMixin, 8 | ) 9 | from fed_rag.exceptions.generator import GeneratorError 10 | 11 | from ..conftest import MockGenerator 12 | 13 | 14 | class MockVideoGenerator(VideoModalityMixin, MockGenerator): 15 | pass 16 | 17 | 18 | def test_video_mixin() -> None: 19 | mixed_generator = MockVideoGenerator() 20 | assert isinstance(mixed_generator, GeneratorHasVideoModality) 21 | assert isinstance(mixed_generator, BaseGenerator) 22 | 23 | 24 | def test_video_mixin_fails_validation() -> None: 25 | with pytest.raises( 26 | GeneratorError, 27 | match="`VideoModalityMixin` must be mixed with `BaseGenerator`.", 28 | ): 29 | 30 | class InvalidMockVideoGenerator(VideoModalityMixin, BaseModel): 31 | pass 32 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/no_encode/mcp/sources/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol 2 | 3 | from mcp.types import CallToolResult 4 | 5 | from fed_rag.data_structures import KnowledgeNode 6 | from fed_rag.exceptions import CallToolResultConversionError 7 | 8 | 9 | class CallToolResultConverter(Protocol): 10 | def __call__( 11 | self, result: CallToolResult, metadata: dict[str, Any] | None = None 12 | ) -> list[KnowledgeNode]: 13 | pass # pragma: no cover 14 | 15 | 16 | def default_converter( 17 | result: CallToolResult, metadata: dict[str, Any] | None = None 18 | ) -> list[KnowledgeNode]: 19 | if result.isError: 20 | raise CallToolResultConversionError( 21 | "Cannot convert a `CallToolResult` with `isError` set to `True`." 22 | ) 23 | 24 | return [ 25 | KnowledgeNode( 26 | node_type="text", 27 | text_content=c.text, 28 | metadata=metadata, 29 | ) 30 | for c in result.content 31 | if c.type == "text" 32 | ] 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest a new feature or improvement for FedRAG. 3 | title: "[Feature]: " 4 | labels: ["enhancement", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for suggesting a feature to improve FedRAG! 10 | Please describe your idea in detail below. 11 | 12 | - type: textarea 13 | id: problem 14 | attributes: 15 | label: Problem Statement 16 | description: What problem or need would this feature solve? 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: proposal 22 | attributes: 23 | label: Proposed Solution 24 | description: How would you like to see this implemented? Feel free to share ideas or API sketches. 25 | validations: 26 | required: true 27 | 28 | - type: textarea 29 | id: alternatives 30 | attributes: 31 | label: Alternatives Considered 32 | description: Have you considered other solutions or workarounds? 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/integration_request.yml: -------------------------------------------------------------------------------- 1 | name: Integration Request 2 | description: Request support for a new framework or tool with FedRAG. 3 | title: "[Integration]: " 4 | labels: ["integration", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Suggest a new integration for FedRAG! 10 | Tell us about the tool or framework and why it would be useful. 11 | 12 | - type: input 13 | id: framework 14 | attributes: 15 | label: Target Framework/Tool 16 | description: Name and (optionally) link to the tool you want to integrate. 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: motivation 22 | attributes: 23 | label: Motivation 24 | description: Why would this integration be valuable for FedRAG users? 25 | validations: 26 | required: true 27 | 28 | - type: textarea 29 | id: ideas 30 | attributes: 31 | label: Proposed Approach 32 | description: If you have ideas about how integration might work, share them here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report an unexpected error or broken behavior in FedRAG. 3 | title: "[Bug]: " 4 | labels: ["bug", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to file a bug report! 10 | Please complete the following sections to help us reproduce and fix the issue. 11 | 12 | - type: textarea 13 | id: what-happened 14 | attributes: 15 | label: Bug Description 16 | description: What happened? What behavior did you expect? 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: steps-to-reproduce 22 | attributes: 23 | label: Steps to Reproduce 24 | description: Provide a clear minimal example to reproduce the bug. 25 | validations: 26 | required: true 27 | 28 | - type: textarea 29 | id: logs 30 | attributes: 31 | label: Relevant Logs/Tracebacks 32 | description: Please copy and paste any error messages or logs. 33 | render: shell 34 | -------------------------------------------------------------------------------- /src/fed_rag/core/rag_system/asynchronous.py: -------------------------------------------------------------------------------- 1 | """Async RAG System Module""" 2 | 3 | from fed_rag._bridges.langchain.bridge import LangChainBridgeMixin 4 | from fed_rag._bridges.llamaindex.bridge import LlamaIndexBridgeMixin 5 | from fed_rag.core.rag_system._asynchronous import _AsyncRAGSystem 6 | 7 | from .synchronous import RAGSystem 8 | 9 | 10 | # Define the public RAGSystem with all available bridges 11 | class AsyncRAGSystem( 12 | LlamaIndexBridgeMixin, LangChainBridgeMixin, _AsyncRAGSystem 13 | ): 14 | """Async RAG System with all available bridge functionality. 15 | 16 | The RAGSystem is the main entry point for creating and managing 17 | retrieval-augmented generation systems. 18 | """ 19 | 20 | def to_sync( 21 | self, 22 | ) -> RAGSystem: 23 | return RAGSystem( 24 | knowledge_store=self.knowledge_store.to_sync(), 25 | generator=self.generator, # NOTE: this should actually be sync! 26 | retriever=self.retriever, # NOTE: this should actually be sync! 27 | rag_config=self.rag_config, 28 | ) 29 | -------------------------------------------------------------------------------- /src/fed_rag/data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | fed_rag.data_structures 3 | 4 | Only components defined in `__all__` are considered stable and public. 5 | """ 6 | 7 | from .bridge import BridgeMetadata, CompatibleVersions 8 | from .evals import ( 9 | AggregationMode, 10 | BenchmarkEvaluatedExample, 11 | BenchmarkExample, 12 | BenchmarkResult, 13 | ) 14 | from .knowledge_node import KnowledgeNode, NodeContent, NodeType 15 | from .rag import Context, Prompt, Query, RAGConfig, RAGResponse, SourceNode 16 | from .results import TestResult, TrainResult 17 | 18 | __all__ = [ 19 | # bridge 20 | "BridgeMetadata", 21 | "CompatibleVersions", 22 | # evals 23 | "AggregationMode", 24 | "BenchmarkExample", 25 | "BenchmarkResult", 26 | "BenchmarkEvaluatedExample", 27 | # results 28 | "TrainResult", 29 | "TestResult", 30 | # knowledge node 31 | "KnowledgeNode", 32 | "NodeType", 33 | "NodeContent", 34 | # rag 35 | "RAGConfig", 36 | "RAGResponse", 37 | "SourceNode", 38 | "Query", 39 | "Context", 40 | "Prompt", 41 | ] 42 | -------------------------------------------------------------------------------- /src/fed_rag/generators/huggingface/utils.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | 3 | from fed_rag.exceptions import MissingExtraError 4 | 5 | 6 | def check_huggingface_installed(cls_name: str | None = None) -> None: 7 | transformers_spec = find_spec("transformers") 8 | peft_spec = find_spec("peft") 9 | sentence_transformers_spec = find_spec("sentence_transformers") 10 | 11 | has_huggingface = ( 12 | (transformers_spec is not None) 13 | and (peft_spec is not None) 14 | and (sentence_transformers_spec is not None) 15 | ) 16 | if not has_huggingface: 17 | if cls_name: 18 | msg = ( 19 | f"`{cls_name}` requires the `huggingface` extra to be installed. " 20 | "To fix please run `pip install fed-rag[huggingface]`." 21 | ) 22 | else: 23 | msg = ( 24 | "Missing installation of the huggingface extra, yet is required " 25 | "by an imported class. To fix please run `pip install fed-rag[huggingface]`." 26 | ) 27 | 28 | raise MissingExtraError(msg) 29 | -------------------------------------------------------------------------------- /src/fed_rag/utils/data/finetuning_datasets/huggingface.py: -------------------------------------------------------------------------------- 1 | """HuggingFace RAG Finetuning Dataset""" 2 | 3 | from typing_extensions import Self 4 | 5 | from fed_rag.exceptions.common import MissingExtraError 6 | 7 | # check if huggingface extra was installed 8 | try: 9 | from datasets import Dataset 10 | except ModuleNotFoundError: 11 | msg = ( 12 | "`HuggingFaceRAGFinetuningDataset` requires the `huggingface` extra to be installed. " 13 | "To fix please run `pip install fed-rag[huggingface]`." 14 | ) 15 | raise MissingExtraError(msg) 16 | 17 | 18 | class HuggingFaceRAGFinetuningDataset(Dataset): 19 | """Thin wrapper over ~datasets.Dataset.""" 20 | 21 | @classmethod 22 | def from_inputs( 23 | cls, 24 | input_ids: list[list[int]], 25 | target_ids: list[list[int]], 26 | attention_mask: list[list[int]], 27 | ) -> Self: 28 | return cls.from_dict( # type: ignore[no-any-return] 29 | { 30 | "input_ids": input_ids, 31 | "target_ids": target_ids, 32 | "attention_mask": attention_mask, 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /src/fed_rag/base/data_collator.py: -------------------------------------------------------------------------------- 1 | """Base Data Collator""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any 5 | 6 | from pydantic import BaseModel, ConfigDict 7 | 8 | from fed_rag import RAGSystem 9 | 10 | 11 | class BaseDataCollator(BaseModel, ABC): 12 | """ 13 | Base Data Collator. 14 | 15 | Abstract base class for collating input examples into batches that can 16 | be used by a retrieval-augmented generation (RAG) system. 17 | """ 18 | 19 | model_config = ConfigDict(arbitrary_types_allowed=True) 20 | rag_system: RAGSystem 21 | 22 | @abstractmethod 23 | def __call__(self, features: list[dict[str, Any]], **kwargs: Any) -> Any: 24 | """Collate examples into a batch. 25 | 26 | Args: 27 | features (list[dict[str, Any]]): A list of feature dictionaries, 28 | where each dictionary represents one example. 29 | **kwargs (Any): Additional keyword arguments that may be used 30 | by specific implementations. 31 | 32 | Returns: 33 | Any: A collated batch, with format depending on the implementation. 34 | """ 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: General Questions or Support 4 | url: https://github.com/VectorInstitute/fed-rag/discussions 5 | about: Please ask questions, request help, or start a conversation here. 6 | 7 | issue_templates: 8 | - name: Bug Report 9 | description: Report an unexpected error or broken behavior in FedRAG. 10 | labels: ["bug"] 11 | file: bug_report.md 12 | 13 | - name: Feature Request 14 | description: Suggest a new feature or improvement for FedRAG. 15 | labels: ["enhancement"] 16 | file: feature_request.md 17 | 18 | - name: Documentation Improvement 19 | description: Propose improvements or updates to the FedRAG documentation. 20 | labels: ["documentation"] 21 | file: documentation_improvement.md 22 | 23 | - name: Integration Request 24 | description: Request support for new frameworks or tools with FedRAG. 25 | labels: ["integration"] 26 | file: integration_request.md 27 | 28 | - name: General Question 29 | description: Ask a question about using or developing with FedRAG. 30 | labels: ["question"] 31 | file: question.md 32 | -------------------------------------------------------------------------------- /src/fed_rag/__init__.pyi: -------------------------------------------------------------------------------- 1 | """Type stubs for fed_rag module""" 2 | 3 | # Lazy-loaded classes (type-only declarations) 4 | # Lazy-loaded modules 5 | from fed_rag import generators as generators 6 | from fed_rag import retrievers as retrievers 7 | from fed_rag import trainer_managers as trainer_managers 8 | from fed_rag import trainers as trainers 9 | from fed_rag.generators import HFPeftModelGenerator as HFPeftModelGenerator 10 | from fed_rag.generators import ( 11 | HFPretrainedModelGenerator as HFPretrainedModelGenerator, 12 | ) 13 | from fed_rag.generators import ( 14 | UnslothFastModelGenerator as UnslothFastModelGenerator, 15 | ) 16 | from fed_rag.retrievers import ( 17 | HFSentenceTransformerRetriever as HFSentenceTransformerRetriever, 18 | ) 19 | from fed_rag.trainer_managers import ( 20 | HuggingFaceRAGTrainerManager as HuggingFaceRAGTrainerManager, 21 | ) 22 | from fed_rag.trainer_managers import ( 23 | PyTorchRAGTrainerManager as PyTorchRAGTrainerManager, 24 | ) 25 | from fed_rag.trainers import ( 26 | HuggingFaceTrainerForLSR as HuggingFaceTrainerForLSR, 27 | ) 28 | from fed_rag.trainers import ( 29 | HuggingFaceTrainerForRALT as HuggingFaceTrainerForRALT, 30 | ) 31 | -------------------------------------------------------------------------------- /tests/data_structures/test_evals.py: -------------------------------------------------------------------------------- 1 | from fed_rag.data_structures import ( 2 | BenchmarkEvaluatedExample, 3 | BenchmarkExample, 4 | KnowledgeNode, 5 | RAGResponse, 6 | SourceNode, 7 | ) 8 | 9 | 10 | def test_model_dump_without_embs() -> None: 11 | evaluated = BenchmarkEvaluatedExample( 12 | score=0.42, 13 | example=BenchmarkExample(query="mock query", response="mock response"), 14 | rag_response=RAGResponse( 15 | response="mock rag reponse", 16 | source_nodes=[ 17 | SourceNode( 18 | score=0.1, 19 | node=KnowledgeNode( 20 | embedding=[1, 2, 3], # embeddings not persisted 21 | node_type="text", 22 | text_content="fake content", 23 | ), 24 | ), 25 | ], 26 | ), 27 | ) 28 | 29 | # act 30 | json_str = evaluated.model_dump_json_without_embeddings() 31 | 32 | # assert 33 | loaded_evaluated = BenchmarkEvaluatedExample.model_validate_json(json_str) 34 | assert loaded_evaluated.rag_response.source_nodes[0].node.embedding is None 35 | -------------------------------------------------------------------------------- /tests/api/test_deprecated_types_imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | 5 | DEPRECATED_IMPORTS = [ 6 | ("fed_rag.types.bridge", "BridgeMetadata"), 7 | ("fed_rag.types.results", "TrainResult"), 8 | ("fed_rag.types.results", "TestResult"), 9 | ("fed_rag.types.knowledge_node", "KnowledgeNode"), 10 | ("fed_rag.types.knowledge_node", "NodeType"), 11 | ("fed_rag.types.knowledge_node", "NodeContent"), 12 | ("fed_rag.types.rag", "RAGConfig"), 13 | ("fed_rag.types.rag", "RAGResponse"), 14 | ("fed_rag.types.rag", "SourceNode"), 15 | ] 16 | 17 | 18 | @pytest.mark.parametrize("module_path,class_name", DEPRECATED_IMPORTS) 19 | def test_import_from_types_raises_deprecation_warning( 20 | module_path: str, class_name: str 21 | ) -> None: 22 | """Test that importing from deprecated types modules raises warnings.""" 23 | 24 | # clear the module from sys.modules if it exists 25 | if module_path in sys.modules: 26 | del sys.modules[module_path] 27 | 28 | with pytest.warns(DeprecationWarning): 29 | import importlib 30 | 31 | module = importlib.import_module(module_path) 32 | getattr(module, class_name) # ensure its loaded 33 | -------------------------------------------------------------------------------- /tests/loss/pytorch/conftest.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | import torch 5 | 6 | EMB_DIM = 10 7 | BATCH_SIZE = 2 8 | NUM_CHUNKS = 3 9 | 10 | 11 | @pytest.fixture() 12 | def retrieved_chunks() -> torch.Tensor: 13 | """Embeddings of 'retrieved' chunks.""" 14 | batch = [] 15 | for bx in range(1, BATCH_SIZE + 1): 16 | embs = [] 17 | for ix in range(1, NUM_CHUNKS + 1): 18 | embs.append([bx / ix for _ in range(EMB_DIM)]) 19 | batch.append(embs) 20 | 21 | return torch.tensor(batch, dtype=torch.float32) 22 | 23 | 24 | @pytest.fixture() 25 | def contexts() -> torch.Tensor: 26 | batch = [] 27 | for ix in range(1, BATCH_SIZE): 28 | batch.append(torch.ones(EMB_DIM) * ix) 29 | return torch.stack(batch, dim=0) 30 | 31 | 32 | @pytest.fixture() 33 | def lm_scores() -> torch.Tensor: 34 | """Mock probas of generated outputs 'given' context and chunk.""" 35 | batch = [] 36 | for bx in range(1, BATCH_SIZE + 1): 37 | scores = [math.exp(ix) for ix in range(NUM_CHUNKS)] 38 | scores = [el / sum(scores) for el in scores] 39 | batch.append(scores) 40 | 41 | return torch.tensor(batch, dtype=torch.float32) 42 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/qa/pubmed.py: -------------------------------------------------------------------------------- 1 | """PubmedQA 2 | 3 | Example 4 | === 5 | { 6 | "question": ..., 7 | "context": { 8 | "contexts": [], 9 | ... 10 | }, 11 | "long_answer": ..., 12 | "final_decision": ... 13 | } 14 | """ 15 | 16 | import pandas as pd 17 | 18 | from ..base_data_prepper import DEFAULT_SAVE_DIR, BaseDataPrepper 19 | from .mixin import QAMixin 20 | 21 | QA_SAVE_DIR = DEFAULT_SAVE_DIR / "qa" 22 | 23 | 24 | class PubmedQADataPrepper(QAMixin, BaseDataPrepper): 25 | @property 26 | def dataset_name(self) -> str: 27 | return "pubmed_qa" 28 | 29 | def _get_answer(self, row: pd.Series) -> str: 30 | return str(row["long_answer"] + "\n\n" + row["final_decision"]) 31 | 32 | def _get_evidence(self, row: pd.Series) -> str: 33 | return "\n\n".join(row["context"]["contexts"]) 34 | 35 | def _get_question(self, row: pd.Series) -> str: 36 | return str(row["question"]) 37 | 38 | 39 | df = pd.read_parquet( 40 | "hf://datasets/qiaojin/PubMedQA/pqa_artificial/train-00000-of-00001.parquet" 41 | ) 42 | data_prepper = PubmedQADataPrepper(df=df, save_dir=QA_SAVE_DIR) 43 | data_prepper.execute_and_save() 44 | -------------------------------------------------------------------------------- /tests/generators/test_unsloth_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from fed_rag.exceptions import MissingExtraError 7 | from fed_rag.generators.unsloth.utils import check_unsloth_installed 8 | 9 | 10 | def test_check_raises_error() -> None: 11 | """Check raises error from utils.""" 12 | 13 | modules = {"unsloth": None} 14 | 15 | with patch.dict("sys.modules", modules): 16 | # without class name 17 | msg = ( 18 | "Missing installation of the `unsloth` extra, yet is required " 19 | "by an imported class. To fix please run `pip install fed-rag[unsloth]`." 20 | ) 21 | with pytest.raises( 22 | MissingExtraError, 23 | match=re.escape(msg), 24 | ): 25 | check_unsloth_installed() 26 | 27 | # with class name 28 | msg = ( 29 | "`FakeClass` requires the `unsloth` extra to be installed. " 30 | "To fix please run `pip install fed-rag[unsloth]`." 31 | ) 32 | with pytest.raises( 33 | MissingExtraError, 34 | match=re.escape(msg), 35 | ): 36 | check_unsloth_installed("FakeClass") 37 | -------------------------------------------------------------------------------- /tests/tokenizers/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | import tokenizers 5 | from tokenizers import Tokenizer, models 6 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 7 | 8 | from fed_rag.base.tokenizer import BaseTokenizer 9 | 10 | 11 | class MockTokenizer(BaseTokenizer): 12 | def encode(self, input: str, **kwargs: Any) -> list[int]: 13 | return [0, 1, 2] 14 | 15 | def decode(self, input_ids: list[int], **kwargs: Any) -> str: 16 | return "mock decoded sentence" 17 | 18 | @property 19 | def unwrapped(self) -> None: 20 | return None 21 | 22 | 23 | @pytest.fixture() 24 | def mock_tokenizer() -> BaseTokenizer: 25 | return MockTokenizer() 26 | 27 | 28 | @pytest.fixture 29 | def hf_tokenizer() -> PreTrainedTokenizer: 30 | tokenizer = Tokenizer( 31 | models.WordPiece({"hello": 0, "[UNK]": 1}, unk_token="[UNK]") 32 | ) 33 | tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.WhitespaceSplit() 34 | return PreTrainedTokenizerFast( 35 | tokenizer_object=tokenizer, 36 | pad_token="[PAD]", 37 | cls_token="[CLS]", 38 | sep_token="[SEP]", 39 | mask_token="[MASK]", 40 | ) 41 | -------------------------------------------------------------------------------- /tests/retrievers/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | import torch 5 | from pydantic import PrivateAttr 6 | from sentence_transformers import SentenceTransformer 7 | 8 | from fed_rag.base.retriever import BaseRetriever 9 | 10 | 11 | class MockRetriever(BaseRetriever): 12 | _encoder: torch.nn.Module = PrivateAttr(default=torch.nn.Linear(2, 1)) 13 | 14 | def encode_context(self, context: str, **kwargs: Any) -> torch.Tensor: 15 | return self._encoder.forward(torch.ones(2)) 16 | 17 | def encode_query(self, query: str, **kwargs: Any) -> torch.Tensor: 18 | return self._encoder.forward(torch.zeros(2)) 19 | 20 | @property 21 | def encoder(self) -> torch.nn.Module: 22 | return self._encoder 23 | 24 | @property 25 | def query_encoder(self) -> torch.nn.Module | None: 26 | return None 27 | 28 | @property 29 | def context_encoder(self) -> torch.nn.Module | None: 30 | return None 31 | 32 | 33 | @pytest.fixture 34 | def mock_retriever() -> MockRetriever: 35 | return MockRetriever() 36 | 37 | 38 | @pytest.fixture 39 | def dummy_sentence_transformer() -> SentenceTransformer: 40 | return SentenceTransformer(modules=[torch.nn.Linear(5, 5)]) 41 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/qa/web_questions.py: -------------------------------------------------------------------------------- 1 | """WebQA 2 | 3 | Example 4 | === 5 | { 6 | 'url': 'http://www.freebase.com/view/en/justin_bieber', 7 | 'question': 'http://www.freebase.com/view/en/justin_bieber', 8 | 'answer': 'answers' 9 | }, 10 | """ 11 | 12 | import pandas as pd 13 | 14 | from ..base_data_prepper import DEFAULT_SAVE_DIR, BaseDataPrepper 15 | from .mixin import QAMixin 16 | 17 | QA_SAVE_DIR = DEFAULT_SAVE_DIR / "qa" 18 | 19 | 20 | class WebQuestionsDataPrepper(QAMixin, BaseDataPrepper): 21 | @property 22 | def dataset_name(self) -> str: 23 | return "web_questions_qa" 24 | 25 | def _get_answer(self, row: pd.Series) -> str: 26 | return str(", ".join(row["answers"])) 27 | 28 | def _get_evidence(self, row: pd.Series) -> str | None: 29 | return None 30 | 31 | def _get_question(self, row: pd.Series) -> str: 32 | return str(row["question"]) 33 | 34 | 35 | splits = { 36 | "train": "data/train-00000-of-00001.parquet", 37 | "test": "data/test-00000-of-00001.parquet", 38 | } 39 | 40 | df = pd.read_parquet("hf://datasets/Stanford/web_questions/" + splits["train"]) 41 | data_prepper = WebQuestionsDataPrepper(df=df, save_dir=QA_SAVE_DIR) 42 | data_prepper.execute_and_save() 43 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/knowledge_stores.py: -------------------------------------------------------------------------------- 1 | """Exceptions for Knowledge Stores.""" 2 | 3 | from .core import FedRAGError, FedRAGWarning 4 | 5 | 6 | class KnowledgeStoreError(FedRAGError): 7 | """Base knowledge store error for all knowledge-store-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class KnowledgeStoreWarning(FedRAGWarning): 13 | """Base knowledge store error for all knowledge-store-related warnings.""" 14 | 15 | pass 16 | 17 | 18 | class KnowledgeStoreNotFoundError(KnowledgeStoreError, FileNotFoundError): 19 | """Raised if the knowledge store can not be found or loaded from file.""" 20 | 21 | pass 22 | 23 | 24 | class InvalidDistanceError(KnowledgeStoreError): 25 | """Raised if provided an invalid similarity distance.""" 26 | 27 | pass 28 | 29 | 30 | class LoadNodeError(KnowledgeStoreError): 31 | """Raised if an error occurs when loading a node.""" 32 | 33 | pass 34 | 35 | 36 | class MCPKnowledgeStoreError(KnowledgeStoreError): 37 | """Base knowledge store error for all knowledge-store-related exceptions.""" 38 | 39 | pass 40 | 41 | 42 | class CallToolResultConversionError(MCPKnowledgeStoreError): 43 | """Raised when trying to convert a ~mcp.CallToolResult that has error status.""" 44 | 45 | pass 46 | -------------------------------------------------------------------------------- /tests/retrievers/test_base.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | import torch 4 | 5 | from fed_rag.base.retriever import BaseRetriever 6 | 7 | 8 | def test_base_abstract_attr() -> None: 9 | abstract_methods = BaseRetriever.__abstractmethods__ 10 | 11 | assert "encode_context" in abstract_methods 12 | assert "encode_query" in abstract_methods 13 | assert "encoder" in abstract_methods 14 | assert "query_encoder" in abstract_methods 15 | assert "context_encoder" in abstract_methods 16 | 17 | 18 | def test_base_encode(mock_retriever: BaseRetriever) -> None: 19 | encoded_ctx = mock_retriever.encode_context("mock context") 20 | encoded_query = mock_retriever.encode_query("mock query") 21 | cosine_sim = encoded_ctx @ encoded_query.T 22 | *_, final_layer = mock_retriever.encoder.parameters() 23 | 24 | with does_not_raise(): 25 | # cosine sim should be a Tensor with a single item 26 | cosine_sim.item() 27 | 28 | assert encoded_ctx.numel() == final_layer.size()[-1] 29 | assert encoded_query.numel() == final_layer.size()[-1] 30 | assert isinstance(mock_retriever.encoder, torch.nn.Module) 31 | assert mock_retriever.query_encoder is None 32 | assert mock_retriever.context_encoder is None 33 | -------------------------------------------------------------------------------- /.github/workflows/unit_test.yml: -------------------------------------------------------------------------------- 1 | name: Unit Testing and Upload Coverage 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | types: 8 | - opened 9 | - synchronize 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | id-token: write 15 | contents: read 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.10", "3.11", "3.12"] 20 | steps: 21 | - name: get code 22 | uses: actions/checkout@v6 23 | 24 | - name: Install uv 25 | uses: astral-sh/setup-uv@v7 26 | with: 27 | # Install a specific version of uv. 28 | version: "0.5.21" 29 | enable-cache: true 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install the project 33 | run: uv sync --all-extras --dev 34 | 35 | - name: Run tests 36 | run: | 37 | uv run make coverage 38 | 39 | - if: matrix.python-version == '3.12' 40 | name: Upload results to Codecov 41 | uses: codecov/codecov-action@v5 42 | with: 43 | token: ${{ secrets.CODECOV_TOKEN }} 44 | slug: VectorInstitute/fed-rag 45 | fail_ci_if_error: true 46 | verbose: true 47 | -------------------------------------------------------------------------------- /src/fed_rag/core/no_encode_rag_system/asynchronous.py: -------------------------------------------------------------------------------- 1 | """Async No Encode RAG System Module""" 2 | 3 | from fed_rag.core.no_encode_rag_system._asynchronous import ( 4 | _AsyncNoEncodeRAGSystem, 5 | ) 6 | 7 | from .synchronous import NoEncodeRAGSystem 8 | 9 | 10 | # Define the public NoEncodeRAGSystem with all available bridges 11 | class AsyncNoEncodeRAGSystem(_AsyncNoEncodeRAGSystem): 12 | """Async NoEncode RAG System with all available bridge functionality. 13 | 14 | The AsyncNoEncodeRAGSystem is the main entry point for creating and managing 15 | retrieval-augmented generation systems that skip encoding altogether, 16 | enabling direct natural language queries to knowledge sources like MCP 17 | servers, APIs, and databases. 18 | 19 | Unlike traditional RAG systems that require separate retriever components 20 | and pre-computed embeddings, NoEncode RAG systems perform direct queries 21 | against NoEncode knowledge sources. 22 | """ 23 | 24 | def to_sync( 25 | self, 26 | ) -> NoEncodeRAGSystem: 27 | return NoEncodeRAGSystem( 28 | knowledge_store=self.knowledge_store.to_sync(), 29 | generator=self.generator, # NOTE: this should actually be sync! 30 | rag_config=self.rag_config, 31 | ) 32 | -------------------------------------------------------------------------------- /tests/api/test_namespaced_imports.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import pytest 4 | 5 | from fed_rag import generators, knowledge_stores, retrievers 6 | 7 | 8 | @pytest.mark.parametrize("name", generators.__all__) 9 | def test_public_generators_all_importable(name: str) -> None: 10 | """Tests that all names listed in generators __all__ are importable.""" 11 | mod = importlib.import_module("fed_rag.generators") 12 | attr = getattr(mod, name) 13 | 14 | assert hasattr(mod, name) 15 | assert attr is not None 16 | 17 | 18 | @pytest.mark.parametrize("name", retrievers.__all__) 19 | def test_public_retrievers_all_importable(name: str) -> None: 20 | """Tests that all names listed in retrievers __all__ are importable.""" 21 | mod = importlib.import_module("fed_rag.retrievers") 22 | attr = getattr(mod, name) 23 | 24 | assert hasattr(mod, name) 25 | assert attr is not None 26 | 27 | 28 | @pytest.mark.parametrize("name", knowledge_stores.__all__) 29 | def test_public_knowledge_stores_all_importable(name: str) -> None: 30 | """Tests that all names listed in knowledge_stores __all__ are importable.""" 31 | mod = importlib.import_module("fed_rag.knowledge_stores") 32 | attr = getattr(mod, name) 33 | 34 | assert hasattr(mod, name) 35 | assert attr is not None 36 | -------------------------------------------------------------------------------- /src/fed_rag/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | from fed_rag import NoEncodeRAGSystem, RAGSystem 2 | from fed_rag.exceptions import FedRAGError 3 | 4 | 5 | def _validate_rag_system(rag_system: RAGSystem | NoEncodeRAGSystem) -> None: 6 | # Skip validation if environment variable is set 7 | import os 8 | 9 | if os.environ.get("FEDRAG_SKIP_VALIDATION") == "1": 10 | return 11 | 12 | from fed_rag.generators.huggingface import ( 13 | HFPeftModelGenerator, 14 | HFPretrainedModelGenerator, 15 | ) 16 | from fed_rag.generators.unsloth import UnslothFastModelGenerator 17 | from fed_rag.retrievers.huggingface.hf_sentence_transformer import ( 18 | HFSentenceTransformerRetriever, 19 | ) 20 | 21 | if not isinstance( 22 | rag_system.generator, 23 | ( 24 | HFPretrainedModelGenerator, 25 | HFPeftModelGenerator, 26 | UnslothFastModelGenerator, 27 | ), 28 | ): 29 | raise FedRAGError( 30 | "Generator must be HFPretrainedModelGenerator or HFPeftModelGenerator" 31 | ) 32 | 33 | if isinstance(rag_system, RAGSystem) and not isinstance( 34 | rag_system.retriever, HFSentenceTransformerRetriever 35 | ): 36 | raise FedRAGError("Retriever must be a HFSentenceTransformerRetriever") 37 | -------------------------------------------------------------------------------- /src/fed_rag/decorators/tester.py: -------------------------------------------------------------------------------- 1 | """Tester Decorators""" 2 | 3 | from typing import Callable 4 | 5 | 6 | class TesterDecorators: 7 | def pytorch(self, func: Callable) -> Callable: 8 | from fed_rag.inspectors.pytorch import inspect_tester_signature 9 | 10 | def decorator(func: Callable) -> Callable: 11 | # inspect func sig 12 | spec = inspect_tester_signature( 13 | func 14 | ) # may need to create a cfg for this if decorater accepts params 15 | 16 | # store fl_task config 17 | func.__setattr__("__fl_task_tester_config", spec) # type: ignore[attr-defined] 18 | 19 | return func 20 | 21 | return decorator(func) 22 | 23 | def huggingface(self, func: Callable) -> Callable: 24 | from fed_rag.inspectors.huggingface import inspect_tester_signature 25 | 26 | def decorator(func: Callable) -> Callable: 27 | # inspect func sig 28 | spec = inspect_tester_signature( 29 | func 30 | ) # may need to create a cfg for this if decorater accepts params 31 | 32 | # store fl_task config 33 | func.__setattr__("__fl_task_tester_config", spec) # type: ignore[attr-defined] 34 | 35 | return func 36 | 37 | return decorator(func) 38 | -------------------------------------------------------------------------------- /src/fed_rag/decorators/trainer.py: -------------------------------------------------------------------------------- 1 | """Trainer Decorators""" 2 | 3 | from typing import Callable 4 | 5 | 6 | class TrainerDecorators: 7 | def pytorch(self, func: Callable) -> Callable: 8 | from fed_rag.inspectors.pytorch import inspect_trainer_signature 9 | 10 | def decorator(func: Callable) -> Callable: 11 | # inspect func sig 12 | spec = inspect_trainer_signature( 13 | func 14 | ) # may need to create a cfg for this if decorater accepts params 15 | 16 | # store fl_task config 17 | func.__setattr__("__fl_task_trainer_config", spec) # type: ignore[attr-defined] 18 | 19 | return func 20 | 21 | return decorator(func) 22 | 23 | def huggingface(self, func: Callable) -> Callable: 24 | from fed_rag.inspectors.huggingface import inspect_trainer_signature 25 | 26 | def decorator(func: Callable) -> Callable: 27 | # inspect func sig 28 | spec = inspect_trainer_signature( 29 | func 30 | ) # may need to create a cfg for this if decorater accepts params 31 | 32 | # store fl_task config 33 | func.__setattr__("__fl_task_trainer_config", spec) # type: ignore[attr-defined] 34 | 35 | return func 36 | 37 | return decorator(func) 38 | -------------------------------------------------------------------------------- /src/fed_rag/data_structures/retriever.py: -------------------------------------------------------------------------------- 1 | """Data structures for retrievers.""" 2 | 3 | from typing import TypedDict 4 | 5 | import torch 6 | 7 | 8 | class EncodeResult(TypedDict): 9 | """ 10 | Represents the result of encoding multiple types of data. 11 | 12 | This TypedDict is used as a structured output for encoding operations 13 | involving various data modalities such as text, image, audio, or video. 14 | Each key corresponds to a specific modality and may contain a tensor 15 | result or None if that modality is not used or applicable. 16 | 17 | Attributes: 18 | text: Union[torch.Tensor, None] 19 | The tensor representation of encoded text data, or None if text 20 | is not processed. 21 | image: Union[torch.Tensor, None] 22 | The tensor representation of encoded image data, or None if image 23 | processing is not performed. 24 | audio: Union[torch.Tensor, None] 25 | The tensor representation of encoded audio data, or None if audio 26 | is not processed. 27 | video: Union[torch.Tensor, None] 28 | The tensor representation of encoded video data, or None if video 29 | processing is not performed. 30 | """ 31 | 32 | text: torch.Tensor | None 33 | image: torch.Tensor | None 34 | audio: torch.Tensor | None 35 | video: torch.Tensor | None 36 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs Publish 2 | on: 3 | push: 4 | branches: 5 | - main 6 | workflow_dispatch: 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write # To push a branch 13 | pull-requests: write # To create a PR from that branch 14 | steps: 15 | - name: get code 16 | uses: actions/checkout@v6 17 | 18 | - name: Configure Git Credentials 19 | run: | 20 | git config user.name github-actions[bot] 21 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 22 | 23 | - uses: actions/setup-python@v6 24 | with: 25 | python-version: "3.12" 26 | 27 | - name: Install uv 28 | uses: astral-sh/setup-uv@v7 29 | with: 30 | # Install a specific version of uv. 31 | version: "0.5.21" 32 | enable-cache: true 33 | 34 | - name: Install the project 35 | run: uv sync --all-extras --group dev --group docs 36 | 37 | - name: Build docs 38 | run: | 39 | uv run mkdocs build 40 | 41 | - name: Deploy to github pages 42 | uses: JamesIves/github-pages-deploy-action@v4.7.6 43 | with: 44 | branch: gh-pages # The branch the action should deploy to. 45 | folder: site # The folder the action should deploy. 46 | -------------------------------------------------------------------------------- /docs/community/resources/pocket_references.md: -------------------------------------------------------------------------------- 1 | # AI Pocket References 2 | 3 | 4 | 5 | 8 | 9 | 12 | 13 | The [AI Pocket Reference](https://github.com/VectorInstitute/ai-pocket-reference) 14 | project is maintained by Vector AI Engineering as an accessible resource for the 15 | AI community. It provides a collection of _pocket references_ offering concise 16 | information on a wide range of AI topics, including Natural Language Processing 17 | (NLP) and Federated Learning (FL). 18 | 19 | ## Recommended Collections 20 | 21 | - [NLP Collection](https://vectorinstitute.github.io/ai-pocket-reference/nlp/) — 22 | Covers various topics within NLP, including RAG, LoRA, Quantization, Chain of Thought, 23 | Agents, and more. 24 | 25 | - [FL Collection](https://vectorinstitute.github.io/ai-pocket-reference/fl/) — 26 | Encompasses the fundamentals of federated learning along with advanced topics such 27 | as personalized federated learning and vertical federated learning. 28 | -------------------------------------------------------------------------------- /src/fed_rag/_bridges/llamaindex/bridge.py: -------------------------------------------------------------------------------- 1 | """LlamaIndex Bridge""" 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from fed_rag._bridges.llamaindex._version import __version__ 6 | from fed_rag.base.bridge import BaseBridgeMixin 7 | 8 | if TYPE_CHECKING: # pragma: no cover 9 | from llama_index.core.indices.managed.base import BaseManagedIndex 10 | 11 | from fed_rag.core.rag_system._synchronous import ( # avoids circular import 12 | _RAGSystem, 13 | ) 14 | 15 | 16 | class LlamaIndexBridgeMixin(BaseBridgeMixin): 17 | """LlamaIndex Bridge. 18 | 19 | This mixin adds LlamaIndex conversion capabilities to _RAGSystem. 20 | When mixed with an unbridged _RAGSystem, it allows direct conversion to 21 | LlamaIndex's BaseManagedIndex through the to_llamaindex() method. 22 | """ 23 | 24 | _bridge_version = __version__ 25 | _bridge_extra = "llama-index" 26 | _framework = "llama-index-core" 27 | _compatible_versions = {"min": "0.12.35"} 28 | _method_name = "to_llamaindex" 29 | 30 | def to_llamaindex(self: "_RAGSystem") -> "BaseManagedIndex": 31 | """Converts the _RAGSystem to a ~llamaindex.core.BaseManagedIndex.""" 32 | self._validate_framework_installed() 33 | 34 | from fed_rag._bridges.llamaindex._managed_index import ( 35 | FedRAGManagedIndex, 36 | ) 37 | 38 | return FedRAGManagedIndex(rag_system=self) 39 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # FedRAG Pull Request Template 2 | 3 | Thanks for contributing to FedRAG! 4 | Please fill out the sections below to help us review your PR efficiently. 5 | 6 | ## Summary 7 | 8 | What does this PR do? Please provide a brief summary of the changes introduced. 9 | 10 | - [ ] Bug fix 11 | - [ ] New feature 12 | - [ ] Documentation update 13 | - [ ] Code quality / linting 14 | - [ ] Other (please describe): 15 | 16 | ## Description 17 | Any information reviewers should be aware of: 18 | 19 | ## Testing 20 | 21 | Describe how you tested your changes. Include the steps to reproduce, commands run, and any relevant outputs. 22 | 23 | - [ ] Unit tests added or updated 24 | - [ ] All tests pass locally (`make test`) 25 | - [ ] Code coverage maintained or improved 26 | 27 | ## Checklist 28 | 29 | Before submitting your PR, please check off the following: 30 | 31 | - [ ] My code follows the existing style and conventions 32 | - [ ] I’ve run linting (`make lint`) 33 | - [ ] I’ve added/updated relevant documentation 34 | - [ ] I’ve added/updated tests as needed 35 | - [ ] I’ve verified integration with existing tools (HuggingFace, LlamaIndex, LangChain, etc. if applicable) 36 | - [ ] I’ve added an entry to the CHANGELOG.md (if applicable) 37 | 38 | ## Related Issues or PRs 39 | 40 | If this PR addresses or relates to existing issues or pull requests, link them here: 41 | 42 | - Closes # 43 | - Related to # 44 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Install apt dependencies 13 | run: | 14 | sudo apt-get update 15 | sudo apt-get install libcurl4-openssl-dev libssl-dev 16 | - uses: actions/checkout@v6 17 | 18 | - name: Install uv 19 | uses: astral-sh/setup-uv@v7 20 | with: 21 | # Install a specific version of uv. 22 | version: "0.5.21" 23 | enable-cache: true 24 | 25 | - name: "Set up Python" 26 | uses: actions/setup-python@v6 27 | with: 28 | python-version: "3.10" 29 | 30 | - name: Install the project 31 | run: uv sync --all-extras --dev 32 | 33 | - name: Build package 34 | run: uv build 35 | 36 | - name: Publish package 37 | uses: pypa/gh-action-pypi-publish@v1.13.0 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | 42 | release_github: 43 | needs: deploy 44 | runs-on: ubuntu-latest 45 | steps: 46 | - name: Create GitHub Release 47 | id: create_release 48 | uses: ncipollo/release-action@v1.20.0 49 | with: 50 | artifacts: "dist/*" 51 | generateReleaseNotes: true 52 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | - toc 5 | --- 6 | 7 | 8 | 9 |

10 | :material-hexagon-multiple:{ .vector-icon } FedRAG 11 |

12 | 13 | ## Simplified RAG fine-tuning across centralized or federated architectures 14 | 15 | [Read the paper](https://d3ddy8balm3goa.cloudfront.net/papers/fedrag-codeml-icml-2025-camera-ready.pdf) 16 | (_Accepted in CODEML Workshop at ICML 2025, Vancouver_) 17 | 18 |
19 | 20 | -

:fontawesome-solid-wand-magic-sparkles:{ .lg .middle } Advanced RAG fine-tuning

21 | 22 | Comprehensive support for state-of-the-art RAG fine-tuning methods that can 23 | be federated with ease. 24 | 25 | [:octicons-arrow-right-24: Getting started](getting_started/essentials.md) 26 | 27 | -

:fontawesome-solid-cubes-stacked:{ .lg .middle } Work with your tools

28 | 29 | Seamlessly integrates with popular frameworks including HuggingFace, 30 | and LlamaIndex — use the tools you already know. 31 | 32 | [:octicons-arrow-right-24: In-Depth Examples](examples/index.md) 33 | 34 | -

:fontawesome-solid-feather:{ .lg .middle } Lightweight abstractions

35 | 36 | Clean, intuitive abstractions that simplify RAG fine-tuning while 37 | maintaining full flexibility and control. 38 | 39 | [:octicons-arrow-right-24: API Reference](api_reference/index.md) 40 | 41 |
42 | -------------------------------------------------------------------------------- /src/fed_rag/tokenizers/unsloth_pretrained_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Unsloth Pretrained Tokenizer""" 2 | 3 | import importlib.util 4 | from typing import TYPE_CHECKING 5 | 6 | from fed_rag.exceptions import MissingExtraError 7 | 8 | from .hf_pretrained_tokenizer import HFPretrainedTokenizer 9 | 10 | if importlib.util.find_spec("unsloth") is None: 11 | _has_unsloth = False 12 | else: 13 | _has_unsloth = True 14 | 15 | if TYPE_CHECKING: # pragma: no cover 16 | from transformers import PreTrainedTokenizer 17 | 18 | 19 | class UnslothPretrainedTokenizer(HFPretrainedTokenizer): 20 | """Unsloth Pretrained Tokenizer. 21 | 22 | NOTE: Unsloth adds a patch on HF tokenizers, so this is a light wrapper. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | tokenizer: "PreTrainedTokenizer", 28 | model_name: str, 29 | ): 30 | if not _has_unsloth: 31 | msg = ( 32 | f"`{self.__class__.__name__}` requires the `unsloth` extra to be installed. " 33 | "To fix please run `pip install fed-rag[unsloth]`." 34 | ) 35 | raise MissingExtraError(msg) 36 | super().__init__( 37 | model_name=model_name, 38 | load_model_at_init=False, 39 | ) 40 | # set the tokenizer manually as with Unsloth we get patched tokenizer along 41 | # with the patched model i.e., model, tokenizer = FastModel.from_pretrained(...) 42 | self._tokenizer = tokenizer 43 | -------------------------------------------------------------------------------- /src/fed_rag/base/tokenizer.py: -------------------------------------------------------------------------------- 1 | """Base Tokenizer""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, TypedDict 5 | 6 | from pydantic import BaseModel, ConfigDict 7 | 8 | 9 | class EncodeResult(TypedDict): 10 | """Data container for tokenizer encoding results.""" 11 | 12 | input_ids: list[int] 13 | attention_mask: list[int] | None 14 | 15 | 16 | class BaseTokenizer(BaseModel, ABC): 17 | """Base Tokenizer Class. 18 | 19 | This abstract class provides the interface for creating Tokenizer objects that 20 | converts strings into tokens. 21 | """ 22 | 23 | model_config = ConfigDict(arbitrary_types_allowed=True) 24 | 25 | @abstractmethod 26 | def encode(self, input: str, **kwargs: Any) -> EncodeResult: 27 | """Encode the input string into list of integers. 28 | 29 | Args: 30 | input (str): The input string to be encoded. 31 | 32 | Returns: 33 | EncodeResult: The result of encoding. 34 | """ 35 | 36 | @abstractmethod 37 | def decode(self, input_ids: list[int], **kwargs: Any) -> str: 38 | """Decode the input token ids into a string. 39 | 40 | Args: 41 | input_ids (list[int]): The token ids to be decoded back to text. 42 | 43 | Returns: 44 | str: The decoded text. 45 | """ 46 | 47 | @property 48 | @abstractmethod 49 | def unwrapped(self) -> Any: 50 | """Return the underlying tokenizer if there is one.""" 51 | -------------------------------------------------------------------------------- /src/fed_rag/evals/benchmarks/huggingface/boolq.py: -------------------------------------------------------------------------------- 1 | """BoolQ benchmark""" 2 | 3 | from typing import Any 4 | 5 | from pydantic import model_validator 6 | 7 | from fed_rag.base.evals.benchmark import BaseBenchmark 8 | 9 | from .mixin import HuggingFaceBenchmarkMixin 10 | from .utils import check_huggingface_evals_installed 11 | 12 | 13 | class HuggingFaceBoolQ(HuggingFaceBenchmarkMixin, BaseBenchmark): 14 | """HuggingFace BoolQ Benchmark. 15 | 16 | BoolQ is a question answering dataset for yes/no questions about a short passage. 17 | 18 | Example schema: 19 | { 20 | "question": "does ethanol take more energy make that produces", 21 | "answer": false, 22 | "passage": "\"All biomass goes through at least some of these steps: ...", 23 | } 24 | """ 25 | 26 | dataset_name = "google/boolq" 27 | 28 | def _get_query_from_example(self, example: dict[str, Any]) -> str: 29 | return str(example["question"]) 30 | 31 | def _get_response_from_example(self, example: dict[str, Any]) -> str: 32 | # Return as string "true"/"false" for consistency 33 | return "true" if example["answer"] else "false" 34 | 35 | def _get_context_from_example(self, example: dict[str, Any]) -> str: 36 | return str(example["passage"]) 37 | 38 | @model_validator(mode="before") 39 | @classmethod 40 | def _validate_extra_installed(cls, data: Any) -> Any: 41 | check_huggingface_evals_installed(cls.__name__) 42 | return data 43 | -------------------------------------------------------------------------------- /src/fed_rag/exceptions/inspectors.py: -------------------------------------------------------------------------------- 1 | """Exceptions for inspectors.""" 2 | 3 | from .core import FedRAGError, FedRAGWarning 4 | 5 | 6 | class InspectorError(FedRAGError): 7 | """Base inspector error for all inspector-related exceptions.""" 8 | 9 | pass 10 | 11 | 12 | class InspectorWarning(FedRAGWarning): 13 | """Base inspector warning for all inspector-related warnings.""" 14 | 15 | pass 16 | 17 | 18 | class MissingNetParam(InspectorError): 19 | """Raised if function is missing nn.Module param.""" 20 | 21 | pass 22 | 23 | 24 | class MissingMultipleDataParams(InspectorError): 25 | """Raised if multiple data params for training, testing and validation are missing.""" 26 | 27 | pass 28 | 29 | 30 | class MissingDataParam(InspectorError): 31 | """Raised if a single data param is missing.""" 32 | 33 | pass 34 | 35 | 36 | class MissingTrainerSpec(InspectorError): 37 | """Raised during inspection if trainer is missing `__fl_task_trainer_config` attr.""" 38 | 39 | pass 40 | 41 | 42 | class MissingTesterSpec(InspectorError): 43 | """Raised during inspection if tester is missing `__fl_task_trainer_config` attr.""" 44 | 45 | pass 46 | 47 | 48 | class UnequalNetParamWarning(InspectorWarning): 49 | """Thrown if trainer and testers have different parameter names for their nn.Module param.""" 50 | 51 | pass 52 | 53 | 54 | class InvalidReturnType(InspectorError): 55 | """Raised if the return type of a function is not the expected one.""" 56 | 57 | pass 58 | -------------------------------------------------------------------------------- /examples/quick-start/quick_start/_cifar_dataloaders.py: -------------------------------------------------------------------------------- 1 | # torch and flwr 2 | import torch 3 | import torchvision.transforms as transforms 4 | from flwr_datasets import FederatedDataset 5 | from flwr_datasets.partitioner import IidPartitioner 6 | from torch.utils.data import DataLoader 7 | 8 | # partition cifar dataset 9 | partitioner = IidPartitioner(num_partitions=2) 10 | fds = FederatedDataset( 11 | dataset="uoft-cs/cifar10", 12 | partitioners={"train": partitioner}, 13 | ) 14 | 15 | 16 | def get_loaders(partition_id: int) -> tuple[DataLoader, DataLoader]: 17 | partition = fds.load_partition(partition_id) 18 | # Divide data on each node: 80% train, 20% test 19 | partition_train_test = partition.train_test_split(test_size=0.2, seed=42) 20 | pytorch_transforms = transforms.Compose( 21 | [ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 24 | ] 25 | ) 26 | 27 | def apply_transforms(batch: torch.Tensor) -> torch.Tensor: 28 | """Apply transforms to the partition from FederatedDataset.""" 29 | batch["img"] = [pytorch_transforms(img) for img in batch["img"]] 30 | return batch 31 | 32 | partition_train_test = partition_train_test.with_transform( 33 | apply_transforms 34 | ) 35 | trainloader = DataLoader( 36 | partition_train_test["train"], batch_size=32, shuffle=True 37 | ) 38 | testloader = DataLoader(partition_train_test["test"], batch_size=32) 39 | return trainloader, testloader 40 | -------------------------------------------------------------------------------- /src/fed_rag/_bridges/langchain/bridge.py: -------------------------------------------------------------------------------- 1 | """LangChain Bridge""" 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from fed_rag._bridges.langchain._version import __version__ 6 | from fed_rag.base.bridge import BaseBridgeMixin 7 | 8 | if TYPE_CHECKING: # pragma: no cover 9 | from langchain_core.language_models import BaseLLM 10 | from langchain_core.vectorstores import VectorStore 11 | 12 | from fed_rag.core.rag_system._synchronous import ( 13 | _RAGSystem, # avoids circular import 14 | ) 15 | 16 | 17 | class LangChainBridgeMixin(BaseBridgeMixin): 18 | """LangChain Bridge. 19 | 20 | This mixin adds LangChain conversion capabilities to _RAGSystem. 21 | When mixed with an unbridged _RAGSystem, it allows direct conversion to 22 | LangChain's VectorStore and BaseLLM through the to_langchain() method. 23 | """ 24 | 25 | _bridge_version = __version__ 26 | _bridge_extra = "langchain" 27 | _framework = "langchain-core" 28 | _compatible_versions = {"min": "0.3.62"} 29 | _method_name = "to_langchain" 30 | 31 | def to_langchain(self: "_RAGSystem") -> tuple["VectorStore", "BaseLLM"]: 32 | """Converts the _RAGSystem to a tuple of ~langchain_core.vectorstores.VectorStore and ~langchain_core.language_models.BaseLLM.""" 33 | self._validate_framework_installed() 34 | 35 | from fed_rag._bridges.langchain._bridge_classes import ( 36 | FedRAGLLM, 37 | FedRAGVectorStore, 38 | ) 39 | 40 | return FedRAGVectorStore(self), FedRAGLLM(self) 41 | -------------------------------------------------------------------------------- /src/fed_rag/data_structures/results.py: -------------------------------------------------------------------------------- 1 | """Data structures for results""" 2 | 3 | from typing import Any 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class TrainResult(BaseModel): 9 | """ 10 | Represents the result of a training process. 11 | 12 | This class encapsulates the outcome of a model's training process, 13 | specifically storing the loss value calculated during training. 14 | 15 | Attributes: 16 | loss (float): The training loss value. 17 | """ 18 | 19 | loss: float 20 | 21 | 22 | class TestResult(BaseModel): 23 | """ 24 | Represents the results of a test process, including loss and additional metrics. 25 | 26 | This class is used to encapsulate the results of testing, such as the calculated 27 | loss value and optional additional metrics. It includes fields for storing the 28 | primary loss value and a dictionary of computed metrics for more detailed analysis 29 | or performance evaluation. This ensures a structured representation of test outcomes. 30 | 31 | Attributes: 32 | loss (float): The primary loss value resulting from the test process. 33 | metrics (dict[str, Any]): Additional metrics computed on the test set. These can 34 | include various performance indicators or statistics relevant to the test. 35 | """ 36 | 37 | __test__ = ( 38 | False # needed for Pytest collision. Avoids PytestCollectionWarning 39 | ) 40 | loss: float 41 | metrics: dict[str, Any] = Field( 42 | description="Additional metrics computed on test set.", 43 | default_factory=dict, 44 | ) 45 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/qa/commonsense.py: -------------------------------------------------------------------------------- 1 | """CommonsenseQA 2 | 3 | Example 4 | === 5 | {'id': '075e483d21c29a511267ef62bedc0461', 6 | 'question': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?', 7 | 'question_concept': 'punishing', 8 | 'choices': {'label': ['A', 'B', 'C', 'D', 'E'], 9 | 'text': ['ignore', 'enforce', 'authoritarian', 'yell at', 'avoid']}, 10 | 'answerKey': 'A'} 11 | """ 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | from ..base_data_prepper import DEFAULT_SAVE_DIR, BaseDataPrepper 17 | from .mixin import QAMixin 18 | 19 | QA_SAVE_DIR = DEFAULT_SAVE_DIR / "qa" 20 | 21 | 22 | class CommonsenseQADataPrepper(QAMixin, BaseDataPrepper): 23 | @property 24 | def dataset_name(self) -> str: 25 | return "commonsense_qa" 26 | 27 | def _get_answer(self, row: pd.Series) -> str: 28 | answer_ix = np.where(row["choices"]["label"] == row["answerKey"]) 29 | return str(row["choices"]["text"][answer_ix][0]) 30 | 31 | def _get_question(self, row: pd.Series) -> str: 32 | return str(row["question"]) 33 | 34 | def _get_evidence(self, row: pd.Series) -> str | None: 35 | return None 36 | 37 | 38 | splits = { 39 | "train": "data/train-00000-of-00001.parquet", 40 | "validation": "data/validation-00000-of-00001.parquet", 41 | "test": "data/test-00000-of-00001.parquet", 42 | } 43 | df = pd.read_parquet("hf://datasets/tau/commonsense_qa/" + splits["train"]) 44 | data_prepper = CommonsenseQADataPrepper(df=df, save_dir=QA_SAVE_DIR) 45 | data_prepper.execute_and_save() 46 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/qa/wiki.py: -------------------------------------------------------------------------------- 1 | """WikiQA 2 | 3 | Example 4 | === 5 | { 6 | 'question_id': 'Q1', 7 | 'question': 'how are glacier caves formed? ', 8 | 'document_title': 'Glacier cave', 9 | 'answer': 'A partly submerged glacier cave on Perito Moreno Glacier .' 10 | 'label': 0 11 | }, 12 | { 13 | 'question_id': 'Q1', 14 | 'question': 'how are glacier caves formed? ', 15 | 'document_title': 'Glacier cave', 16 | 'answer': 'A glacier cave is a cave formed within the ice of a glacier .' 17 | 'label': 1 18 | } 19 | """ 20 | 21 | import pandas as pd 22 | 23 | from ..base_data_prepper import DEFAULT_SAVE_DIR, BaseDataPrepper 24 | from .mixin import QAMixin 25 | 26 | QA_SAVE_DIR = DEFAULT_SAVE_DIR / "qa" 27 | 28 | 29 | class WikiQADataPrepper(QAMixin, BaseDataPrepper): 30 | @property 31 | def dataset_name(self) -> str: 32 | return "wiki_qa" 33 | 34 | def _get_answer(self, row: pd.Series) -> str: 35 | return str(row["answer"]) 36 | 37 | def _get_question(self, row: pd.Series) -> str: 38 | return str(row["question"]) 39 | 40 | def _get_evidence(self, row: pd.Series) -> str | None: 41 | return None 42 | 43 | 44 | splits = { 45 | "test": "data/test-00000-of-00001.parquet", 46 | "validation": "data/validation-00000-of-00001.parquet", 47 | "train": "data/train-00000-of-00001.parquet", 48 | } 49 | 50 | df = pd.read_parquet("hf://datasets/microsoft/wiki_qa/" + splits["test"]) 51 | # Keeping only the entries with the correct answer (i.e., label=1) because that's all we need. 52 | df = df[df["label"] == 1] 53 | data_prepper = WikiQADataPrepper(df=df, save_dir=QA_SAVE_DIR) 54 | data_prepper.execute_and_save() 55 | -------------------------------------------------------------------------------- /tests/tokenizers/test_unsloth_pretrained.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | 7 | from fed_rag.base.tokenizer import BaseTokenizer 8 | from fed_rag.exceptions import MissingExtraError 9 | from fed_rag.tokenizers.unsloth_pretrained_tokenizer import ( 10 | UnslothPretrainedTokenizer, 11 | ) 12 | 13 | 14 | def test_hf_pretrained_generator_class() -> None: 15 | names_of_base_classes = [ 16 | b.__name__ for b in UnslothPretrainedTokenizer.__mro__ 17 | ] 18 | assert BaseTokenizer.__name__ in names_of_base_classes 19 | 20 | 21 | def test_unsloth_extra_missing() -> None: 22 | """Test extra is not installed.""" 23 | 24 | modules = {"unsloth": None} 25 | module_to_import = "fed_rag.tokenizers.unsloth_pretrained_tokenizer" 26 | 27 | if module_to_import in sys.modules: 28 | original_module = sys.modules.pop(module_to_import) 29 | 30 | with patch.dict("sys.modules", modules): 31 | msg = ( 32 | "`UnslothPretrainedTokenizer` requires the `unsloth` extra to be installed. " 33 | "To fix please run `pip install fed-rag[unsloth]`." 34 | ) 35 | with pytest.raises( 36 | MissingExtraError, 37 | match=re.escape(msg), 38 | ): 39 | from fed_rag.tokenizers.unsloth_pretrained_tokenizer import ( 40 | UnslothPretrainedTokenizer, 41 | ) 42 | 43 | mock_tokenizer = MagicMock() 44 | 45 | UnslothPretrainedTokenizer(mock_tokenizer, "fake_name") 46 | 47 | # restore module so to not affect other tests 48 | sys.modules[module_to_import] = original_module 49 | -------------------------------------------------------------------------------- /src/fed_rag/knowledge_stores/qdrant/utils.py: -------------------------------------------------------------------------------- 1 | """Qdrant utils module.""" 2 | 3 | from importlib.util import find_spec 4 | from typing import TYPE_CHECKING 5 | 6 | from fed_rag.data_structures.knowledge_node import KnowledgeNode 7 | from fed_rag.exceptions import KnowledgeStoreError, MissingExtraError 8 | 9 | if TYPE_CHECKING: # pragma: no cover 10 | from qdrant_client.http.models import ScoredPoint 11 | from qdrant_client.models import PointStruct 12 | 13 | 14 | def check_qdrant_installed() -> None: 15 | if find_spec("qdrant_client") is None: 16 | raise MissingExtraError( 17 | "Qdrant knowledge stores require the qdrant-client to be installed. " 18 | "To fix please run `pip install fed-rag[qdrant]`." 19 | ) 20 | 21 | 22 | def convert_knowledge_node_to_qdrant_point( 23 | node: KnowledgeNode, 24 | ) -> "PointStruct": 25 | from qdrant_client.models import PointStruct 26 | 27 | if node.embedding is None: 28 | raise KnowledgeStoreError( 29 | "Cannot load a node with embedding set to None." 30 | ) 31 | 32 | return PointStruct( 33 | id=node.node_id, 34 | vector=node.embedding, 35 | payload=node.model_dump_without_embeddings(), 36 | ) 37 | 38 | 39 | def convert_scored_point_to_knowledge_node_and_score_tuple( 40 | scored_point: "ScoredPoint", 41 | ) -> tuple[float, KnowledgeNode]: 42 | knowledge_data = scored_point.payload 43 | knowledge_data.update( 44 | embedding=scored_point.vector 45 | ) # attach vector to embedding if it is even returned 46 | return ( 47 | scored_point.score, 48 | KnowledgeNode.model_validate(knowledge_data), 49 | ) 50 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/qa/math.py: -------------------------------------------------------------------------------- 1 | """MathQA 2 | 3 | Example 4 | === 5 | { 6 | 'Problem': 'the banker ' s gain of a certain sum due 3 yea...', 7 | 'Rationale': 'explanation : t = 3 years r = 10 % td = ( bg ...', 8 | 'options': 'a ) rs . 400 , b ) rs . 300 , c ) rs . 500 , d...', 9 | 'correct': 'a', 10 | 'annotated_formula': 'divide(multiply(const_100, divide(multiply(36,...', 11 | 'linear_formula': 'multiply(n2,const_100)|multiply(n0,n1)|divide(... ', 12 | 'category' : 'gain' 13 | } 14 | """ 15 | 16 | import re 17 | 18 | import pandas as pd 19 | from datasets import load_dataset 20 | 21 | from ..base_data_prepper import DEFAULT_SAVE_DIR, BaseDataPrepper 22 | from .mixin import QAMixin 23 | 24 | QA_SAVE_DIR = DEFAULT_SAVE_DIR / "qa" 25 | 26 | 27 | class MathQADataPrepper(QAMixin, BaseDataPrepper): 28 | @property 29 | def dataset_name(self) -> str: 30 | return "math_qa" 31 | 32 | def _get_answer(self, row: pd.Series) -> str: 33 | options = re.findall(r"([a-z])\s*\)\s*([^,]+)", row["options"]) 34 | for label, text in options: 35 | if label.strip() == row["correct"].strip(): 36 | answer = row["Rationale"] + "\n\n" + text.strip() 37 | return str(answer) 38 | 39 | def _get_question(self, row: pd.Series) -> str: 40 | return str(row["Problem"]) 41 | 42 | def _get_evidence(self, row: pd.Series) -> str | None: 43 | return None 44 | 45 | 46 | splits = {"test": "test", "validation": "valid", "train": "train"} 47 | 48 | dataset = load_dataset("allenai/math_qa", trust_remote_code=True) 49 | df = pd.DataFrame(dataset["train"]) 50 | data_prepper = MathQADataPrepper(df=df, save_dir=QA_SAVE_DIR) 51 | data_prepper.execute_and_save() 52 | -------------------------------------------------------------------------------- /tests/trainer_configs/test_pytorch_trainer_config.py: -------------------------------------------------------------------------------- 1 | """PyTorchTrainerConfig Unit Tests""" 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | import pytest 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from fed_rag.trainer_configs import PyTorchTrainerConfig 11 | 12 | 13 | class _TestDataset(Dataset): 14 | def __init__(self, size: int) -> None: 15 | self.features = np.random.rand(size, 2) 16 | self.labels = np.random.choice(2, size=size) 17 | 18 | def __len__(self) -> int: 19 | return len(self.labels) 20 | 21 | def __getitem__(self, index: int) -> tuple[np.ndarray, Any]: 22 | return self.features[index], self.labels[index] 23 | 24 | 25 | @pytest.fixture() 26 | def train_dataloader() -> DataLoader: 27 | dataset = _TestDataset(size=10) 28 | return DataLoader(dataset, batch_size=2, shuffle=True) 29 | 30 | 31 | @pytest.fixture() 32 | def val_dataloader() -> DataLoader: 33 | dataset = _TestDataset(size=4) 34 | return DataLoader(dataset, batch_size=2, shuffle=True) 35 | 36 | 37 | def test_init( 38 | train_dataloader: DataLoader, val_dataloader: DataLoader 39 | ) -> None: 40 | mdl = torch.nn.Linear(2, 1) 41 | cfg = PyTorchTrainerConfig( 42 | net=mdl, 43 | train_data=train_dataloader, 44 | val_data=val_dataloader, 45 | a=1, 46 | b=2, 47 | c="3", 48 | ) 49 | 50 | # get item 51 | assert cfg["a"] == 1 52 | assert cfg["b"] == 2 53 | assert cfg["c"] == "3" 54 | # get attr 55 | assert cfg.a == 1 56 | assert cfg.b == 2 57 | assert cfg.c == "3" 58 | assert cfg.net == mdl 59 | assert cfg.train_data == train_dataloader 60 | assert cfg.val_data == val_dataloader 61 | -------------------------------------------------------------------------------- /docs/getting_started/import_patterns.md: -------------------------------------------------------------------------------- 1 | # Import Patterns 2 | 3 | FedRAG provides a carefully designed public API for working with RAG and both centralized 4 | and federated fine-tuning components. All components exported at the root level 5 | and from public subpackages are considered stable and follow semantic versioning guidelines. 6 | 7 | ## Root Imports 8 | 9 | Import core components directly from the root: 10 | 11 | ```py 12 | from fed_rag import ( 13 | RAGSystem, 14 | RAGConfig, 15 | HFPretrainedModelGenerator, 16 | HFSentenceTransformerRetriever, 17 | InMemoryKnowledgeStore, 18 | ) 19 | 20 | # Now use the components directly 21 | system = RAGSystem( 22 | retriever=HFSentenceTransformerRetriever(...), 23 | generator=HFPretrainedModelGenerator(...), 24 | knowledge_store=InMemoryKnowledgeStore(), 25 | rag_config=RAGConfig(...), 26 | ) 27 | ``` 28 | 29 | ## Namespaced Imports 30 | 31 | For better organization and increased clarity, you can import from specific 32 | component categories: 33 | 34 | ```py 35 | from fed_rag.core import RAGSystem 36 | from fed_rag.data_structures.rag import RAGConfig 37 | from fed_rag.generators import HFPretrainedModelGenerator 38 | from fed_rag.retrievers import HFSentenceTransformerRetriever 39 | from fed_rag.knowledge_stores import InMemoryKnowledgeStore 40 | 41 | # Create system with components from different namespaces 42 | system = RAGSystem( 43 | retriever=HFSentenceTransformerRetriever(...), 44 | generator=HFPretrainedModelGenerator(...), 45 | knowledge_store=InMemoryKnowledgeStore(), 46 | rag_config=RAGConfig(...), 47 | ) 48 | ``` 49 | 50 | !!! note 51 | Modules and functions prefixed with an underscore (e.g., `_internal`) are considered 52 | implementation details and may change between versions. 53 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/_dataset_prep/qa/mixin.py: -------------------------------------------------------------------------------- 1 | """QA Data Prepper""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import TypedDict 5 | 6 | import pandas as pd 7 | 8 | 9 | class QAMixin(ABC): 10 | @property 11 | def required_cols(self) -> list[str]: 12 | return ["answer", "question"] 13 | 14 | @abstractmethod 15 | def _get_answer(self, row: pd.Series) -> str: 16 | """Get answer from an example row.""" 17 | 18 | @abstractmethod 19 | def _get_question(self, row: pd.Series) -> str: 20 | """Get question from an example row.""" 21 | 22 | @abstractmethod 23 | def _get_evidence(self, row: pd.Series) -> str | None: 24 | """Get evidence from an example row.""" 25 | 26 | def _prep_df(self) -> None: 27 | if not hasattr(self, "df"): 28 | raise ValueError("Missing 'df' property.") 29 | if not isinstance(self.df, pd.DataFrame): 30 | raise ValueError( 31 | "Invalid type for 'df' property. Should be a ~pd.DataFrame." 32 | ) 33 | self.df["answer"] = self.df.apply( 34 | lambda row: self._get_answer(row), axis=1 35 | ) 36 | self.df["question"] = self.df.apply( 37 | lambda row: self._get_question(row), axis=1 38 | ) 39 | self.df["evidence"] = self.df.apply( 40 | lambda row: self._get_evidence(row), axis=1 41 | ) 42 | 43 | class InstructionExample(TypedDict): 44 | answer: str 45 | question: str 46 | evidence: str | None 47 | 48 | def example_to_json(self, row: pd.Series) -> dict[str, str]: 49 | instruction_example: QAMixin.InstructionExample = { 50 | "answer": row["answer"], 51 | "question": row["question"], 52 | "evidence": None, 53 | } 54 | return instruction_example # type:ignore [return-value] 55 | -------------------------------------------------------------------------------- /examples/ra-dit/ra_dit/generators/utils.py: -------------------------------------------------------------------------------- 1 | """Utils module.""" 2 | 3 | from enum import Enum 4 | 5 | from pydantic import BaseModel, PrivateAttr 6 | 7 | from fed_rag.generators.huggingface import ( 8 | HFPeftModelGenerator, 9 | HFPretrainedModelGenerator, 10 | ) 11 | 12 | 13 | class ModelVariants(str, Enum): 14 | PLAIN = "plain" 15 | Q4BIT = "q4bit" 16 | LORA = "lora" 17 | QLORA = "qlora" 18 | 19 | 20 | class ModelRegistry(BaseModel): 21 | _plain: HFPretrainedModelGenerator | None = PrivateAttr(default=None) 22 | _q4bit: HFPretrainedModelGenerator | None = PrivateAttr(default=None) 23 | _lora: HFPeftModelGenerator | None = PrivateAttr(default=None) 24 | _qlora: HFPeftModelGenerator | None = PrivateAttr(default=None) 25 | 26 | def __init__( 27 | self, 28 | plain: HFPretrainedModelGenerator | None = None, 29 | q4bit: HFPretrainedModelGenerator | None = None, 30 | lora: HFPeftModelGenerator | None = None, 31 | qlora: HFPeftModelGenerator | None = None, 32 | ): 33 | super().__init__() 34 | self._plain = plain 35 | self._q4bit = q4bit 36 | self._lora = lora 37 | self._qlora = qlora 38 | 39 | def __getitem__( 40 | self, key: str 41 | ) -> HFPeftModelGenerator | HFPretrainedModelGenerator: 42 | match key: 43 | case ModelVariants.PLAIN: 44 | retval = self._plain 45 | case ModelVariants.Q4BIT: 46 | retval = self._q4bit 47 | case ModelVariants.LORA: 48 | retval = self._lora 49 | case ModelVariants.QLORA: 50 | retval = self._qlora 51 | case _: 52 | raise ValueError(f"Invalid variant {key}.") 53 | if retval is None: 54 | raise ValueError(f"Variant {key} has not been specified") 55 | return retval 56 | -------------------------------------------------------------------------------- /tests/retrievers/mixins/test_audio_retriever_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | import torch 5 | from pydantic import BaseModel, PrivateAttr 6 | 7 | from fed_rag.base.retriever import BaseRetriever 8 | from fed_rag.base.retriever_mixins import ( 9 | AudioRetrieverMixin, 10 | RetrieverHasAudioModality, 11 | ) 12 | from fed_rag.exceptions.retriever import RetrieverError 13 | 14 | from ..conftest import MockRetriever 15 | 16 | 17 | class MockMMRetriever(AudioRetrieverMixin, MockRetriever): 18 | _audio_encoder: torch.nn.Module = PrivateAttr( 19 | default=torch.nn.Linear(2, 1) 20 | ) 21 | 22 | @property 23 | def audio_encoder(self) -> torch.nn.Module | None: 24 | return self._audio_encoder 25 | 26 | def encode_audio( 27 | self, audio: Any | list[Any], **kwargs: Any 28 | ) -> torch.Tensor: 29 | return self._audio_encoder.forward(torch.ones(2)) 30 | 31 | 32 | def test_audio_retriever_mixin() -> None: 33 | mixed_retriever = MockMMRetriever() 34 | 35 | assert isinstance(mixed_retriever, RetrieverHasAudioModality) 36 | assert isinstance(mixed_retriever, BaseRetriever) 37 | 38 | 39 | def test_audio_retriever_mixin_fails_validation() -> None: 40 | with pytest.raises( 41 | RetrieverError, 42 | match="`AudioRetrieverMixin` must be mixed with `BaseRetriever`.", 43 | ): 44 | 45 | class InvalidMockMMRetriever(AudioRetrieverMixin, BaseModel): 46 | _audio_encoder: torch.nn.Module = PrivateAttr( 47 | default=torch.nn.Linear(2, 1) 48 | ) 49 | 50 | @property 51 | def audio_encoder(self) -> torch.nn.Module | None: 52 | return self._audio_encoder 53 | 54 | def encode_audio( 55 | self, audio: Any | list[Any], **kwargs: Any 56 | ) -> torch.Tensor: 57 | return torch.ones(2) 58 | -------------------------------------------------------------------------------- /tests/retrievers/mixins/test_video_retriever_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | import torch 5 | from pydantic import BaseModel, PrivateAttr 6 | 7 | from fed_rag.base.retriever import BaseRetriever 8 | from fed_rag.base.retriever_mixins import ( 9 | RetrieverHasVideoModality, 10 | VideoRetrieverMixin, 11 | ) 12 | from fed_rag.exceptions.retriever import RetrieverError 13 | 14 | from ..conftest import MockRetriever 15 | 16 | 17 | class MockMMRetriever(VideoRetrieverMixin, MockRetriever): 18 | _video_encoder: torch.nn.Module = PrivateAttr( 19 | default=torch.nn.Linear(2, 1) 20 | ) 21 | 22 | @property 23 | def video_encoder(self) -> torch.nn.Module | None: 24 | return self._video_encoder 25 | 26 | def encode_video( 27 | self, video: Any | list[Any], **kwargs: Any 28 | ) -> torch.Tensor: 29 | return self._video_encoder.forward(torch.ones(2)) 30 | 31 | 32 | def test_video_retriever_mixin() -> None: 33 | mixed_retriever = MockMMRetriever() 34 | 35 | assert isinstance(mixed_retriever, RetrieverHasVideoModality) 36 | assert isinstance(mixed_retriever, BaseRetriever) 37 | 38 | 39 | def test_video_retriever_mixin_fails_validation() -> None: 40 | with pytest.raises( 41 | RetrieverError, 42 | match="`VideoRetrieverMixin` must be mixed with `BaseRetriever`.", 43 | ): 44 | 45 | class InvalidMockMMRetriever(VideoRetrieverMixin, BaseModel): 46 | _video_encoder: torch.nn.Module = PrivateAttr( 47 | default=torch.nn.Linear(2, 1) 48 | ) 49 | 50 | @property 51 | def video_encoder(self) -> torch.nn.Module | None: 52 | return self._video_encoder 53 | 54 | def encode_video( 55 | self, video: Any | list[Any], **kwargs: Any 56 | ) -> torch.Tensor: 57 | return torch.ones(2) 58 | -------------------------------------------------------------------------------- /tests/trainers/pytorch/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import pytest 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | from fed_rag.base.trainer import BaseGeneratorTrainer, BaseRetrieverTrainer 8 | from fed_rag.data_structures.results import TestResult, TrainResult 9 | from fed_rag.trainers.pytorch.mixin import PyTorchTrainerMixin 10 | 11 | 12 | class TestRetrieverTrainer(PyTorchTrainerMixin, BaseRetrieverTrainer): 13 | __test__ = ( 14 | False # needed for Pytest collision. Avoids PytestCollectionWarning 15 | ) 16 | 17 | def train(self) -> TrainResult: 18 | return TrainResult(loss=0.42) 19 | 20 | def evaluate(self) -> TestResult: 21 | return TestResult(loss=0.42) 22 | 23 | 24 | class TestGeneratorTrainer(PyTorchTrainerMixin, BaseGeneratorTrainer): 25 | __test__ = ( 26 | False # needed for Pytest collision. Avoids PytestCollectionWarning 27 | ) 28 | 29 | def train(self) -> TrainResult: 30 | return TrainResult(loss=0.42) 31 | 32 | def evaluate(self) -> TestResult: 33 | return TestResult(loss=0.42) 34 | 35 | 36 | class _TestDataset(Dataset): 37 | def __init__(self, size: int) -> None: 38 | self.features = np.random.rand(size, 2) 39 | self.labels = np.random.choice(2, size=size) 40 | 41 | def __len__(self) -> int: 42 | return len(self.labels) 43 | 44 | def __getitem__(self, index: int) -> tuple[np.ndarray, Any]: 45 | return self.features[index], self.labels[index] 46 | 47 | 48 | @pytest.fixture() 49 | def train_dataset() -> Dataset: 50 | return _TestDataset(size=10) 51 | 52 | 53 | @pytest.fixture() 54 | def another_train_dataset() -> Dataset: 55 | return _TestDataset(size=10) 56 | 57 | 58 | @pytest.fixture() 59 | def train_dataloader(train_dataset: Dataset) -> DataLoader: 60 | return DataLoader(train_dataset, batch_size=2, shuffle=True) 61 | -------------------------------------------------------------------------------- /tests/retrievers/mixins/test_image_retriever_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | from pydantic import BaseModel, PrivateAttr 7 | 8 | from fed_rag.base.retriever import BaseRetriever 9 | from fed_rag.base.retriever_mixins import ( 10 | ImageRetrieverMixin, 11 | RetrieverHasImageModality, 12 | ) 13 | from fed_rag.exceptions.retriever import RetrieverError 14 | 15 | from ..conftest import MockRetriever 16 | 17 | 18 | class MockMMRetriever(ImageRetrieverMixin, MockRetriever): 19 | _image_encoder: torch.nn.Module = PrivateAttr( 20 | default=torch.nn.Linear(2, 1) 21 | ) 22 | 23 | @property 24 | def image_encoder(self) -> torch.nn.Module | None: 25 | return self._image_encoder 26 | 27 | def encode_image( 28 | self, image: Image.Image | list[Image.Image], **kwargs: Any 29 | ) -> torch.Tensor: 30 | return self._image_encoder.forward(torch.ones(2)) 31 | 32 | 33 | def test_image_retriever_mixin() -> None: 34 | mixed_retriever = MockMMRetriever() 35 | 36 | assert isinstance(mixed_retriever, RetrieverHasImageModality) 37 | assert isinstance(mixed_retriever, BaseRetriever) 38 | 39 | 40 | def test_image_retriever_mixin_fails_validation() -> None: 41 | with pytest.raises( 42 | RetrieverError, 43 | match="`ImageRetrieverMixin` must be mixed with `BaseRetriever`.", 44 | ): 45 | 46 | class InvalidMockMMRetriever(ImageRetrieverMixin, BaseModel): 47 | _image_encoder: torch.nn.Module = PrivateAttr( 48 | default=torch.nn.Linear(2, 1) 49 | ) 50 | 51 | @property 52 | def image_encoder(self) -> torch.nn.Module | None: 53 | return self._image_encoder 54 | 55 | def encode_image( 56 | self, image: Image.Image | list[Image.Image], **kwargs: Any 57 | ) -> torch.Tensor: 58 | return torch.ones(2) 59 | -------------------------------------------------------------------------------- /tests/evals/benchmarks/huggingface/test_boolq.py: -------------------------------------------------------------------------------- 1 | """Tests for BoolQ benchmark""" 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | from datasets import Dataset 7 | 8 | import fed_rag.evals.benchmarks as benchmarks 9 | from fed_rag.data_structures.evals import BenchmarkExample 10 | 11 | 12 | @pytest.fixture 13 | def dummy_boolq() -> Dataset: 14 | """Create a dummy BoolQ dataset for testing.""" 15 | return Dataset.from_dict( 16 | { 17 | "question": ["is confectionary sugar the same as powdered sugar"], 18 | "answer": [True], 19 | "passage": [ 20 | "Powdered sugar, also called confectioners' sugar, is a finely ground sugar ..." 21 | ], 22 | } 23 | ) 24 | 25 | 26 | @patch("datasets.load_dataset") 27 | def test_boolq_query_response_context_extractors( 28 | mock_load_dataset: MagicMock, dummy_boolq: Dataset 29 | ) -> None: 30 | mock_load_dataset.return_value = dummy_boolq 31 | boolq = benchmarks.HuggingFaceBoolQ() 32 | 33 | assert isinstance(boolq[0], BenchmarkExample) 34 | assert ( 35 | boolq[0].query == "is confectionary sugar the same as powdered sugar" 36 | ) 37 | assert boolq[0].response == "true" 38 | assert ( 39 | boolq[0].context 40 | == "Powdered sugar, also called confectioners' sugar, is a finely ground sugar ..." 41 | ) 42 | 43 | 44 | @patch("datasets.load_dataset") 45 | def test_boolq_false_response(mock_load_dataset: MagicMock) -> None: 46 | dataset = Dataset.from_dict( 47 | { 48 | "question": ["is elder scrolls online the same as skyrim"], 49 | "answer": [False], 50 | "passage": [ 51 | "As with other games in The Elder Scrolls series, the game is set on the continent of Tamriel." 52 | ], 53 | } 54 | ) 55 | mock_load_dataset.return_value = dataset 56 | boolq = benchmarks.HuggingFaceBoolQ() 57 | 58 | assert boolq[0].response == "false" 59 | -------------------------------------------------------------------------------- /src/fed_rag/trainers/pytorch/mixin.py: -------------------------------------------------------------------------------- 1 | """PyTorch Trainer Mixin""" 2 | 3 | from abc import ABC 4 | from typing import Any, Protocol, runtime_checkable 5 | 6 | import torch.nn as nn 7 | from pydantic import BaseModel, ConfigDict 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from fed_rag.exceptions import InconsistentDatasetError 11 | 12 | from .training_args import TrainingArgs 13 | 14 | 15 | # Define the protocol for runtime checking 16 | @runtime_checkable 17 | class PyTorchTrainerProtocol(Protocol): 18 | train_dataset: Dataset 19 | training_arguments: TrainingArgs | None 20 | train_dataloader: DataLoader 21 | 22 | def model(self) -> nn.Module: 23 | pass # pragma: no cover 24 | 25 | 26 | class PyTorchTrainerMixin(BaseModel, ABC): 27 | """PyTorch Trainer Mixin.""" 28 | 29 | model_config = ConfigDict( 30 | arbitrary_types_allowed=True, 31 | ) 32 | train_dataset: Dataset 33 | train_dataloader: DataLoader 34 | training_arguments: TrainingArgs | None = None 35 | 36 | def __init__( 37 | self, 38 | train_dataloader: DataLoader, 39 | train_dataset: Dataset | None = None, 40 | training_arguments: TrainingArgs | None = None, 41 | **kwargs: Any, 42 | ): 43 | if train_dataset is None: 44 | train_dataset = train_dataloader.dataset 45 | else: 46 | # ensure consistency between loader.dataset and the supplied one 47 | if id(train_dataset) != id(train_dataloader.dataset): 48 | raise InconsistentDatasetError( 49 | "Inconsistent datasets detected between supplied `train_dataset` and that " 50 | "associated with the `train_dataloader`. These two datasets must be the same." 51 | ) 52 | 53 | super().__init__( 54 | train_dataset=train_dataset, 55 | train_dataloader=train_dataloader, 56 | training_arguments=training_arguments, 57 | **kwargs, 58 | ) 59 | -------------------------------------------------------------------------------- /tests/fl_tasks/pytorch/conftest.py: -------------------------------------------------------------------------------- 1 | """PyTorchFLTask Unit Tests""" 2 | 3 | from typing import Any, Callable 4 | 5 | import numpy as np 6 | import pytest 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from fed_rag.data_structures import TestResult, TrainResult 11 | from fed_rag.decorators import federate 12 | 13 | 14 | class _TestDataset(Dataset): 15 | def __init__(self, size: int) -> None: 16 | self.features = np.random.rand(size, 2) 17 | self.labels = np.random.choice(2, size=size) 18 | 19 | def __len__(self) -> int: 20 | return len(self.labels) 21 | 22 | def __getitem__(self, index: int) -> tuple[np.ndarray, Any]: 23 | return self.features[index], self.labels[index] 24 | 25 | 26 | @pytest.fixture() 27 | def train_dataloader() -> DataLoader: 28 | dataset = _TestDataset(size=10) 29 | return DataLoader(dataset, batch_size=2, shuffle=True) 30 | 31 | 32 | @pytest.fixture() 33 | def val_dataloader() -> DataLoader: 34 | dataset = _TestDataset(size=4) 35 | return DataLoader(dataset, batch_size=2, shuffle=True) 36 | 37 | 38 | @pytest.fixture() 39 | def trainer() -> Callable: 40 | @federate.trainer.pytorch 41 | def fn( 42 | net: nn.Module, 43 | train_loader: DataLoader, 44 | val_loader: DataLoader, 45 | ) -> TrainResult: 46 | return TrainResult(loss=0.0) 47 | 48 | return fn # type: ignore 49 | 50 | 51 | @pytest.fixture() 52 | def tester() -> Callable: 53 | @federate.tester.pytorch 54 | def fn( 55 | net: nn.Module, 56 | test_loader: DataLoader, 57 | ) -> TestResult: 58 | return TestResult(loss=0.0, metrics={}) 59 | 60 | return fn # type: ignore 61 | 62 | 63 | @pytest.fixture() 64 | def mismatch_tester() -> Callable: 65 | @federate.tester.pytorch 66 | def fn( 67 | mdl: nn.Module, # mismatch in name here 68 | test_loader: DataLoader, 69 | ) -> TestResult: 70 | return TestResult(loss=0.0, metrics={}) 71 | 72 | return fn # type: ignore 73 | -------------------------------------------------------------------------------- /docs/getting_started/integrations.md: -------------------------------------------------------------------------------- 1 | # Integrations 2 | 3 | FedRAG offers integrations with popular frameworks and tools across the RAG ecosystem. 4 | This page documents currently supported integrations and our roadmap for future compatibility. 5 | 6 | !!! info "Status Legend" 7 | :material-check-bold: — Currently supported; 8 | :material-clock: — Planned (linked to GitHub issue); 9 | Empty — Not currently planned 10 | 11 | ## Deep learning libraries 12 | 13 | | Framework | Status | 14 | |------------| --------------------- | 15 | | PyTorch | :material-check-bold: | 16 | | Keras | | 17 | | TensorFlow | | 18 | | Jax | | 19 | 20 | ## Fine-tuning frameworks 21 | 22 | | Framework | Status | 23 | | ----------- | --------------------- | 24 | | HuggingFace | :material-check-bold: | 25 | | Unsloth | :material-check-bold: | 26 | 27 | ## RAG inference frameworks 28 | 29 | | Framework | Status | 30 | | ---------- | --------------------- | 31 | | LlamaIndex | :material-check-bold: | 32 | | LangChain | :material-check-bold: | 33 | | Haystack | | 34 | 35 | ## Knowledge Stores 36 | 37 | | Storage Solution | Status | 38 | |------------------|:-------------------------------------------------------------------------:| 39 | | Qdrant | :material-check-bold: | 40 | | ChromaDB | [:material-clock:](https://github.com/VectorInstitute/fed-rag/issues/293) | 41 | | FAISS | [:material-clock:](https://github.com/VectorInstitute/fed-rag/issues/292) | 42 | | PGVector | | 43 | 44 | !!! note "Contributing Integrations" 45 | We welcome community contributions for additional integrations. See our 46 | [CONTRIBUTING](https://github.com/VectorInstitute/fed-rag/blob/main/CONTRIBUTING.md) 47 | guidelines for more information on implementing and submitting new integrations. 48 | -------------------------------------------------------------------------------- /examples/knowledge_stores/ra-dit-ks/README.md: -------------------------------------------------------------------------------- 1 | # Example: Wikipedia Dec 2021 Knowledge Store 2 | 3 | A Docker image providing a pre-built Qdrant vector database with an Atlas corpus knowledge store for retrieval-augmented applications. 4 | 5 | ⚠️ **Note:** This Docker image is approximately 3.6GB in size due to the included Python environment and ML libraries. Ensure you have sufficient disk space and bandwidth when pulling the image. 6 | 7 | ## Quick Start 8 | 9 | ```bash 10 | # Pull the image 11 | docker pull vectorinstitute/qdrant-atlas-dec-wiki-2021:latest 12 | 13 | # Run the container with basic settings and gpu acceleration 14 | docker run -d \ 15 | --name qdrant-vector-db \ 16 | --gpus all \ 17 | -p 6333:6333 \ 18 | -p 6334:6334 \ 19 | -v qdrant_data:/qdrant_storage \ 20 | vectorinstitute/qdrant-atlas-dec-wiki-2021:latest 21 | ``` 22 | 23 | ## Using the Knowledge Store 24 | 25 | Once the container has `healthy` status, then we can use 26 | 27 | ### Using with fed-rag 28 | 29 | ```python 30 | from fed_rag.retriever.knowledge_store import QdrantKnowledgeStore 31 | from fed_rag.retrievers.huggingface.hf_sentence_transformer import ( 32 | HFSentenceTransformerRetriever, 33 | ) 34 | 35 | # build retriever for encoding queries 36 | retriever = HFSentenceTransformerRetriever( 37 | query_model_name="nthakur/dragon-plus-query-encoder", 38 | context_model_name="nthakur/dragon-plus-context-encoder", 39 | load_model_at_init=False, 40 | ) 41 | 42 | # Connect to the containerized knowledge store 43 | knowledge_store = QdrantKnowledgeStore( 44 | collection_name="nthakur.dragon-plus-context-encoder", 45 | host="localhost", 46 | port=6333, 47 | ) 48 | 49 | # Retrieve documents 50 | query = "What is the history of marine biology?" 51 | query_emb = retriever.encode_query(query).tolist() 52 | 53 | results = knowledge_store.retrieve(query_emb=query_emb, top_k=3) 54 | for node in results: 55 | print(f"Score: {node.score}, Content: {str(node.node)}") 56 | ``` 57 | 58 | ## Acknowledgements 59 | 60 | - [Qdrant](https://qdrant.tech/) - Vector Database 61 | - [Facebook AI Research](https://github.com/facebookresearch/atlas) - Atlas Corpus 62 | -------------------------------------------------------------------------------- /docs/examples/ra_dit/index.md: -------------------------------------------------------------------------------- 1 | # A comprehensive implementation of RA-DIT 2 | 3 | 4 | 5 | 6 | View in Github 7 | 8 | 9 | We consider the paper "RA-DIT: Retrieval-Augmented Dual Instruction Tuning" by Lin, 10 | Xi Victoria et al. (2023)[^1] and implement simplified versions of their experiments 11 | using FedRAG. In this work, the authors build a RAG system and fine-tune both 12 | the generator and retriever using a diverse question-answering (QA) datasets. 13 | Their experimental results demonstrate that a fine-tuned RAG system consistently 14 | outperforms two key baselines: a standalone generator LLM and an un-fine-tuned RAG 15 | system. These findings highlight the substantial benefits of applying the RA-DIT 16 | approach to enhance RAG system performance. 17 | 18 | This comprehensive implementation demonstrates the key concepts and 19 | techniques from the original research while adapting them for practical demonstration. 20 | More specifically, in this example, we: 21 | 22 | 1. [Build a Qdrant Knowledge Store](./qdrant_knowledge_store_wikipedia.md) — Take 23 | artifacts derived from Wikipedia to populate a`QdrantKnowledgeStore`. 24 | 25 | 2. [Fine-tune with QA datasets](./qdrant_knowledge_store_wikipedia.md) — Build a 26 | [`RAGSystem`](../../api_reference/rag_system/index.md) and fine-tune it with 27 | some QA datasets using LSR and RALT trainers. 28 | 29 | 3. [Evaluate with Benchmarks](./benchmarking.md) — Benchmark our fine-tuned RAG system 30 | on MMLU and compare it to a few appropriate baselines. 31 | 32 | 4. [Federated Fine-tuning](./federated_finetune.md) — Demonstrate how we can go 33 | from centralized to federated fine-tuning of our RAG system. 34 | 35 | !!! note 36 | Federated fine-tuning was not considered in Lin, Xi Victoria et al (2023)[^1]. 37 | 38 | 39 | [^1]: Lin, Xi Victoria, et al. "Ra-dit: Retrieval-augmented dual instruction tuning." 40 | The Twelfth International Conference on Learning Representations. 2023. 41 | -------------------------------------------------------------------------------- /docs/getting_started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Installing from package managers 4 | 5 | ### PyPi 6 | 7 | As seen in the previous quickstart examples, we can install FedRAG via `pip`: 8 | 9 | ``` sh 10 | pip install fed-rag 11 | ``` 12 | 13 | ### Conda 14 | 15 | For `conda` users, `fed-rag` has been published to the 16 | [`conda-forge`](https://conda-forge.org/) channel, and thus can be installed 17 | with `conda` using the below command: 18 | 19 | ``` sh 20 | conda install -c conda-forge fed-rag 21 | ``` 22 | 23 | ## Installing from source 24 | 25 | To install from source, first clone the repository: 26 | 27 | ``` sh 28 | # https 29 | git clone https://github.com/VectorInstitute/fed-rag.git 30 | 31 | # ssh 32 | git clone git@github.com:VectorInstitute/fed-rag.git 33 | ``` 34 | 35 | After cloning the repository, you have a few options for installing the library. 36 | The next two subsections outline how to complete the installation using either 37 | `pip` or `uv`, respectively. 38 | 39 | ### Using `pip` 40 | 41 | To complete the installation, first `cd` into the `fed-rag` directory and then 42 | run the following `pip install` command: 43 | 44 | ``` sh 45 | cd fed-rag 46 | pip install -e . 47 | ``` 48 | 49 | !!! tip 50 | We recommended to always use a fresh virtual environment for new projects. 51 | Before running the above command, ensure that your dedicated virtual environment 52 | is active. 53 | 54 | ### Using `uv` 55 | 56 | FedRAG uses [`uv`](https://docs.astral.sh/uv/) for dependency management, publishing 57 | to PyPi, and for setting up development environments. 58 | 59 | Users can also use `uv` to complete the source installation of FedRAG. 60 | 61 | !!! note 62 | This method requires `uv` to be installed onto the users development machine. 63 | For installation instructions visit `uv`'s [official documentation](https://docs.astral.sh/uv/getting-started/installation/). 64 | 65 | ``` sh 66 | cd fed-rag 67 | uv sync 68 | ``` 69 | 70 | To install with desired extras and groups, add the flags `--extra ` 71 | and `--optional `, respectively. As an example: 72 | 73 | ``` sh 74 | cd fed-rag 75 | uv sync --extra huggingface --group dev 76 | ``` 77 | -------------------------------------------------------------------------------- /src/fed_rag/base/evals/benchmark.py: -------------------------------------------------------------------------------- 1 | """Base Benchmark and Benchmarker""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Generator, Iterator, Sequence 5 | 6 | from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator 7 | 8 | from fed_rag.data_structures.evals import BenchmarkExample 9 | from fed_rag.exceptions import BenchmarkGetExamplesError, BenchmarkParseError 10 | 11 | 12 | class BaseBenchmark(BaseModel, ABC): 13 | """Base Benchmark.""" 14 | 15 | _examples: Sequence[BenchmarkExample] = PrivateAttr() 16 | 17 | model_config = ConfigDict(arbitrary_types_allowed=True) 18 | 19 | # give it a sequence interface for accessing examples more easily 20 | def __getitem__(self, index: int) -> BenchmarkExample: 21 | return self._examples.__getitem__(index) 22 | 23 | def __len__(self) -> int: 24 | return self._examples.__len__() 25 | 26 | # shouldn't override Pydantic BaseModels' __iter__ 27 | def as_iterator(self) -> Iterator[BenchmarkExample]: 28 | return self._examples.__iter__() 29 | 30 | @model_validator(mode="after") 31 | def set_examples(self) -> "BaseBenchmark": 32 | try: 33 | self._examples = self._get_examples() 34 | except BenchmarkParseError as e: 35 | raise BenchmarkGetExamplesError( 36 | f"Failed to parse examples: {str(e)}" 37 | ) from e 38 | except Exception as e: 39 | raise ( 40 | BenchmarkGetExamplesError(f"Failed to get examples: {str(e)}") 41 | ) from e 42 | return self 43 | 44 | # abstractmethods 45 | @abstractmethod 46 | def _get_examples(self, **kwargs: Any) -> Sequence[BenchmarkExample]: 47 | """Method to get examples.""" 48 | 49 | @abstractmethod 50 | def as_stream(self) -> Generator[BenchmarkExample, None, None]: 51 | """Produce a stream of `BenchmarkExamples`.""" 52 | 53 | @property 54 | @abstractmethod 55 | def num_examples(self) -> int: 56 | """Number of examples in the benchmark. 57 | 58 | NOTE: if streaming, `_examples` is likely set to an empty list. Thus, 59 | we leave this implementation for the subclasses. 60 | """ 61 | -------------------------------------------------------------------------------- /tests/trainers/huggingface/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from datasets import Dataset 4 | from sentence_transformers import SentenceTransformer 5 | from transformers import Trainer 6 | 7 | from fed_rag import RAGSystem 8 | from fed_rag.base.trainer import BaseGeneratorTrainer, BaseRetrieverTrainer 9 | from fed_rag.data_structures.results import TestResult, TrainResult 10 | from fed_rag.trainers.huggingface.mixin import HuggingFaceTrainerMixin 11 | 12 | 13 | class TestHFRetrieverTrainer(HuggingFaceTrainerMixin, BaseRetrieverTrainer): 14 | __test__ = ( 15 | False # needed for Pytest collision. Avoids PytestCollectionWarning 16 | ) 17 | 18 | def train(self) -> TrainResult: 19 | return TrainResult(loss=0.42) 20 | 21 | def evaluate(self) -> TestResult: 22 | return TestResult(loss=0.42) 23 | 24 | def hf_trainer_obj(self) -> Trainer: 25 | return Trainer() 26 | 27 | 28 | class TestHFGeneratorTrainer(HuggingFaceTrainerMixin, BaseGeneratorTrainer): 29 | __test__ = ( 30 | False # needed for Pytest collision. Avoids PytestCollectionWarning 31 | ) 32 | 33 | def train(self) -> TrainResult: 34 | return TrainResult(loss=0.42) 35 | 36 | def evaluate(self) -> TestResult: 37 | return TestResult(loss=0.42) 38 | 39 | def hf_trainer_obj(self) -> Trainer: 40 | return Trainer() 41 | 42 | 43 | @pytest.fixture() 44 | def train_dataset() -> Dataset: 45 | return Dataset.from_dict( 46 | { 47 | "query": ["first query", "second query"], 48 | "response": ["first response", "second response"], 49 | } 50 | ) 51 | 52 | 53 | @pytest.fixture() 54 | def hf_rag_system(mock_rag_system: RAGSystem) -> RAGSystem: 55 | encoder = SentenceTransformer(modules=[torch.nn.Linear(5, 5)]) 56 | # Mock the tokenize method on the first module 57 | encoder.tokenizer = None 58 | encoder._first_module().tokenize = lambda texts: { 59 | "input_ids": torch.ones((len(texts), 10)) 60 | } 61 | encoder.encode = lambda texts, **kwargs: torch.ones( 62 | (len(texts) if isinstance(texts, list) else 1, 5) 63 | ) 64 | 65 | mock_rag_system.retriever.encoder = encoder 66 | return mock_rag_system 67 | -------------------------------------------------------------------------------- /src/fed_rag/base/generator.py: -------------------------------------------------------------------------------- 1 | """Base Generator""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import torch 6 | from pydantic import BaseModel, ConfigDict 7 | 8 | from fed_rag.base.tokenizer import BaseTokenizer 9 | from fed_rag.data_structures import Context, Prompt, Query 10 | 11 | DEFAULT_PROMPT_TEMPLATE = """ 12 | You are a helpful assistant. Given the user's query, provide a succinct 13 | and accurate response. If context is provided, use it in your answer if it helps 14 | you to create the most accurate response. 15 | 16 | 17 | {query} 18 | 19 | 20 | 21 | {context} 22 | 23 | 24 | 25 | 26 | """ 27 | 28 | 29 | class BaseGenerator(BaseModel, ABC): 30 | """Base Generator Class.""" 31 | 32 | model_config = ConfigDict(arbitrary_types_allowed=True) 33 | 34 | @abstractmethod 35 | def generate( 36 | self, 37 | query: str | list[str] | Query | list[Query], 38 | context: str | list[str] | Context | list[Context], 39 | **kwargs: dict, 40 | ) -> str | list[str]: 41 | """Generate an output from a given query and context.""" 42 | 43 | @abstractmethod 44 | def complete( 45 | self, prompt: str | list[str] | Prompt | list[Prompt], **kwargs: dict 46 | ) -> str | list[str]: 47 | """Completion interface for generator LLMs.""" 48 | 49 | @property 50 | @abstractmethod 51 | def model(self) -> torch.nn.Module: 52 | """Model associated with this generator.""" 53 | 54 | @property 55 | @abstractmethod 56 | def tokenizer(self) -> BaseTokenizer: 57 | """Tokenizer associated with this generator.""" 58 | 59 | @abstractmethod 60 | def compute_target_sequence_proba( 61 | self, prompt: str | Prompt, target: str 62 | ) -> torch.Tensor: 63 | """Compute P(target | prompt). 64 | 65 | NOTE: this is used in LM Supervised Retriever fine-tuning. 66 | """ 67 | 68 | @property 69 | @abstractmethod 70 | def prompt_template(self) -> str: 71 | """Prompt template for formating query and context.""" 72 | 73 | @prompt_template.setter 74 | @abstractmethod 75 | def prompt_template(self, value: str) -> None: 76 | """Prompt template setter.""" 77 | -------------------------------------------------------------------------------- /tests/evals/benchmarks/_benchmarks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generator, Sequence 2 | 3 | from fed_rag.base.evals.benchmark import BaseBenchmark 4 | from fed_rag.data_structures import BenchmarkExample 5 | from fed_rag.evals.benchmarks.huggingface.mixin import ( 6 | HuggingFaceBenchmarkMixin, 7 | ) 8 | 9 | 10 | class TestBenchmark(BaseBenchmark): 11 | __test__ = ( 12 | False # needed for Pytest collision. Avoids PytestCollectionWarning 13 | ) 14 | 15 | def _get_examples(self, **kwargs: Any) -> Sequence[BenchmarkExample]: 16 | return [ 17 | BenchmarkExample(query="query 1", response="response 1"), 18 | BenchmarkExample(query="query 2", response="response 2"), 19 | BenchmarkExample(query="query 3", response="response 3"), 20 | ] 21 | 22 | def as_stream(self) -> Generator[BenchmarkExample, None, None]: 23 | for ex in self._get_examples(): 24 | yield ex 25 | 26 | @property 27 | def num_examples(self) -> int: 28 | return len(self._get_examples()) 29 | 30 | 31 | class TestHFBenchmark(HuggingFaceBenchmarkMixin, BaseBenchmark): 32 | __test__ = ( 33 | False # needed for Pytest collision. Avoids PytestCollectionWarning 34 | ) 35 | 36 | dataset_name = "test_benchmark" 37 | 38 | def _get_query_from_example(self, example: dict[str, Any]) -> str: 39 | return str(example["query"]) 40 | 41 | def _get_response_from_example(self, example: dict[str, Any]) -> str: 42 | return str(example["response"]) 43 | 44 | def _get_context_from_example(self, example: dict[str, Any]) -> str: 45 | return str(example["context"]) 46 | 47 | 48 | class TestBenchmarkBadExamples(BaseBenchmark): 49 | __test__ = ( 50 | False # needed for Pytest collision. Avoids PytestCollectionWarning 51 | ) 52 | 53 | def _get_examples(self, **kwargs: Any) -> Sequence[BenchmarkExample]: 54 | raise RuntimeError("Too bad, so sad.") 55 | 56 | def as_stream(self) -> Generator[BenchmarkExample, None, None]: 57 | for ex in self._get_examples(): 58 | yield ex 59 | 60 | @property 61 | def num_examples(self) -> int: 62 | return len(self._get_examples()) 63 | 64 | 65 | __all__ = ["TestBenchmark", "TestHFBenchmark"] 66 | -------------------------------------------------------------------------------- /tests/utils/test_asyncio.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import re 3 | 4 | import pytest 5 | 6 | from fed_rag.exceptions import FedRAGError 7 | from fed_rag.utils.asyncio import asyncio_run 8 | 9 | 10 | async def simple_async_function(value: int = 42) -> int: 11 | """Simple async function for testing.""" 12 | await asyncio.sleep(0.001) 13 | return value 14 | 15 | 16 | async def async_function_with_exception() -> None: 17 | """Async function that raises an exception.""" 18 | await asyncio.sleep(0.001) 19 | raise RuntimeError("Test exception") 20 | 21 | 22 | def test_simple_coroutine_execution() -> None: 23 | """Test running a simple coroutine.""" 24 | result = asyncio_run(simple_async_function(123)) 25 | assert result == 123 26 | 27 | 28 | def test_coroutine_with_default_args() -> None: 29 | """Test running a coroutine with default arguments.""" 30 | result = asyncio_run(simple_async_function()) 31 | assert result == 42 32 | 33 | 34 | def test_existing_but_not_running_loop() -> None: 35 | """Test behavior with an existing but not running loop.""" 36 | loop = asyncio.new_event_loop() 37 | asyncio.set_event_loop(loop) 38 | 39 | try: 40 | # Loop exists but is not running 41 | result = asyncio_run(simple_async_function(789)) 42 | assert result == 789 43 | finally: 44 | loop.close() 45 | 46 | 47 | def test_nested_asyncio_run_calls() -> None: 48 | """Test that nested calls work correctly.""" 49 | 50 | async def outer_async() -> int: 51 | # This will run in a separate thread due to nested context 52 | inner_result = asyncio_run(simple_async_function(111)) 53 | return int(inner_result * 2) 54 | 55 | result = asyncio_run(outer_async()) 56 | assert result == 222 57 | 58 | 59 | def test_coroutine_exception_propagation() -> None: 60 | """Test that exceptions from coroutines are properly propagated.""" 61 | msg = ( 62 | "Unable to execute async operation in current context. " 63 | "This may be due to nested event loops. Consider using nest_asyncio.apply() " 64 | "to allow nested event loops, or use async methods directly." 65 | ) 66 | 67 | with pytest.raises(FedRAGError, match=re.escape(msg)): 68 | asyncio_run(async_function_with_exception()) 69 | -------------------------------------------------------------------------------- /src/fed_rag/inspectors/pytorch/tester.py: -------------------------------------------------------------------------------- 1 | """PyTorch Tester Inspector""" 2 | 3 | import inspect 4 | from typing import Any, Callable 5 | 6 | from fed_rag.data_structures import TestResult 7 | from fed_rag.exceptions import ( 8 | InvalidReturnType, 9 | MissingDataParam, 10 | MissingNetParam, 11 | ) 12 | from fed_rag.inspectors.common import TesterSignatureSpec 13 | 14 | 15 | def inspect_tester_signature(fn: Callable) -> TesterSignatureSpec: 16 | sig = inspect.signature(fn) 17 | 18 | # validate return type 19 | return_type = sig.return_annotation 20 | if (return_type is Any) or not issubclass(return_type, TestResult): 21 | msg = "Tester should return a fed_rag.data_structures.TestResult or a subclsas of it." 22 | raise InvalidReturnType(msg) 23 | 24 | # inspect fn params 25 | extra_tester_kwargs = [] 26 | net_param = None 27 | test_data_param = None 28 | net_parameter_class_name = None 29 | 30 | for name, t in sig.parameters.items(): 31 | if name in ("self", "cls"): 32 | continue 33 | 34 | if type_name := getattr(t.annotation, "__name__", None): 35 | if type_name == "Module" and net_param is None: 36 | net_param = name 37 | net_parameter_class_name = type_name 38 | continue 39 | 40 | if type_name == "DataLoader" and test_data_param is None: 41 | test_data_param = name 42 | continue 43 | 44 | extra_tester_kwargs.append(name) 45 | 46 | if net_param is None: 47 | msg = ( 48 | "Inspection failed to find a model param. " 49 | "For PyTorch this param must have type `nn.Module`." 50 | ) 51 | raise MissingNetParam(msg) 52 | 53 | if test_data_param is None: 54 | msg = ( 55 | "Inspection failed to find a data param for a test dataset." 56 | "For PyTorch this params must be of type `torch.utils.data.DataLoader`" 57 | ) 58 | raise MissingDataParam(msg) 59 | 60 | spec = TesterSignatureSpec( 61 | net_parameter=net_param, 62 | test_data_param=test_data_param, 63 | extra_test_kwargs=extra_tester_kwargs, 64 | net_parameter_class_name=net_parameter_class_name, 65 | ) 66 | return spec 67 | --------------------------------------------------------------------------------